""" Qwen2-Audio captioning utilities for music annotation workflows. This module supports: 1) Local inference with Qwen2-Audio models via transformers. 2) Remote inference via a Hugging Face Endpoint with a simple JSON contract. 3) Segment-based analysis with timestamped aggregation. 4) Export helpers for ACE-Step LoRA sidecars and manifest files. """ from __future__ import annotations import base64 import io import json import os import re import shutil import subprocess import tempfile import urllib.request from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np import soundfile as sf import torchaudio AUDIO_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aac"} DEFAULT_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct" DEFAULT_ANALYSIS_PROMPT = ( "Analyze and detail the musical elements, tones, instruments, genre and effects. " "Describe the effects and mix of instruments and vocals. Vocals may use modern production " "techniques such as pitch correction and tuning effects. Explain how musical elements interact " "throughout the song with timestamps. Go in depth on vocal performance and musical writing. " "Be concise but detail-rich." ) DEFAULT_LONG_ANALYSIS_PROMPT = ( "Analyze the full song and return a concise but detailed timestamped prose breakdown. " "Use sections every 10 to 20 seconds (or major arrangement changes). For each section, " "describe vocals, instrumentation, genre cues, effects, mix/energy changes, and how elements " "interact. End with a short overall summary paragraph." ) SEGMENT_JSON_SCHEMA_HINT = ( 'Return JSON only with keys: "segment_summary" (string), "section_label" (string), ' '"genre" (array of strings), "instruments" (array of strings), "effects" (array of strings), ' '"vocal_characteristics" (array of strings), "mix_notes" (array of strings), ' '"interaction_notes" (string), "bpm_guess" (number or null), "key_guess" (string or ""), ' '"notable_moments" (array of objects with "timestamp_sec" and "note").' ) @dataclass class SegmentResult: index: int start_sec: float end_sec: float prompt: str raw_response: str parsed: Dict[str, Any] def list_audio_files(folder: str) -> List[str]: root = Path(folder) if not root.is_dir(): return [] files: List[str] = [] for path in sorted(root.rglob("*")): if path.suffix.lower() in AUDIO_EXTENSIONS: files.append(str(path)) return files def _load_audio_with_fallback(path: str) -> Tuple[np.ndarray, int]: """Load audio to mono float32 numpy array with fallback decode path.""" try: wav, sr = torchaudio.load(path) wav = wav.float().numpy() if wav.ndim == 1: mono = wav else: mono = wav.mean(axis=0) return mono.astype(np.float32), int(sr) except Exception as torchaudio_exc: try: audio_np, sr = sf.read(path, dtype="float32", always_2d=True) mono = audio_np.mean(axis=1) return mono.astype(np.float32), int(sr) except Exception as sf_exc: # Last fallback: ffmpeg decode (works when local libsndfile lacks mp3 codec). try: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp_wav = tmp.name cmd = [ "ffmpeg", "-y", "-i", str(path), "-vn", "-ac", "1", "-ar", "16000", tmp_wav, ] proc = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) if proc.returncode != 0: tail = (proc.stderr or "")[-800:] raise RuntimeError(f"ffmpeg decode failed: {tail}") audio_np, sr = sf.read(tmp_wav, dtype="float32", always_2d=True) mono = audio_np.mean(axis=1) return mono.astype(np.float32), int(sr) except Exception as ffmpeg_exc: raise RuntimeError( f"Audio decode failed for '{path}'. " f"torchaudio_error={torchaudio_exc}; " f"soundfile_error={sf_exc}; " f"ffmpeg_error={ffmpeg_exc}" ) from ffmpeg_exc finally: try: if "tmp_wav" in locals(): Path(tmp_wav).unlink(missing_ok=True) except Exception: pass def load_audio_mono(path: str, target_sr: int = 16000) -> Tuple[np.ndarray, int]: audio, sr = _load_audio_with_fallback(path) if sr == target_sr: return audio, sr wav = torch_audio_from_numpy(audio) resampled = torchaudio.functional.resample(wav, sr, target_sr) return resampled.squeeze(0).cpu().numpy().astype(np.float32), target_sr def torch_audio_from_numpy(audio: np.ndarray): import torch if audio.ndim != 1: raise ValueError(f"Expected mono waveform [T], got shape={audio.shape}") return torch.from_numpy(audio).unsqueeze(0) def split_audio_segments( audio: np.ndarray, sample_rate: int, segment_seconds: float, overlap_seconds: float, ) -> List[Tuple[float, float, np.ndarray]]: if segment_seconds <= 0: raise ValueError("segment_seconds must be > 0") if overlap_seconds < 0: raise ValueError("overlap_seconds must be >= 0") if overlap_seconds >= segment_seconds: raise ValueError("overlap_seconds must be smaller than segment_seconds") total_samples = int(audio.shape[0]) segment_samples = max(1, int(round(segment_seconds * sample_rate))) step_samples = max(1, int(round((segment_seconds - overlap_seconds) * sample_rate))) segments: List[Tuple[float, float, np.ndarray]] = [] start = 0 idx = 0 while start < total_samples: end = min(total_samples, start + segment_samples) seg_audio = audio[start:end] start_sec = start / sample_rate end_sec = end / sample_rate segments.append((start_sec, end_sec, seg_audio)) idx += 1 if end >= total_samples: break start = idx * step_samples return segments def _extract_json_from_text(text: str) -> Optional[Dict[str, Any]]: text = (text or "").strip() if not text: return None # Direct parse first. try: obj = json.loads(text) if isinstance(obj, dict): return obj except Exception: pass # Parse markdown code fence if present. fence_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.S | re.I) if fence_match: block = fence_match.group(1) try: obj = json.loads(block) if isinstance(obj, dict): return obj except Exception: pass # Fallback: first brace-balanced object. start = text.find("{") if start < 0: return None depth = 0 for i in range(start, len(text)): ch = text[i] if ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: candidate = text[start : i + 1] try: obj = json.loads(candidate) if isinstance(obj, dict): return obj except Exception: return None return None def _ensure_string_list(value: Any) -> List[str]: if value is None: return [] if isinstance(value, str): v = value.strip() return [v] if v else [] out: List[str] = [] if isinstance(value, Sequence): for item in value: if item is None: continue s = str(item).strip() if s: out.append(s) deduped: List[str] = [] seen = set() for item in out: key = item.lower() if key in seen: continue seen.add(key) deduped.append(item) return deduped def _float_or_none(value: Any) -> Optional[float]: if value is None: return None try: return float(value) except Exception: return None _GENRE_KEYWORDS = [ "pop", "rock", "hip-hop", "hip hop", "rap", "r&b", "rnb", "electronic", "edm", "trap", "house", "techno", "ambient", "indie", "soul", "jazz", "metal", "punk", "country", "lo-fi", "lofi", "drill", ] _INSTRUMENT_KEYWORDS = [ "drums", "kick", "snare", "hihat", "hi-hat", "808", "bass", "synth", "piano", "guitar", "electric guitar", "acoustic guitar", "strings", "pad", "lead", "pluck", "vocal", "choir", ] _EFFECT_KEYWORDS = [ "reverb", "delay", "distortion", "saturation", "autotune", "auto tune", "pitch correction", "compression", "eq", "sidechain", "chorus", "flanger", "phaser", "stereo widening", ] _VOCAL_KEYWORDS = [ "autotune", "auto tune", "pitch correction", "harmonies", "ad-libs", "ad libs", "falsetto", "breathy", "raspy", "processed vocals", ] def _clean_model_text(text: str) -> str: s = (text or "").strip() if not s: return "" # Remove repetitive leading boilerplate often produced when JSON is requested. s = re.sub(r"^\s*The output should be a JSON object with these fields\.?\s*", "", s, flags=re.I) s = re.sub(r"^\s*This is the requested information for the given song segment:?\s*", "", s, flags=re.I) s = re.sub(r"^\s*From\s+\d+(\.\d+)?s\s+to\s+\d+(\.\d+)?s\s*", "", s, flags=re.I) return s.strip() def _extract_bpm_guess(text: str) -> Optional[float]: for pat in [r"\b(\d{2,3}(?:\.\d+)?)\s*bpm\b", r"\btempo\s*(?:of|is|:)?\s*(\d{2,3}(?:\.\d+)?)\b"]: m = re.search(pat, text, flags=re.I) if m: try: val = float(m.group(1)) if 30 <= val <= 300: return val except Exception: continue return None def _extract_key_guess(text: str) -> str: patterns = [ r"\b([A-G](?:#|b)?\s*(?:major|minor))\b", r"\b([A-G](?:#|b)?m)\b", ] for pat in patterns: m = re.search(pat, text, flags=re.I) if m: key = m.group(1).strip() return key[0].upper() + key[1:] return "" def _extract_keyword_hits(text: str, keywords: List[str]) -> List[str]: lower = text.lower() found: List[str] = [] for kw in keywords: if kw.lower() in lower: label = kw.replace("rnb", "R&B").replace("hip-hop", "hip-hop") if label.lower() not in {x.lower() for x in found}: found.append(label) return found class BaseCaptioner: backend_name = "base" model_id = DEFAULT_MODEL_ID def generate( self, audio: np.ndarray, sample_rate: int, prompt: str, max_new_tokens: int, temperature: float, ) -> str: raise NotImplementedError class LocalQwen2AudioCaptioner(BaseCaptioner): backend_name = "local" def __init__( self, model_id: str = DEFAULT_MODEL_ID, device: str = "auto", torch_dtype: str = "auto", trust_remote_code: bool = True, ): self.model_id = model_id self.device = device self.torch_dtype = torch_dtype self.trust_remote_code = trust_remote_code self._processor = None self._model = None def _load(self): if self._processor is not None and self._model is not None: return import torch try: from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration except Exception as exc: raise RuntimeError( "Qwen2-Audio classes are unavailable. Install a recent transformers build " "(for example transformers>=4.53.0) and retry." ) from exc if self.torch_dtype == "auto": dtype = torch.float16 if torch.cuda.is_available() else torch.float32 elif self.torch_dtype == "bfloat16": dtype = torch.bfloat16 elif self.torch_dtype == "float16": dtype = torch.float16 else: dtype = torch.float32 device_map = "auto" if self.device == "auto" else None self._processor = AutoProcessor.from_pretrained( self.model_id, trust_remote_code=self.trust_remote_code, ) self._model = Qwen2AudioForConditionalGeneration.from_pretrained( self.model_id, torch_dtype=dtype, device_map=device_map, trust_remote_code=self.trust_remote_code, ) if device_map is None: if self.device == "auto": target_device = "cuda" if torch.cuda.is_available() else "cpu" else: target_device = self.device self._model.to(target_device) def _model_device(self): import torch if self._model is None: return torch.device("cpu") return next(self._model.parameters()).device def generate( self, audio: np.ndarray, sample_rate: int, prompt: str, max_new_tokens: int, temperature: float, ) -> str: self._load() import torch conversation = [ {"role": "system", "content": "You are a precise music analysis assistant."}, { "role": "user", "content": [ {"type": "audio", "audio_url": "local://segment.wav"}, {"type": "text", "text": prompt}, ], }, ] text = self._processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=False, ) inputs = self._processor( text=text, audio=[audio], sampling_rate=sample_rate, return_tensors="pt", padding=True, ) device = self._model_device() for key, value in list(inputs.items()): if hasattr(value, "to"): inputs[key] = value.to(device) do_sample = bool(temperature and temperature > 0) gen_kwargs = { "max_new_tokens": int(max_new_tokens), "do_sample": do_sample, } if do_sample: gen_kwargs["temperature"] = max(float(temperature), 1e-5) with torch.no_grad(): generated = self._model.generate(**inputs, **gen_kwargs) prompt_tokens = inputs["input_ids"].size(1) generated_new = generated[:, prompt_tokens:] text_out = self._processor.batch_decode( generated_new, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] if not text_out.strip(): text_out = self._processor.batch_decode( generated, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] return text_out.strip() class HFEndpointCaptioner(BaseCaptioner): backend_name = "hf_endpoint" def __init__( self, endpoint_url: str, token: Optional[str] = None, model_id: str = DEFAULT_MODEL_ID, timeout_seconds: int = 180, ): if not endpoint_url: raise ValueError("endpoint_url is required for HFEndpointCaptioner") self.endpoint_url = endpoint_url.strip() self.token = token or os.getenv("HF_TOKEN", "") self.model_id = model_id self.timeout_seconds = timeout_seconds def generate( self, audio: np.ndarray, sample_rate: int, prompt: str, max_new_tokens: int, temperature: float, ) -> str: # Serialize to wav bytes for endpoint transport. buffer = io.BytesIO() sf.write(buffer, audio, sample_rate, format="WAV") wav_bytes = buffer.getvalue() audio_b64 = base64.b64encode(wav_bytes).decode("utf-8") payload = { "inputs": { "prompt": prompt, "audio_base64": audio_b64, "sample_rate": sample_rate, "max_new_tokens": int(max_new_tokens), "temperature": float(temperature), "model_id": self.model_id, } } req = urllib.request.Request( self.endpoint_url, data=json.dumps(payload).encode("utf-8"), headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {self.token}"} if self.token else {}), }, method="POST", ) with urllib.request.urlopen(req, timeout=self.timeout_seconds) as resp: body = resp.read().decode("utf-8") data = json.loads(body) # Accept common endpoint output shapes. if isinstance(data, dict): if isinstance(data.get("generated_text"), str): return data["generated_text"].strip() if isinstance(data.get("text"), str): return data["text"].strip() if isinstance(data.get("output_text"), str): return data["output_text"].strip() if isinstance(data, list) and data: first = data[0] if isinstance(first, dict) and isinstance(first.get("generated_text"), str): return first["generated_text"].strip() return str(data).strip() def build_segment_prompt( base_prompt: str, start_sec: float, end_sec: float, ) -> str: return ( f"{base_prompt}\n\n" f"Analyze only the song segment from {start_sec:.2f}s to {end_sec:.2f}s.\n" "Use timestamp references in absolute song seconds.\n" f"{SEGMENT_JSON_SCHEMA_HINT}" ) def _make_fallback_segment_dict(raw_text: str) -> Dict[str, Any]: summary = _clean_model_text(raw_text) if not summary: summary = "No analysis generated." bpm_guess = _extract_bpm_guess(summary) key_guess = _extract_key_guess(summary) genres = _extract_keyword_hits(summary, _GENRE_KEYWORDS) instruments = _extract_keyword_hits(summary, _INSTRUMENT_KEYWORDS) effects = _extract_keyword_hits(summary, _EFFECT_KEYWORDS) vocal_chars = _extract_keyword_hits(summary, _VOCAL_KEYWORDS) return { "segment_summary": summary, "section_label": "", "genre": genres, "instruments": instruments, "effects": effects, "vocal_characteristics": vocal_chars, "mix_notes": [], "interaction_notes": summary, "bpm_guess": bpm_guess, "key_guess": key_guess, "notable_moments": [], } def _parse_segment_output(raw_text: str) -> Dict[str, Any]: parsed = _extract_json_from_text(raw_text) if not parsed: return _make_fallback_segment_dict(raw_text) out = dict(parsed) out["segment_summary"] = str(out.get("segment_summary", "")).strip() out["section_label"] = str(out.get("section_label", "")).strip() out["genre"] = _ensure_string_list(out.get("genre")) out["instruments"] = _ensure_string_list(out.get("instruments")) out["effects"] = _ensure_string_list(out.get("effects")) out["vocal_characteristics"] = _ensure_string_list(out.get("vocal_characteristics")) out["mix_notes"] = _ensure_string_list(out.get("mix_notes")) out["interaction_notes"] = str(out.get("interaction_notes", "")).strip() out["bpm_guess"] = _float_or_none(out.get("bpm_guess")) out["key_guess"] = str(out.get("key_guess", "")).strip() nm = out.get("notable_moments") cleaned_nm: List[Dict[str, Any]] = [] if isinstance(nm, Sequence): for item in nm: if not isinstance(item, dict): continue ts = _float_or_none(item.get("timestamp_sec")) note = str(item.get("note", "")).strip() if ts is None and not note: continue cleaned_nm.append({"timestamp_sec": ts, "note": note}) out["notable_moments"] = cleaned_nm return out def _pick_common_key(values: List[str]) -> str: counts: Dict[str, int] = {} first_original: Dict[str, str] = {} for v in values: s = (v or "").strip() if not s: continue k = s.lower() counts[k] = counts.get(k, 0) + 1 if k not in first_original: first_original[k] = s if not counts: return "" best = sorted(counts.items(), key=lambda x: (-x[1], x[0]))[0][0] return first_original[best] def _collect_unique(items: List[List[str]], limit: int = 12) -> List[str]: out: List[str] = [] seen = set() for group in items: for item in group: key = item.strip().lower() if not key or key in seen: continue seen.add(key) out.append(item.strip()) if len(out) >= limit: return out return out def _derive_caption(genres: List[str], instruments: List[str], vocals: List[str]) -> str: parts: List[str] = [] if genres: parts.append(", ".join(genres[:2])) if instruments: parts.append("with " + ", ".join(instruments[:3])) if vocals: parts.append("and modern processed vocals") if not parts: return "music track with detailed arrangement and production dynamics" return " ".join(parts) def generate_track_annotation( audio_path: str, captioner: BaseCaptioner, prompt: str = DEFAULT_ANALYSIS_PROMPT, segment_seconds: float = 30.0, overlap_seconds: float = 2.0, max_new_tokens: int = 384, temperature: float = 0.1, keep_raw_outputs: bool = True, include_long_analysis: bool = False, long_analysis_prompt: str = DEFAULT_LONG_ANALYSIS_PROMPT, long_analysis_max_new_tokens: int = 1200, long_analysis_temperature: float = 0.1, ) -> Dict[str, Any]: audio, sr = load_audio_mono(audio_path, target_sr=16000) duration_sec = float(audio.shape[0]) / float(sr) if sr > 0 else 0.0 segments = split_audio_segments( audio=audio, sample_rate=sr, segment_seconds=segment_seconds, overlap_seconds=overlap_seconds, ) results: List[SegmentResult] = [] for idx, (start_sec, end_sec, seg_audio) in enumerate(segments): seg_prompt = build_segment_prompt(prompt, start_sec=start_sec, end_sec=end_sec) raw = captioner.generate( audio=seg_audio, sample_rate=sr, prompt=seg_prompt, max_new_tokens=max_new_tokens, temperature=temperature, ) parsed = _parse_segment_output(raw) results.append( SegmentResult( index=idx, start_sec=start_sec, end_sec=end_sec, prompt=seg_prompt, raw_response=raw, parsed=parsed, ) ) timeline: List[Dict[str, Any]] = [] all_genres: List[List[str]] = [] all_instruments: List[List[str]] = [] all_effects: List[List[str]] = [] all_vocals: List[List[str]] = [] all_mix_notes: List[List[str]] = [] bpm_values: List[float] = [] keys: List[str] = [] interaction_summary: List[str] = [] for seg in results: p = seg.parsed all_genres.append(_ensure_string_list(p.get("genre"))) all_instruments.append(_ensure_string_list(p.get("instruments"))) all_effects.append(_ensure_string_list(p.get("effects"))) all_vocals.append(_ensure_string_list(p.get("vocal_characteristics"))) all_mix_notes.append(_ensure_string_list(p.get("mix_notes"))) bpm = _float_or_none(p.get("bpm_guess")) if bpm is not None and bpm > 0: bpm_values.append(bpm) key_guess = str(p.get("key_guess", "")).strip() if key_guess: keys.append(key_guess) if p.get("interaction_notes"): interaction_summary.append(str(p["interaction_notes"]).strip()) timeline_entry = { "segment_index": seg.index, "start_sec": round(seg.start_sec, 3), "end_sec": round(seg.end_sec, 3), "section_label": str(p.get("section_label", "")).strip(), "segment_summary": str(p.get("segment_summary", "")).strip(), "instruments": _ensure_string_list(p.get("instruments")), "effects": _ensure_string_list(p.get("effects")), "vocal_characteristics": _ensure_string_list(p.get("vocal_characteristics")), "interaction_notes": str(p.get("interaction_notes", "")).strip(), "mix_notes": _ensure_string_list(p.get("mix_notes")), "notable_moments": p.get("notable_moments", []), } if keep_raw_outputs: timeline_entry["raw_response"] = seg.raw_response timeline.append(timeline_entry) genres = _collect_unique(all_genres, limit=10) instruments = _collect_unique(all_instruments, limit=16) effects = _collect_unique(all_effects, limit=16) vocal_traits = _collect_unique(all_vocals, limit=12) mix_notes = _collect_unique(all_mix_notes, limit=24) keyscale = _pick_common_key(keys) bpm = int(round(sum(bpm_values) / len(bpm_values))) if bpm_values else None caption = _derive_caption(genres=genres, instruments=instruments, vocals=vocal_traits) sidecar: Dict[str, Any] = { "caption": caption, "lyrics": "", "bpm": bpm, "keyscale": keyscale, "timesignature": "4/4", "vocal_language": "unknown", "duration": round(duration_sec, 3), "annotation_version": "qwen2_audio_music_v1", "source_audio": str(audio_path), "analysis_prompt": prompt, "analysis_backend": captioner.backend_name, "analysis_model": captioner.model_id, "analysis_generated_at": datetime.now(timezone.utc).isoformat(), "music_analysis": { "genres": genres, "instruments": instruments, "effects": effects, "vocal_characteristics": vocal_traits, "mix_notes": mix_notes, "interaction_summary": interaction_summary, "timeline": timeline, "segment_seconds": segment_seconds, "overlap_seconds": overlap_seconds, "segment_count": len(timeline), }, } if include_long_analysis: long_prompt = (long_analysis_prompt or "").strip() or DEFAULT_LONG_ANALYSIS_PROMPT try: long_raw = captioner.generate( audio=audio, sample_rate=sr, prompt=long_prompt, max_new_tokens=int(long_analysis_max_new_tokens), temperature=float(long_analysis_temperature), ) long_text = _clean_model_text(long_raw) sidecar["analysis_long_prompt"] = long_prompt sidecar["analysis_long"] = long_text sidecar["music_analysis"]["summary_long"] = long_text except Exception as exc: sidecar["analysis_long_prompt"] = long_prompt sidecar["analysis_long"] = "" sidecar["analysis_long_error"] = str(exc) return sidecar def build_captioner( backend: str, model_id: str = DEFAULT_MODEL_ID, endpoint_url: str = "", token: str = "", device: str = "auto", torch_dtype: str = "auto", ) -> BaseCaptioner: backend = (backend or "").strip().lower() if backend in {"local", "hf_space_local"}: return LocalQwen2AudioCaptioner( model_id=model_id or DEFAULT_MODEL_ID, device=device, torch_dtype=torch_dtype, ) if backend in {"endpoint", "hf_endpoint"}: return HFEndpointCaptioner( endpoint_url=endpoint_url, token=token, model_id=model_id or DEFAULT_MODEL_ID, ) raise ValueError(f"Unsupported backend: {backend}") def export_annotation_records( records: List[Dict[str, Any]], output_dir: str, copy_audio: bool = True, write_inplace_sidecars: bool = True, ) -> Dict[str, Any]: """ Export analyzed tracks as LoRA-ready sidecars + manifest. records item schema: { "audio_path": "...", "sidecar": {...annotation json...} } """ out_root = Path(output_dir) out_root.mkdir(parents=True, exist_ok=True) dataset_root = out_root / "dataset" if copy_audio: dataset_root.mkdir(parents=True, exist_ok=True) manifest_path = out_root / "annotations_manifest.jsonl" index_path = out_root / "annotations_index.json" manifest_lines: List[str] = [] index_items: List[Dict[str, Any]] = [] written_count = 0 for rec in records: src_audio = Path(rec["audio_path"]) sidecar = dict(rec["sidecar"]) if not src_audio.exists(): continue if copy_audio: dst_audio = dataset_root / src_audio.name if src_audio.resolve() != dst_audio.resolve(): shutil.copy2(src_audio, dst_audio) dst_sidecar = dst_audio.with_suffix(".json") else: dst_sidecar = (out_root / src_audio.name).with_suffix(".json") dst_sidecar.write_text(json.dumps(sidecar, indent=2, ensure_ascii=False), encoding="utf-8") written_count += 1 if write_inplace_sidecars: inplace_sidecar = src_audio.with_suffix(".json") inplace_sidecar.write_text( json.dumps(sidecar, indent=2, ensure_ascii=False), encoding="utf-8", ) manifest_row = { "audio_path": str(dst_sidecar.with_suffix(src_audio.suffix).as_posix()) if copy_audio else str(src_audio), "sidecar_path": str(dst_sidecar), "caption": sidecar.get("caption", ""), "duration": sidecar.get("duration"), "bpm": sidecar.get("bpm"), "keyscale": sidecar.get("keyscale", ""), } manifest_lines.append(json.dumps(manifest_row, ensure_ascii=False)) index_items.append( { "source_audio": str(src_audio), "exported_sidecar": str(dst_sidecar), "caption": sidecar.get("caption", ""), } ) manifest_path.write_text("\n".join(manifest_lines), encoding="utf-8") index_path.write_text( json.dumps( { "generated_at": datetime.now(timezone.utc).isoformat(), "records": index_items, }, indent=2, ensure_ascii=False, ), encoding="utf-8", ) return { "written_count": written_count, "manifest_path": str(manifest_path), "index_path": str(index_path), "dataset_root": str(dataset_root) if copy_audio else "", } def read_prompt_file(prompt_file: str) -> str: path = Path(prompt_file) if not path.is_file(): raise FileNotFoundError(f"Prompt file not found: {prompt_file}") text = path.read_text(encoding="utf-8").strip() if not text: raise ValueError(f"Prompt file is empty: {prompt_file}") return text