from fastapi import FastAPI, UploadFile, Form from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import torch import torch.nn.functional as F import librosa from word_syllables import word_syllables import tempfile import os import logging from datetime import datetime import hashlib from difflib import SequenceMatcher import re # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('speech_api.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) app = FastAPI(title="Cognia Wav2Vec2 Speech API") MODEL_DIR = "Artyomorax/cognia-wav2vec" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SR = 16000 logger.info(f"🚀 Starting API with DEVICE: {DEVICE}") processor = Wav2Vec2Processor.from_pretrained(MODEL_DIR) model = Wav2Vec2ForCTC.from_pretrained(MODEL_DIR).to(DEVICE) model.eval() logger.info("✅ Model loaded successfully") def syllables_to_phonetic(syllables): """Convert syllables list to phonetic string for comparison""" return "".join(syllables).replace("-", "").replace(" ", "").lower() def find_closest_word_in_dictionary(transcription): """ Search word_syllables dictionary to find the closest matching word to the transcription. Returns (word, syllables, confidence). """ trans_phonetic = transcription.replace(" ", "").replace("-", "").lower() best_word = None best_syllables = [] best_score = 0.0 for word, syllables in word_syllables.items(): word_phonetic = syllables_to_phonetic(syllables) # Calculate similarity score = SequenceMatcher(None, trans_phonetic, word_phonetic).ratio() if score > best_score: best_score = score best_word = word best_syllables = syllables logger.info(f"🔍 Dictionary match: '{trans_phonetic}' → '{best_word}' ({best_score:.2f})") return best_word, best_syllables, best_score def phonetic_similarity(s1, s2): """Calculate phonetic similarity between two strings""" # Remove spaces and hyphens for comparison s1_clean = s1.replace(" ", "").replace("-", "").lower() s2_clean = s2.replace(" ", "").replace("-", "").lower() # Use SequenceMatcher for fuzzy matching return SequenceMatcher(None, s1_clean, s2_clean).ratio() def syllable_similarity(syll1, syll2): """ Compare two syllable lists and return similarity score """ # Convert to phonetic strings s1_phonetic = syllables_to_phonetic(syll1) s2_phonetic = syllables_to_phonetic(syll2) # Overall phonetic match phonetic_score = SequenceMatcher(None, s1_phonetic, s2_phonetic).ratio() # Syllable count match (bonus for same number of syllables) count_match = 1.0 - abs(len(syll1) - len(syll2)) / max(len(syll1), len(syll2), 1) # Combined score (weighted) return (phonetic_score * 0.8) + (count_match * 0.2) def find_best_match(transcription, target_word): """ Find the best matching segment in transcription that corresponds to target word. Now uses word_syllables dictionary for validation. Returns the matched segment, syllables, and confidence score. """ trans_clean = transcription.replace("-", " ").lower().strip() target_clean = target_word.lower().strip() # Get target syllables from dictionary target_sylls = word_syllables.get(target_clean, []) if not target_sylls: logger.warning(f"⚠️ Target word '{target_clean}' not found in dictionary") return trans_clean, trans_clean.split(), 0.5 # Direct match if target_clean in trans_clean: logger.info(f"✓ Direct match found: '{target_clean}'") return target_clean, target_sylls, 1.0 # Try to find the word with spaces removed trans_nospace = trans_clean.replace(" ", "") target_nospace = target_clean.replace(" ", "") if target_nospace in trans_nospace: logger.info(f"✓ No-space match found: '{target_nospace}'") return target_clean, target_sylls, 0.95 # Split transcription and check all possible combinations words = trans_clean.split() best_match = "" best_syllables = [] best_score = 0.0 # Check single words against dictionary for word in words: if word in word_syllables: # Found exact word in dictionary word_sylls = word_syllables[word] score = syllable_similarity(word_sylls, target_sylls) if score > best_score: best_score = score best_match = word best_syllables = word_sylls logger.info(f" Single word match: '{word}' (score: {score:.2f})") # Check consecutive word combinations for i in range(len(words)): for j in range(i + 1, min(i + 4, len(words) + 1)): combo = "".join(words[i:j]) combo_display = " ".join(words[i:j]) # Check if combined word exists in dictionary if combo in word_syllables: combo_sylls = word_syllables[combo] score = syllable_similarity(combo_sylls, target_sylls) if score > best_score: best_score = score best_match = combo_display best_syllables = combo_sylls logger.info(f" Combo match: '{combo}' (score: {score:.2f})") else: # Phonetic comparison even if not in dictionary score = phonetic_similarity(combo, target_clean) if score > best_score: best_score = score best_match = combo_display best_syllables = combo_display.split() logger.info(f" Phonetic match: '{combo}' (score: {score:.2f})") # If no good match found, search entire dictionary if best_score < 0.6: logger.info(" Searching entire dictionary for closest match...") dict_word, dict_sylls, dict_score = find_closest_word_in_dictionary(trans_clean) if dict_score > best_score: best_score = dict_score best_match = dict_word best_syllables = dict_sylls logger.info(f" Dictionary match used: '{dict_word}' (score: {dict_score:.2f})") logger.info(f"🎯 Best match for '{target_clean}': '{best_match}' (score: {best_score:.2f})") logger.info(f" Matched syllables: {best_syllables}") return best_match, best_syllables, best_score def get_embedding(syllables): """Generate embeddings from syllables with logging""" logger.debug(f" 📊 Computing embedding for syllables: {syllables}") token_ids = [] for s in syllables: ids = processor.tokenizer(s, add_special_tokens=False).input_ids token_ids.extend(ids) if not token_ids: logger.warning(" ⚠️ No token IDs generated") return None tokens = torch.tensor(token_ids, device=DEVICE) embeds = model.lm_head.weight[tokens] mean_embed = embeds.mean(dim=0) logger.debug(f" 📊 Embedding shape: {mean_embed.shape}, mean value: {mean_embed.mean().item():.4f}") return mean_embed def calculate_pronunciation_score(cosine_sim, match_score): """ Calculate a combined pronunciation score that considers both embedding similarity and transcription matching. """ # Weight the cosine similarity more heavily, but penalize poor transcription matches combined_score = (cosine_sim * 0.7) + (match_score * 0.3) # Apply penalties for very poor matches if match_score < 0.5: combined_score *= 0.7 # Heavy penalty elif match_score < 0.7: combined_score *= 0.85 # Moderate penalty return combined_score @app.post("/analyze") async def analyze_audio(audio_file: UploadFile, target_word: str = Form(...)): request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") start_time = datetime.now() logger.info(f"\n{'='*80}") logger.info(f"📥 NEW REQUEST [{request_id}]") logger.info(f"⏰ Timestamp: {start_time.isoformat()}") logger.info(f"📁 File: {audio_file.filename} ({audio_file.content_type})") logger.info(f"🎯 Target word: '{target_word}'") # Save uploaded file to temporary location with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: content = await audio_file.read() file_size = len(content) file_hash = hashlib.md5(content).hexdigest() logger.info(f"📦 File size: {file_size} bytes") logger.info(f"🔐 File MD5: {file_hash}") tmp_file.write(content) tmp_file_path = tmp_file.name try: # Load audio logger.info(f"🎵 Loading audio from: {tmp_file_path}") load_start = datetime.now() audio, sr = librosa.load(tmp_file_path, sr=SR) load_duration = (datetime.now() - load_start).total_seconds() logger.info(f"✅ Audio loaded in {load_duration:.3f}s") logger.info(f"🎵 Audio shape: {audio.shape}, duration: {len(audio)/SR:.2f}s") # Process audio logger.info(f"🔄 Processing audio through Wav2Vec2...") process_start = datetime.now() input_values = processor(audio, sampling_rate=SR, return_tensors="pt").input_values.to(DEVICE) logger.info(f"📊 Input values shape: {input_values.shape}") # Model inference logger.info(f"🧠 Running model inference...") inference_start = datetime.now() with torch.no_grad(): logits = model(input_values).logits inference_duration = (datetime.now() - inference_start).total_seconds() logger.info(f"✅ Inference completed in {inference_duration:.3f}s") logger.info(f"📊 Logits shape: {logits.shape}") # Decode transcription pred_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(pred_ids, group_tokens=True)[0].lower().strip() logger.info(f"📝 RAW TRANSCRIPTION: '{transcription}'") # Find best match in transcription using dictionary target_word_clean = target_word.lower().strip() matched_segment, matched_syllables, match_score = find_best_match(transcription, target_word_clean) logger.info(f"🎯 MATCHED SEGMENT: '{matched_segment}' (confidence: {match_score:.2f})") logger.info(f"🔤 MATCHED SYLLABLES: {matched_syllables}") # Get target syllables from dictionary target_sylls = word_syllables.get(target_word_clean, []) if not target_sylls: logger.warning(f"⚠️ Target word '{target_word_clean}' not found in dictionary") target_sylls = target_word_clean.split() logger.info(f"🔤 Target syllables: {target_sylls}") logger.info(f"🔤 Predicted syllables (from match): {matched_syllables}") # Calculate embeddings logger.info(f"🧮 Computing embeddings...") embed_start = datetime.now() pred_emb = get_embedding(matched_syllables) ref_emb = get_embedding(target_sylls) if pred_emb is not None and ref_emb is not None: logger.info(f"📊 Predicted embedding stats: mean={pred_emb.mean().item():.4f}, std={pred_emb.std().item():.4f}") logger.info(f"📊 Reference embedding stats: mean={ref_emb.mean().item():.4f}, std={ref_emb.std().item():.4f}") cosine_sim = F.cosine_similarity( pred_emb.unsqueeze(0), ref_emb.unsqueeze(0) ).item() logger.info(f"✨ COSINE SIMILARITY (raw): {cosine_sim}") # Calculate combined pronunciation score pronunciation_score = calculate_pronunciation_score(cosine_sim, match_score) logger.info(f"🎯 COMBINED PRONUNCIATION SCORE: {pronunciation_score:.2f}") else: cosine_sim = 0.0 pronunciation_score = 0.0 logger.warning(f"⚠️ Could not compute similarity (embeddings are None)") embed_duration = (datetime.now() - embed_start).total_seconds() logger.info(f"✅ Embedding computation completed in {embed_duration:.3f}s") # Calculate total time total_duration = (datetime.now() - start_time).total_seconds() # Determine feedback based on pronunciation score if pronunciation_score >= 0.85: feedback = "Excellent pronunciation!" color = "green" elif pronunciation_score >= 0.70: feedback = "Good pronunciation, minor improvements possible" color = "blue" elif pronunciation_score >= 0.55: feedback = "Fair pronunciation, needs practice" color = "orange" else: feedback = "Needs significant improvement" color = "red" result = { "request_id": request_id, "timestamp": start_time.isoformat(), "transcription": transcription, "matched_segment": matched_segment, "match_confidence": round(float(match_score), 2), "target_word": target_word_clean, "cosine_similarity": round(float(cosine_sim), 2), "pronunciation_score": round(float(pronunciation_score), 2), "feedback": feedback, "feedback_color": color, "processing_time": { "audio_load_seconds": round(load_duration, 3), "model_inference_seconds": round(inference_duration, 3), "embedding_computation_seconds": round(embed_duration, 3), "total_seconds": round(total_duration, 3) }, "metadata": { "file_size_bytes": file_size, "file_hash": file_hash, "audio_duration_seconds": round(len(audio)/SR, 2), "device": DEVICE, "target_syllables": target_sylls, "matched_syllables": matched_syllables, "is_in_dictionary": target_word_clean in word_syllables } } logger.info(f"📤 RESPONSE:") logger.info(f" ├─ Transcription: '{transcription}'") logger.info(f" ├─ Matched Segment: '{matched_segment}'") logger.info(f" ├─ Match Confidence: {result['match_confidence']}") logger.info(f" ├─ Cosine Similarity: {result['cosine_similarity']}") logger.info(f" ├─ Pronunciation Score: {result['pronunciation_score']}") logger.info(f" ├─ Feedback: {feedback}") logger.info(f" └─ Total Processing Time: {total_duration:.3f}s") logger.info(f"{'='*80}\n") return result except Exception as e: logger.error(f"❌ ERROR in request [{request_id}]: {str(e)}", exc_info=True) raise finally: # Clean up temporary file if os.path.exists(tmp_file_path): os.remove(tmp_file_path) logger.debug(f"🗑️ Cleaned up temporary file: {tmp_file_path}") @app.get("/") def home(): logger.info("📍 Health check endpoint called") return { "message": "Cognia Wav2Vec2 Speech API running.", "device": DEVICE, "status": "ready", "dictionary_size": len(word_syllables) } @app.get("/health") def health(): """Detailed health check endpoint""" return { "status": "healthy", "device": DEVICE, "model": MODEL_DIR, "dictionary_size": len(word_syllables), "timestamp": datetime.now().isoformat() } @app.get("/dictionary/search") def search_dictionary(query: str): """Search for words in the dictionary""" query_lower = query.lower() matches = { word: syllables for word, syllables in word_syllables.items() if query_lower in word or query_lower in "".join(syllables).lower() } return { "query": query, "matches": matches, "count": len(matches) } if __name__ == "__main__": import uvicorn # For local testing - cloud platforms like HF Spaces ignore this port = int(os.getenv("PORT", 7860)) # HF Spaces uses 7860 by default uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")