from fastapi import FastAPI, UploadFile, Form from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import torch import torch.nn.functional as F import librosa from word_syllables import word_syllables import tempfile import os import logging from datetime import datetime import hashlib from difflib import SequenceMatcher import noisereduce as nr # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('speech_api.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) app = FastAPI(title="Cognia Wav2Vec2 Speech API") MODEL_DIR = "Artyomorax/cognia-wav2vec" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SR = 16000 logger.info(f"🚀 Starting API with DEVICE: {DEVICE}") processor = Wav2Vec2Processor.from_pretrained(MODEL_DIR) model = Wav2Vec2ForCTC.from_pretrained(MODEL_DIR).to(DEVICE) model.eval() logger.info("✅ Model loaded successfully") # ==================== Utility functions ==================== def syllables_to_phonetic(syllables): """Convert syllables list to phonetic string for comparison""" return "".join(syllables).replace("-", "").replace(" ", "").lower() def phonetic_similarity(s1, s2): """Calculate phonetic similarity between two strings""" s1_clean = s1.replace(" ", "").replace("-", "").lower() s2_clean = s2.replace(" ", "").replace("-", "").lower() return SequenceMatcher(None, s1_clean, s2_clean).ratio() def syllable_similarity(syll1, syll2): """Compare two syllable lists and return similarity score""" s1_phonetic = syllables_to_phonetic(syll1) s2_phonetic = syllables_to_phonetic(syll2) phonetic_score = SequenceMatcher(None, s1_phonetic, s2_phonetic).ratio() count_match = 1.0 - abs(len(syll1) - len(syll2)) / max(len(syll1), len(syll2), 1) return (phonetic_score * 0.8) + (count_match * 0.2) def find_closest_word_in_dictionary(transcription): """Search dictionary for closest matching word""" trans_phonetic = transcription.replace(" ", "").replace("-", "").lower() best_word = None best_syllables = [] best_score = 0.0 for word, syllables in word_syllables.items(): word_phonetic = syllables_to_phonetic(syllables) score = SequenceMatcher(None, trans_phonetic, word_phonetic).ratio() if score > best_score: best_score = score best_word = word best_syllables = syllables logger.info(f"🔍 Dictionary match: '{trans_phonetic}' → '{best_word}' ({best_score:.2f})") return best_word, best_syllables, best_score def find_best_match(transcription, target_word): """Find best matching segment in transcription corresponding to target word""" trans_clean = transcription.replace("-", " ").lower().strip() target_clean = target_word.lower().strip() target_sylls = word_syllables.get(target_clean, []) if not target_sylls: logger.warning(f"⚠️ Target word '{target_clean}' not found in dictionary") return trans_clean, trans_clean.split(), 0.5 if target_clean in trans_clean: logger.info(f"✓ Direct match found: '{target_clean}'") return target_clean, target_sylls, 1.0 trans_nospace = trans_clean.replace(" ", "") target_nospace = target_clean.replace(" ", "") if target_nospace in trans_nospace: logger.info(f"✓ No-space match found: '{target_nospace}'") return target_clean, target_sylls, 0.95 words = trans_clean.split() best_match = "" best_syllables = [] best_score = 0.0 for word in words: if word in word_syllables: word_sylls = word_syllables[word] score = syllable_similarity(word_sylls, target_sylls) if score > best_score: best_score = score best_match = word best_syllables = word_sylls logger.info(f" Single word match: '{word}' (score: {score:.2f})") for i in range(len(words)): for j in range(i + 1, min(i + 4, len(words) + 1)): combo = "".join(words[i:j]) combo_display = " ".join(words[i:j]) if combo in word_syllables: combo_sylls = word_syllables[combo] score = syllable_similarity(combo_sylls, target_sylls) if score > best_score: best_score = score best_match = combo_display best_syllables = combo_sylls logger.info(f" Combo match: '{combo}' (score: {score:.2f})") else: score = phonetic_similarity(combo, target_clean) if score > best_score: best_score = score best_match = combo_display best_syllables = combo_display.split() logger.info(f" Phonetic match: '{combo}' (score: {score:.2f})") if best_score < 0.6: logger.info(" Searching entire dictionary for closest match...") dict_word, dict_sylls, dict_score = find_closest_word_in_dictionary(trans_clean) if dict_score > best_score: best_score = dict_score best_match = dict_word best_syllables = dict_sylls logger.info(f" Dictionary match used: '{dict_word}' (score: {dict_score:.2f})") logger.info(f"🎯 Best match for '{target_clean}': '{best_match}' (score: {best_score:.2f})") logger.info(f" Matched syllables: {best_syllables}") return best_match, best_syllables, best_score def get_embedding(syllables): """Generate embeddings from syllables""" token_ids = [] for s in syllables: ids = processor.tokenizer(s, add_special_tokens=False).input_ids token_ids.extend(ids) if not token_ids: return None tokens = torch.tensor(token_ids, device=DEVICE) embeds = model.lm_head.weight[tokens] return embeds.mean(dim=0) def calculate_pronunciation_score(cosine_sim, match_score): """Combine embedding similarity and transcription match into final score""" combined_score = (cosine_sim * 0.7) + (match_score * 0.3) if match_score < 0.5: combined_score *= 0.7 elif match_score < 0.7: combined_score *= 0.85 return combined_score # ==================== API Endpoints ==================== @app.post("/analyze") async def analyze_audio(audio_file: UploadFile, target_word: str = Form(...)): request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") start_time = datetime.now() logger.info(f"\n{'='*80}") logger.info(f"📥 NEW REQUEST [{request_id}]") logger.info(f"📁 File: {audio_file.filename} ({audio_file.content_type})") logger.info(f"🎯 Target word: '{target_word}'") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: content = await audio_file.read() file_size = len(content) file_hash = hashlib.md5(content).hexdigest() tmp_file.write(content) tmp_file_path = tmp_file.name try: # Load audio logger.info(f"🎵 Loading audio from: {tmp_file_path}") audio, sr = librosa.load(tmp_file_path, sr=SR) logger.info(f"🎵 Audio shape: {audio.shape}, duration: {len(audio)/SR:.2f}s") # ===== Noise reduction ===== noise_sample = audio[:int(0.3*SR)] audio_denoised = nr.reduce_noise(y=audio, y_noise=noise_sample, sr=SR) logger.info(f"🔊 Noise reduction applied") # Process through Wav2Vec2 input_values = processor(audio_denoised, sampling_rate=SR, return_tensors="pt").input_values.to(DEVICE) with torch.no_grad(): logits = model(input_values).logits pred_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(pred_ids, group_tokens=True)[0].lower().strip() logger.info(f"📝 RAW TRANSCRIPTION: '{transcription}'") # Match target word target_word_clean = target_word.lower().strip() matched_segment, matched_syllables, match_score = find_best_match(transcription, target_word_clean) # Embeddings & pronunciation score pred_emb = get_embedding(matched_syllables) ref_emb = get_embedding(word_syllables.get(target_word_clean, target_word_clean.split())) if pred_emb is not None and ref_emb is not None: cosine_sim = F.cosine_similarity(pred_emb.unsqueeze(0), ref_emb.unsqueeze(0)).item() pronunciation_score = calculate_pronunciation_score(cosine_sim, match_score) else: cosine_sim = 0.0 pronunciation_score = 0.0 # Feedback if pronunciation_score >= 0.85: feedback = "Excellent pronunciation!" color = "green" elif pronunciation_score >= 0.70: feedback = "Good pronunciation, minor improvements possible" color = "blue" elif pronunciation_score >= 0.55: feedback = "Fair pronunciation, needs practice" color = "orange" else: feedback = "Needs significant improvement" color = "red" result = { "request_id": request_id, "timestamp": start_time.isoformat(), "transcription": transcription, "matched_segment": matched_segment, "match_confidence": round(float(match_score), 2), "target_word": target_word_clean, "cosine_similarity": round(float(cosine_sim), 2), "pronunciation_score": round(float(pronunciation_score), 2), "feedback": feedback, "feedback_color": color, "metadata": { "file_size_bytes": file_size, "file_hash": file_hash, "audio_duration_seconds": round(len(audio)/SR, 2), "device": DEVICE, "target_syllables": word_syllables.get(target_word_clean, target_word_clean.split()), "matched_syllables": matched_syllables, "is_in_dictionary": target_word_clean in word_syllables } } logger.info(f"📤 RESPONSE: {result}") return result finally: if os.path.exists(tmp_file_path): os.remove(tmp_file_path) logger.debug(f"🗑️ Cleaned up temporary file: {tmp_file_path}") @app.get("/") def home(): logger.info("📍 Health check endpoint called") return {"message": "Cognia Wav2Vec2 Speech API running.", "device": DEVICE} @app.get("/health") def health(): return {"status": "healthy", "device": DEVICE, "model": MODEL_DIR, "timestamp": datetime.now().isoformat()} @app.get("/dictionary/search") def search_dictionary(query: str): query_lower = query.lower() matches = {word: syllables for word, syllables in word_syllables.items() if query_lower in word or query_lower in "".join(syllables).lower()} return {"query": query, "matches": matches, "count": len(matches)} if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")