michsethowusu's picture
Update app.py
24f5c54 verified
Raw
History Blame Contribute Delete
6.43 kB
import os
import gradio as gr
import torch
import librosa
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from encodec import EncodecModel
# ----------------------------------------------------------------------
# Constants (must match training configuration)
# ----------------------------------------------------------------------
SAMPLE_RATE = 24000
NUM_CODEBOOKS = 8
CODEBOOK_SIZE = 1024
# Total audio tokens added: 1 start token + NUM_CODEBOOKS * CODEBOOK_SIZE
NUM_AUDIO_TOKENS_ADDED = 1 + NUM_CODEBOOKS * CODEBOOK_SIZE # 8193
# ----------------------------------------------------------------------
# Load EnCodec once
# ----------------------------------------------------------------------
def load_encodec(device):
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
model.to(device)
return model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encodec_model = load_encodec(device)
# ----------------------------------------------------------------------
# Load fine-tuned Qwen model from Hugging Face Hub
# ----------------------------------------------------------------------
MODEL_ID = "michsethowusu/twi-symptoms-predict"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16, # more recent API, avoids deprecation warning
device_map="auto",
trust_remote_code=True,
)
model.eval()
# Calculate old_vocab_size = total vocab - added audio tokens
old_vocab_size = len(tokenizer) - NUM_AUDIO_TOKENS_ADDED
audio_start_id = old_vocab_size # first added token is the start marker
def audio_token_id(cb, val):
"""Global token ID for codebook cb and value val."""
return old_vocab_size + 1 + cb * CODEBOOK_SIZE + val
# ----------------------------------------------------------------------
# Audio preprocessing
# ----------------------------------------------------------------------
def audio_to_tokens(audio_path):
"""Convert a WAV file to interleaved discrete token list."""
wav, sr = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True)
wav = torch.tensor(wav, device=device).unsqueeze(0).unsqueeze(0) # (1,1,T)
with torch.no_grad():
encoded_frames = encodec_model.encode(wav)
codes = encoded_frames[0][0].cpu().numpy() # (8, T')
interleaved = codes.T.flatten().tolist() # (T' * 8,)
return interleaved
# ----------------------------------------------------------------------
# Classification function
# ----------------------------------------------------------------------
def classify_audio(audio_filepath):
if audio_filepath is None:
return (
"<div style='color:#e74c3c; font-size:18px;'>"
"⚠️ No audio recorded. Please try again."
"</div>"
)
# 1. Tokenize audio
try:
tokens = audio_to_tokens(audio_filepath)
except Exception as e:
return f"<div style='color:#e74c3c;'>❌ Error processing audio: {str(e)}</div>"
if len(tokens) == 0:
return "<div style='color:#e74c3c;'>❌ Could not extract audio codes.</div>"
# 2. Build input sequence
T_prime = len(tokens) // NUM_CODEBOOKS
audio_ids = [audio_start_id]
for t in range(T_prime):
for cb in range(NUM_CODEBOOKS):
val = tokens[t * NUM_CODEBOOKS + cb]
audio_ids.append(audio_token_id(cb, val))
input_ids = torch.tensor([audio_ids], device=model.device)
# 3. Generate output
with torch.no_grad():
generated = model.generate(
input_ids,
max_new_tokens=50,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
output_ids = generated[0][len(audio_ids):]
prediction = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
# 4. Parse the label string
# Expected format: "Body Part: Eyes, Sub-Issue: Pain & Pressure"
if "Body Part:" in prediction and "Sub-Issue:" in prediction:
parts = prediction.split(", Sub-Issue:")
body_part = parts[0].replace("Body Part:", "").strip()
sub_issue = parts[1].strip()
result = (
f"<div style='background:#f0fdf4; border-left:6px solid #2ecc71; "
f"padding:1.5rem; border-radius:12px; margin-top:1rem;'>"
f"<h2>🩺 Matched Symptom</h2>"
f"<p><strong>Body Part:</strong> {body_part}</p>"
f"<p><strong>Sub‑Issue:</strong> {sub_issue}</p>"
f"</div>"
)
else:
# Fallback: just show raw prediction
result = (
f"<div style='background:#fef9e7; border-left:6px solid #f1c40f; "
f"padding:1.5rem; border-radius:12px; margin-top:1rem;'>"
f"<h2>❓ Prediction</h2>"
f"<p>{prediction}</p>"
f"</div>"
)
return result
# ----------------------------------------------------------------------
# Gradio UI
# ----------------------------------------------------------------------
custom_css = """
.gradio-container { font-family: 'Segoe UI', system-ui, sans-serif; }
#app-title { text-align: center; font-size: 2.5rem; font-weight: 700; color: #1e3c72; }
#app-subtitle { text-align: center; font-size: 1.1rem; color: #555; }
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.HTML("<div id='app-title'>🇬🇭 Twi Symptom Classifier</div>")
gr.HTML("<div id='app-subtitle'>Record your symptom in Twi – the AI identifies the body part and issue</div>")
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="🎤 Record your symptom")
submit_btn = gr.Button("🔍 Analyze", variant="primary", size="lg")
clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="sm")
output_display = gr.HTML(label="Result")
submit_btn.click(fn=classify_audio, inputs=audio_input, outputs=output_display)
clear_btn.click(fn=lambda: ("", ""), inputs=[], outputs=[audio_input, output_display])
gr.Markdown(
"<div style='text-align:center; margin-top:2rem; color:#888;'>"
"Powered by Qwen2.5‑0.5B + EnCodec<br>Twi Symptom Dataset"
"</div>"
)
demo.launch()