"""Vocence engine for the merged Qwen3-TTS VoiceDesign checkpoint. The Vocence Chutes wrapper instantiates ``Miner`` with the on-disk path of the HF snapshot and then drives it through the contract: Miner(path_hf_repo: Path) warmup() -> None generate_wav(instruction: str, text: str) -> tuple[np.ndarray, int] All weights, the audio codec, and the tokenizer ship together in the snapshot — nothing is fetched at runtime. The HF cache is pre-populated by the wrapper, so ``from_pretrained(model_name)`` resolves from disk without hitting the network. """ from __future__ import annotations import dataclasses import threading from pathlib import Path from typing import Any import numpy as np import yaml _REPO_REQUIRED_FILE = "config.json" _RUNTIME_CONFIG_FILE = "vocence_config.yaml" @dataclasses.dataclass class _RuntimeOpts: """Subset of vocence_config.yaml that the engine actually consumes.""" language: str = "English" sample_rate: int = 24000 device_pref: str = "cuda" dtype_pref: str = "bfloat16" flash_attention_2: bool = False @classmethod def from_config(cls, data: dict) -> "_RuntimeOpts": runtime = data.get("runtime") or {} generation = data.get("generation") or {} limits = data.get("limits") or {} return cls( language=str( limits.get("default_language") or runtime.get("default_language") or "English" ), sample_rate=int(generation.get("sample_rate", 24000)), device_pref=str(runtime.get("device_preference", "cuda")).lower(), dtype_pref=str(runtime.get("dtype", "bfloat16")).lower(), flash_attention_2=bool(runtime.get("use_flash_attention_2", False)), ) class Miner: """Loads merged Qwen3-TTS weights and serves the Vocence API.""" WARMUP_BUDGET_S = 180.0 def __init__(self, path_hf_repo: Path) -> None: self.repo = Path(path_hf_repo).resolve() if not (self.repo / _REPO_REQUIRED_FILE).is_file(): raise FileNotFoundError( f"Snapshot incomplete: {self.repo / _REPO_REQUIRED_FILE} not found" ) with (self.repo / _RUNTIME_CONFIG_FILE).open("r", encoding="utf-8") as fh: cfg = yaml.safe_load(fh) or {} model_name = cfg["model_name"] self.opts = _RuntimeOpts.from_config(cfg) self.model = self._build_model(model_name) def __repr__(self) -> str: return f"" # ------------------------------------------------------------------ # # Vocence contract # # ------------------------------------------------------------------ # def warmup(self) -> None: outcome: dict[str, Any] = {"ok": False, "err": None} def _heat() -> None: try: self.generate_wav(instruction="Calm neutral delivery.", text="Warmup.") outcome["ok"] = True except Exception as exc: # noqa: BLE001 — surface to host outcome["err"] = repr(exc) worker = threading.Thread(target=_heat, daemon=True) worker.start() worker.join(timeout=self.WARMUP_BUDGET_S) if not outcome["ok"]: raise RuntimeError( f"Miner warmup did not complete: {outcome['err'] or 'timeout'}" ) def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]: # The validator's `instruction` and `text` are passed verbatim to the model, # per MINER_GUIDE section 8b.C — no truncation / normalization / rewriting. wavs, sample_rate = self.model.generate_voice_design( text=text, instruct=instruction, language=self.opts.language, ) if not wavs or wavs[0] is None: raise ValueError("Qwen3-TTS returned no audio") wave = self._coerce_mono_float32(wavs[0]) return wave, int(sample_rate) # ------------------------------------------------------------------ # # Internal # # ------------------------------------------------------------------ # @staticmethod def _coerce_mono_float32(arr: Any) -> np.ndarray: wave = np.asarray(arr, dtype=np.float32) if wave.ndim > 1: wave = wave.mean(axis=1) return wave def _build_model(self, model_name): import torch from qwen_tts import Qwen3TTSModel cuda_available = bool(torch.cuda.is_available()) device_map = ( "cuda:0" if (self.opts.device_pref == "cuda" and cuda_available) else "cpu" ) torch_dtype = ( torch.bfloat16 if (self.opts.dtype_pref == "bfloat16" and cuda_available) else torch.float32 ) attempt_order = ( ("flash_attention_2", "sdpa") if self.opts.flash_attention_2 else ("sdpa",) ) last_error: BaseException | None = None for attn in attempt_order: try: model = Qwen3TTSModel.from_pretrained( model_name, device_map=device_map, dtype=torch_dtype, attn_implementation=attn, ) print( f"[Miner] Qwen3-TTS ready on {device_map} " f"(dtype={self.opts.dtype_pref}, attn={attn})" ) return model except Exception as exc: # noqa: BLE001 — try next attn variant last_error = exc raise RuntimeError(f"Qwen3-TTS failed to load: {last_error!r}")