Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, Form | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| import torch | |
| import torch.nn.functional as F | |
| import librosa | |
| from word_syllables import word_syllables # <-- import dictionary | |
| app = FastAPI(title="Cognia Wav2Vec2 Speech API") | |
| MODEL_DIR = "Artyomorax/cognia-wav2vec" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SR = 16000 | |
| processor = Wav2Vec2Processor.from_pretrained(MODEL_DIR) | |
| model = Wav2Vec2ForCTC.from_pretrained(MODEL_DIR).to(DEVICE) | |
| model.eval() | |
| def get_embedding(syllables): | |
| token_ids = [] | |
| for s in syllables: | |
| ids = processor.tokenizer(s, add_special_tokens=False).input_ids | |
| token_ids.extend(ids) | |
| if not token_ids: | |
| return None | |
| tokens = torch.tensor(token_ids, device=DEVICE) | |
| embeds = model.lm_head.weight[tokens] | |
| return embeds.mean(dim=0) | |
| async def analyze_audio(file: UploadFile, target_word: str = Form(...)): | |
| audio, sr = librosa.load(file.file, sr=SR) | |
| input_values = processor(audio, sampling_rate=SR, return_tensors="pt").input_values.to(DEVICE) | |
| with torch.no_grad(): | |
| logits = model(input_values).logits | |
| pred_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(pred_ids, group_tokens=True)[0].lower().strip() | |
| target_word_clean = target_word.lower().strip() | |
| target_sylls = word_syllables.get(target_word_clean, []) | |
| pred_syllables = transcription.replace("-", " ").split() | |
| pred_emb = get_embedding(pred_syllables) | |
| ref_emb = get_embedding(target_sylls) | |
| cosine_sim = F.cosine_similarity(pred_emb.unsqueeze(0), ref_emb.unsqueeze(0)).item() if pred_emb is not None and ref_emb is not None else 0.0 | |
| return { | |
| "transcription": transcription, | |
| "target_word": target_word_clean, | |
| "cosine_similarity": round(float(cosine_sim), 2) | |
| } | |
| def home(): | |
| return {"message": "Cognia Wav2Vec2 Speech API running."} | |