Artyomorax commited on
Commit
76337aa
·
verified ·
1 Parent(s): bac2e58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -5
app.py CHANGED
@@ -2,20 +2,39 @@ from fastapi import FastAPI, UploadFile, Form
2
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
3
  import torch
4
  import torch.nn.functional as F
5
- import numpy as np
6
- from word_syllables import word_syllables
7
  import librosa
8
 
9
  app = FastAPI(title="Cognia Wav2Vec2 Speech API")
10
 
 
 
 
11
  MODEL_DIR = "Artyomorax/cognia-wav2vec"
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
- SR = 16000
14
 
15
  processor = Wav2Vec2Processor.from_pretrained(MODEL_DIR)
16
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_DIR).to(DEVICE)
17
  model.eval()
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def get_embedding(syllables):
20
  token_ids = []
21
  for s in syllables:
@@ -27,6 +46,9 @@ def get_embedding(syllables):
27
  embeds = model.lm_head.weight[tokens]
28
  return embeds.mean(dim=0)
29
 
 
 
 
30
  @app.post("/analyze")
31
  async def analyze_audio(file: UploadFile, target_word: str = Form(...)):
32
  # load audio
@@ -40,15 +62,18 @@ async def analyze_audio(file: UploadFile, target_word: str = Form(...)):
40
  transcription = processor.batch_decode(pred_ids, group_tokens=True)[0].lower().strip()
41
 
42
  # cosine similarity
43
- target_sylls = word_syllables.get(target_word.lower(), [])
 
44
  pred_syllables = transcription.replace("-", " ").split()
 
45
  pred_emb = get_embedding(pred_syllables)
46
  ref_emb = get_embedding(target_sylls)
 
47
  cosine_sim = F.cosine_similarity(pred_emb.unsqueeze(0), ref_emb.unsqueeze(0)).item() if pred_emb is not None and ref_emb is not None else 0.0
48
 
49
  return {
50
  "transcription": transcription,
51
- "target_word": target_word,
52
  "cosine_similarity": round(float(cosine_sim), 2)
53
  }
54
 
 
2
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
3
  import torch
4
  import torch.nn.functional as F
 
 
5
  import librosa
6
 
7
  app = FastAPI(title="Cognia Wav2Vec2 Speech API")
8
 
9
+ # -------------------------------
10
+ # MODEL SETUP
11
+ # -------------------------------
12
  MODEL_DIR = "Artyomorax/cognia-wav2vec"
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ SR = 16000 # sampling rate
15
 
16
  processor = Wav2Vec2Processor.from_pretrained(MODEL_DIR)
17
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_DIR).to(DEVICE)
18
  model.eval()
19
 
20
+ # -------------------------------
21
+ # WORD SYLLABLES MAPPING
22
+ # -------------------------------
23
+ word_syllables = {
24
+ "otso": ["ot", "so"],
25
+ "ulap": ["u", "lap"],
26
+ "ubo": ["u", "bo"],
27
+ "anak": ["a", "nak"],
28
+ "aso": ["a", "so"],
29
+ "aklat": ["ak", "lat"],
30
+ "bahay": ["ba", "hay"],
31
+ "bata": ["ba", "ta"],
32
+ # ... continue filling your full dictionary ...
33
+ }
34
+
35
+ # -------------------------------
36
+ # HELPERS
37
+ # -------------------------------
38
  def get_embedding(syllables):
39
  token_ids = []
40
  for s in syllables:
 
46
  embeds = model.lm_head.weight[tokens]
47
  return embeds.mean(dim=0)
48
 
49
+ # -------------------------------
50
+ # API ENDPOINT
51
+ # -------------------------------
52
  @app.post("/analyze")
53
  async def analyze_audio(file: UploadFile, target_word: str = Form(...)):
54
  # load audio
 
62
  transcription = processor.batch_decode(pred_ids, group_tokens=True)[0].lower().strip()
63
 
64
  # cosine similarity
65
+ target_word_clean = target_word.lower().strip()
66
+ target_sylls = word_syllables.get(target_word_clean, [])
67
  pred_syllables = transcription.replace("-", " ").split()
68
+
69
  pred_emb = get_embedding(pred_syllables)
70
  ref_emb = get_embedding(target_sylls)
71
+
72
  cosine_sim = F.cosine_similarity(pred_emb.unsqueeze(0), ref_emb.unsqueeze(0)).item() if pred_emb is not None and ref_emb is not None else 0.0
73
 
74
  return {
75
  "transcription": transcription,
76
+ "target_word": target_word_clean,
77
  "cosine_similarity": round(float(cosine_sim), 2)
78
  }
79