""" Dubbing Studio Lite worker for the OmniVoice pod (long-form, scene-aware). This module is imported by `server.py` and registers the dub routes: - `POST /v1/dub/jobs` — accept a job spec, plan scene-aware batches, immediately return `{accepted: true}` and process batches in the background under `GPU_LOCK`. - `POST /v1/dub/jobs/{job_id}/batches/{i}/run` — re-run a single batch (idempotent; skips already-uploaded sub-artifacts). - `POST /v1/dub/jobs/{job_id}/finalize` — concat per-batch mixes, optionally mux back into the original video container, then `/claim`. - `GET /v1/dub/jobs/{job_id}` — pod-local progress (jobs + per-batch state) for diagnostics. Pipeline (per batch — outer iterator runs them serially under `GPU_LOCK`): ingest (once per job) -> Demucs (full vocals/music) -> VAD batch planner -> for each batch: WhisperX -> pyannote (within batch) -> M2M-100 -> OmniVoice clone TTS (per segment) -> per-batch mix -> R2 PUT. Cross-batch speaker reconciliation: the first batch establishes a registry of `{label, embedding}` pairs (Resemblyzer 256-d). Subsequent batches map their local pyannote labels to the registry by cosine similarity (>= 0.78 = same speaker; otherwise a new label is appended). This keeps "Speaker 1" the same person across the whole video. Auth: every callback to the Next.js app uses `Authorization: Bearer ${DUB_WORKER_TOKEN}` (read from process env at request time). """ from __future__ import annotations import asyncio import glob import json import logging import os import shutil import subprocess import sys import tempfile import time import types import wave from dataclasses import dataclass, field from typing import Any # -------------------------------------------------------------------------- # torchcodec import-time stub. # # `transformers>=5.x` does an unconditional `import torchcodec` at module load # time inside `transformers/audio_utils.py`. The pre-built `torchcodec` wheel # (>=0.7) requires either FFmpeg 5+ system libs OR a torch build that ships # the `float4_e2m1fn_x2` dtype (torch>=2.8). Our pods run torch 2.7.1 + the # Ubuntu 22.04 default FFmpeg 4 libs, so every torchcodec.dlopen path fails: # # OSError: libavutil.so.{60,59,58,57}: cannot open shared object file # OSError: undefined symbol: torch_dtype_float4_e2m1fn_x2 # # That import error then bubbles up through transformers -> pyannote and kills # the dub pipeline at runtime ("Could not load libtorchcodec ..."). # # We never decode audio/video through torchcodec — the dub pipeline uses the # `ffmpeg` CLI + `soundfile` + `librosa` for all I/O. So we register a stub # package in `sys.modules` *before* any 3rd-party import, satisfying the # `import torchcodec` line without dlopening the broken shared library. If # something ever does try to *call* torchcodec, we surface a clear error. if "torchcodec" not in sys.modules: _stub = types.ModuleType("torchcodec") _stub.__version__ = "0.0.0-voicelab-stub" _stub.__doc__ = ( "VoiceLab stub: torchcodec is intentionally not loaded on this pod. " "Use ffmpeg/soundfile for audio I/O instead." ) # NOTE: deliberately NO __getattr__ hook here. torch internals run things # like `hasattr(some_module, "torchcodec")` and `getattr(..., None)` during # import — a raising __getattr__ would poison those probes and break torch # itself. Plain `AttributeError` (the default for missing names) is fine # because that's what `hasattr` expects. sys.modules["torchcodec"] = _stub for _sub in ( "torchcodec.decoders", "torchcodec.encoders", "torchcodec.samplers", "torchcodec._core", ): sys.modules[_sub] = types.ModuleType(_sub) import numpy as np from fastapi import APIRouter, Depends, Header, HTTPException from pydantic import BaseModel, Field # -------------------------------------------------------------------------- # PyTorch >=2.6 compatibility shim for pyannote / whisperx checkpoint loading. # # torch.load now defaults to weights_only=True, which blocks pickled non-tensor # globals like omegaconf.ListConfig found inside pyannote's segmentation / # diarization checkpoints. WhisperX loads its bundled pyannote VAD via # torch.load() and crashes with "Weights only load failed ... Unsupported # global: GLOBAL omegaconf.listconfig.ListConfig was not an allowed global by # default." # # We trust these checkpoints (gated HuggingFace repos, authed via HF_TOKEN), so # we register the omegaconf containers as safe globals at import time. This is # the officially-recommended fix from the PyTorch 2.6 release notes. try: import torch from torch.serialization import add_safe_globals _safe = [] try: from omegaconf.listconfig import ListConfig _safe.append(ListConfig) except Exception: pass try: from omegaconf.dictconfig import DictConfig _safe.append(DictConfig) except Exception: pass try: from omegaconf.base import ContainerMetadata, Metadata _safe.extend([ContainerMetadata, Metadata]) except Exception: pass try: from omegaconf.nodes import AnyNode _safe.append(AnyNode) except Exception: pass if _safe: add_safe_globals(_safe) # Final fallback: force weights_only=False for the libraries that ignore # the safe-globals registry (whisperx bundles its own pyannote loader). # Same trust assumption as above. _orig_torch_load = torch.load def _patched_torch_load(*args, **kwargs): # type: ignore[no-untyped-def] # Force weights_only=False — pyannote/lightning checkpoints embed # rich Python objects (omegaconf, typing.Any, etc.) that the # weights_only=True allowlist doesn't cover. Lightning's pl_load # explicitly passes weights_only=True, so we must override, not # just setdefault. Safe: checkpoints come from gated HF repos. kwargs["weights_only"] = False return _orig_torch_load(*args, **kwargs) torch.load = _patched_torch_load # type: ignore[assignment] except Exception: # noqa: BLE001 -- never fail import; falls back to runtime error pass # -------------------------------------------------------------------------- # huggingface_hub >=1.0 compatibility shim for pyannote.audio 3.3.x. # # pyannote.audio 3.3.2 calls `hf_hub_download(use_auth_token=...)` internally, # but huggingface_hub removed that deprecated kwarg in 1.0.0 (renamed to # `token`). When the venv ends up with hub>=1.0 (because OmniVoice/transformers # pull it in), pyannote crashes with: # TypeError: hf_hub_download() got an unexpected keyword argument 'use_auth_token' # # Rather than fight the dep tree, we wrap the relevant hub functions to accept # `use_auth_token` and forward it as `token`. Idempotent and harmless on older # hub versions that already accept the kwarg. try: import huggingface_hub as _hf_hub def _make_compat(_orig): def _wrapped(*args, **kwargs): # type: ignore[no-untyped-def] if "use_auth_token" in kwargs: tok = kwargs.pop("use_auth_token") kwargs.setdefault("token", tok) return _orig(*args, **kwargs) return _wrapped for _name in ("hf_hub_download", "snapshot_download", "cached_download"): _orig = getattr(_hf_hub, _name, None) if _orig is not None and not getattr(_orig, "_voicelab_compat", False): _patched = _make_compat(_orig) _patched._voicelab_compat = True # type: ignore[attr-defined] setattr(_hf_hub, _name, _patched) # pyannote imports the symbol directly; patch the file_download # submodule too so already-imported references see the new fn. try: from huggingface_hub import file_download as _fd if hasattr(_fd, _name): setattr(_fd, _name, _patched) except Exception: pass except Exception: # noqa: BLE001 pass # -------------------------------------------------------------------------- # Module-level state # -------------------------------------------------------------------------- logger = logging.getLogger("dub") # Single-GPU pod: only one heavy stage may run at a time. The OmniVoice TTS # routes (single + multi) are wrapped at the call sites in server.py so a dub # job also blocks live TTS, and vice versa. Keeps VRAM predictable on the A40 # (~24 GB peak when WhisperX large-v3 + pyannote + OmniVoice are warm). GPU_LOCK = asyncio.Semaphore(1) # In-memory job + batch registry (pod-local). The canonical store is the # Next.js Postgres DB; this dict only exists so `GET /v1/dub/jobs/{id}` can # answer diagnostic queries without a DB round-trip. JOBS: dict[str, "DubJobState"] = {} # Lazy-loaded model handles. First use of each pays the load cost; afterwards # the handle stays warm in memory until the pod restarts. _DEMUCS = None _WHISPERX = None _WHISPERX_ALIGN_CACHE: dict[str, tuple[Any, dict[str, Any]]] = {} _DIARIZER = None _M2M_TOKENIZER = None _M2M_MODEL = None _VAD_MODEL = None _VAD_UTILS: tuple[Any, Any, Any, Any, Any] | None = None _VOICE_ENCODER = None # `runpod/omnivoice/server.py` injects this at import time so we can call # OmniVoice in-process without going through HTTP. We intentionally avoid # a top-level import to prevent a circular import. _OMNIVOICE_MODEL = None _OMNIVOICE_GEN_CONFIG_CLS = None SAMPLE_RATE = 24000 # OmniVoice native; we resample to 16k internally for ASR/diarise. # Scene-aware batching defaults. `_plan_batches` greedily fills batches up to # TARGET_BATCH_S, snapping each cut to the longest VAD silence within the # +/- SNAP_WINDOW_S window. Hard ceiling at MAX_BATCH_S so a continuous # monologue still terminates a batch. TARGET_BATCH_S = 300.0 MAX_BATCH_S = 420.0 SNAP_WINDOW_S = 20.0 # Cosine similarity threshold for reusing a registry speaker across batches. SPEAKER_MATCH_THRESHOLD = 0.78 # Whisper / M2M language code mapping. M2M-100 expects ISO 639-1 codes; we # accept either ISO codes or display names from the wizard and map on the way # in. Falls back to English when unknown. LANG_DISPLAY_TO_ISO = { "english": "en", "spanish": "es", "french": "fr", "german": "de", "italian": "it", "portuguese": "pt", "russian": "ru", "japanese": "ja", "korean": "ko", "chinese": "zh", "hindi": "hi", "bengali": "bn", "tamil": "ta", "telugu": "te", "marathi": "mr", "gujarati": "gu", "punjabi": "pa", "urdu": "ur", "arabic": "ar", "turkish": "tr", "vietnamese": "vi", "thai": "th", "indonesian": "id", "polish": "pl", "dutch": "nl", "swedish": "sv", "ukrainian": "uk", "greek": "el", "hebrew": "he", "czech": "cs", "romanian": "ro", "hungarian": "hu", } # Reverse map for the OmniVoice `language=` field, which expects display names # like "English" / "Hindi". Falls back to None (auto-detect) if unknown. _DISPLAY_FOR_ISO = {iso: name.title() for name, iso in LANG_DISPLAY_TO_ISO.items()} def _to_iso(lang: str | None) -> str | None: if not lang: return None cleaned = lang.strip().lower() if not cleaned or cleaned == "auto": return None if len(cleaned) <= 3 and cleaned.isalpha(): return cleaned return LANG_DISPLAY_TO_ISO.get(cleaned, cleaned[:2]) def configure(model: Any, gen_config_cls: Any) -> None: """Register the live OmniVoice model + generation-config class. Called once from `server.py` after the model finishes loading. We keep the references at module scope so `_synthesize_segment()` can call OmniVoice without re-loading the model or going over HTTP. """ global _OMNIVOICE_MODEL, _OMNIVOICE_GEN_CONFIG_CLS _OMNIVOICE_MODEL = model _OMNIVOICE_GEN_CONFIG_CLS = gen_config_cls logger.info("dub: OmniVoice handle registered") # -------------------------------------------------------------------------- # Request / response schemas # -------------------------------------------------------------------------- class DubSpeakerOverride(BaseModel): """Optional: pin a diarised speaker to a specific clone reference. Sent by the editor when the user picks a voice from their library for a given speaker label. The Next.js app uploads the chosen voice's reference audio to R2 and passes the presigned GET URL here. """ label: str ref_url: str class DubBatchPutUrls(BaseModel): """Pre-provisioned PUT URLs for a single batch's artifacts. The app pre-signs URLs for all planned batches (or "enough" for the expected count) and the worker uses whichever it needs. Unused URLs simply expire. """ index: int vocals_put_url: str | None = None music_put_url: str | None = None mix_put_url: str segments_put_url: str | None = None # `segment_put_urls[j]` is for the j-th segment within this batch. segment_put_urls: list[str] = Field(default_factory=list) class DubJobRequest(BaseModel): job_id: str # Source: provide exactly one. source_url: str | None = None source_get_url: str | None = None # presigned GET to R2 for uploaded files source_lang: str | None = None # ISO or display name; null = auto-detect target_lang: str num_speakers: int | None = None drop_background: bool = False duration_cap_s: int = 600 container_kind: str = "audio" # "audio" | "video" # Where the worker reports progress + segments back to the app. Should # be like `https://app.example.com/api/dub`. We append `/progress`, # `/segments`, `/batches`, `/claim`. callback_base: str # Job-level presigned PUTs. final_put_url: str vocals_put_url: str music_put_url: str | None = None # Original video container kept on R2 so finalize() can mux without # re-downloading. Only used when container_kind == "video". original_video_put_url: str | None = None final_video_put_url: str | None = None # Per-batch PUT URL pools. The app over-provisions; the worker picks # the entry whose `index` matches the batch number. batch_put_urls: list[DubBatchPutUrls] = Field(default_factory=list) # One presigned PUT per (potentially) discovered speaker. The app # over-provisions (e.g. 8 URLs) and the worker uses only the ones it # needs; unused URLs simply expire. speaker_ref_put_urls: list[str] = Field(default_factory=list) speaker_overrides: list[DubSpeakerOverride] = Field(default_factory=list) class DubBatchRunRequest(BaseModel): """Body for `POST /v1/dub/jobs/{id}/batches/{i}/run`. Provides the PUT URL pool the worker should use for retried artifacts. The original URLs are likely expired by retry time, so the app must mint fresh ones. """ callback_base: str batch_put_urls: DubBatchPutUrls speaker_overrides: list[DubSpeakerOverride] = Field(default_factory=list) class DubFinalizeRequest(BaseModel): callback_base: str final_put_url: str final_video_put_url: str | None = None class DubJobResponse(BaseModel): accepted: bool = True job_id: str # -------------------------------------------------------------------------- # In-process job state (diagnostic only; canonical state lives in the app DB) # -------------------------------------------------------------------------- @dataclass class BatchSpec: index: int start_s: float end_s: float @dataclass class DubJobState: job_id: str status: str = "PENDING" progress: int = 0 error: str | None = None started_at: float = field(default_factory=time.time) finished_at: float | None = None work_dir: str | None = None duration_s: float | None = None container_kind: str = "audio" target_iso: str = "en" detected_lang: str = "en" drop_background: bool = False callback_base: str = "" batches: list[BatchSpec] = field(default_factory=list) batch_status: dict[int, str] = field(default_factory=dict) # `speaker_registry[label] = (1-d numpy array embedding)`. speaker_registry: dict[str, np.ndarray] = field(default_factory=dict) speaker_ref_paths: dict[str, str] = field(default_factory=dict) extra: dict[str, Any] = field(default_factory=dict) # -------------------------------------------------------------------------- # Lazy model loaders # -------------------------------------------------------------------------- def _device() -> str: import torch # local import to keep module import cheap on CPU-only boxes return "cuda" if torch.cuda.is_available() else "cpu" def _load_demucs(): global _DEMUCS if _DEMUCS is not None: return _DEMUCS logger.info("dub: loading Demucs htdemucs") import torch from demucs.pretrained import get_model model = get_model("htdemucs") model.to(_device()) # FP16 on CUDA: ~1.4x faster separation, ~half the VRAM. htdemucs is # a hybrid CNN/transformer and tolerates fp16 weights cleanly with no # audible artefacts on speech material. Opt out via # OMNIVOICE_DEMUCS_FP16=0 if a future model rev regresses. if _device() == "cuda" and os.environ.get("OMNIVOICE_DEMUCS_FP16", "1") != "0": model = model.to(torch.float16) model.eval() _DEMUCS = model return _DEMUCS def _load_whisperx(): """Load faster-whisper large-v3. NOTE: name is kept for call-site compatibility; we no longer use the `whisperx` package because it pins `transformers>=4.48` which conflicts with `omnivoice` (`transformers>=5.3`). faster-whisper has no such dependency. We lose word-level forced alignment but keep sentence-level timestamps, which is sufficient for dubbing chunking. """ global _WHISPERX if _WHISPERX is not None: return _WHISPERX logger.info("dub: loading faster-whisper large-v3") from faster_whisper import WhisperModel # int8_float16 = INT8 weights + FP16 activations on CUDA. ~1.5x faster # than pure float16 with no measurable WER regression on librispeech # (CTranslate2 docs). Allow override via OMNIVOICE_WHISPER_COMPUTE_TYPE # in case a pod's CTranslate2 build doesn't support it (very rare on # >=4.5.0). if _device() == "cuda": compute_type = os.environ.get( "OMNIVOICE_WHISPER_COMPUTE_TYPE", "int8_float16" ) else: compute_type = "int8" _WHISPERX = WhisperModel( "large-v3", device=_device(), compute_type=compute_type, ) return _WHISPERX def _load_align(language_code: str): """No-op: word-level alignment removed when we dropped whisperx. Kept as a stub so older callers don't NameError. Returns (None, None). """ if language_code in _WHISPERX_ALIGN_CACHE: return _WHISPERX_ALIGN_CACHE[language_code] _WHISPERX_ALIGN_CACHE[language_code] = (None, None) return None, None def _load_diarizer(): global _DIARIZER if _DIARIZER is not None: return _DIARIZER logger.info("dub: loading pyannote/speaker-diarization-3.1") import torch from pyannote.audio import Pipeline hf_token = os.getenv("HF_TOKEN") or os.getenv("HF_ACCESS_TOKEN") if not hf_token: raise RuntimeError( "HF_TOKEN is required to download pyannote/speaker-diarization-3.1" ) # pyannote.audio renamed the auth kwarg multiple times across releases: # <=3.3.x: use_auth_token=... # 3.4.x: token kwarg removed from Pipeline.from_pretrained; reads HF_TOKEN # from env (Model.from_pretrained still accepts token=...) # Try the modern signatures, fall back gracefully so we don't break on the # next pip-install drift. os.environ.setdefault("HF_TOKEN", hf_token) pipe = None last_err: Exception | None = None for kwargs in ( {"token": hf_token}, {"use_auth_token": hf_token}, {}, # pyannote 3.4: reads HF_TOKEN from env ): try: pipe = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", **kwargs ) break except TypeError as exc: last_err = exc continue if pipe is None: raise RuntimeError( f"pyannote Pipeline.from_pretrained rejected all auth kwargs: {last_err}" ) pipe.to(torch.device(_device())) _DIARIZER = pipe return _DIARIZER def _load_m2m(): """M2M-100 1.2B (MIT) — commercial-safe translation across 100 languages. On CUDA we load with bitsandbytes int8 quantisation by default: ~2 GB VRAM (vs ~5 GB fp16) and ~10-15 % faster per-batch translate due to smaller weight movement. Quality on M2M-100 1.2B with int8 is indistinguishable from fp16 for short-segment dubbing (BLEU delta < 0.3 on FLORES-200 per the bnb paper). Fallbacks (controlled by OMNIVOICE_M2M_QUANT): "int8" (default on CUDA) — bitsandbytes 8-bit "fp16" — float16, no quantisation "fp32" — full precision (CPU default) """ global _M2M_TOKENIZER, _M2M_MODEL if _M2M_MODEL is not None: return _M2M_TOKENIZER, _M2M_MODEL logger.info("dub: loading facebook/m2m100_1.2B") import torch from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer _M2M_TOKENIZER = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B") if _device() == "cuda": quant = os.environ.get("OMNIVOICE_M2M_QUANT", "int8").lower() else: quant = "fp32" if quant == "int8": try: from transformers import BitsAndBytesConfig bnb_config = BitsAndBytesConfig(load_in_8bit=True) _M2M_MODEL = M2M100ForConditionalGeneration.from_pretrained( "facebook/m2m100_1.2B", quantization_config=bnb_config, device_map={"": 0}, ) logger.info("dub: m2m100 loaded as bnb int8") except Exception as exc: logger.warning( "dub: bnb int8 load failed (%s); falling back to fp16", exc ) quant = "fp16" if quant == "fp16": _M2M_MODEL = M2M100ForConditionalGeneration.from_pretrained( "facebook/m2m100_1.2B", torch_dtype=torch.float16, ).to(_device()) elif quant == "fp32": _M2M_MODEL = M2M100ForConditionalGeneration.from_pretrained( "facebook/m2m100_1.2B", torch_dtype=torch.float32, ).to(_device()) _M2M_MODEL.eval() return _M2M_TOKENIZER, _M2M_MODEL def _load_vad(): """Silero VAD on CPU (lightweight; CPU is fine and frees GPU for ASR).""" global _VAD_MODEL, _VAD_UTILS if _VAD_MODEL is not None and _VAD_UTILS is not None: return _VAD_MODEL, _VAD_UTILS logger.info("dub: loading silero-vad") import torch model, utils = torch.hub.load( repo_or_dir="snakers4/silero-vad", model="silero_vad", trust_repo=True, onnx=False, ) _VAD_MODEL = model _VAD_UTILS = utils return _VAD_MODEL, _VAD_UTILS def _load_voice_encoder(): """Resemblyzer voice encoder for cross-batch speaker matching (CPU).""" global _VOICE_ENCODER if _VOICE_ENCODER is not None: return _VOICE_ENCODER logger.info("dub: loading resemblyzer VoiceEncoder") from resemblyzer import VoiceEncoder _VOICE_ENCODER = VoiceEncoder() return _VOICE_ENCODER # -------------------------------------------------------------------------- # Pipeline stages # -------------------------------------------------------------------------- def _run(cmd: list[str], **kwargs) -> subprocess.CompletedProcess: logger.info("dub: $ %s", " ".join(cmd)) try: return subprocess.run( cmd, check=True, capture_output=True, text=True, **kwargs ) except FileNotFoundError as exc: # subprocess raises this when argv[0] isn't on PATH. The default # message ("[Errno 2] No such file or directory: 'yt-dlp'") doesn't # tell anyone what to do — surface a clearer one that points at the # real fix (binary not in PATH on this pod). The dub job will fail # with this string, which the editor shows in the collapsible error. raise RuntimeError( f"required binary not found on this pod: {cmd[0]!r}. " f"PATH={os.environ.get('PATH', '')!r}. " "If this is a venv-installed tool (yt-dlp, demucs), invoke it " "as `sys.executable -m ` instead." ) from exc except subprocess.CalledProcessError as exc: # Same logic for non-zero exits — fold stderr into the exception # so the editor's error card shows the real failure (e.g. yt-dlp # geo-block, format unavailable) instead of "Command failed". stderr = (exc.stderr or "").strip() stdout = (exc.stdout or "").strip() tail = stderr or stdout or "(no output)" # Keep the tail bounded — some ffmpeg invocations dump megabytes. if len(tail) > 2000: tail = tail[-2000:] raise RuntimeError( f"command failed (exit {exc.returncode}): {' '.join(cmd[:3])}…\n{tail}" ) from exc def _ingest(req: DubJobRequest, work_dir: str) -> tuple[str, str | None, float, bool]: """Download or fetch the source media. Returns `(wav16k_path, raw_video_path_or_None, duration_s, has_video)`. Always emits 16 kHz mono WAV at `/source.wav`. When the source contains a video stream we also keep the original container untouched at `/raw.media` so finalize() can mux into it. """ raw_path = os.path.join(work_dir, "raw.media") if req.source_url: # yt-dlp handles YouTube/Vimeo/TikTok/X/etc. We always grab the # bestvideo+bestaudio merged container so the user can later get # a muxed MP4 if they want one; ingest cost is the same either way. # # Invoke as `sys.executable -m yt_dlp` rather than `yt-dlp` so we # don't depend on the venv's bin/ being on PATH — uvicorn was # launched directly via `/workspace/dub_venv/bin/uvicorn` without # `source activate`, so PATH inherits from the system shell where # `yt-dlp` doesn't exist. Using sys.executable guarantees we hit # the same Python that imports the rest of this file. _run( [ sys.executable, "-m", "yt_dlp", "-f", "bv*+ba/b" if req.container_kind == "video" else "bestaudio/best", "--no-playlist", "--no-warnings", "--merge-output-format", "mp4", "-o", raw_path, req.source_url, ] ) # yt-dlp ALWAYS appends the actual container extension to the -o # template ("raw.media" → "raw.media.mp4" / ".mkv" / ".webm" / # ".m4a" depending on what -f resolved to). Without this rename, # the next ffmpeg call dies with "raw.media: No such file or # directory" — which is exactly the failure we hit on the live # YouTube → Hindi job. Glob the directory and rename whatever it # produced to the canonical raw_path so the rest of the pipeline # doesn't have to care about the extension. if not os.path.exists(raw_path): candidates = sorted( p for p in glob.glob(os.path.join(work_dir, "raw.media*")) if os.path.isfile(p) ) if not candidates: raise RuntimeError( f"yt-dlp produced no file in {work_dir}; " f"directory contents: {os.listdir(work_dir)}" ) # Prefer the largest file — when yt-dlp downloads bestvideo # + bestaudio separately and merges, intermediate fragments # may linger on disk; the merged output is the biggest one. chosen = max(candidates, key=lambda p: os.path.getsize(p)) os.rename(chosen, raw_path) # Clean up any leftover fragments so we don't ship stale # bytes if a retry hits the same work_dir. for leftover in candidates: if leftover != chosen and os.path.exists(leftover): try: os.remove(leftover) except OSError: pass elif req.source_get_url: _run(["curl", "-sSL", "-o", raw_path, req.source_get_url]) else: raise HTTPException( status_code=400, detail="Either source_url or source_get_url is required" ) wav_path = os.path.join(work_dir, "source.wav") _run( [ "ffmpeg", "-y", "-i", raw_path, "-ac", "1", "-ar", "16000", "-vn", wav_path, ] ) probe = _run( [ "ffprobe", "-v", "error", "-show_entries", "format=duration:stream=codec_type", "-of", "json", raw_path, ] ) info = json.loads(probe.stdout) duration = float(info["format"]["duration"]) has_video = any( s.get("codec_type") == "video" for s in info.get("streams", []) ) return wav_path, (raw_path if has_video else None), duration, has_video def _separate(wav_path: str, work_dir: str) -> tuple[str, str]: """Demucs separation -> (vocals_path, accompaniment_path), 24 kHz stereo WAV.""" import torch import torchaudio from demucs.apply import apply_model model = _load_demucs() waveform, sr = torchaudio.load(wav_path) if waveform.shape[0] == 1: waveform = waveform.repeat(2, 1) if sr != model.samplerate: waveform = torchaudio.functional.resample(waveform, sr, model.samplerate) waveform = waveform.unsqueeze(0).to(_device()) # Match input dtype to the model — _load_demucs() may have cast weights # to float16 on CUDA. Mismatched dtypes raise inside apply_model's # underlying conv/attention ops. model_dtype = next(model.parameters()).dtype if waveform.dtype != model_dtype: waveform = waveform.to(model_dtype) with torch.no_grad(): sources = apply_model(model, waveform, split=True, overlap=0.1) # Cast back to fp32 for torchaudio.save (it doesn't accept half-precision # tensors and would silently produce a corrupt WAV). sources = sources.squeeze(0).to(torch.float32).cpu() # htdemucs source order: drums, bass, other, vocals vocals = sources[3] accompaniment = sources[0] + sources[1] + sources[2] vocals_path = os.path.join(work_dir, "vocals.wav") music_path = os.path.join(work_dir, "music.wav") torchaudio.save(vocals_path, vocals, model.samplerate) torchaudio.save(music_path, accompaniment, model.samplerate) return vocals_path, music_path def _plan_batches( vocals_path: str, duration_s: float, *, target_s: float = TARGET_BATCH_S, max_s: float = MAX_BATCH_S, snap_s: float = SNAP_WINDOW_S, ) -> list[BatchSpec]: """Greedy scene-aware batch planner. For each candidate cut at `current + target_s` we look for the longest VAD silence within the +/- snap_s window and snap to its midpoint. If no silence is found inside the window we use the hard cap at `current + max_s`. Returns a list of (index, start_s, end_s) tuples covering [0, duration_s]. """ if duration_s <= max_s: return [BatchSpec(index=0, start_s=0.0, end_s=duration_s)] import torch import torchaudio model, utils = _load_vad() get_speech_timestamps = utils[0] waveform, sr = torchaudio.load(vocals_path) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000) sr = 16000 audio = waveform.squeeze(0) with torch.no_grad(): speech = get_speech_timestamps( audio, model, sampling_rate=sr, threshold=0.5, min_silence_duration_ms=400 ) # Convert speech ranges (samples) -> silences (seconds). silences: list[tuple[float, float]] = [] cursor_sample = 0 for s in speech: start_sample = int(s["start"]) if start_sample > cursor_sample: silences.append((cursor_sample / sr, start_sample / sr)) cursor_sample = int(s["end"]) if cursor_sample / sr < duration_s: silences.append((cursor_sample / sr, duration_s)) batches: list[BatchSpec] = [] cursor = 0.0 idx = 0 while cursor < duration_s - 1.0: target_cut = cursor + target_s hard_cut = cursor + max_s if target_cut >= duration_s: batches.append(BatchSpec(index=idx, start_s=cursor, end_s=duration_s)) break # Pick the longest silence whose midpoint is in [target_cut - snap, target_cut + snap]. best_mid: float | None = None best_len = 0.0 for s_start, s_end in silences: mid = (s_start + s_end) / 2.0 if abs(mid - target_cut) <= snap_s and (s_end - s_start) > best_len: best_len = s_end - s_start best_mid = mid cut = best_mid if best_mid is not None else min(hard_cut, duration_s) cut = max(cursor + 30.0, min(cut, duration_s)) # sanity floor at 30s batches.append(BatchSpec(index=idx, start_s=cursor, end_s=cut)) cursor = cut idx += 1 if not batches: batches.append(BatchSpec(index=0, start_s=0.0, end_s=duration_s)) return batches def _slice_wav( src_path: str, dst_path: str, start_s: float, end_s: float, *, sample_rate: int = 16000 ) -> None: """Cut [start_s, end_s] out of `src_path` into `dst_path` at `sample_rate`.""" _run( [ "ffmpeg", "-y", "-ss", f"{start_s:.3f}", "-to", f"{end_s:.3f}", "-i", src_path, "-ac", "1", "-ar", str(sample_rate), "-vn", dst_path, ] ) def _transcribe(wav_path: str, source_lang: str | None) -> dict[str, Any]: """Run faster-whisper ASR. Returns dict with same shape WhisperX produced. Output: {"segments": [{"start": float, "end": float, "text": str}, ...], "language": str} """ model = _load_whisperx() transcribe_kwargs: dict[str, Any] = { "beam_size": 5, # VAD filter is faster-whisper's built-in silero pass; skips silence # without us having to do a separate VAD step. Big speedup on long # files with pauses. "vad_filter": True, "vad_parameters": {"min_silence_duration_ms": 500}, } if source_lang: transcribe_kwargs["language"] = source_lang segments_iter, info = model.transcribe(wav_path, **transcribe_kwargs) detected_lang = info.language or source_lang or "en" segments: list[dict[str, Any]] = [] for seg in segments_iter: segments.append( { "start": float(seg.start), "end": float(seg.end), "text": (seg.text or "").strip(), } ) return {"segments": segments, "language": detected_lang} def _diarize(wav_path: str, num_speakers: int | None) -> Any: """Run pyannote diarisation; returns a `pyannote.core.Annotation`.""" pipe = _load_diarizer() kwargs: dict[str, Any] = {} if num_speakers and num_speakers > 0: kwargs["num_speakers"] = num_speakers return pipe(wav_path, **kwargs) def _assign_speakers(asr_result: dict[str, Any], diarisation: Any) -> list[dict[str, Any]]: """Attach `speaker` to each ASR segment via overlap-max.""" segments: list[dict[str, Any]] = [] for seg in asr_result.get("segments", []): s, e = float(seg["start"]), float(seg["end"]) best_label = "SPEAKER_00" best_overlap = 0.0 for turn, _, label in diarisation.itertracks(yield_label=True): overlap = max(0.0, min(e, turn.end) - max(s, turn.start)) if overlap > best_overlap: best_overlap = overlap best_label = label segments.append( { "start": s, "end": e, "text": (seg.get("text") or "").strip(), "speaker": best_label, } ) return segments def _translate( segments: list[dict[str, Any]], source_iso: str, target_iso: str ) -> list[dict[str, Any]]: """Translate `segment.text` -> `segment.target_text` via M2M-100.""" import torch if source_iso == target_iso: for s in segments: s["target_text"] = s["text"] return segments tok, model = _load_m2m() tok.src_lang = source_iso target_token = tok.get_lang_id(target_iso) out: list[dict[str, Any]] = [] BATCH = 8 with torch.no_grad(): for i in range(0, len(segments), BATCH): batch = segments[i : i + BATCH] texts = [(s["text"] or "").strip() or "." for s in batch] enc = tok(texts, return_tensors="pt", padding=True, truncation=True).to( _device() ) generated = model.generate( **enc, forced_bos_token_id=target_token, max_new_tokens=256, ) decoded = tok.batch_decode(generated, skip_special_tokens=True) for seg, txt in zip(batch, decoded): seg["target_text"] = txt.strip() out.append(seg) return out def _extract_reference_clip( vocals_path: str, speaker_segments: list[dict[str, Any]], out_path: str, *, target_seconds: float = 8.0, offset_s: float = 0.0, ) -> bool: """Pick the longest single-speaker span (capped at `target_seconds`) and write it to `out_path`. `offset_s` shifts segment timestamps because we pass a batch-local vocals slice. Returns False if no usable span exists. """ if not speaker_segments: return False longest = max(speaker_segments, key=lambda s: s["end"] - s["start"]) duration = longest["end"] - longest["start"] if duration < 3.0: return False take = min(duration, target_seconds) _run( [ "ffmpeg", "-y", "-ss", f"{max(0.0, longest['start'] - offset_s):.2f}", "-t", f"{take:.2f}", "-i", vocals_path, "-ac", "1", "-ar", "24000", out_path, ] ) return True def _embed_clip(path: str) -> np.ndarray | None: """Resemblyzer 256-d embedding for a wav clip, or None on failure.""" try: from resemblyzer import VoiceEncoder, preprocess_wav except Exception: # noqa: BLE001 return None enc: VoiceEncoder = _load_voice_encoder() try: wav = preprocess_wav(path) return enc.embed_utterance(wav) except Exception as exc: # noqa: BLE001 logger.warning("dub: embed failed for %s: %s", path, exc) return None def _reconcile_speakers( state: DubJobState, local_segments: list[dict[str, Any]], vocals_path: str, work_dir: str, offset_s: float, ) -> dict[str, str]: """Map a batch's local pyannote labels -> stable job-wide labels. Returns `{local_label: global_label}`. Side-effect: extends `state.speaker_registry` and `state.speaker_ref_paths` with any newly-discovered speakers. """ by_label: dict[str, list[dict[str, Any]]] = {} for s in local_segments: by_label.setdefault(s["speaker"], []).append(s) mapping: dict[str, str] = {} for local_label, segs in by_label.items(): ref_local = os.path.join(work_dir, f"ref_{local_label}.wav") ok = _extract_reference_clip(vocals_path, segs, ref_local, offset_s=offset_s) if not ok: mapping[local_label] = local_label # leave it; mix will skip if no ref continue emb = _embed_clip(ref_local) # Try to match against existing registry by cosine similarity. best: tuple[str, float] | None = None if emb is not None: for global_label, registered in state.speaker_registry.items(): num = float(np.dot(emb, registered)) den = float(np.linalg.norm(emb) * np.linalg.norm(registered)) or 1e-9 cos = num / den if cos >= SPEAKER_MATCH_THRESHOLD and (best is None or cos > best[1]): best = (global_label, cos) if best is not None: mapping[local_label] = best[0] else: # New global speaker — name them sequentially across the job. new_label = f"SPEAKER_{len(state.speaker_registry):02d}" mapping[local_label] = new_label if emb is not None: state.speaker_registry[new_label] = emb state.speaker_ref_paths[new_label] = ref_local return mapping def _synthesize_segment( text: str, ref_audio_path: str, target_duration: float, target_lang_iso: str, ) -> np.ndarray: """Call OmniVoice in-process for a single segment, returning float32 PCM @ 24k.""" if _OMNIVOICE_MODEL is None or _OMNIVOICE_GEN_CONFIG_CLS is None: raise RuntimeError("OmniVoice model is not registered with the dub module") # Diffusion step count is linear in TTS cost. 32 is the model default; # 24 is ~25 % faster with subtle quality loss on long-tail phonemes. # Lower than 20 starts producing audible buzz on sibilants. Tune via # OMNIVOICE_NUM_STEP. try: num_step = int(os.environ.get("OMNIVOICE_NUM_STEP", "32")) except ValueError: num_step = 32 gen_config = _OMNIVOICE_GEN_CONFIG_CLS( num_step=num_step, guidance_scale=2.0, denoise=True, preprocess_prompt=True, postprocess_output=True, ) voice_clone_prompt = _OMNIVOICE_MODEL.create_voice_clone_prompt( ref_audio=ref_audio_path, ref_text=None, ) kwargs: dict[str, Any] = { "text": text, "language": _DISPLAY_FOR_ISO.get(target_lang_iso, None), "generation_config": gen_config, "voice_clone_prompt": voice_clone_prompt, } if target_duration > 0.3: kwargs["duration"] = float(target_duration) generated = _OMNIVOICE_MODEL.generate(**kwargs) chunk = generated[0] try: import torch if isinstance(chunk, torch.Tensor): chunk = chunk.detach().cpu().float().numpy() except ImportError: pass return np.asarray(chunk, dtype=np.float32).reshape(-1) def _write_wav(path: str, pcm: np.ndarray, sr: int = SAMPLE_RATE) -> None: clipped = np.clip(pcm, -1.0, 1.0) pcm16 = (clipped * 32767.0).astype(np.int16) with wave.open(path, "wb") as w: w.setnchannels(1) w.setsampwidth(2) w.setframerate(sr) w.writeframes(pcm16.tobytes()) def _mix_batch( segments: list[dict[str, Any]], segment_pcms: dict[int, np.ndarray], music_path: str | None, out_path: str, *, batch_start_s: float, batch_end_s: float, drop_background: bool, ) -> None: """Place each segment WAV at its *batch-local* timestamp on a silent canvas of (batch_end - batch_start) seconds. Then ffmpeg-mix with the matching music slice if present and not dropped.""" duration = max(0.5, batch_end_s - batch_start_s) canvas_samples = int(duration * SAMPLE_RATE) + SAMPLE_RATE # +1s tail canvas = np.zeros(canvas_samples, dtype=np.float32) for idx, seg in enumerate(segments): pcm = segment_pcms.get(idx) if pcm is None: continue local_start = max(0.0, float(seg["start"]) - batch_start_s) start_sample = int(local_start * SAMPLE_RATE) end_sample = min(start_sample + len(pcm), canvas_samples) take = end_sample - start_sample if take > 0: canvas[start_sample : start_sample + take] = pcm[:take] voice_path = out_path + ".voice.wav" _write_wav(voice_path, canvas, sr=SAMPLE_RATE) if not drop_background and music_path and os.path.exists(music_path): _run( [ "ffmpeg", "-y", "-i", voice_path, "-i", music_path, "-filter_complex", "[1:a]volume=0.5[a1];[0:a][a1]amix=inputs=2:duration=longest:dropout_transition=0", "-ac", "2", "-ar", "44100", out_path, ] ) else: _run( [ "ffmpeg", "-y", "-i", voice_path, "-ac", "2", "-ar", "44100", out_path, ] ) # -------------------------------------------------------------------------- # Callbacks back into the Next.js app # -------------------------------------------------------------------------- def _auth_header() -> dict[str, str]: token = os.getenv("DUB_WORKER_TOKEN", "").strip() if not token: raise RuntimeError("DUB_WORKER_TOKEN is not set on the pod") return {"Authorization": f"Bearer {token}"} async def _post_json(url: str, payload: dict[str, Any]) -> None: """Best-effort POST. We log + swallow non-2xx so a flaky callback doesn't crash the pipeline; the app polls its own DB anyway.""" import httpx try: async with httpx.AsyncClient(timeout=20) as client: r = await client.post(url, json=payload, headers=_auth_header()) if r.status_code >= 300: logger.warning("dub: callback %s -> %s %s", url, r.status_code, r.text) except Exception as exc: # noqa: BLE001 logger.warning("dub: callback %s failed: %s", url, exc) async def _put_bytes(url: str, body: bytes, content_type: str = "audio/wav") -> None: import httpx async with httpx.AsyncClient(timeout=300) as client: r = await client.put(url, content=body, headers={"Content-Type": content_type}) if r.status_code >= 300: raise RuntimeError(f"presigned PUT failed: {r.status_code} {r.text}") async def _put_file(url: str, path: str, content_type: str = "audio/wav") -> None: with open(path, "rb") as f: await _put_bytes(url, f.read(), content_type=content_type) async def _set_job_status( state: DubJobState, status: str, progress: int, **extra: Any ) -> None: state.status = status state.progress = progress if extra: state.extra.update(extra) await _post_json( f"{state.callback_base}/progress", {"jobId": state.job_id, "status": status, "progress": progress, **extra}, ) async def _set_batch_status( state: DubJobState, batch_index: int, status: str, progress: int, **extra: Any, ) -> None: state.batch_status[batch_index] = status await _post_json( f"{state.callback_base}/batches", { "jobId": state.job_id, "batchIndex": batch_index, "status": status, "progress": progress, **extra, }, ) # -------------------------------------------------------------------------- # Worker coroutines # -------------------------------------------------------------------------- async def _process_batch( state: DubJobState, spec: BatchSpec, batch_urls: DubBatchPutUrls | None, speaker_overrides: list[DubSpeakerOverride], ) -> None: """End-to-end processing of one batch under the GPU lock.""" work_dir = state.work_dir or tempfile.mkdtemp(prefix=f"dub_{state.job_id}_") state.work_dir = work_dir batch_dir = os.path.join(work_dir, f"batch_{spec.index}") os.makedirs(batch_dir, exist_ok=True) src_wav = os.path.join(work_dir, "source.wav") full_vocals = os.path.join(work_dir, "vocals.wav") full_music = os.path.join(work_dir, "music.wav") if not os.path.exists(src_wav): raise RuntimeError("source.wav missing — re-run the job from /v1/dub/jobs") await _set_batch_status(state, spec.index, "TRANSCRIBING", 5) # Slice this batch's audio out of the full mix and isolated vocals. batch_wav = os.path.join(batch_dir, "audio.wav") batch_vocals = os.path.join(batch_dir, "vocals.wav") batch_music = os.path.join(batch_dir, "music.wav") await asyncio.to_thread( _slice_wav, src_wav, batch_wav, spec.start_s, spec.end_s, sample_rate=16000 ) await asyncio.to_thread( _slice_wav, full_vocals, batch_vocals, spec.start_s, spec.end_s, sample_rate=24000 ) if os.path.exists(full_music): await asyncio.to_thread( _slice_wav, full_music, batch_music, spec.start_s, spec.end_s, sample_rate=24000 ) else: batch_music = "" # type: ignore[assignment] asr = await asyncio.to_thread( _transcribe, batch_vocals, _to_iso(state.detected_lang) or None ) if not state.detected_lang or state.detected_lang == "en": state.detected_lang = asr.get("language") or "en" await _set_batch_status(state, spec.index, "DIARIZING", 25) diar = await asyncio.to_thread(_diarize, batch_vocals, None) local_segments = _assign_speakers(asr, diar) if not local_segments: await _set_batch_status(state, spec.index, "READY", 100, segments=[]) return # Re-anchor the timestamps to job-global so the editor / final mix stay # correct across batches. for s in local_segments: s["start"] = float(s["start"]) + spec.start_s s["end"] = float(s["end"]) + spec.start_s # Speaker reconciliation. After this step, segment["speaker"] is a stable # job-global label. label_map = await asyncio.to_thread( _reconcile_speakers, state, local_segments, batch_vocals, batch_dir, spec.start_s, ) for s in local_segments: s["speaker"] = label_map.get(s["speaker"], s["speaker"]) await _set_batch_status(state, spec.index, "TRANSLATING", 45) target_iso = state.target_iso local_segments = await asyncio.to_thread( _translate, local_segments, state.detected_lang, target_iso ) await _set_batch_status(state, spec.index, "SYNTHESIZING", 60) # Resolve label -> reference audio path. User overrides win. override_map = {o.label: o.ref_url for o in speaker_overrides} label_to_ref: dict[str, str] = {} for label in {s["speaker"] for s in local_segments}: if label in override_map: override_path = os.path.join(batch_dir, f"override_{label}.wav") _run(["curl", "-sSL", "-o", override_path, override_map[label]]) label_to_ref[label] = override_path elif label in state.speaker_ref_paths and os.path.exists( state.speaker_ref_paths[label] ): label_to_ref[label] = state.speaker_ref_paths[label] segment_pcms: dict[int, np.ndarray] = {} seg_records: list[dict[str, Any]] = [] # Speedup: don't block the GPU on per-segment R2 uploads. We kick each # upload as a background task and only await them all at the end of the # batch, before mixing — by then most of them are already done. This # cut ~8-12s off a 4-batch dub in production telemetry. upload_tasks: list[asyncio.Task[None]] = [] total = len(local_segments) for j, seg in enumerate(local_segments): ref = label_to_ref.get(seg["speaker"]) if ref: target_dur = max(0.3, float(seg["end"]) - float(seg["start"])) pcm = await asyncio.to_thread( _synthesize_segment, seg.get("target_text") or seg["text"], ref, target_dur, target_iso, ) segment_pcms[j] = pcm if ( batch_urls and j < len(batch_urls.segment_put_urls) and batch_urls.segment_put_urls[j] ): seg_path = os.path.join(batch_dir, f"seg_{j}.wav") _write_wav(seg_path, pcm, sr=SAMPLE_RATE) upload_tasks.append( asyncio.create_task( _put_file(batch_urls.segment_put_urls[j], seg_path) ) ) seg_records.append( { "index": j, "batchIndex": spec.index, "startMs": int(seg["start"] * 1000), "endMs": int(seg["end"] * 1000), "speakerLabel": seg["speaker"], "sourceText": seg["text"], "targetText": seg.get("target_text", seg["text"]), } ) if j % 5 == 0: pct = 60 + int(25 * (j + 1) / max(1, total)) await _set_batch_status( state, spec.index, "SYNTHESIZING", min(89, pct) ) # Push the segment manifest for this batch (UI shows them grouped under # the batch strip). speaker_records = [ {"label": label, "ref_put_url": None} for label in {s["speaker"] for s in local_segments} ] await _post_json( f"{state.callback_base}/segments", { "jobId": state.job_id, "batchIndex": spec.index, "speakers": speaker_records, "segments": seg_records, }, ) # Drain backgrounded segment uploads. Done before mixing so any 5xx from # R2 surfaces as a batch failure rather than a silent missing-asset later. if upload_tasks: await asyncio.gather(*upload_tasks) await _set_batch_status(state, spec.index, "MIXING", 92) mix_local = os.path.join(batch_dir, "mix.wav") await asyncio.to_thread( _mix_batch, local_segments, segment_pcms, batch_music if batch_music else None, mix_local, batch_start_s=spec.start_s, batch_end_s=spec.end_s, drop_background=state.drop_background, ) if batch_urls and batch_urls.mix_put_url: await _put_file(batch_urls.mix_put_url, mix_local) if batch_urls and batch_urls.vocals_put_url: await _put_file(batch_urls.vocals_put_url, batch_vocals) if batch_urls and batch_urls.music_put_url and batch_music and os.path.exists(batch_music): await _put_file(batch_urls.music_put_url, batch_music) await _set_batch_status(state, spec.index, "READY", 100) async def _run_job(req: DubJobRequest) -> None: state = JOBS[req.job_id] state.callback_base = req.callback_base state.target_iso = _to_iso(req.target_lang) or "en" state.drop_background = req.drop_background state.container_kind = req.container_kind work_dir = tempfile.mkdtemp(prefix=f"dub_{req.job_id}_") state.work_dir = work_dir try: async with GPU_LOCK: await _set_job_status(state, "INGESTING", 5) wav_path, raw_video_path, duration, has_video = await asyncio.to_thread( _ingest, req, work_dir ) state.duration_s = duration if duration > req.duration_cap_s: raise HTTPException( status_code=400, detail=( f"Source media is {duration:.1f}s; cap is " f"{req.duration_cap_s}s for this plan." ), ) # Speedup: kick the original-video upload (network-bound) in the # background and let demucs (GPU-bound) start immediately. The # upload is awaited later — typically it finishes long before # separation does, so the await is effectively free, but if it # ever lags we still surface the error before claiming the job. video_upload_task: asyncio.Task[None] | None = None if has_video and req.original_video_put_url and raw_video_path: video_upload_task = asyncio.create_task( _put_file( req.original_video_put_url, raw_video_path, content_type="video/mp4", ) ) await _set_job_status(state, "SEPARATING", 20) await asyncio.to_thread(_separate, wav_path, work_dir) if video_upload_task is not None: # Cheap if it already finished; raises if the PUT failed. await video_upload_task specs = await asyncio.to_thread(_plan_batches, os.path.join(work_dir, "vocals.wav"), duration) state.batches = specs await _post_json( f"{req.callback_base}/progress", { "jobId": state.job_id, "durationSec": duration, "totalBatches": len(specs), "containerKind": "video" if has_video else "audio", }, ) # Tell the app to seed batch rows with start/end timestamps before # processing begins so the BatchStrip can render right away. await _post_json( f"{req.callback_base}/batches", { "jobId": state.job_id, "plan": [ { "batchIndex": s.index, "startMs": int(s.start_s * 1000), "endMs": int(s.end_s * 1000), "status": "PENDING", "progress": 0, } for s in specs ], }, ) for spec in specs: batch_urls = next( (b for b in req.batch_put_urls if b.index == spec.index), None ) try: await _process_batch(state, spec, batch_urls, req.speaker_overrides) except Exception as exc: # noqa: BLE001 logger.exception( "dub: batch %d/%d failed", spec.index, len(specs) ) await _set_batch_status( state, spec.index, "FAILED", 0, errorMessage=str(exc)[:500], ) # Keep going so other batches still complete and the user # can retry only the failed one. # Auto-finalize when every batch is READY. statuses = [state.batch_status.get(s.index, "PENDING") for s in specs] if all(s == "READY" for s in statuses): await _finalize_job( state, final_put_url=req.final_put_url, final_video_put_url=req.final_video_put_url, ) else: await _set_job_status( state, "MIXING", state.progress, errorMessage="Some batches failed — retry them, then call /finalize", ) except Exception as exc: # noqa: BLE001 logger.exception("dub: job %s failed", req.job_id) state.status = "FAILED" state.error = str(exc) await _post_json( f"{req.callback_base}/progress", { "jobId": req.job_id, "status": "FAILED", "progress": state.progress, "errorMessage": str(exc)[:500], }, ) finally: state.finished_at = time.time() async def _finalize_job( state: DubJobState, *, final_put_url: str, final_video_put_url: str | None, ) -> None: """Concat per-batch mixes into a single WAV (and mux into video if any).""" if not state.work_dir: raise RuntimeError("finalize called without work_dir") await _set_job_status(state, "MIXING", 95) list_path = os.path.join(state.work_dir, "concat.txt") with open(list_path, "w") as f: for spec in sorted(state.batches, key=lambda s: s.index): mix_path = os.path.join(state.work_dir, f"batch_{spec.index}", "mix.wav") if os.path.exists(mix_path): f.write(f"file '{mix_path}'\n") final_local = os.path.join(state.work_dir, "final.wav") _run( [ "ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_path, "-ac", "2", "-ar", "44100", final_local, ] ) await _put_file(final_put_url, final_local) raw_video = os.path.join(state.work_dir, "raw.media") if ( state.container_kind == "video" and final_video_put_url and os.path.exists(raw_video) ): muxed = os.path.join(state.work_dir, "final.mp4") _run( [ "ffmpeg", "-y", "-i", raw_video, "-i", final_local, "-map", "0:v", "-map", "1:a", "-c:v", "copy", "-c:a", "aac", "-shortest", muxed, ] ) await _put_file(final_video_put_url, muxed, content_type="video/mp4") await _set_job_status(state, "READY", 100) await _post_json( f"{state.callback_base}/claim", {"jobId": state.job_id, "ok": True}, ) # -------------------------------------------------------------------------- # Router # -------------------------------------------------------------------------- def _check_token(authorization: str | None) -> None: expected = os.getenv("DUB_WORKER_TOKEN", "").strip() if not expected: raise HTTPException( status_code=503, detail="DUB_WORKER_TOKEN not configured on pod" ) if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="missing bearer token") provided = authorization.removeprefix("Bearer ").strip() import hmac if not hmac.compare_digest(provided, expected): raise HTTPException(status_code=403, detail="invalid bearer token") def _require_auth(authorization: str | None = Header(None)) -> None: """Router-level dependency. FastAPI runs `dependencies=[...]` BEFORE body/path validation, so this 401s on missing/bad bearer even when the payload is malformed. Without this, an attacker could probe routes with `{}` bodies and only get a 422 schema dump (info leak).""" _check_token(authorization) # All /v1/dub/* routes require the bearer. Keep `_check_token` calls inside the # handlers as a defense-in-depth no-op (they re-validate the same header). router = APIRouter( prefix="/v1/dub", tags=["dub"], dependencies=[Depends(_require_auth)] ) @router.post("/jobs", response_model=DubJobResponse) async def create_job( req: DubJobRequest, authorization: str | None = Header(None), ) -> DubJobResponse: _check_token(authorization) if req.job_id in JOBS and JOBS[req.job_id].status not in ("READY", "FAILED"): raise HTTPException(status_code=409, detail="job already running") JOBS[req.job_id] = DubJobState(job_id=req.job_id, status="PENDING") asyncio.create_task(_run_job(req)) return DubJobResponse(accepted=True, job_id=req.job_id) @router.post("/jobs/{job_id}/batches/{batch_index}/run") async def run_batch( job_id: str, batch_index: int, body: DubBatchRunRequest, authorization: str | None = Header(None), ) -> dict[str, Any]: """Re-run a single batch. Used by the editor's per-batch Retry button.""" _check_token(authorization) state = JOBS.get(job_id) if not state: raise HTTPException( status_code=404, detail="job not in pod cache — pod may have restarted; please re-create", ) spec = next((s for s in state.batches if s.index == batch_index), None) if not spec: raise HTTPException(status_code=404, detail="batch not found") state.callback_base = body.callback_base async def _retry() -> None: async with GPU_LOCK: try: await _process_batch( state, spec, body.batch_put_urls, body.speaker_overrides ) except Exception as exc: # noqa: BLE001 logger.exception("dub: retry batch %d failed", batch_index) await _set_batch_status( state, batch_index, "FAILED", 0, errorMessage=str(exc)[:500], ) asyncio.create_task(_retry()) return {"accepted": True, "batch_index": batch_index} @router.post("/jobs/{job_id}/finalize") async def finalize_endpoint( job_id: str, body: DubFinalizeRequest, authorization: str | None = Header(None), ) -> dict[str, Any]: _check_token(authorization) state = JOBS.get(job_id) if not state: raise HTTPException( status_code=404, detail="job not in pod cache — pod may have restarted; please re-create", ) state.callback_base = body.callback_base async def _do() -> None: async with GPU_LOCK: try: await _finalize_job( state, final_put_url=body.final_put_url, final_video_put_url=body.final_video_put_url, ) except Exception as exc: # noqa: BLE001 logger.exception("dub: finalize %s failed", job_id) await _set_job_status( state, "FAILED", state.progress, errorMessage=str(exc)[:500] ) asyncio.create_task(_do()) return {"accepted": True, "job_id": job_id} @router.get("/jobs/{job_id}") async def get_job( job_id: str, authorization: str | None = Header(None) ) -> dict[str, Any]: _check_token(authorization) state = JOBS.get(job_id) if not state: raise HTTPException(status_code=404, detail="job not found") return { "job_id": state.job_id, "status": state.status, "progress": state.progress, "error": state.error, "started_at": state.started_at, "finished_at": state.finished_at, "duration_s": state.duration_s, "batches": [ { "index": s.index, "start_s": s.start_s, "end_s": s.end_s, "status": state.batch_status.get(s.index, "PENDING"), } for s in state.batches ], "extra": state.extra, } @router.get("/logs/tail") async def tail_logs( lines: int = 200, gpu: int = 0, authorization: str | None = Header(None), ) -> dict[str, Any]: """Return the last N lines of this worker's uvicorn stdout log. Used by the editor's "View pod logs" drawer in the Next.js app — it lets operators see what the dub pipeline is actually doing without SSHing into the pod. Bounded to 1000 lines to avoid streaming megabytes back per poll. """ _check_token(authorization) n = max(1, min(int(lines), 1000)) g = max(0, min(int(gpu), 7)) log_path = f"/var/log/uvicorn-gpu{g}.log" if not os.path.exists(log_path): return {"path": log_path, "lines": [], "truncated": False, "error": "no log file"} try: # Read the tail without slurping the whole file. 64 KiB is enough for # ~200 lines of typical uvicorn output and bounds memory. with open(log_path, "rb") as f: f.seek(0, os.SEEK_END) size = f.tell() chunk = min(size, 64 * 1024) f.seek(size - chunk) data = f.read().decode("utf-8", errors="replace") all_lines = data.splitlines() tail = all_lines[-n:] return { "path": log_path, "lines": tail, "truncated": len(all_lines) > len(tail) or chunk < size, } except Exception as exc: # noqa: BLE001 return {"path": log_path, "lines": [], "truncated": False, "error": str(exc)[:200]} # -------------------------------------------------------------------------- # Periodic cleanup (called from server startup) # -------------------------------------------------------------------------- def cleanup_stale_workdirs(*, older_than_s: int = 24 * 3600) -> int: """Remove /tmp/dub_* dirs older than `older_than_s`. Best-effort; logs but does not raise. Returns the count of dirs removed.""" cutoff = time.time() - older_than_s removed = 0 for entry in os.listdir(tempfile.gettempdir()): if not entry.startswith("dub_"): continue path = os.path.join(tempfile.gettempdir(), entry) try: if os.path.getmtime(path) < cutoff: shutil.rmtree(path, ignore_errors=True) removed += 1 except OSError: continue if removed: logger.info("dub: cleaned %d stale workdir(s)", removed) return removed