""" TTS module - Orpheus TTS with vLLM for fast inference. Supports voice cloning (pretrained model) and named voices (finetuned model). Uses vLLM's AsyncLLMEngine directly for 2-4x speedup over transformers. """ import os import re import time import wave import logging import subprocess import asyncio import threading import numpy as np import torch logger = logging.getLogger(__name__) TEMP_DIR = "/tmp/tts_output" VOICE_REF_DIR = "/app/assets" SAMPLE_RATE = 24000 FALLBACK_VOICE = os.environ.get("ORPHEUS_VOICE", "leo") # Special token IDs for Orpheus TOKEN_START_OF_HUMAN = 128259 TOKEN_END_OF_TEXT = 128009 TOKEN_END_OF_HUMAN = 128260 TOKEN_START_OF_AI = 128261 TOKEN_AUDIO_START = 128257 TOKEN_AUDIO_END = 128258 TOKEN_END_OF_AI = 128262 TOKEN_PAD = 128263 AUDIO_CODE_OFFSET = 128266 # Singleton state _engine = None # vLLM AsyncLLMEngine _tokenizer = None _snac_model = None _ref_tokens = None _ref_transcript = None _use_voice_cloning = False _initialized = False _request_counter = 0 _event_loop = None _loop_thread = None def _ts(): return time.strftime("%H:%M:%S", time.gmtime()) + f".{int(time.time()*1000)%1000:03d}" def ensure_temp_dir(): os.makedirs(TEMP_DIR, exist_ok=True) return TEMP_DIR def _start_event_loop(): """Start a background event loop for vLLM async operations.""" global _event_loop _event_loop = asyncio.new_event_loop() asyncio.set_event_loop(_event_loop) _event_loop.run_forever() def initialize(): """ Load vLLM engine, tokenizer, and SNAC decoder. Voice cloning uses pretrained model; fallback uses finetuned model. """ global _engine, _tokenizer, _snac_model, _ref_tokens, _ref_transcript global _use_voice_cloning, _initialized, _loop_thread if _initialized: return t0 = time.time() logger.info(f"[{_ts()}] [TTS] Initializing Orpheus TTS with vLLM backend...") from transformers import AutoTokenizer as HFTokenizer from snac import SNAC # 1. Load SNAC decoder t1 = time.time() _snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() if torch.cuda.is_available(): _snac_model = _snac_model.cuda() logger.info(f"[{_ts()}] [TTS] SNAC decoder loaded: {time.time()-t1:.2f}s") # 2. Determine model based on available reference audio + transcript _ref_transcript = _load_transcript() ref_audio_path = _find_reference_audio() hf_token = os.environ.get("HF_TOKEN", None) if ref_audio_path and _ref_transcript and hf_token: _use_voice_cloning = True model_name = "canopylabs/orpheus-3b-0.1-pretrained" logger.info(f"[{_ts()}] [TTS] Voice cloning mode: ref audio + transcript found") logger.info(f"[{_ts()}] [TTS] Transcript: \"{_ref_transcript[:80]}...\"") else: _use_voice_cloning = False model_name = "canopylabs/orpheus-3b-0.1-ft" if not ref_audio_path: logger.info(f"[{_ts()}] [TTS] No reference audio found") if not _ref_transcript: logger.info(f"[{_ts()}] [TTS] No transcript found") if not hf_token: logger.info(f"[{_ts()}] [TTS] No HF_TOKEN") logger.info(f"[{_ts()}] [TTS] Using finetuned model with voice: {FALLBACK_VOICE}") # 3. Load tokenizer (HuggingFace tokenizer for prompt building) t2 = time.time() token_kwargs = {"token": hf_token} if hf_token else {} _tokenizer = HFTokenizer.from_pretrained(model_name, **token_kwargs) logger.info(f"[{_ts()}] [TTS] Tokenizer loaded: {time.time()-t2:.2f}s") # 4. Start background event loop for vLLM _loop_thread = threading.Thread(target=_start_event_loop, daemon=True) _loop_thread.start() # Wait for loop to be ready import time as _t while _event_loop is None: _t.sleep(0.01) # 5. Initialize vLLM engine t3 = time.time() _init_vllm_engine(model_name, hf_token) logger.info(f"[{_ts()}] [TTS] vLLM engine loaded: {time.time()-t3:.2f}s") # 6. Encode reference audio if voice cloning if _use_voice_cloning and ref_audio_path: _ref_tokens = _encode_reference_audio(ref_audio_path) if _ref_tokens is None: logger.warning(f"[{_ts()}] [TTS] Reference encoding failed, falling back to named voice") _use_voice_cloning = False _initialized = True # Pre-build the cached reference prefix for voice cloning _build_ref_prefix() mode = f"voice cloning ({len(_ref_tokens)} ref tokens)" if _use_voice_cloning else f"named voice ({FALLBACK_VOICE})" logger.info(f"[{_ts()}] [TTS] ✓ Orpheus+vLLM ready in {time.time()-t0:.2f}s | mode: {mode}") def _init_vllm_engine(model_name, hf_token): """Initialize the vLLM AsyncLLMEngine.""" global _engine from vllm import AsyncLLMEngine, AsyncEngineArgs # Determine GPU memory and dtype dtype = "bfloat16" if torch.cuda.is_available(): capability = torch.cuda.get_device_capability() if capability[0] < 8: dtype = "float16" # SNAC takes ~0.5GB, FLOAT takes ~2-3GB, leave room gpu_mem_util = float(os.environ.get("VLLM_GPU_MEM_UTIL", "0.65")) engine_args = AsyncEngineArgs( model=model_name, dtype=dtype, gpu_memory_utilization=gpu_mem_util, max_model_len=4096, # Need room for ref audio tokens + generation enforce_eager=True, # Avoid CUDA graph issues on HF Spaces enable_prefix_caching=True, # Cache KV for ref audio prefix (same every request) disable_log_requests=True, ) # Pass HF token for gated models if hf_token: os.environ["HF_TOKEN"] = hf_token # Create engine on the background event loop future = asyncio.run_coroutine_threadsafe( _create_engine(engine_args), _event_loop ) _engine = future.result(timeout=300) # 5 min timeout for model download async def _create_engine(engine_args): """Async engine creation.""" from vllm import AsyncLLMEngine engine = AsyncLLMEngine.from_engine_args(engine_args) return engine def _load_transcript(): for d in [VOICE_REF_DIR, "/app"]: for name in ["voice_transcript.txt", "transcript.txt"]: path = os.path.join(d, name) if os.path.exists(path): text = open(path).read().strip() if text: return text return None def _find_reference_audio(): for d in [VOICE_REF_DIR, "/app"]: for name in ["voice.mp4", "voice.wav", "voice.mp3"]: path = os.path.join(d, name) if os.path.exists(path) and os.path.getsize(path) > 1024: return path return None def _encode_reference_audio(ref_path): """Encode reference audio to SNAC tokens for voice cloning.""" import librosa logger.info(f"[{_ts()}] [TTS] Encoding reference: {ref_path}") ensure_temp_dir() wav_path = ref_path if not ref_path.endswith(".wav"): wav_path = os.path.join(TEMP_DIR, "voice_ref.wav") try: subprocess.run([ "ffmpeg", "-y", "-i", ref_path, "-ar", str(SAMPLE_RATE), "-ac", "1", "-f", "wav", wav_path ], capture_output=True, check=True) except Exception as e: logger.error(f"[{_ts()}] [TTS] ffmpeg failed: {e}") return None try: t0 = time.time() audio_array, sr = librosa.load(wav_path, sr=SAMPLE_RATE) max_samples = 15 * SAMPLE_RATE if len(audio_array) > max_samples: audio_array = audio_array[:max_samples] logger.info(f"[{_ts()}] [TTS] Trimmed reference to 15s") duration = len(audio_array) / SAMPLE_RATE logger.info(f"[{_ts()}] [TTS] Reference audio: {duration:.1f}s") waveform = torch.from_numpy(audio_array).unsqueeze(0).to(dtype=torch.float32).unsqueeze(0) if torch.cuda.is_available(): waveform = waveform.cuda() with torch.inference_mode(): codes = _snac_model.encode(waveform) all_codes = [] for i in range(codes[0].shape[1]): all_codes.append(codes[0][0][i].item() + AUDIO_CODE_OFFSET) all_codes.append(codes[1][0][2 * i].item() + AUDIO_CODE_OFFSET + 4096) all_codes.append(codes[2][0][4 * i].item() + AUDIO_CODE_OFFSET + (2 * 4096)) all_codes.append(codes[2][0][(4 * i) + 1].item() + AUDIO_CODE_OFFSET + (3 * 4096)) all_codes.append(codes[1][0][(2 * i) + 1].item() + AUDIO_CODE_OFFSET + (4 * 4096)) all_codes.append(codes[2][0][(4 * i) + 2].item() + AUDIO_CODE_OFFSET + (5 * 4096)) all_codes.append(codes[2][0][(4 * i) + 3].item() + AUDIO_CODE_OFFSET + (6 * 4096)) logger.info(f"[{_ts()}] [TTS] Reference encoded: {len(all_codes)} tokens ({time.time()-t0:.2f}s)") return all_codes except Exception as e: logger.error(f"[{_ts()}] [TTS] Failed to encode reference: {e}", exc_info=True) return None # Cached reference prefix (built once, reused every request) _ref_prefix_ids = None def _build_ref_prefix(): """Pre-build the fixed reference prefix token IDs (called once at init).""" global _ref_prefix_ids if not _use_voice_cloning or _ref_tokens is None: _ref_prefix_ids = None return start = [TOKEN_START_OF_HUMAN] end = [TOKEN_END_OF_TEXT, TOKEN_END_OF_HUMAN, TOKEN_START_OF_AI, TOKEN_AUDIO_START] final = [TOKEN_AUDIO_END, TOKEN_END_OF_AI] transcript_ids = _tokenizer(_ref_transcript, return_tensors="pt").input_ids[0].tolist() _ref_prefix_ids = start + transcript_ids + end + _ref_tokens + final logger.info(f"[{_ts()}] [TTS] Reference prefix cached: {len(_ref_prefix_ids)} tokens") def _build_prompt_token_ids(text): """ Build input token IDs list for vLLM (not tensors — vLLM takes lists). Voice cloning (pretrained): [SOH][transcript][EOT EOH SOA AUDIO_START][ref_audio][AUDIO_END EOAI] ← cached prefix [SOH][text][EOT EOH SOA AUDIO_START] ← varies per request Named voice (finetuned): [SOH]{voice}: {text}[EOT EOH SOA AUDIO_START] """ if _use_voice_cloning and _ref_prefix_ids is not None: start = [TOKEN_START_OF_HUMAN] end = [TOKEN_END_OF_TEXT, TOKEN_END_OF_HUMAN, TOKEN_START_OF_AI, TOKEN_AUDIO_START] gen_ids = _tokenizer(text, return_tensors="pt").input_ids[0].tolist() gen_block = start + gen_ids + end return _ref_prefix_ids + gen_block else: prompt = f"{FALLBACK_VOICE}: {text}" token_ids = _tokenizer(prompt, return_tensors="pt").input_ids[0].tolist() start = [TOKEN_START_OF_HUMAN] end = [TOKEN_END_OF_TEXT, TOKEN_END_OF_HUMAN, TOKEN_START_OF_AI, TOKEN_AUDIO_START] return start + token_ids + end def _decode_audio_tokens_from_list(token_ids): """ Decode a list of generated token IDs into audio. Expects only the NEW tokens (not the prompt). """ # Remove any end tokens filtered = [t for t in token_ids if t not in (TOKEN_AUDIO_END, TOKEN_END_OF_AI, TOKEN_AUDIO_START)] # Trim to multiple of 7 length = (len(filtered) // 7) * 7 if length == 0: logger.warning(f"[{_ts()}] [TTS] No valid audio tokens") return None code_list = [t - AUDIO_CODE_OFFSET for t in filtered[:length]] layer_1, layer_2, layer_3 = [], [], [] for i in range(len(code_list) // 7): layer_1.append(code_list[7 * i]) layer_2.append(code_list[7 * i + 1] - 4096) layer_3.append(code_list[7 * i + 2] - (2 * 4096)) layer_3.append(code_list[7 * i + 3] - (3 * 4096)) layer_2.append(code_list[7 * i + 4] - (4 * 4096)) layer_3.append(code_list[7 * i + 5] - (5 * 4096)) layer_3.append(code_list[7 * i + 6] - (6 * 4096)) codes = [ torch.tensor(layer_1).unsqueeze(0), torch.tensor(layer_2).unsqueeze(0), torch.tensor(layer_3).unsqueeze(0), ] if torch.cuda.is_available(): codes = [c.cuda() for c in codes] with torch.inference_mode(): audio_hat = _snac_model.decode(codes) return audio_hat.detach().squeeze().cpu().numpy() async def _generate_tokens_async(prompt_token_ids, max_tokens, temperature, top_p, rep_penalty): """Run vLLM generation asynchronously and collect all output tokens.""" global _request_counter from vllm import SamplingParams _request_counter += 1 request_id = f"tts-{_request_counter}" sampling_params = SamplingParams( temperature=temperature, top_p=top_p, max_tokens=max_tokens, repetition_penalty=rep_penalty, stop_token_ids=[TOKEN_AUDIO_END], ) # Use TokensPrompt dict for vLLM 0.7.3 compatibility prompt_input = {"prompt_token_ids": prompt_token_ids} all_token_ids = [] async for output in _engine.generate( prompt=prompt_input, sampling_params=sampling_params, request_id=request_id, ): # vLLM gives cumulative output at each step if output.outputs: all_token_ids = list(output.outputs[0].token_ids) return all_token_ids def _generate_tokens_sync(prompt_token_ids, max_tokens, temperature, top_p, rep_penalty): """Synchronous wrapper for async vLLM generation.""" future = asyncio.run_coroutine_threadsafe( _generate_tokens_async(prompt_token_ids, max_tokens, temperature, top_p, rep_penalty), _event_loop ) return future.result(timeout=120) def _clean_text_for_tts(text): """Remove tags and asterisk actions that confuse voice cloning.""" text = re.sub(r'<[^>]+>', '', text) text = re.sub(r'\*[^*]+\*', '', text) text = re.sub(r'\s+', ' ', text).strip() return text def generate_audio(text: str, output_filename: str = None) -> str: """ Generate speech audio from text using Orpheus + vLLM. Returns path to wav file, or None on failure. """ if not text or not text.strip(): return None text = _clean_text_for_tts(text) if not text: return None if not _initialized: logger.error(f"[{_ts()}] [TTS] Not initialized!") return None temp_dir = ensure_temp_dir() if output_filename is None: timestamp = int(time.time() * 1000) output_filename = f"tts_{timestamp}" if not output_filename.endswith('.wav'): output_path = os.path.join(temp_dir, f"{output_filename}.wav") else: output_path = os.path.join(temp_dir, output_filename) try: t0 = time.time() logger.info(f"[{_ts()}] [TTS] Generating: {text[:60]}...") # Build prompt prompt_ids = _build_prompt_token_ids(text) logger.info(f"[{_ts()}] [TTS] Prompt: {len(prompt_ids)} tokens") # Scale max_tokens to text length to prevent runaway generation # Observed: model typically uses 35-43 tokens/word; cap at 45 for headroom # This prevents runaway (50+ tok/word) while allowing natural stops word_count = len(text.split()) scaled_max = min(990, max(200, word_count * 45)) logger.info(f"[{_ts()}] [TTS] Words: {word_count} → max_tokens: {scaled_max}") # Generation settings matched to mode # rep_penalty=1.2 gives slight speed boost without hurting voice clone quality # temperature=0.5 keeps voice stable; top_p=0.9 for natural variation if _use_voice_cloning: gen_kwargs = dict(max_tokens=scaled_max, temperature=0.5, top_p=0.9, rep_penalty=1.2) else: gen_kwargs = dict(max_tokens=min(1200, word_count * 35), temperature=0.7, top_p=0.9, rep_penalty=1.2) # Generate via vLLM t1 = time.time() output_ids = _generate_tokens_sync(prompt_ids, **gen_kwargs) t2 = time.time() logger.info(f"[{_ts()}] [TTS] vLLM: {len(output_ids)} tokens in {t2-t1:.2f}s ({len(output_ids)/(t2-t1):.0f} tok/s)") # Decode audio t3 = time.time() audio_np = _decode_audio_tokens_from_list(output_ids) if audio_np is None: return None t4 = time.time() logger.info(f"[{_ts()}] [TTS] SNAC decode: {t4-t3:.2f}s | audio: {len(audio_np)/SAMPLE_RATE:.1f}s") # Save WAV with wave.open(output_path, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(SAMPLE_RATE) audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype(np.int16) wf.writeframes(audio_int16.tobytes()) file_size = os.path.getsize(output_path) total = time.time() - t0 logger.info(f"[{_ts()}] [TTS] Saved: {output_path} ({file_size/1024:.0f}KB) | total: {total:.2f}s") return output_path except Exception as e: logger.error(f"[{_ts()}] [TTS] Error: {e}", exc_info=True) return None