File size: 1,972 Bytes
3df3bf3
bac2e58
3df3bf3
bac2e58
 
de86168
3df3bf3
 
 
bac2e58
 
de86168
bac2e58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3df3bf3
 
 
bac2e58
 
 
3df3bf3
bac2e58
 
 
3df3bf3
76337aa
 
bac2e58
76337aa
bac2e58
 
76337aa
bac2e58
3df3bf3
 
 
76337aa
bac2e58
3df3bf3
bac2e58
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from fastapi import FastAPI, UploadFile, Form
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import torch.nn.functional as F
import librosa
from word_syllables import word_syllables  # <-- import dictionary

app = FastAPI(title="Cognia Wav2Vec2 Speech API")

MODEL_DIR = "Artyomorax/cognia-wav2vec"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SR = 16000

processor = Wav2Vec2Processor.from_pretrained(MODEL_DIR)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_DIR).to(DEVICE)
model.eval()

def get_embedding(syllables):
    token_ids = []
    for s in syllables:
        ids = processor.tokenizer(s, add_special_tokens=False).input_ids
        token_ids.extend(ids)
    if not token_ids:
        return None
    tokens = torch.tensor(token_ids, device=DEVICE)
    embeds = model.lm_head.weight[tokens]
    return embeds.mean(dim=0)

@app.post("/analyze")
async def analyze_audio(file: UploadFile, target_word: str = Form(...)):
    audio, sr = librosa.load(file.file, sr=SR)
    input_values = processor(audio, sampling_rate=SR, return_tensors="pt").input_values.to(DEVICE)

    with torch.no_grad():
        logits = model(input_values).logits
    pred_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(pred_ids, group_tokens=True)[0].lower().strip()

    target_word_clean = target_word.lower().strip()
    target_sylls = word_syllables.get(target_word_clean, [])
    pred_syllables = transcription.replace("-", " ").split()

    pred_emb = get_embedding(pred_syllables)
    ref_emb = get_embedding(target_sylls)

    cosine_sim = F.cosine_similarity(pred_emb.unsqueeze(0), ref_emb.unsqueeze(0)).item() if pred_emb is not None and ref_emb is not None else 0.0

    return {
        "transcription": transcription,
        "target_word": target_word_clean,
        "cosine_similarity": round(float(cosine_sim), 2)
    }

@app.get("/")
def home():
    return {"message": "Cognia Wav2Vec2 Speech API running."}