import os import gradio as gr import torch import librosa import numpy as np from transformers import AutoTokenizer, AutoModelForCausalLM from encodec import EncodecModel # ---------------------------------------------------------------------- # Constants (must match training configuration) # ---------------------------------------------------------------------- SAMPLE_RATE = 24000 NUM_CODEBOOKS = 8 CODEBOOK_SIZE = 1024 # Total audio tokens added: 1 start token + NUM_CODEBOOKS * CODEBOOK_SIZE NUM_AUDIO_TOKENS_ADDED = 1 + NUM_CODEBOOKS * CODEBOOK_SIZE # 8193 # ---------------------------------------------------------------------- # Load EnCodec once # ---------------------------------------------------------------------- def load_encodec(device): model = EncodecModel.encodec_model_24khz() model.set_target_bandwidth(6.0) model.to(device) return model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") encodec_model = load_encodec(device) # ---------------------------------------------------------------------- # Load fine-tuned Qwen model from Hugging Face Hub # ---------------------------------------------------------------------- MODEL_ID = "michsethowusu/twi-symptoms-predict" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, # more recent API, avoids deprecation warning device_map="auto", trust_remote_code=True, ) model.eval() # Calculate old_vocab_size = total vocab - added audio tokens old_vocab_size = len(tokenizer) - NUM_AUDIO_TOKENS_ADDED audio_start_id = old_vocab_size # first added token is the start marker def audio_token_id(cb, val): """Global token ID for codebook cb and value val.""" return old_vocab_size + 1 + cb * CODEBOOK_SIZE + val # ---------------------------------------------------------------------- # Audio preprocessing # ---------------------------------------------------------------------- def audio_to_tokens(audio_path): """Convert a WAV file to interleaved discrete token list.""" wav, sr = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True) wav = torch.tensor(wav, device=device).unsqueeze(0).unsqueeze(0) # (1,1,T) with torch.no_grad(): encoded_frames = encodec_model.encode(wav) codes = encoded_frames[0][0].cpu().numpy() # (8, T') interleaved = codes.T.flatten().tolist() # (T' * 8,) return interleaved # ---------------------------------------------------------------------- # Classification function # ---------------------------------------------------------------------- def classify_audio(audio_filepath): if audio_filepath is None: return ( "
" "⚠️ No audio recorded. Please try again." "
" ) # 1. Tokenize audio try: tokens = audio_to_tokens(audio_filepath) except Exception as e: return f"
❌ Error processing audio: {str(e)}
" if len(tokens) == 0: return "
❌ Could not extract audio codes.
" # 2. Build input sequence T_prime = len(tokens) // NUM_CODEBOOKS audio_ids = [audio_start_id] for t in range(T_prime): for cb in range(NUM_CODEBOOKS): val = tokens[t * NUM_CODEBOOKS + cb] audio_ids.append(audio_token_id(cb, val)) input_ids = torch.tensor([audio_ids], device=model.device) # 3. Generate output with torch.no_grad(): generated = model.generate( input_ids, max_new_tokens=50, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) output_ids = generated[0][len(audio_ids):] prediction = tokenizer.decode(output_ids, skip_special_tokens=True).strip() # 4. Parse the label string # Expected format: "Body Part: Eyes, Sub-Issue: Pain & Pressure" if "Body Part:" in prediction and "Sub-Issue:" in prediction: parts = prediction.split(", Sub-Issue:") body_part = parts[0].replace("Body Part:", "").strip() sub_issue = parts[1].strip() result = ( f"
" f"

🩺 Matched Symptom

" f"

Body Part: {body_part}

" f"

Sub‑Issue: {sub_issue}

" f"
" ) else: # Fallback: just show raw prediction result = ( f"
" f"

❓ Prediction

" f"

{prediction}

" f"
" ) return result # ---------------------------------------------------------------------- # Gradio UI # ---------------------------------------------------------------------- custom_css = """ .gradio-container { font-family: 'Segoe UI', system-ui, sans-serif; } #app-title { text-align: center; font-size: 2.5rem; font-weight: 700; color: #1e3c72; } #app-subtitle { text-align: center; font-size: 1.1rem; color: #555; } """ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.HTML("
πŸ‡¬πŸ‡­ Twi Symptom Classifier
") gr.HTML("
Record your symptom in Twi – the AI identifies the body part and issue
") with gr.Row(): with gr.Column(scale=1): audio_input = gr.Audio(sources=["microphone"], type="filepath", label="🎀 Record your symptom") submit_btn = gr.Button("πŸ” Analyze", variant="primary", size="lg") clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", size="sm") output_display = gr.HTML(label="Result") submit_btn.click(fn=classify_audio, inputs=audio_input, outputs=output_display) clear_btn.click(fn=lambda: ("", ""), inputs=[], outputs=[audio_input, output_display]) gr.Markdown( "
" "Powered by Qwen2.5‑0.5B + EnCodec
Twi Symptom Dataset" "
" ) demo.launch()