File size: 2,354 Bytes
f896e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
import numpy as np
import librosa
from transformers import AutoModelForCTC, Wav2Vec2Processor

MODEL_ID = "mohas8/wav2vec2-xlsr300m-shobdotori-ctc-lexicore"
TARGET_SR = 16_000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---- Load processor & model once ----
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = AutoModelForCTC.from_pretrained(MODEL_ID).to(DEVICE)
model.eval()

# ---- Optional simple normalization (lightweight) ----
def normalize_bangla_text(s: str) -> str:
    if not isinstance(s, str):
        return ""
    s = s.replace("\n", " ").strip()
    # ΰ¦šΰ¦Ύΰ¦‡ΰ¦²ΰ§‡ এখানে আরও Bangla-specific cleaning করঀে ΰ¦ͺΰ¦Ύΰ¦°ΰ§‹
    return s

END_SET = {"ΰ₯€", "!", "?", "."}

def fix_punct(s: str) -> str:
    s = s.strip()
    if not s:
        return s
    last = s[-1]
    if last in END_SET:
        return s
    # খুব simple rule: default to 'ΰ₯€'
    return s + "ΰ₯€"

# ---- Gradio callback ----
def transcribe(audio):
    # Gradio gives (sr, waveform)
    if audio is None:
        return ""
    sr, wav = audio
    wav = np.array(wav, dtype=np.float32)

    # stereo -> mono
    if wav.ndim > 1:
        wav = wav.mean(axis=-1)

    # resample to 16k if needed
    if sr != TARGET_SR:
        wav = librosa.resample(wav, orig_sr=sr, target_sr=TARGET_SR)

    inputs = processor(
        wav,
        sampling_rate=TARGET_SR,
        return_tensors="pt",
        padding=True,
    )

    with torch.no_grad():
        logits = model(inputs.input_values.to(DEVICE)).logits

    pred_ids = torch.argmax(logits, dim=-1)
    text = processor.batch_decode(pred_ids)[0]
    text = normalize_bangla_text(text)
    text = fix_punct(text)
    return text

# ---- Gradio UI ----
demo = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(sampling_rate=TARGET_SR, type="numpy", label="Upload or record Bangla dialect audio"),
    outputs=gr.Textbox(label="Standard Bangla transcription"),
    title="LexiCore – Shobdotori ASR (Wav2Vec2 XLS-R 300M CTC)",
    description=(
        "Dialectal Bangla β†’ Standard Bangla transcription using "
        "mohas8/wav2vec2-xlsr300m-shobdotori-ctc-lexicore (greedy CTC decode)."
    ),
)

if __name__ == "__main__":
    demo.launch()