pravinuxd commited on
Commit
7684f3c
·
verified ·
1 Parent(s): 8686966

add server.py for self-bootstrap

Browse files
Files changed (1) hide show
  1. server.py +779 -0
server.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OmniVoice FastAPI server for Linux + NVIDIA (e.g. RunPod).
3
+
4
+ Exposes the full OmniVoice surface (voice clone, voice design, auto voice) + all
5
+ generation parameters from `omnivoice.OmniVoiceGenerationConfig`.
6
+
7
+ Endpoints
8
+ ---------
9
+ - GET /health Liveness/readiness + startup error.
10
+ - GET /v1/models List served model.
11
+ - GET /v1/languages All language display names supported by the model.
12
+ - GET /v1/voice-design/attributes Attribute groups for the Voice Design composer.
13
+ - POST /v1/audio/speech Unified TTS endpoint (multipart):
14
+ text=<str> (required)
15
+ mode=clone|design|auto (default: clone if ref_audio else design if instruct else auto)
16
+ ref_audio=<file> (clone)
17
+ ref_text=<str> (clone, optional — Whisper auto-transcribes if omitted)
18
+ instruct=<str> (design or clone overlay)
19
+ language=<str> (display name, e.g. "English", "Hindi"; "Auto" for auto-detect)
20
+ speed=<float> (default 1.0)
21
+ duration=<float> (seconds; if set overrides speed)
22
+ num_step=<int> (default 32)
23
+ guidance_scale=<float> (default 2.0)
24
+ denoise=<bool> (default true)
25
+ preprocess_prompt=<bool> (default true)
26
+ postprocess_output=<bool> (default true)
27
+ - POST /v1/audio/speech/multi Multi-character story generation (JSON body).
28
+ Splits text on [Sn]…[/Sn] markers, runs OmniVoice once per chunk with
29
+ each character's mode/instruct/ref_audio (clone or design), and stitches
30
+ the resulting WAVs (with a configurable inter-segment silence) into a
31
+ single WAV response. See `MultiSpeechRequest` schema below.
32
+ - POST /v1/audio/speech/clone Backward-compat shim (forwards to /v1/audio/speech with mode=clone).
33
+
34
+ Returns: 200 audio/wav (16-bit PCM mono @ 24 kHz).
35
+ """
36
+
37
+ import asyncio
38
+ import base64
39
+ import binascii
40
+ import io
41
+ import logging
42
+ import os
43
+ import platform
44
+ import re
45
+ import tempfile
46
+ import wave
47
+ from typing import Any
48
+
49
+ import numpy as np
50
+ import torch
51
+ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
52
+ from fastapi.responses import JSONResponse, StreamingResponse
53
+ from pydantic import BaseModel, Field
54
+ from omnivoice import OmniVoice, OmniVoiceGenerationConfig
55
+ from omnivoice.utils.lang_map import LANG_NAMES, lang_display_name
56
+
57
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
58
+
59
+ MODEL_ID = os.getenv("OMNIVOICE_MODEL_ID", "pravinuxd/OmniVoice")
60
+ SAMPLE_RATE = 24000
61
+
62
+
63
+ def _resolve_device_and_dtype() -> tuple[str, torch.dtype]:
64
+ if torch.cuda.is_available():
65
+ return "cuda:0", torch.float16
66
+ return "cpu", torch.float32
67
+
68
+
69
+ DEVICE, DTYPE = _resolve_device_and_dtype()
70
+
71
+
72
+ def _enforce_cuda_if_runpod() -> None:
73
+ on_runpod = bool(os.getenv("RUNPOD_POD_ID"))
74
+ require = os.getenv("OMNIVOICE_REQUIRE_CUDA", "").strip().lower()
75
+ if require in ("1", "true", "yes"):
76
+ must_cuda = True
77
+ elif require in ("0", "false", "no"):
78
+ must_cuda = False
79
+ else:
80
+ must_cuda = on_runpod
81
+ if must_cuda and not torch.cuda.is_available():
82
+ raise RuntimeError(
83
+ "CUDA is not available. This image is built for NVIDIA GPUs on Linux (RunPod). "
84
+ "Attach a GPU to the pod or set OMNIVOICE_REQUIRE_CUDA=0 only for debugging."
85
+ )
86
+
87
+
88
+ _enforce_cuda_if_runpod()
89
+
90
+ if torch.cuda.is_available():
91
+ logging.info(
92
+ "Using GPU: %s | CUDA %s",
93
+ torch.cuda.get_device_name(0),
94
+ torch.version.cuda,
95
+ )
96
+ else:
97
+ logging.warning("Running on CPU — not recommended for production OmniVoice inference.")
98
+ logging.info(
99
+ "Platform: %s | MODEL_ID=%s | device=%s | dtype=%s",
100
+ platform.system(),
101
+ MODEL_ID,
102
+ DEVICE,
103
+ DTYPE,
104
+ )
105
+
106
+ # Pre-compute language list (sorted display names) once.
107
+ _LANGUAGES_SORTED = sorted({lang_display_name(n) for n in LANG_NAMES})
108
+
109
+ # Voice Design attributes (mirrors omnivoice.cli.demo categories so the UI stays in sync).
110
+ _VOICE_DESIGN_ATTRIBUTES = {
111
+ "gender": {
112
+ "label": "Gender",
113
+ "info": "Speaker gender.",
114
+ "options": [
115
+ {"value": "male", "label": "Male / 男"},
116
+ {"value": "female", "label": "Female / 女"},
117
+ ],
118
+ },
119
+ "age": {
120
+ "label": "Age",
121
+ "info": "Approximate speaker age.",
122
+ "options": [
123
+ {"value": "child", "label": "Child / 儿童"},
124
+ {"value": "teenager", "label": "Teenager / 少年"},
125
+ {"value": "young adult", "label": "Young Adult / 青年"},
126
+ {"value": "middle-aged", "label": "Middle-aged / 中年"},
127
+ {"value": "elderly", "label": "Elderly / 老年"},
128
+ ],
129
+ },
130
+ "pitch": {
131
+ "label": "Pitch",
132
+ "info": "Voice pitch register.",
133
+ "options": [
134
+ {"value": "very low pitch", "label": "Very Low / 极低音调"},
135
+ {"value": "low pitch", "label": "Low / ���音调"},
136
+ {"value": "moderate pitch", "label": "Moderate / 中音调"},
137
+ {"value": "high pitch", "label": "High / 高音调"},
138
+ {"value": "very high pitch", "label": "Very High / 极高音调"},
139
+ ],
140
+ },
141
+ "style": {
142
+ "label": "Style",
143
+ "info": "Speaking style.",
144
+ "options": [
145
+ {"value": "whisper", "label": "Whisper / 耳语"},
146
+ ],
147
+ },
148
+ "english_accent": {
149
+ "label": "English Accent",
150
+ "info": "Only effective when generating English speech.",
151
+ "options": [
152
+ {"value": "american accent", "label": "American"},
153
+ {"value": "australian accent", "label": "Australian"},
154
+ {"value": "british accent", "label": "British"},
155
+ {"value": "canadian accent", "label": "Canadian"},
156
+ {"value": "chinese accent", "label": "Chinese"},
157
+ {"value": "indian accent", "label": "Indian"},
158
+ {"value": "japanese accent", "label": "Japanese"},
159
+ {"value": "korean accent", "label": "Korean"},
160
+ {"value": "portuguese accent", "label": "Portuguese"},
161
+ {"value": "russian accent", "label": "Russian"},
162
+ ],
163
+ },
164
+ "chinese_dialect": {
165
+ "label": "Chinese Dialect",
166
+ "info": "Only effective when generating Chinese speech.",
167
+ "options": [
168
+ {"value": "河南话", "label": "Henan / 河南话"},
169
+ {"value": "陕西话", "label": "Shaanxi / 陕西话"},
170
+ {"value": "四川话", "label": "Sichuan / 四川话"},
171
+ {"value": "贵州话", "label": "Guizhou / 贵州话"},
172
+ {"value": "云南话", "label": "Yunnan / 云南话"},
173
+ {"value": "桂林话", "label": "Guilin / 桂林话"},
174
+ {"value": "济南话", "label": "Jinan / 济南话"},
175
+ {"value": "石家庄话", "label": "Shijiazhuang / 石家庄话"},
176
+ {"value": "甘肃话", "label": "Gansu / 甘肃话"},
177
+ {"value": "宁夏话", "label": "Ningxia / 宁夏话"},
178
+ {"value": "青岛话", "label": "Qingdao / 青岛话"},
179
+ {"value": "东北话", "label": "Northeast / 东北话"},
180
+ ],
181
+ },
182
+ }
183
+
184
+ app = FastAPI(title="OmniVoice Pod API", version="2")
185
+ model: OmniVoice | None = None
186
+ startup_error: str | None = None
187
+
188
+ # Dubbing Studio Lite — sibling module that registers /v1/dub/* routes.
189
+ # Imported here so the GPU_LOCK lives in a single shared instance and the
190
+ # OmniVoice model handle can be shared in-process (no HTTP round-trip per
191
+ # segment).
192
+ try:
193
+ from . import dub as _dub # type: ignore[import-not-found]
194
+ except ImportError:
195
+ import dub as _dub # type: ignore[no-redef]
196
+
197
+ app.include_router(_dub.router)
198
+
199
+
200
+ def _load_model() -> OmniVoice:
201
+ return OmniVoice.from_pretrained(
202
+ MODEL_ID,
203
+ device_map=DEVICE,
204
+ dtype=DTYPE,
205
+ load_asr=True,
206
+ )
207
+
208
+
209
+ @app.on_event("startup")
210
+ def startup_event() -> None:
211
+ global model, startup_error
212
+ try:
213
+ model = _load_model()
214
+ startup_error = None
215
+ # Hand the live model + generation-config class to the dub module so
216
+ # `/v1/dub/jobs` can synthesize segments without going over HTTP.
217
+ _dub.configure(model, OmniVoiceGenerationConfig)
218
+ except Exception as exc: # pragma: no cover
219
+ startup_error = f"{type(exc).__name__}: {exc}"
220
+ raise
221
+
222
+
223
+ def _wav_bytes(audio: np.ndarray) -> bytes:
224
+ clipped = np.clip(audio, -1.0, 1.0)
225
+ pcm16 = (clipped * 32767.0).astype(np.int16)
226
+ buf = io.BytesIO()
227
+ with wave.open(buf, "wb") as wav:
228
+ wav.setnchannels(1)
229
+ wav.setsampwidth(2)
230
+ wav.setframerate(SAMPLE_RATE)
231
+ wav.writeframes(pcm16.tobytes())
232
+ return buf.getvalue()
233
+
234
+
235
+ def _parse_bool(val: str | None, default: bool) -> bool:
236
+ if val is None:
237
+ return default
238
+ return val.strip().lower() in ("1", "true", "yes", "on")
239
+
240
+
241
+ def _normalize_language(lang: str | None) -> str | None:
242
+ if not lang:
243
+ return None
244
+ cleaned = lang.strip()
245
+ if not cleaned or cleaned.lower() == "auto":
246
+ return None
247
+ return cleaned
248
+
249
+
250
+ def _resolve_mode(mode: str | None, has_ref_audio: bool, has_instruct: bool) -> str:
251
+ if mode:
252
+ m = mode.strip().lower()
253
+ if m in ("clone", "design", "auto"):
254
+ return m
255
+ if has_ref_audio:
256
+ return "clone"
257
+ if has_instruct:
258
+ return "design"
259
+ return "auto"
260
+
261
+
262
+ @app.get("/health")
263
+ def health() -> JSONResponse:
264
+ ready = model is not None and startup_error is None
265
+ return JSONResponse(
266
+ {
267
+ "status": "healthy" if ready else "starting",
268
+ "ready": ready,
269
+ "model_loaded": ready,
270
+ "model_id": MODEL_ID,
271
+ "device": DEVICE,
272
+ "startup_error": startup_error,
273
+ }
274
+ )
275
+
276
+
277
+ @app.get("/v1/models")
278
+ def list_models() -> JSONResponse:
279
+ return JSONResponse(
280
+ {
281
+ "object": "list",
282
+ "data": [
283
+ {
284
+ "id": "omnivoice",
285
+ "object": "model",
286
+ "owned_by": "pravinuxd",
287
+ "root": MODEL_ID,
288
+ }
289
+ ],
290
+ }
291
+ )
292
+
293
+
294
+ @app.get("/v1/languages")
295
+ def list_languages() -> JSONResponse:
296
+ return JSONResponse({"languages": _LANGUAGES_SORTED, "count": len(_LANGUAGES_SORTED)})
297
+
298
+
299
+ @app.get("/v1/voice-design/attributes")
300
+ def list_voice_design_attributes() -> JSONResponse:
301
+ return JSONResponse({"attributes": _VOICE_DESIGN_ATTRIBUTES})
302
+
303
+
304
+ def _generate_audio(
305
+ *,
306
+ text: str,
307
+ mode: str,
308
+ ref_audio_path: str | None,
309
+ ref_text: str | None,
310
+ instruct: str | None,
311
+ language: str | None,
312
+ speed: float,
313
+ duration: float | None,
314
+ num_step: int,
315
+ guidance_scale: float,
316
+ denoise: bool,
317
+ preprocess_prompt: bool,
318
+ postprocess_output: bool,
319
+ ) -> bytes:
320
+ if model is None:
321
+ raise HTTPException(status_code=503, detail=startup_error or "Model not ready")
322
+
323
+ gen_config = OmniVoiceGenerationConfig(
324
+ num_step=num_step,
325
+ guidance_scale=guidance_scale,
326
+ denoise=denoise,
327
+ preprocess_prompt=preprocess_prompt,
328
+ postprocess_output=postprocess_output,
329
+ )
330
+
331
+ kw: dict[str, Any] = {
332
+ "text": text,
333
+ "language": language,
334
+ "generation_config": gen_config,
335
+ }
336
+ if speed != 1.0:
337
+ kw["speed"] = speed
338
+ if duration is not None and duration > 0:
339
+ kw["duration"] = duration
340
+
341
+ if mode == "clone":
342
+ if not ref_audio_path:
343
+ raise HTTPException(status_code=400, detail="mode=clone requires ref_audio")
344
+ kw["voice_clone_prompt"] = model.create_voice_clone_prompt(
345
+ ref_audio=ref_audio_path,
346
+ ref_text=ref_text or None,
347
+ )
348
+
349
+ if instruct and instruct.strip():
350
+ kw["instruct"] = instruct.strip()
351
+
352
+ try:
353
+ generated = model.generate(**kw)
354
+ except HTTPException:
355
+ raise
356
+ except Exception as exc:
357
+ logging.exception("OmniVoice generation failed")
358
+ raise HTTPException(
359
+ status_code=500, detail=f"{type(exc).__name__}: {exc}"
360
+ ) from exc
361
+
362
+ return _wav_bytes(generated[0])
363
+
364
+
365
+ @app.post("/v1/audio/speech")
366
+ async def synth_speech(
367
+ text: str = Form(...),
368
+ mode: str | None = Form(None),
369
+ ref_audio: UploadFile | None = File(None),
370
+ ref_text: str | None = Form(None),
371
+ instruct: str | None = Form(None),
372
+ language: str | None = Form(None),
373
+ speed: float = Form(1.0),
374
+ duration: float | None = Form(None),
375
+ num_step: int = Form(32),
376
+ guidance_scale: float = Form(2.0),
377
+ denoise: str | None = Form(None),
378
+ preprocess_prompt: str | None = Form(None),
379
+ postprocess_output: str | None = Form(None),
380
+ ) -> StreamingResponse:
381
+ text = (text or "").strip()
382
+ if not text:
383
+ raise HTTPException(status_code=400, detail="text is required")
384
+
385
+ has_ref = ref_audio is not None and (ref_audio.filename or "")
386
+ has_instruct = bool(instruct and instruct.strip())
387
+ resolved_mode = _resolve_mode(mode, bool(has_ref), has_instruct)
388
+
389
+ tmp_path: str | None = None
390
+ try:
391
+ if resolved_mode == "clone":
392
+ if not has_ref:
393
+ raise HTTPException(
394
+ status_code=400, detail="mode=clone requires ref_audio"
395
+ )
396
+ audio_bytes = await ref_audio.read()
397
+ suffix = (
398
+ os.path.splitext(ref_audio.filename or "reference.wav")[1] or ".wav"
399
+ )
400
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
401
+ tmp.write(audio_bytes)
402
+ tmp_path = tmp.name
403
+
404
+ # Share the dub module's GPU lock so a long-running dub job and a
405
+ # live TTS request never fight for the single A40 GPU.
406
+ async with _dub.GPU_LOCK:
407
+ wav_bytes = await asyncio.to_thread(
408
+ _generate_audio,
409
+ text=text,
410
+ mode=resolved_mode,
411
+ ref_audio_path=tmp_path,
412
+ ref_text=(ref_text or None),
413
+ instruct=(instruct or None),
414
+ language=_normalize_language(language),
415
+ speed=float(speed),
416
+ duration=(float(duration) if duration is not None else None),
417
+ num_step=int(num_step or 32),
418
+ guidance_scale=float(guidance_scale),
419
+ denoise=_parse_bool(denoise, True),
420
+ preprocess_prompt=_parse_bool(preprocess_prompt, True),
421
+ postprocess_output=_parse_bool(postprocess_output, True),
422
+ )
423
+
424
+ headers = {"X-OmniVoice-Mode": resolved_mode}
425
+ return StreamingResponse(
426
+ io.BytesIO(wav_bytes), media_type="audio/wav", headers=headers
427
+ )
428
+ finally:
429
+ if tmp_path and os.path.exists(tmp_path):
430
+ os.unlink(tmp_path)
431
+
432
+
433
+ # ---------------------------------------------------------------------------
434
+ # Multi-character story endpoint
435
+ # ---------------------------------------------------------------------------
436
+
437
+ # Matches [S1] … [/S1] (or [S2], …, up to [S8]). The slot number is captured
438
+ # in group 1, optional attributes (e.g. `duration=2.5s`) in group 2, the inner
439
+ # text in group 3. We use a non-greedy match so two consecutive blocks like
440
+ # "[S1]hi[/S1] [S2]bye[/S2]" parse correctly.
441
+ _SPEAKER_TAG_RE = re.compile(
442
+ r"\[\s*S([1-8])([^\]]*)\](.*?)\[\s*/\s*S\1\s*\]",
443
+ re.DOTALL,
444
+ )
445
+
446
+ # Pure-silence directive embedded in dialogue text. Recognised inside any
447
+ # [Sn]…[/Sn] block; when the entire inner text matches we emit `silence(s)`
448
+ # without invoking the model. Supports `[pause=2s]`, `[pause=500ms]`,
449
+ # `[pause=2]` (defaults to seconds), and case-insensitive.
450
+ _PAUSE_TAG_RE = re.compile(
451
+ r"^\s*\[\s*pause\s*=\s*([0-9]*\.?[0-9]+)\s*(s|ms)?\s*\]\s*$",
452
+ re.IGNORECASE,
453
+ )
454
+
455
+ # Per-block `duration=` attribute parsed out of the speaker tag's attribute
456
+ # string (group 2 above).
457
+ _DURATION_ATTR_RE = re.compile(
458
+ r"duration\s*=\s*([0-9]*\.?[0-9]+)\s*(s|ms)?",
459
+ re.IGNORECASE,
460
+ )
461
+
462
+
463
+ def _parse_pause_seconds(text: str) -> float | None:
464
+ """If `text` is a single `[pause=…]` directive return its length in seconds."""
465
+ m = _PAUSE_TAG_RE.match(text)
466
+ if not m:
467
+ return None
468
+ value = float(m.group(1))
469
+ unit = (m.group(2) or "s").lower()
470
+ return value / 1000.0 if unit == "ms" else value
471
+
472
+
473
+ def _parse_duration_attr(attrs: str) -> float | None:
474
+ """Parse `duration=Xs` / `duration=Xms` from the `[Sn …]` attribute string."""
475
+ m = _DURATION_ATTR_RE.search(attrs or "")
476
+ if not m:
477
+ return None
478
+ value = float(m.group(1))
479
+ unit = (m.group(2) or "s").lower()
480
+ return value / 1000.0 if unit == "ms" else value
481
+
482
+
483
+ class CharacterConfig(BaseModel):
484
+ """One character slot (S1..S8) used by /v1/audio/speech/multi."""
485
+
486
+ slot: int = Field(..., ge=1, le=8)
487
+ name: str | None = None
488
+ mode: str = Field(..., pattern="^(design|clone)$")
489
+ instruct: str | None = None
490
+ language: str | None = None
491
+ speed: float | None = None
492
+ # base64-encoded WAV/MP3/etc. — only used when mode == "clone".
493
+ ref_audio_b64: str | None = None
494
+ ref_text: str | None = None
495
+
496
+
497
+ class MultiSpeechRequest(BaseModel):
498
+ text: str
499
+ characters: list[CharacterConfig] = Field(default_factory=list)
500
+ # Default character used for any narrative text outside [Sn]…[/Sn] blocks.
501
+ # If omitted, narration is generated with mode=auto (random voice for the
502
+ # detected language).
503
+ narrator: CharacterConfig | None = None
504
+ # Common controls applied to every chunk unless the character overrides.
505
+ speed: float = 1.0
506
+ num_step: int = 32
507
+ guidance_scale: float = 2.0
508
+ denoise: bool = True
509
+ preprocess_prompt: bool = True
510
+ postprocess_output: bool = True
511
+ # Silence inserted between chunks for natural pacing.
512
+ inter_segment_silence_ms: int = 250
513
+
514
+
515
+ def _split_into_segments(
516
+ text: str,
517
+ ) -> list[tuple[int | None, str, dict[str, float]]]:
518
+ """Split `text` into (slot|None, chunk_text, attrs) segments.
519
+
520
+ Slot is None for narration (text outside any [Sn]…[/Sn] block).
521
+ Empty / whitespace-only chunks are dropped. `attrs` carries optional
522
+ metadata parsed from the speaker tag (currently `duration`).
523
+ """
524
+ segments: list[tuple[int | None, str, dict[str, float]]] = []
525
+ cursor = 0
526
+ for match in _SPEAKER_TAG_RE.finditer(text):
527
+ before = text[cursor : match.start()]
528
+ if before.strip():
529
+ segments.append((None, before.strip(), {}))
530
+ slot = int(match.group(1))
531
+ attrs_str = match.group(2) or ""
532
+ inner = match.group(3).strip()
533
+ attrs: dict[str, float] = {}
534
+ dur = _parse_duration_attr(attrs_str)
535
+ if dur is not None and dur > 0:
536
+ attrs["duration"] = dur
537
+ if inner:
538
+ segments.append((slot, inner, attrs))
539
+ cursor = match.end()
540
+ tail = text[cursor:]
541
+ if tail.strip():
542
+ segments.append((None, tail.strip(), {}))
543
+ return segments
544
+
545
+
546
+ def _decode_ref_audio(b64: str) -> bytes:
547
+ try:
548
+ return base64.b64decode(b64, validate=True)
549
+ except (binascii.Error, ValueError) as exc:
550
+ raise HTTPException(
551
+ status_code=400, detail=f"Invalid ref_audio_b64: {exc}"
552
+ ) from exc
553
+
554
+
555
+ def _silence_pcm(milliseconds: int) -> np.ndarray:
556
+ samples = max(0, int(SAMPLE_RATE * (milliseconds / 1000.0)))
557
+ return np.zeros(samples, dtype=np.float32)
558
+
559
+
560
+ def _wav_bytes_to_float_pcm(buf: bytes) -> np.ndarray:
561
+ """Read a 16-bit PCM mono WAV (any sample rate) and return float32 [-1, 1].
562
+
563
+ If the sample rate doesn't match SAMPLE_RATE we keep it as-is and rely on
564
+ the model emitting at SAMPLE_RATE; this is a defensive helper used only
565
+ when we round-trip WAVs (we never receive WAVs from the model — we get
566
+ raw float arrays — so this branch is mostly here for testing).
567
+ """
568
+ with wave.open(io.BytesIO(buf), "rb") as wav:
569
+ n = wav.getnframes()
570
+ raw = wav.readframes(n)
571
+ pcm16 = np.frombuffer(raw, dtype=np.int16)
572
+ return (pcm16.astype(np.float32) / 32767.0).copy()
573
+
574
+
575
+ def _generate_chunk(
576
+ *,
577
+ text: str,
578
+ character: CharacterConfig | None,
579
+ common_speed: float,
580
+ num_step: int,
581
+ guidance_scale: float,
582
+ denoise: bool,
583
+ preprocess_prompt: bool,
584
+ postprocess_output: bool,
585
+ block_duration: float | None = None,
586
+ ) -> np.ndarray:
587
+ """Generate a single chunk with the given character and return a float32 PCM array."""
588
+ if model is None:
589
+ raise HTTPException(status_code=503, detail=startup_error or "Model not ready")
590
+
591
+ mode = "auto"
592
+ instruct = None
593
+ language = None
594
+ speed = common_speed
595
+ ref_audio_path: str | None = None
596
+ ref_text: str | None = None
597
+ cleanup_paths: list[str] = []
598
+
599
+ try:
600
+ if character is not None:
601
+ mode = character.mode
602
+ instruct = (character.instruct or "").strip() or None
603
+ language = _normalize_language(character.language)
604
+ if character.speed is not None:
605
+ speed = character.speed
606
+ if mode == "clone":
607
+ if not character.ref_audio_b64:
608
+ raise HTTPException(
609
+ status_code=400,
610
+ detail=(
611
+ f"Character S{character.slot} (mode=clone) requires ref_audio_b64"
612
+ ),
613
+ )
614
+ audio_bytes = _decode_ref_audio(character.ref_audio_b64)
615
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
616
+ tmp.write(audio_bytes)
617
+ ref_audio_path = tmp.name
618
+ cleanup_paths.append(ref_audio_path)
619
+ ref_text = (character.ref_text or "").strip() or None
620
+
621
+ gen_config = OmniVoiceGenerationConfig(
622
+ num_step=num_step,
623
+ guidance_scale=guidance_scale,
624
+ denoise=denoise,
625
+ preprocess_prompt=preprocess_prompt,
626
+ postprocess_output=postprocess_output,
627
+ )
628
+ kw: dict[str, Any] = {
629
+ "text": text,
630
+ "language": language,
631
+ "generation_config": gen_config,
632
+ }
633
+ # `duration` overrides `speed` per the OmniVoice contract — only set
634
+ # one or the other.
635
+ if block_duration is not None and block_duration > 0:
636
+ kw["duration"] = float(block_duration)
637
+ elif speed != 1.0:
638
+ kw["speed"] = speed
639
+ if mode == "clone" and ref_audio_path:
640
+ kw["voice_clone_prompt"] = model.create_voice_clone_prompt(
641
+ ref_audio=ref_audio_path,
642
+ ref_text=ref_text,
643
+ )
644
+ if instruct:
645
+ kw["instruct"] = instruct
646
+
647
+ try:
648
+ generated = model.generate(**kw)
649
+ except HTTPException:
650
+ raise
651
+ except Exception as exc:
652
+ logging.exception("OmniVoice multi-character chunk failed")
653
+ raise HTTPException(
654
+ status_code=500,
655
+ detail=f"{type(exc).__name__}: {exc}",
656
+ ) from exc
657
+
658
+ # `generate` returns a tensor or numpy array shaped (T,) or (1, T).
659
+ chunk = generated[0]
660
+ if isinstance(chunk, torch.Tensor):
661
+ chunk = chunk.detach().cpu().float().numpy()
662
+ chunk = np.asarray(chunk, dtype=np.float32).reshape(-1)
663
+ return chunk
664
+ finally:
665
+ for path in cleanup_paths:
666
+ if path and os.path.exists(path):
667
+ try:
668
+ os.unlink(path)
669
+ except OSError:
670
+ pass
671
+
672
+
673
+ def _concat_pcm(arrays: list[np.ndarray], silence_samples: int) -> np.ndarray:
674
+ if not arrays:
675
+ return np.zeros(0, dtype=np.float32)
676
+ if len(arrays) == 1:
677
+ return arrays[0]
678
+ silence = np.zeros(silence_samples, dtype=np.float32)
679
+ out: list[np.ndarray] = []
680
+ for i, arr in enumerate(arrays):
681
+ if i > 0:
682
+ out.append(silence)
683
+ out.append(arr)
684
+ return np.concatenate(out, dtype=np.float32)
685
+
686
+
687
+ @app.post("/v1/audio/speech/multi")
688
+ async def synth_multi_speech(req: MultiSpeechRequest) -> StreamingResponse:
689
+ text = (req.text or "").strip()
690
+ if not text:
691
+ raise HTTPException(status_code=400, detail="text is required")
692
+
693
+ characters_by_slot: dict[int, CharacterConfig] = {c.slot: c for c in req.characters}
694
+
695
+ segments = _split_into_segments(text)
696
+ if not segments:
697
+ raise HTTPException(
698
+ status_code=400, detail="text has no generatable content after parsing"
699
+ )
700
+
701
+ referenced_slots = {slot for slot, _, _ in segments if slot is not None}
702
+ missing = sorted(referenced_slots - characters_by_slot.keys())
703
+ if missing:
704
+ raise HTTPException(
705
+ status_code=400,
706
+ detail=f"Missing character config for slot(s): {missing}",
707
+ )
708
+
709
+ chunks: list[np.ndarray] = []
710
+ # Hold the GPU for the entire multi-segment run so a competing dub job
711
+ # can't trample our KV cache mid-story. Pause segments don't touch the
712
+ # model so they execute instantly inside the lock.
713
+ async with _dub.GPU_LOCK:
714
+ for slot, chunk_text, attrs in segments:
715
+ pause_seconds = _parse_pause_seconds(chunk_text)
716
+ if pause_seconds is not None and pause_seconds > 0:
717
+ chunks.append(_silence_pcm(int(pause_seconds * 1000)))
718
+ continue
719
+
720
+ character = (
721
+ characters_by_slot.get(slot) if slot is not None else req.narrator
722
+ )
723
+ # Per-block target duration (Movie Dubbing): the speaker tag may carry
724
+ # `[Sn duration=2.5s]…[/Sn]`. We pass it down to OmniVoice so it can
725
+ # fit the chunk to the requested length.
726
+ block_duration = attrs.get("duration") if attrs else None
727
+ pcm = await asyncio.to_thread(
728
+ _generate_chunk,
729
+ text=chunk_text,
730
+ character=character,
731
+ common_speed=req.speed,
732
+ num_step=req.num_step,
733
+ guidance_scale=req.guidance_scale,
734
+ denoise=req.denoise,
735
+ preprocess_prompt=req.preprocess_prompt,
736
+ postprocess_output=req.postprocess_output,
737
+ block_duration=block_duration,
738
+ )
739
+ chunks.append(pcm)
740
+
741
+ silence_samples = max(0, int(SAMPLE_RATE * (req.inter_segment_silence_ms / 1000.0)))
742
+ combined = _concat_pcm(chunks, silence_samples=silence_samples)
743
+ wav_bytes = _wav_bytes(combined)
744
+ headers = {
745
+ "X-OmniVoice-Mode": "multi",
746
+ "X-OmniVoice-Segments": str(len(segments)),
747
+ }
748
+ return StreamingResponse(
749
+ io.BytesIO(wav_bytes), media_type="audio/wav", headers=headers
750
+ )
751
+
752
+
753
+ @app.post("/v1/audio/speech/clone")
754
+ async def clone_speech_compat(
755
+ text: str = Form(...),
756
+ ref_audio: UploadFile = File(...),
757
+ ref_text: str | None = Form(None),
758
+ instruct: str | None = Form(None),
759
+ language: str | None = Form(None),
760
+ speed: float = Form(1.0),
761
+ num_step: int = Form(32),
762
+ guidance_scale: float = Form(2.0),
763
+ ) -> StreamingResponse:
764
+ """Backward-compat shim — same as POST /v1/audio/speech with mode=clone."""
765
+ return await synth_speech( # type: ignore[return-value]
766
+ text=text,
767
+ mode="clone",
768
+ ref_audio=ref_audio,
769
+ ref_text=ref_text,
770
+ instruct=instruct,
771
+ language=language,
772
+ speed=speed,
773
+ duration=None,
774
+ num_step=num_step,
775
+ guidance_scale=guidance_scale,
776
+ denoise=None,
777
+ preprocess_prompt=None,
778
+ postprocess_output=None,
779
+ )