import gc import io import numpy as np import torch import librosa from fastapi import FastAPI, UploadFile, File from transformers import ( WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, ) app = FastAPI() # --- YOUR ORIGINAL CONFIG & LOADING LOGIC --- HF_MODEL_ID = "devrahulbanjara/whisper-small-nepali" CHUNK_SAMPLES = 480_000 OVERLAP_SAMPLES = 16_000 device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 print(f"Loading model on {device}...") feature_extractor = WhisperFeatureExtractor.from_pretrained(HF_MODEL_ID) tokenizer = WhisperTokenizer.from_pretrained( "openai/whisper-small", language="nepali", task="transcribe" ) processor = WhisperProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer) model = WhisperForConditionalGeneration.from_pretrained( HF_MODEL_ID, torch_dtype=dtype ).to(device).eval() # --- YOUR ORIGINAL TRANSCRIPTION FUNCTIONS --- def _transcribe_chunk(chunk: np.ndarray) -> str: inp = processor(chunk, sampling_rate=16_000, return_tensors="pt") feat = inp.input_features.to(device=device, dtype=dtype) with torch.no_grad(): ids = model.generate(feat, language="nepali", task="transcribe") text = processor.batch_decode(ids, skip_special_tokens=True)[0] return text def transcribe(audio_array: np.ndarray, sr: int) -> str: # Resample to 16kHz as per your logic if sr != 16_000: audio_array = librosa.resample( audio_array.astype(np.float32), orig_sr=sr, target_sr=16_000 ) audio_array = np.asarray(audio_array, dtype=np.float32) total = len(audio_array) if total <= CHUNK_SAMPLES: return _transcribe_chunk(audio_array) parts, start = [], 0 while start < total: end = min(start + CHUNK_SAMPLES, total) parts.append(_transcribe_chunk(audio_array[start:end])) start += CHUNK_SAMPLES - OVERLAP_SAMPLES return " ".join(parts) # --- THE API ENDPOINT --- @app.post("/transcribe") async def api_transcribe(file: UploadFile = File(...)): # 1. Read the uploaded bytes audio_bytes = await file.read() # 2. Load into numpy via librosa (handles wav, mp3, flac, etc.) # This replaces the sf.read part and handles stereo->mono automatically audio_array, sr = librosa.load(io.BytesIO(audio_bytes), sr=None) # 3. Run your logic result = transcribe(audio_array, sr) # 4. Cleanup gc.collect() if device == "cuda": torch.cuda.empty_cache() return {"text": result} @app.get("/") def health(): return {"status": "online"}