import gradio as gr from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor import librosa import torch import epitran import re import editdistance from jiwer import wer import orjson import string import eng_to_ipa as ipa import numpy as np # --- Device setup --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- WordMap --- WORD_MAP = { 'A': {'word': 'Apple', 'phonetic': 'ˈæpəl'}, 'B': {'word': 'Ball', 'phonetic': 'bɔːl'}, 'C': {'word': 'Cat', 'phonetic': 'kæt'}, 'D': {'word': 'Dog', 'phonetic': 'dɒɡ'}, 'E': {'word': 'Egg', 'phonetic': 'ɛɡ'}, 'F': {'word': 'Fish', 'phonetic': 'fɪʃ'}, 'G': {'word': 'Goat', 'phonetic': 'ɡoʊt'}, 'H': {'word': 'Hat', 'phonetic': 'hæt'}, 'I': {'word': 'Ice', 'phonetic': 'aɪs'}, 'J': {'word': 'Jar', 'phonetic': 'dʒɑːr'}, 'K': {'word': 'Kite', 'phonetic': 'kaɪt'}, 'L': {'word': 'Lion', 'phonetic': 'ˈlaɪən'}, 'M': {'word': 'Moon', 'phonetic': 'muːn'}, 'N': {'word': 'Nest', 'phonetic': 'nɛst'}, 'O': {'word': 'Orange', 'phonetic': 'ˈɔːrɪndʒ'}, 'P': {'word': 'Pen', 'phonetic': 'pɛn'}, 'Q': {'word': 'Queen', 'phonetic': 'kwiːn'}, 'R': {'word': 'Rabbit', 'phonetic': 'ˈræbɪt'}, 'S': {'word': 'Sun', 'phonetic': 'sʌn'}, 'T': {'word': 'Tree', 'phonetic': 'triː'}, 'U': {'word': 'Umbrella', 'phonetic': 'ʌmˈbrɛlə'}, 'V': {'word': 'Van', 'phonetic': 'væn'}, 'W': {'word': 'Watch', 'phonetic': 'wɒtʃ'}, 'X': {'word': 'Xylophone', 'phonetic': 'ˈzaɪləfoʊn'}, 'Y': {'word': 'Yarn', 'phonetic': 'jɑːrn'}, 'Z': {'word': 'Zebra', 'phonetic': 'ˈziːbrə'} } # --- Load model once at startup --- processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device).eval() epi = epitran.Epitran("eng-Latn") # --- Helper functions --- def clean_phonemes(ipa_text): return re.sub(r'[^\w\s]', '', ipa_text) def transliterate_english(word): try: word = word.lower().translate(str.maketrans('', '', string.punctuation)) for entry in WORD_MAP.values(): if entry['word'].lower() == word: return clean_phonemes(entry['phonetic']) or "" return clean_phonemes(ipa.convert(word)) or "" except Exception: return "" def find_closest_word(transcription, reference_word): if not transcription: return reference_word, 100.0 transcription = transcription.lower().strip() distances = {entry['word'].lower(): editdistance.eval(transcription, entry['word'].lower()) for entry in WORD_MAP.values()} closest_word = min(distances, key=distances.get) max_len = max(len(transcription), len(closest_word)) similarity = round((1 - distances[closest_word] / max(1, max_len)) * 100, 2) return closest_word, similarity # --- Main analysis function --- def analyze_phonemes(language, reference_text, audio_input): try: # Handle audio input if isinstance(audio_input, (tuple, list)): audio, sr = audio_input[0], audio_input[1] else: audio, sr = librosa.load(audio_input, sr=16000, mono=True) audio = audio.astype(np.float32) audio = audio / max(1e-9, np.max(np.abs(audio))) trimmed_audio, _ = librosa.effects.trim(audio, top_db=25) if len(trimmed_audio) < 2400: return orjson.dumps({ "language": language, "reference_text": reference_text, "transcription": "No speech detected", "word_alignment": [], "metrics": {"message": "Audio too short or silent."} }).decode() trimmed_audio = trimmed_audio[:48000] # up to 3 seconds input_values = processor(trimmed_audio, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device) with torch.no_grad(): logits = model(input_values).logits pred_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(pred_ids)[0].strip().lower() probs = torch.softmax(logits, dim=-1) max_probs = probs.max(dim=-1).values.mean().item() transcription_clean = transcription.replace("the", "").strip() closest_word, similarity = find_closest_word(transcription_clean, reference_text.lower()) transcription_clean = closest_word if max_probs < 0.65 or not transcription_clean: return orjson.dumps({ "language": language, "reference_text": reference_text, "transcription": "No speech detected", "word_alignment": [], "metrics": {"message": "Unclear or noisy speech."} }).decode() obs_phonemes = [list(transliterate_english(word)) for word in transcription_clean.split() if transliterate_english(word)] ref_words = reference_text.lower().split() ref_phonemes = [list(transliterate_english(word)) for word in ref_words if transliterate_english(word)] results = { "language": language, "reference_text": reference_text, "transcription": transcription_clean, "word_alignment": [], "metrics": {"similarity": similarity} } total_phoneme_errors = 0 total_phoneme_length = 0 correct_words = 0 total_word_length = len(ref_phonemes) for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)): ref_str = ''.join(ref) obs_str = ''.join(obs) edits = editdistance.eval(ref, obs) acc = round((1 - edits / max(1, len(ref))) * 100, 2) results["word_alignment"].append({ "word_index": i, "reference_phonemes": ref_str, "observed_phonemes": obs_str, "edit_distance": edits, "accuracy": acc, "is_correct": edits == 0 }) total_phoneme_errors += edits total_phoneme_length += len(ref) correct_words += int(edits == 0) phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2) phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2) word_acc = round((correct_words / max(1, total_word_length)) * 100, 2) word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2) text_wer = round(wer(reference_text, transcription_clean) * 100, 2) results["metrics"].update({ "word_accuracy": word_acc, "word_error_rate": word_er, "phoneme_accuracy": phoneme_acc, "phoneme_error_rate": phoneme_er, "asr_word_error_rate": text_wer }) return orjson.dumps(results).decode() except Exception as e: return orjson.dumps({ "language": language, "reference_text": reference_text, "transcription": "Error processing audio", "word_alignment": [], "metrics": {"message": f"Error: {str(e)}"} }).decode() # --- Gradio UI --- def get_default_text(language): return "A" if language == "English" else "" with gr.Blocks() as demo: gr.Markdown("# Multilingual Phoneme Alignment Analysis") gr.Markdown("Compare audio pronunciation with reference text at phoneme level.") with gr.Row(): language = gr.Dropdown(["English"], label="Language", value="English") reference_text = gr.Textbox(label="Reference Text", value="A") audio_input = gr.Audio(label="Record Audio", type="numpy") submit_btn = gr.Button("Analyze") output = gr.JSON(label="Phoneme Alignment Results") language.change(fn=get_default_text, inputs=language, outputs=reference_text) submit_btn.click(fn=analyze_phonemes, inputs=[language, reference_text, audio_input], outputs=output) demo.launch()