from fastapi import FastAPI, UploadFile, Form from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import torch import torch.nn.functional as F import librosa from word_syllables import word_syllables # <-- import dictionary app = FastAPI(title="Cognia Wav2Vec2 Speech API") MODEL_DIR = "Artyomorax/cognia-wav2vec" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SR = 16000 processor = Wav2Vec2Processor.from_pretrained(MODEL_DIR) model = Wav2Vec2ForCTC.from_pretrained(MODEL_DIR).to(DEVICE) model.eval() def get_embedding(syllables): token_ids = [] for s in syllables: ids = processor.tokenizer(s, add_special_tokens=False).input_ids token_ids.extend(ids) if not token_ids: return None tokens = torch.tensor(token_ids, device=DEVICE) embeds = model.lm_head.weight[tokens] return embeds.mean(dim=0) @app.post("/analyze") async def analyze_audio(file: UploadFile, target_word: str = Form(...)): audio, sr = librosa.load(file.file, sr=SR) input_values = processor(audio, sampling_rate=SR, return_tensors="pt").input_values.to(DEVICE) with torch.no_grad(): logits = model(input_values).logits pred_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(pred_ids, group_tokens=True)[0].lower().strip() target_word_clean = target_word.lower().strip() target_sylls = word_syllables.get(target_word_clean, []) pred_syllables = transcription.replace("-", " ").split() pred_emb = get_embedding(pred_syllables) ref_emb = get_embedding(target_sylls) cosine_sim = F.cosine_similarity(pred_emb.unsqueeze(0), ref_emb.unsqueeze(0)).item() if pred_emb is not None and ref_emb is not None else 0.0 return { "transcription": transcription, "target_word": target_word_clean, "cosine_similarity": round(float(cosine_sim), 2) } @app.get("/") def home(): return {"message": "Cognia Wav2Vec2 Speech API running."}