Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,20 +2,39 @@ from fastapi import FastAPI, UploadFile, Form
|
|
| 2 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
-
import numpy as np
|
| 6 |
-
from word_syllables import word_syllables
|
| 7 |
import librosa
|
| 8 |
|
| 9 |
app = FastAPI(title="Cognia Wav2Vec2 Speech API")
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
MODEL_DIR = "Artyomorax/cognia-wav2vec"
|
| 12 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
-
SR = 16000
|
| 14 |
|
| 15 |
processor = Wav2Vec2Processor.from_pretrained(MODEL_DIR)
|
| 16 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_DIR).to(DEVICE)
|
| 17 |
model.eval()
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def get_embedding(syllables):
|
| 20 |
token_ids = []
|
| 21 |
for s in syllables:
|
|
@@ -27,6 +46,9 @@ def get_embedding(syllables):
|
|
| 27 |
embeds = model.lm_head.weight[tokens]
|
| 28 |
return embeds.mean(dim=0)
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
@app.post("/analyze")
|
| 31 |
async def analyze_audio(file: UploadFile, target_word: str = Form(...)):
|
| 32 |
# load audio
|
|
@@ -40,15 +62,18 @@ async def analyze_audio(file: UploadFile, target_word: str = Form(...)):
|
|
| 40 |
transcription = processor.batch_decode(pred_ids, group_tokens=True)[0].lower().strip()
|
| 41 |
|
| 42 |
# cosine similarity
|
| 43 |
-
|
|
|
|
| 44 |
pred_syllables = transcription.replace("-", " ").split()
|
|
|
|
| 45 |
pred_emb = get_embedding(pred_syllables)
|
| 46 |
ref_emb = get_embedding(target_sylls)
|
|
|
|
| 47 |
cosine_sim = F.cosine_similarity(pred_emb.unsqueeze(0), ref_emb.unsqueeze(0)).item() if pred_emb is not None and ref_emb is not None else 0.0
|
| 48 |
|
| 49 |
return {
|
| 50 |
"transcription": transcription,
|
| 51 |
-
"target_word":
|
| 52 |
"cosine_similarity": round(float(cosine_sim), 2)
|
| 53 |
}
|
| 54 |
|
|
|
|
| 2 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
|
|
|
|
|
|
| 5 |
import librosa
|
| 6 |
|
| 7 |
app = FastAPI(title="Cognia Wav2Vec2 Speech API")
|
| 8 |
|
| 9 |
+
# -------------------------------
|
| 10 |
+
# MODEL SETUP
|
| 11 |
+
# -------------------------------
|
| 12 |
MODEL_DIR = "Artyomorax/cognia-wav2vec"
|
| 13 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
+
SR = 16000 # sampling rate
|
| 15 |
|
| 16 |
processor = Wav2Vec2Processor.from_pretrained(MODEL_DIR)
|
| 17 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_DIR).to(DEVICE)
|
| 18 |
model.eval()
|
| 19 |
|
| 20 |
+
# -------------------------------
|
| 21 |
+
# WORD SYLLABLES MAPPING
|
| 22 |
+
# -------------------------------
|
| 23 |
+
word_syllables = {
|
| 24 |
+
"otso": ["ot", "so"],
|
| 25 |
+
"ulap": ["u", "lap"],
|
| 26 |
+
"ubo": ["u", "bo"],
|
| 27 |
+
"anak": ["a", "nak"],
|
| 28 |
+
"aso": ["a", "so"],
|
| 29 |
+
"aklat": ["ak", "lat"],
|
| 30 |
+
"bahay": ["ba", "hay"],
|
| 31 |
+
"bata": ["ba", "ta"],
|
| 32 |
+
# ... continue filling your full dictionary ...
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
# -------------------------------
|
| 36 |
+
# HELPERS
|
| 37 |
+
# -------------------------------
|
| 38 |
def get_embedding(syllables):
|
| 39 |
token_ids = []
|
| 40 |
for s in syllables:
|
|
|
|
| 46 |
embeds = model.lm_head.weight[tokens]
|
| 47 |
return embeds.mean(dim=0)
|
| 48 |
|
| 49 |
+
# -------------------------------
|
| 50 |
+
# API ENDPOINT
|
| 51 |
+
# -------------------------------
|
| 52 |
@app.post("/analyze")
|
| 53 |
async def analyze_audio(file: UploadFile, target_word: str = Form(...)):
|
| 54 |
# load audio
|
|
|
|
| 62 |
transcription = processor.batch_decode(pred_ids, group_tokens=True)[0].lower().strip()
|
| 63 |
|
| 64 |
# cosine similarity
|
| 65 |
+
target_word_clean = target_word.lower().strip()
|
| 66 |
+
target_sylls = word_syllables.get(target_word_clean, [])
|
| 67 |
pred_syllables = transcription.replace("-", " ").split()
|
| 68 |
+
|
| 69 |
pred_emb = get_embedding(pred_syllables)
|
| 70 |
ref_emb = get_embedding(target_sylls)
|
| 71 |
+
|
| 72 |
cosine_sim = F.cosine_similarity(pred_emb.unsqueeze(0), ref_emb.unsqueeze(0)).item() if pred_emb is not None and ref_emb is not None else 0.0
|
| 73 |
|
| 74 |
return {
|
| 75 |
"transcription": transcription,
|
| 76 |
+
"target_word": target_word_clean,
|
| 77 |
"cosine_similarity": round(float(cosine_sim), 2)
|
| 78 |
}
|
| 79 |
|