File size: 5,814 Bytes
7b58314
 
 
 
 
 
 
 
 
 
09a0f63
 
7b58314
 
 
 
 
 
 
 
 
09a0f63
7b58314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09a0f63
7b58314
 
 
 
09a0f63
 
 
 
 
7b58314
 
 
 
 
 
 
 
09a0f63
7b58314
 
 
 
 
 
 
 
 
09a0f63
 
 
 
 
 
 
7b58314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09a0f63
 
 
7b58314
 
09a0f63
 
7b58314
09a0f63
 
7b58314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09a0f63
7b58314
 
 
 
09a0f63
 
 
7b58314
 
 
 
 
 
09a0f63
 
 
7b58314
 
 
 
09a0f63
7b58314
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Vocence engine for the merged Qwen3-TTS VoiceDesign checkpoint.

The Vocence Chutes wrapper instantiates ``Miner`` with the on-disk path of the HF
snapshot and then drives it through the contract:

    Miner(path_hf_repo: Path)
    warmup() -> None
    generate_wav(instruction: str, text: str) -> tuple[np.ndarray, int]

All weights, the audio codec, and the tokenizer ship together in the snapshot —
nothing is fetched at runtime. The HF cache is pre-populated by the wrapper, so
``from_pretrained(model_name)`` resolves from disk without hitting the network.
"""
from __future__ import annotations

import dataclasses
import threading
from pathlib import Path
from typing import Any

import numpy as np
import yaml


_REPO_REQUIRED_FILE = "config.json"
_RUNTIME_CONFIG_FILE = "vocence_config.yaml"


@dataclasses.dataclass
class _RuntimeOpts:
    """Subset of vocence_config.yaml that the engine actually consumes."""

    language: str = "English"
    sample_rate: int = 24000
    device_pref: str = "cuda"
    dtype_pref: str = "bfloat16"
    flash_attention_2: bool = False

    @classmethod
    def from_config(cls, data: dict) -> "_RuntimeOpts":
        runtime = data.get("runtime") or {}
        generation = data.get("generation") or {}
        limits = data.get("limits") or {}
        return cls(
            language=str(
                limits.get("default_language")
                or runtime.get("default_language")
                or "English"
            ),
            sample_rate=int(generation.get("sample_rate", 24000)),
            device_pref=str(runtime.get("device_preference", "cuda")).lower(),
            dtype_pref=str(runtime.get("dtype", "bfloat16")).lower(),
            flash_attention_2=bool(runtime.get("use_flash_attention_2", False)),
        )


class Miner:
    """Loads merged Qwen3-TTS weights and serves the Vocence API."""

    WARMUP_BUDGET_S = 180.0

    def __init__(self, path_hf_repo: Path) -> None:
        self.repo = Path(path_hf_repo).resolve()
        if not (self.repo / _REPO_REQUIRED_FILE).is_file():
            raise FileNotFoundError(
                f"Snapshot incomplete: {self.repo / _REPO_REQUIRED_FILE} not found"
            )

        with (self.repo / _RUNTIME_CONFIG_FILE).open("r", encoding="utf-8") as fh:
            cfg = yaml.safe_load(fh) or {}
        model_name = cfg["model_name"]

        self.opts = _RuntimeOpts.from_config(cfg)
        self.model = self._build_model(model_name)

    def __repr__(self) -> str:
        return f"<Miner repo={self.repo.name} language={self.opts.language!r}>"

    # ------------------------------------------------------------------ #
    # Vocence contract                                                    #
    # ------------------------------------------------------------------ #

    def warmup(self) -> None:
        outcome: dict[str, Any] = {"ok": False, "err": None}

        def _heat() -> None:
            try:
                self.generate_wav(instruction="Calm neutral delivery.", text="Warmup.")
                outcome["ok"] = True
            except Exception as exc:  # noqa: BLE001 — surface to host
                outcome["err"] = repr(exc)

        worker = threading.Thread(target=_heat, daemon=True)
        worker.start()
        worker.join(timeout=self.WARMUP_BUDGET_S)
        if not outcome["ok"]:
            raise RuntimeError(
                f"Miner warmup did not complete: {outcome['err'] or 'timeout'}"
            )

    def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
        # The validator's `instruction` and `text` are passed verbatim to the model,
        # per MINER_GUIDE section 8b.C — no truncation / normalization / rewriting.
        wavs, sample_rate = self.model.generate_voice_design(
            text=text,
            instruct=instruction,
            language=self.opts.language,
        )
        if not wavs or wavs[0] is None:
            raise ValueError("Qwen3-TTS returned no audio")

        wave = self._coerce_mono_float32(wavs[0])
        return wave, int(sample_rate)

    # ------------------------------------------------------------------ #
    # Internal                                                            #
    # ------------------------------------------------------------------ #

    @staticmethod
    def _coerce_mono_float32(arr: Any) -> np.ndarray:
        wave = np.asarray(arr, dtype=np.float32)
        if wave.ndim > 1:
            wave = wave.mean(axis=1)
        return wave

    def _build_model(self, model_name):
        import torch
        from qwen_tts import Qwen3TTSModel

        cuda_available = bool(torch.cuda.is_available())
        device_map = (
            "cuda:0" if (self.opts.device_pref == "cuda" and cuda_available) else "cpu"
        )
        torch_dtype = (
            torch.bfloat16
            if (self.opts.dtype_pref == "bfloat16" and cuda_available)
            else torch.float32
        )

        attempt_order = (
            ("flash_attention_2", "sdpa") if self.opts.flash_attention_2 else ("sdpa",)
        )
        last_error: BaseException | None = None
        for attn in attempt_order:
            try:
                model = Qwen3TTSModel.from_pretrained(
                    model_name,
                    device_map=device_map,
                    dtype=torch_dtype,
                    attn_implementation=attn,
                )
                print(
                    f"[Miner] Qwen3-TTS ready on {device_map} "
                    f"(dtype={self.opts.dtype_pref}, attn={attn})"
                )
                return model
            except Exception as exc:  # noqa: BLE001 — try next attn variant
                last_error = exc
        raise RuntimeError(f"Qwen3-TTS failed to load: {last_error!r}")