Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import librosa | |
| from transformers import AutoModelForCTC, Wav2Vec2Processor | |
| MODEL_ID = "mohas8/wav2vec2-xlsr300m-shobdotori-ctc-lexicore" | |
| TARGET_SR = 16_000 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---- Load processor & model once ---- | |
| processor = Wav2Vec2Processor.from_pretrained(MODEL_ID) | |
| model = AutoModelForCTC.from_pretrained(MODEL_ID).to(DEVICE) | |
| model.eval() | |
| # ---- Optional simple normalization (lightweight) ---- | |
| def normalize_bangla_text(s: str) -> str: | |
| if not isinstance(s, str): | |
| return "" | |
| s = s.replace("\n", " ").strip() | |
| # ΰ¦ΰ¦Ύΰ¦ΰ¦²ΰ§ ΰ¦ΰ¦ΰ¦Ύΰ¦¨ΰ§ ΰ¦ΰ¦°ΰ¦ Bangla-specific cleaning ΰ¦ΰ¦°ΰ¦€ΰ§ ΰ¦ͺΰ¦Ύΰ¦°ΰ§ | |
| return s | |
| END_SET = {"ΰ₯€", "!", "?", "."} | |
| def fix_punct(s: str) -> str: | |
| s = s.strip() | |
| if not s: | |
| return s | |
| last = s[-1] | |
| if last in END_SET: | |
| return s | |
| # ΰ¦ΰ§ΰ¦¬ simple rule: default to 'ΰ₯€' | |
| return s + "ΰ₯€" | |
| # ---- Gradio callback ---- | |
| def transcribe(audio): | |
| # Gradio gives (sr, waveform) | |
| if audio is None: | |
| return "" | |
| sr, wav = audio | |
| wav = np.array(wav, dtype=np.float32) | |
| # stereo -> mono | |
| if wav.ndim > 1: | |
| wav = wav.mean(axis=-1) | |
| # resample to 16k if needed | |
| if sr != TARGET_SR: | |
| wav = librosa.resample(wav, orig_sr=sr, target_sr=TARGET_SR) | |
| inputs = processor( | |
| wav, | |
| sampling_rate=TARGET_SR, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| logits = model(inputs.input_values.to(DEVICE)).logits | |
| pred_ids = torch.argmax(logits, dim=-1) | |
| text = processor.batch_decode(pred_ids)[0] | |
| text = normalize_bangla_text(text) | |
| text = fix_punct(text) | |
| return text | |
| # ---- Gradio UI ---- | |
| demo = gr.Interface( | |
| fn=transcribe, | |
| inputs=gr.Audio(sampling_rate=TARGET_SR, type="numpy", label="Upload or record Bangla dialect audio"), | |
| outputs=gr.Textbox(label="Standard Bangla transcription"), | |
| title="LexiCore β Shobdotori ASR (Wav2Vec2 XLS-R 300M CTC)", | |
| description=( | |
| "Dialectal Bangla β Standard Bangla transcription using " | |
| "mohas8/wav2vec2-xlsr300m-shobdotori-ctc-lexicore (greedy CTC decode)." | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |