import gradio as gr import torch import numpy as np import librosa from transformers import AutoModelForCTC, Wav2Vec2Processor MODEL_ID = "mohas8/wav2vec2-xlsr300m-shobdotori-ctc-lexicore" TARGET_SR = 16_000 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ---- Load processor & model once ---- processor = Wav2Vec2Processor.from_pretrained(MODEL_ID) model = AutoModelForCTC.from_pretrained(MODEL_ID).to(DEVICE) model.eval() # ---- Optional simple normalization (lightweight) ---- def normalize_bangla_text(s: str) -> str: if not isinstance(s, str): return "" s = s.replace("\n", " ").strip() # চাইলে এখানে আরও Bangla-specific cleaning করতে পারো return s END_SET = {"।", "!", "?", "."} def fix_punct(s: str) -> str: s = s.strip() if not s: return s last = s[-1] if last in END_SET: return s # খুব simple rule: default to '।' return s + "।" # ---- Gradio callback ---- def transcribe(audio): # Gradio gives (sr, waveform) if audio is None: return "" sr, wav = audio wav = np.array(wav, dtype=np.float32) # stereo -> mono if wav.ndim > 1: wav = wav.mean(axis=-1) # resample to 16k if needed if sr != TARGET_SR: wav = librosa.resample(wav, orig_sr=sr, target_sr=TARGET_SR) inputs = processor( wav, sampling_rate=TARGET_SR, return_tensors="pt", padding=True, ) with torch.no_grad(): logits = model(inputs.input_values.to(DEVICE)).logits pred_ids = torch.argmax(logits, dim=-1) text = processor.batch_decode(pred_ids)[0] text = normalize_bangla_text(text) text = fix_punct(text) return text # ---- Gradio UI ---- demo = gr.Interface( fn=transcribe, inputs=gr.Audio(sampling_rate=TARGET_SR, type="numpy", label="Upload or record Bangla dialect audio"), outputs=gr.Textbox(label="Standard Bangla transcription"), title="LexiCore – Shobdotori ASR (Wav2Vec2 XLS-R 300M CTC)", description=( "Dialectal Bangla → Standard Bangla transcription using " "mohas8/wav2vec2-xlsr300m-shobdotori-ctc-lexicore (greedy CTC decode)." ), ) if __name__ == "__main__": demo.launch()