"""Pocket-TTS FastAPI server. Clean API, no Gradio.""" import io import os import wave import tempfile import subprocess from pathlib import Path import numpy as np from fastapi import FastAPI, Query, HTTPException from fastapi.responses import Response, HTMLResponse try: import torch from pocket_tts import TTSModel except ImportError: torch = None TTSModel = None app = FastAPI(title="Pocket-TTS API") # Global model state _state = { "initialized": False, "model": None, "sample_rate": 24000, } # Built-in voices from kyutai/pocket-tts model repo BUILTIN_VOICES = [ "alba", "azelma", "cosette", "eponine", "fantine", "javert", "jean", "marius", ] def _get_embeddings_dir() -> Path: """Find the embeddings directory (inside the pocket-tts package or HF cache).""" # pocket-tts downloads model to HF cache; embeddings are in the model dir # Check common locations from huggingface_hub import hf_hub_download cache_dir = Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface")) # The model stores embeddings alongside the model weights # pocket-tts >= 2.0 ships embeddings inside the package data try: from pocket_tts import data pkg_dir = Path(data.__file__).parent emb_dir = pkg_dir / "embeddings" if emb_dir.exists(): return emb_dir except Exception: pass # Fallback: check HF cache for kyutai/pocket-tts for emb_version in ["embeddings_v3", "embeddings_v2", "embeddings"]: try: path = hf_hub_download("kyutai/pocket-tts", f"{emb_version}/alba.safetensors") return Path(path).parent except Exception: continue return Path("/nonexistent") def _init_model(): """Initialize the TTS model (lazy, called on first request).""" if _state["initialized"]: return if TTSModel is None: raise RuntimeError("pocket-tts not installed") print("Initializing Pocket TTS model...") model = TTSModel.load_model() _state["model"] = model _state["sample_rate"] = getattr(model, "sample_rate", 24000) _state["initialized"] = True print(f"Pocket TTS initialized. Sample rate: {_state['sample_rate']} Hz") def _get_voice_state(voice: str): """Get voice state from pre-computed embedding.""" import safetensors.torch from pocket_tts.modules.stateful_module import init_states model = _state["model"] # Try loading from the model's embeddings directory embeddings_dir = _get_embeddings_dir() embedding_path = embeddings_dir / f"{voice}.safetensors" if not embedding_path.exists(): # Try downloading from HF try: from huggingface_hub import hf_hub_download for emb_version in ["embeddings_v3", "embeddings_v2", "embeddings"]: try: path = hf_hub_download("kyutai/pocket-tts", f"{emb_version}/{voice}.safetensors") embedding_path = Path(path) break except Exception: continue except Exception: pass if not embedding_path.exists(): raise ValueError(f"No embedding found for voice '{voice}'. Available: {BUILTIN_VOICES}") print(f"Loading embedding for '{voice}' from {embedding_path}") state_dict = safetensors.torch.load_file(str(embedding_path)) audio_prompt = state_dict["audio_prompt"].to(model.device) voice_state = init_states(model.flow_lm, batch_size=1, sequence_length=1000) model._run_flow_lm_and_increment_step(model_state=voice_state, audio_conditioning=audio_prompt) # Detach tensors for deepcopy compatibility def detach_tensors(obj): if isinstance(obj, torch.Tensor): return obj.detach().clone() elif isinstance(obj, dict): return {k: detach_tensors(v) for k, v in obj.items()} else: return obj return detach_tensors(voice_state) def _generate_audio(text: str, voice: str, temperature: float = 0.7) -> tuple[np.ndarray, int]: """Generate audio as numpy array. Returns (audio_int16, sample_rate).""" _init_model() model = _state["model"] sample_rate = _state["sample_rate"] voice_state = _get_voice_state(voice) # Generate audio for the full text audio = model.generate_audio( voice_state, text, frames_after_eos=2, copy_state=True, ) audio_np = audio.cpu().numpy() if hasattr(audio, 'cpu') else audio # Normalize volume to 95% peak max_val = np.max(np.abs(audio_np)) if max_val > 0: audio_np = audio_np / max_val * 0.95 # Convert to int16 audio_int16 = np.clip(audio_np * 32767, -32767, 32767).astype(np.int16) return audio_int16, sample_rate def _wav_bytes(audio_int16: np.ndarray, sample_rate: int) -> bytes: """Create WAV bytes from int16 audio.""" buf = io.BytesIO() with wave.open(buf, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(sample_rate) wf.writeframes(audio_int16.tobytes()) return buf.getvalue() def _ogg_bytes(audio_int16: np.ndarray, sample_rate: int) -> bytes: """Convert int16 audio to OGG/Opus via ffmpeg.""" wav_data = _wav_bytes(audio_int16, sample_rate) proc = subprocess.run( ["ffmpeg", "-y", "-f", "wav", "-i", "pipe:0", "-c:a", "libopus", "-b:a", "64k", "-ar", "48000", "-ac", "1", "-f", "ogg", "pipe:1"], input=wav_data, capture_output=True, timeout=30, ) if proc.returncode != 0: raise RuntimeError(f"ffmpeg failed: {proc.stderr.decode()[:200]}") return proc.stdout @app.get("/tts") async def tts( text: str = Query(..., description="Text to synthesize"), voice: str = Query("alba", description="Voice name"), temperature: float = Query(0.7, ge=0.1, le=1.5), format: str = Query("ogg", description="Output format: wav or ogg"), ): """Generate TTS audio. Returns WAV or OGG/Opus.""" try: audio_int16, sample_rate = _generate_audio(text, voice, temperature) except ValueError as e: raise HTTPException(400, str(e)) except Exception as e: raise HTTPException(500, str(e)[:200]) if format == "ogg": try: data = _ogg_bytes(audio_int16, sample_rate) return Response(content=data, media_type="audio/ogg", headers={"Content-Disposition": "attachment; filename=tts.ogg"}) except Exception as e: raise HTTPException(500, f"OGG encoding failed: {str(e)[:200]}") data = _wav_bytes(audio_int16, sample_rate) return Response(content=data, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=tts.wav"}) @app.get("/voices") async def voices(): """List available voices.""" return {"voices": BUILTIN_VOICES} @app.get("/health") async def health(): return {"status": "ok", "initialized": _state["initialized"]} @app.get("/", response_class=HTMLResponse) async def index(): """Simple landing page with usage info.""" return """
FastAPI server running kyutai/pocket-tts
GET /tts?text=Hello&voice=alba&format=ogg — Generate speech (wav/ogg)GET /voices — List available voicesGET /health — Health checkalba, azelma, cosette, eponine, fantine, javert, jean, marius
curl "https://YOUR_SPACE.hf.space/tts?text=Hello+world&voice=alba&format=ogg" -o tts.ogg"""