Spaces:
Sleeping
Sleeping
File size: 1,972 Bytes
3df3bf3 bac2e58 3df3bf3 bac2e58 de86168 3df3bf3 bac2e58 de86168 bac2e58 3df3bf3 bac2e58 3df3bf3 bac2e58 3df3bf3 76337aa bac2e58 76337aa bac2e58 76337aa bac2e58 3df3bf3 76337aa bac2e58 3df3bf3 bac2e58 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | from fastapi import FastAPI, UploadFile, Form
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import torch.nn.functional as F
import librosa
from word_syllables import word_syllables # <-- import dictionary
app = FastAPI(title="Cognia Wav2Vec2 Speech API")
MODEL_DIR = "Artyomorax/cognia-wav2vec"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SR = 16000
processor = Wav2Vec2Processor.from_pretrained(MODEL_DIR)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_DIR).to(DEVICE)
model.eval()
def get_embedding(syllables):
token_ids = []
for s in syllables:
ids = processor.tokenizer(s, add_special_tokens=False).input_ids
token_ids.extend(ids)
if not token_ids:
return None
tokens = torch.tensor(token_ids, device=DEVICE)
embeds = model.lm_head.weight[tokens]
return embeds.mean(dim=0)
@app.post("/analyze")
async def analyze_audio(file: UploadFile, target_word: str = Form(...)):
audio, sr = librosa.load(file.file, sr=SR)
input_values = processor(audio, sampling_rate=SR, return_tensors="pt").input_values.to(DEVICE)
with torch.no_grad():
logits = model(input_values).logits
pred_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(pred_ids, group_tokens=True)[0].lower().strip()
target_word_clean = target_word.lower().strip()
target_sylls = word_syllables.get(target_word_clean, [])
pred_syllables = transcription.replace("-", " ").split()
pred_emb = get_embedding(pred_syllables)
ref_emb = get_embedding(target_sylls)
cosine_sim = F.cosine_similarity(pred_emb.unsqueeze(0), ref_emb.unsqueeze(0)).item() if pred_emb is not None and ref_emb is not None else 0.0
return {
"transcription": transcription,
"target_word": target_word_clean,
"cosine_similarity": round(float(cosine_sim), 2)
}
@app.get("/")
def home():
return {"message": "Cognia Wav2Vec2 Speech API running."}
|