Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor | |
| import librosa | |
| import torch | |
| import epitran | |
| import re | |
| import difflib | |
| import editdistance | |
| from jiwer import wer | |
| import json | |
| import string | |
| import eng_to_ipa as ipa | |
| # Use lighter model for English to improve speed | |
| MODELS = { | |
| "Arabic": { | |
| "processor": Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"), | |
| "model": Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic"), | |
| "epitran": epitran.Epitran("ara-Arab") | |
| }, | |
| "English": { | |
| "processor": Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h"), | |
| "model": Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h"), | |
| "epitran": epitran.Epitran("eng-Latn") | |
| } | |
| } | |
| for lang in MODELS.values(): | |
| lang["model"].config.ctc_loss_reduction = "mean" | |
| def clean_phonemes(ipa_text): | |
| return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa_text) | |
| def safe_transliterate_arabic(epi, word): | |
| try: | |
| word = word.strip() | |
| ipa = epi.transliterate(word) | |
| if not ipa.strip(): | |
| raise ValueError("Empty IPA string") | |
| return clean_phonemes(ipa) | |
| except Exception as e: | |
| print(f"[Warning] Arabic transliteration failed for '{word}': {e}") | |
| return "" | |
| def transliterate_english(word): | |
| try: | |
| word = word.lower().translate(str.maketrans('', '', string.punctuation)) | |
| ipa_text = ipa.convert(word) | |
| return clean_phonemes(ipa_text) | |
| except Exception as e: | |
| print(f"[Warning] English IPA conversion failed for '{word}': {e}") | |
| return "" | |
| def analyze_phonemes(language, reference_text, audio_file): | |
| lang_models = MODELS[language] | |
| processor = lang_models["processor"] | |
| model = lang_models["model"] | |
| epi = lang_models["epitran"] | |
| transliterate_fn = safe_transliterate_arabic if language == "Arabic" else transliterate_english | |
| ref_phonemes = [list(transliterate_fn(word)) for word in reference_text.split()] | |
| # Load and trim audio to max 1.5s | |
| audio, sr = librosa.load(audio_file, sr=16000) | |
| max_duration = 1.5 | |
| if len(audio) > int(sr * max_duration): | |
| audio = audio[:int(sr * max_duration)] | |
| input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values | |
| with torch.no_grad(): | |
| logits = model(input_values).logits | |
| pred_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(pred_ids)[0].strip() | |
| obs_phonemes = [list(transliterate_fn(word)) for word in transcription.split()] | |
| results = { | |
| "language": language, | |
| "reference_text": reference_text, | |
| "transcription": transcription, | |
| "word_alignment": [], | |
| "metrics": {} | |
| } | |
| total_phoneme_errors = 0 | |
| total_phoneme_length = 0 | |
| correct_words = 0 | |
| total_word_length = len(ref_phonemes) | |
| for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)): | |
| ref_str = ''.join(ref) | |
| obs_str = ''.join(obs) | |
| edits = editdistance.eval(ref, obs) | |
| acc = round((1 - edits / max(1, len(ref))) * 100, 2) | |
| matcher = difflib.SequenceMatcher(None, ref, obs) | |
| ops = matcher.get_opcodes() | |
| error_details = [] | |
| for tag, i1, i2, j1, j2 in ops: | |
| ref_seg = ''.join(ref[i1:i2]) or '-' | |
| obs_seg = ''.join(obs[j1:j2]) or '-' | |
| if tag != 'equal': | |
| error_details.append({ | |
| "type": tag.upper(), | |
| "reference": ref_seg, | |
| "observed": obs_seg | |
| }) | |
| results["word_alignment"].append({ | |
| "word_index": i, | |
| "reference_phonemes": ref_str, | |
| "observed_phonemes": obs_str, | |
| "edit_distance": edits, | |
| "accuracy": acc, | |
| "is_correct": edits == 0, | |
| "errors": error_details | |
| }) | |
| total_phoneme_errors += edits | |
| total_phoneme_length += len(ref) | |
| correct_words += int(edits == 0) | |
| phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2) | |
| phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2) | |
| word_acc = round((correct_words / max(1, total_word_length)) * 100, 2) | |
| word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2) | |
| text_wer = round(wer(reference_text, transcription) * 100, 2) | |
| results["metrics"] = { | |
| "word_accuracy": word_acc, | |
| "word_error_rate": word_er, | |
| "phoneme_accuracy": phoneme_acc, | |
| "phoneme_error_rate": phoneme_er, | |
| "asr_word_error_rate": text_wer | |
| } | |
| return json.dumps(results, indent=2, ensure_ascii=False) | |
| def get_default_text(language): | |
| return { | |
| "Arabic": "ููุจูุฃูููู ุขููุงุกู ุฑูุจููููู ูุง ุชูููุฐููุจูุงูู", | |
| "English": "The quick brown fox jumps over the lazy dog" | |
| }.get(language, "") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Multilingual Phoneme Alignment Analysis") | |
| gr.Markdown("Compare audio pronunciation with reference text at phoneme level") | |
| with gr.Row(): | |
| language = gr.Dropdown(["Arabic", "English"], label="Language", value="Arabic") | |
| reference_text = gr.Textbox(label="Reference Text", value=get_default_text("Arabic")) | |
| audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
| submit_btn = gr.Button("Analyze") | |
| output = gr.JSON(label="Phoneme Alignment Results") | |
| language.change( | |
| fn=get_default_text, | |
| inputs=language, | |
| outputs=reference_text, | |
| api_name="/get_default_text" | |
| ) | |
| submit_btn.click( | |
| fn=analyze_phonemes, | |
| inputs=[language, reference_text, audio_input], | |
| outputs=output, | |
| api_name="/analyze_phonemes" | |
| ) | |
| demo.launch() | |