Indie_cuisine / app.py
kxrthik05's picture
Update app.py
dcf0cb1 verified
Raw
History Blame Contribute Delete
13.3 kB
# app.py
# Accent Detection + Regional Cuisine Recommendation (MFCC + HuBERT)
import os
from pathlib import Path
import pickle
import numpy as np
import librosa
import soundfile as sf # noqa: F401
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.linear_model import LogisticRegression # noqa: F401
import gradio as gr
# Optional HuBERT imports
try:
import torch
from transformers import Wav2Vec2FeatureExtractor, HubertModel
HF_AVAILABLE = True
except Exception:
HF_AVAILABLE = False
DEFAULT_SR = 16000
CHECKPOINT_DIR = Path("checkpoints")
# ---------------------------------------------------------------------
# 1. CUISINE MAP (based on your exact dishes)
# ---------------------------------------------------------------------
CUISINE_MAP = {
"andhra_pradesh": {
"region": "Andhra Pradesh",
"starters_veg": [
"Pesarattu (green gram dosa)"
],
"starters_nonveg": [
"Kodi Vepudu (Andhra chicken fry)",
"Royyala Vepudu (prawn fry)"
],
"main_veg": [
"Pulihora",
"Ulava Charu with Rice"
],
"main_nonveg": [
"Gongura Mutton",
"Andhra Chicken Curry / Kodi Kura"
],
"desserts": [
"Bobbatlu (Puran Poli)",
"Pootharekulu",
"Paramannam (rice kheer)"
],
},
"kerala": {
"region": "Kerala",
"starters_veg": [
"Banana Chips"
],
"starters_nonveg": [
"Erachi Fry (Beef Fry)",
"Fish Cutlets"
],
"main_veg": [
"Puttu & Kadala Curry",
"Appam with Veg Stew",
"Kerala Sadya"
],
"main_nonveg": [
"Karimeen Pollichathu",
"Kerala Fish Curry"
],
"desserts": [
"Palada Payasam",
"Ada Pradhaman",
"Elaneer Payasam"
],
},
"gujarat": {
"region": "Gujarat",
"starters_veg": [
"Dhokla",
"Khandvi",
"Sev Khamani"
],
"starters_nonveg": [
"Local non‑veg starters (rare in traditional cuisine)"
],
"main_veg": [
"Undhiyu",
"Thepla",
"Dal Dhokli",
"Sev Tameta Nu Shaak"
],
"main_nonveg": [
"Local non‑veg mains (if available)"
],
"desserts": [
"Basundi",
"Shrikhand",
"Mohanthal"
],
},
"jharkhand": {
"region": "Jharkhand",
"starters_veg": [
"Dhuska",
"Chilka Roti",
"Rugra Fry (mushroom fry)"
],
"starters_nonveg": [
"Local non‑veg starters"
],
"main_veg": [
"Bamboo Shoot Curry",
"Kadho (local dal curry)",
"Litti‑Chokha"
],
"main_nonveg": [
"Local non‑veg curries",
"Handia (served with meals)"
],
"desserts": [
"Thekua",
"Tilkut",
"Malpua (Jharkhand style)"
],
},
"tamil_nadu": {
"region": "Tamil Nadu",
"starters_veg": [
"Medu Vada",
"Masala Vadai"
],
"starters_nonveg": [
"Chicken 65"
],
"main_veg": [
"Sambar Rice",
"Pongal"
],
"main_nonveg": [
"Chettinad Chicken Curry",
"Kothu Parotta"
],
"desserts": [
"Payasam",
"Kesari",
"Jigarthanda (Madurai)"
],
},
"karnataka": {
"region": "Karnataka",
"starters_veg": [
"Maddur Vada",
"Goli Baje"
],
"starters_nonveg": [
"Mangalore Chicken Ghee Roast"
],
"main_veg": [
"Bisi Bele Bath",
"Neer Dosa",
"Ragi Mudde with Sambar"
],
"main_nonveg": [
"Coorg Pandi Curry"
],
"desserts": [
"Mysore Pak",
"Kesari Bath",
"Obbattu / Holige"
],
},
# Fallback if something isn't mapped yet
"default": {
"region": "Unknown / Other",
"starters_veg": ["Local vegetarian starters"],
"starters_nonveg": ["Local non‑veg starters"],
"main_veg": ["Local vegetarian mains"],
"main_nonveg": ["Local non‑veg mains"],
"desserts": ["Local desserts"],
},
}
# ---------------------------------------------------------------------
# 2. LABEL ALIAS MAP (so HuBERT/MFCC labels map to these keys)
# ---------------------------------------------------------------------
LABEL_ALIAS = {
# Andhra / Telugu-ish labels
"andhra": "andhra_pradesh",
"andhra_pradesh": "andhra_pradesh",
"ap": "andhra_pradesh",
"telugu": "andhra_pradesh",
"telugu_andhra": "andhra_pradesh",
"telugu_india": "andhra_pradesh",
# Kerala / Malayalam
"kerala": "kerala",
"ml": "kerala",
"malayalam": "kerala",
"kerala_malayalam": "kerala",
# Gujarat
"gujarat": "gujarat",
"gj": "gujarat",
"gujarati": "gujarat",
# Jharkhand
"jharkhand": "jharkhand",
"jh": "jharkhand",
# Tamil Nadu
"tamil_nadu": "tamil_nadu",
"tamil": "tamil_nadu",
"tn": "tamil_nadu",
# Karnataka
"karnataka": "karnataka",
"ka": "karnataka",
"kannada": "karnataka",
}
def format_cuisine_output(pred_label: str) -> str:
# Normalise raw predicted label and map through alias
raw = (pred_label or "").strip().lower()
key = LABEL_ALIAS.get(raw, raw)
info = CUISINE_MAP.get(key, CUISINE_MAP["default"])
lines = []
lines.append(f"Inferred region: {info['region']}\n")
lines.append("Starters (Veg): " + ", ".join(info["starters_veg"]))
lines.append("Starters (Non‑Veg): " + ", ".join(info["starters_nonveg"]))
lines.append("")
lines.append("Main Course (Veg): " + ", ".join(info["main_veg"]))
lines.append("Main Course (Non‑Veg): " + ", ".join(info["main_nonveg"]))
lines.append("")
lines.append("Desserts: " + ", ".join(info["desserts"]))
return "\n".join(lines)
# ---------------------------------------------------------------------
# 3. Audio & Feature helpers
# ---------------------------------------------------------------------
def normalize_audio_path(audio):
if audio is None:
return None
if isinstance(audio, str):
return audio
if hasattr(audio, "name"):
return audio.name
if isinstance(audio, (tuple, list)):
return audio[0]
return str(audio)
def extract_mfcc_pooled(path: str, sr: int = DEFAULT_SR, n_mfcc: int = 40) -> np.ndarray:
y, _ = librosa.load(path, sr=sr, mono=True)
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=n_mfcc)
mean = mfcc.mean(axis=1)
std = mfcc.std(axis=1)
return np.concatenate([mean, std])
_hf_feat = None
_hf_model = None
def get_hubert_layer_embedding(path: str, layer_idx: int = 11) -> np.ndarray:
"""
Uses facebook/hubert-large-ll60k (hidden size 1024, matching a 1024‑dim scaler).
Returns a 1024‑dim vector: mean‑pooled over time from the chosen layer.
"""
global _hf_feat, _hf_model
if not HF_AVAILABLE:
raise RuntimeError("Transformers / torch are not available on this Space.")
if _hf_feat is None or _hf_model is None:
_hf_feat = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/hubert-large-ll60k"
)
_hf_model = HubertModel.from_pretrained(
"facebook/hubert-large-ll60k", output_hidden_states=True
)
y, _ = librosa.load(path, sr=DEFAULT_SR, mono=True)
inputs = _hf_feat(y, sampling_rate=DEFAULT_SR, return_tensors="pt", padding=True)
with torch.no_grad():
out = _hf_model(**inputs)
hidden_states = out.hidden_states # list[Tensor(batch, time, dim)]
hs = hidden_states[layer_idx] # (batch, time, dim)
if hs.ndim == 3:
vec = hs.mean(dim=1).squeeze(0).cpu().numpy() # (dim,)
else:
vec = hs.mean(dim=0).cpu().numpy()
return vec
# ---------------------------------------------------------------------
# 4. Load models from checkpoints
# ---------------------------------------------------------------------
app_state = {"models": {}}
def load_models():
# MFCC model
try:
with open(CHECKPOINT_DIR / "clf_mfcc.pkl", "rb") as f:
clf_mfcc = pickle.load(f)
with open(CHECKPOINT_DIR / "scaler_mfcc.pkl", "rb") as f:
scaler_mfcc = pickle.load(f)
with open(CHECKPOINT_DIR / "le_mfcc.pkl", "rb") as f:
le_mfcc = pickle.load(f)
app_state["models"]["mfcc"] = {
"clf": clf_mfcc,
"scaler": scaler_mfcc,
"le": le_mfcc,
}
print("Loaded MFCC model.")
except Exception as e:
print("Could not load MFCC model:", e)
# HuBERT model
try:
with open(CHECKPOINT_DIR / "clf_hubert.pkl", "rb") as f:
clf_h = pickle.load(f)
with open(CHECKPOINT_DIR / "scaler_hubert.pkl", "rb") as f:
scaler_h = pickle.load(f)
with open(CHECKPOINT_DIR / "le_hubert.pkl", "rb") as f:
le_h = pickle.load(f)
layer_file = CHECKPOINT_DIR / "hubert_layer.txt"
if layer_file.exists():
trained_layer = int(layer_file.read_text().strip())
else:
trained_layer = 11
app_state["models"]["hubert"] = {
"clf": clf_h,
"scaler": scaler_h,
"le": le_h,
"layer": trained_layer,
}
print("Loaded HuBERT model (layer", trained_layer, ").")
except Exception as e:
print("Could not load HuBERT model:", e)
load_models()
# ---------------------------------------------------------------------
# 5. Prediction logic
# ---------------------------------------------------------------------
def predict_accent_and_cuisine(audio_file, feature_choice, hubert_layer_idx, use_trained):
audio_path = normalize_audio_path(audio_file)
if audio_path is None:
return "No audio provided.", ""
feat = "hubert" if feature_choice == "HuBERT" else "mfcc"
if not use_trained:
return (
"On‑the‑fly training is disabled on this Space. "
"Please keep 'Use trained model' checked.",
"",
)
model_info = app_state["models"].get(feat)
if not model_info:
return f"No trained {feat.upper()} model found on server.", ""
clf = model_info["clf"]
scaler = model_info["scaler"]
le = model_info["le"]
# Feature extraction
try:
if feat == "mfcc":
x = extract_mfcc_pooled(audio_path)
else:
# Use stored trained layer if available; otherwise the slider value
trained_layer = model_info.get("layer", hubert_layer_idx)
x = get_hubert_layer_embedding(audio_path, layer_idx=trained_layer)
except Exception as e:
return f"Feature extraction failed: {e}", ""
# Prediction
try:
Xs = scaler.transform(x.reshape(1, -1))
pred_idx = clf.predict(Xs)[0]
pred_label = le.inverse_transform([pred_idx])[0]
except Exception as e:
return f"Prediction error: {e}", ""
cuisine_text = format_cuisine_output(pred_label)
return pred_label, cuisine_text
# ---------------------------------------------------------------------
# 6. Gradio UI
# ---------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Accent Detection & Cuisine Recommendation")
gr.Markdown(
"Upload a short speech clip. The app predicts the speaker's regional accent "
"and suggests popular veg / non‑veg starters, main course, and desserts "
"from that region.\n\n"
"**Tip:** Start with MFCC. Use HuBERT once the model has fully loaded."
)
with gr.Row():
audio_in = gr.Audio(
type="filepath",
label="Upload audio (.wav / .mp3 / .flac)",
)
with gr.Column():
feature_choice = gr.Radio(
choices=["MFCC", "HuBERT"],
value="MFCC",
label="Feature Type",
)
hubert_layer_idx = gr.Slider(
minimum=0,
maximum=23,
step=1,
value=11,
label="HuBERT Layer (for HuBERT mode)",
)
use_trained = gr.Checkbox(
value=True,
label="Use trained model (required on this Space)",
)
btn = gr.Button("Predict Accent & Recommend Dishes")
out_label = gr.Textbox(label="Predicted Accent")
out_cuisine = gr.Textbox(
label="Recommended Cuisines",
lines=10,
)
btn.click(
fn=predict_accent_and_cuisine,
inputs=[audio_in, feature_choice, hubert_layer_idx, use_trained],
outputs=[out_label, out_cuisine],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))