import math import time from threading import Thread import gradio as gr import matplotlib import numpy as np import torch try: import whisper HAS_WHISPER = True except ImportError: HAS_WHISPER = False from peft import PeftModel from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, ) # Use non-interactive backend for matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt # ================================================================ # CONFIGURATION # ================================================================ BASE_MODEL = "ethicalabs/Echo-DSRN-114M-v0.1.2" ADAPTER_PATH = "ethicalabs/Echo-SmolTools-114M-Intent-PEFT" # ================================================================ # METADATA: INTENTS & EXAMPLES # ================================================================ INTENTS = [ "datetime_query", "iot_hue_lightchange", "transport_ticket", "takeaway_query", "qa_stock", "general_greet", "recommendation_events", "music_dislikeness", "iot_wemo_off", "cooking_recipe", "qa_currency", "transport_traffic", "general_quirky", "weather_query", "audio_volume_up", "email_addcontact", "takeaway_order", "email_querycontact", "iot_hue_lightup", "recommendation_locations", "play_audiobook", "lists_createoradd", "news_query", "alarm_query", "iot_wemo_on", "general_joke", "qa_definition", "social_query", "music_settings", "audio_volume_other", "calendar_remove", "iot_hue_lightdim", "calendar_query", "email_sendemail", "iot_cleaning", "audio_volume_down", "play_radio", "cooking_query", "datetime_convert", "qa_maths", "iot_hue_lightoff", "iot_hue_lighton", "transport_query", "music_likeness", "email_query", "play_music", "audio_volume_mute", "social_post", "alarm_set", "qa_factoid", "calendar_set", "play_game", "alarm_remove", "lists_remove", "transport_taxi", "recommendation_movies", "iot_coffee", "music_query", "play_podcasts", "lists_query", ] EXAMPLES = { "it-IT": [ "spegni le luci per favore", "abbassa le luci dell' ingresso", "riproduci oro di mango", "quali sono le previsioni meteo della settimana", "riproduci malibu", ], "en-US": [ "turn the lights off please", "dim the lights in the hall", "clean the flat", "cleaning is good dust is so bad do now your magic clean my carpet", "list most rated delivery options for chinese food", ], "es-ES": [ "apaga las luces por favor", "atenua las luces en el pasillo", "oscurece la habitación", "me gustaría escuchar barcelona de queen", "ponme barcelona por queen", ], "pt-PT": [ "desligar as luzes", "diminuir as luzes no salão", "diz qual é o stato da minha memória disponível", "mostra uma lista com entrega ao domicílio de comida chinesa com mais avaliações", "eu gostava de ouvir punksinatra corridinho à portuguesa", ], "fr-FR": [ "éteigne les lumières s'il te plait", "tamiser les lumières dans la salle", "nettoyer la télévision", "trouve mes plats à emporter thaïlandais autour de la concorde", "j'aimerais écouter ne me quitte pas de jacques brel", ], "de-DE": [ "schalte bitte die lichter aus", "dimme die lichter im eingangsbereich", "wie lautet der status meines verfügbaren speichers", "ich würde gerne queen's barcelona hören", "spiel barcelona von queen", ], } # ================================================================ # MODEL LOADING # ================================================================ print("🚀 Initializing Echo-DSRN Dashboard...") print(f" Base: {BASE_MODEL}") print(f" Adapter: {ADAPTER_PATH}") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) model = PeftModel.from_pretrained(model, ADAPTER_PATH) model.eval() DEVICE = next(model.parameters()).device print(f" ✅ Model loaded on {DEVICE}") if HAS_WHISPER: print("🎙️ Loading Whisper 'tiny' engine...") try: whisper_model = whisper.load_model("tiny", device=DEVICE) print(" ✅ Whisper ready.") except Exception as e: print(f" ❌ Whisper loading failed: {e}") HAS_WHISPER = False else: print(" ℹ️ Whisper not available (optional dependency).") whisper_model = None # ================================================================ # OBSERVABILITY HOOKS # ================================================================ def get_dsrn_state_heatmap(): """Generates a heatmap of the DSRN recurrent state (c_t) magnitude.""" plt.close('all') config = model.config state_dim = config.hidden_size * config.num_heads if hasattr(model, "_latest_c_states") and model._latest_c_states is not None: c_vector = model._latest_c_states[-1][0].detach().cpu().float().numpy() c_vector = np.abs(c_vector) else: c_vector = np.zeros(state_dim) w = int(math.sqrt(state_dim)) h = state_dim // w state_magnitudes = c_vector[: w * h].reshape((h, w)) fig, ax = plt.subplots(figsize=(6, 3.5), dpi=100) fig.patch.set_facecolor("#0f0f23") ax.set_facecolor("#0f0f23") im = ax.imshow(state_magnitudes, cmap="magma", aspect="auto", interpolation="nearest") ax.set_title( "DSRN Slow State (c_t) Memory Density", color="#e0e0f0", fontsize=10, fontweight="bold" ) ax.set_xticks([]) ax.set_yticks([]) cbar = plt.colorbar(im, ax=ax, fraction=0.02, pad=0.04) cbar.ax.tick_params(colors="#707080", labelsize=7) plt.tight_layout() return fig def get_surprise_lambda_visual(): """Generates a visualization of the Surprise Lambda activation.""" if hasattr(model, "_latest_gate_stats") and getattr(model, "_latest_gate_stats") is not None: surprise_val = model._latest_gate_stats[-1][0, -1].item() else: surprise_val = 0.0 surprise_val = np.clip(surprise_val, 0.0, 1.0) plt.close('all') fig, ax = plt.subplots(figsize=(6, 1.0), dpi=100) fig.patch.set_facecolor("#0f0f23") ax.set_facecolor("#0f0f23") color = '#ef4444' if surprise_val > 0.6 else '#f59e0b' if surprise_val > 0.3 else '#10b981' ax.barh([0], [surprise_val], color=color, height=0.6, alpha=0.9, zorder=2) ax.barh([0], [1.0], color='white', height=0.6, alpha=0.1, zorder=0) ax.set_xlim(0, 1) ax.set_ylim(-0.5, 0.5) ax.set_yticks([]) ax.set_xticks([0, 0.25, 0.5, 0.75, 1.0]) ax.tick_params(colors="#707080", labelsize=8) for spine in ax.spines.values(): spine.set_visible(False) ax.set_title( f"DSRN Surprise Signal (λ_t): {surprise_val:.4f}", color="#e0e0f0", fontsize=9, fontweight="bold", ) plt.tight_layout() return fig # ================================================================ # INFERENCE LOGIC # ================================================================ def transcribe_audio(audio_path): if not HAS_WHISPER or audio_path is None: return "" try: result = whisper_model.transcribe(audio_path) return result["text"].strip() except Exception as e: return f"Error transcribing: {e}" def classify_intent(utterance, locale): if not utterance.strip(): yield "", None, None, "0.0 TPS" return messages = [ { "role": "system", "content": "You are a helpful multilingual intent classification assistant.", }, {"role": "user", "content": f"Classify the intent of the following request: {utterance}"}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs = { "input_ids": inputs.input_ids, "max_new_tokens": 15, "do_sample": False, "pad_token_id": tokenizer.eos_token_id, "streamer": streamer, "output_dsrn_telemetry": True, } thread = Thread(target=lambda: model.generate(**gen_kwargs)) thread.start() start_time = time.time() tokens = 0 full_response = "" for chunk in streamer: full_response += chunk tokens += 1 elapsed = time.time() - start_time tps = tokens / elapsed if elapsed > 0 else 0 # Plotting is heavy, refresh every few tokens or at the end if tokens % 2 == 0: yield full_response.strip(), get_dsrn_state_heatmap(), get_surprise_lambda_visual(), f"⚡ {tps:.1f} TPS" # Final yield to ensure plot is up to date elapsed = time.time() - start_time tps = tokens / elapsed if elapsed > 0 else 0 yield full_response.strip(), get_dsrn_state_heatmap(), get_surprise_lambda_visual(), f"⚡ {tps:.1f} TPS" # ================================================================ # UI CONSTRUCTION # ================================================================ with gr.Blocks(theme=gr.themes.Soft(), title="Echo-DSRN 114M SmolTools: Multilingual Intent Classifier") as demo: gr.Markdown( f""" # 🎙️ Echo-DSRN 114M SmolTools: Multilingual Intent Classifier ### High-Fidelity 1-Shot Inference & Observability Cockpit This dashboard provides real-time intent classification across 60 categories. **🚀 Model Lineage**: - **Base**: `{BASE_MODEL}` - **Adapter**: `{ADAPTER_PATH}` **⚠️ Limitations**: While highly optimized for edge-routing, accuracy varies by locale. The 114M model may occasionally confuse overlapping semantic clusters (e.g., *calendar* vs. *alarm*) in low-context utterances. """ ) with gr.Row(): with gr.Column(scale=2): with gr.Group(): locale = gr.Dropdown( choices=list(EXAMPLES.keys()), value="en-US", label="Target Locale", info="Select target language for intent examples.", ) examples_dropdown = gr.Dropdown( choices=EXAMPLES["en-US"], value=EXAMPLES["en-US"][0], label="Example Utterances", info="Pre-compiled samples from the Amazon MASSIVE validation set.", ) audio_input = gr.Audio( sources=["microphone"], type="filepath", label="Voice Command (Experimental)", visible=HAS_WHISPER, ) transcribe_btn = gr.Button( "🎤 Transcribe Audio", variant="secondary", visible=HAS_WHISPER ) input_text = gr.Textbox( value=EXAMPLES["en-US"][0], placeholder="Enter a request in any supported language...", label="Utterance", lines=2, ) with gr.Row(): classify_btn = gr.Button("🚀 Classify Intent", variant="primary") reset_btn = gr.Button("🔄 Reset") with gr.Group(): gr.Markdown("### 🏷️ Predicted Intent") output_label = gr.Label(label="", show_label=False) tps_stats = gr.Markdown("**Telemetry:** 0.0 TPS") with gr.Column(scale=3): gr.Markdown("### 🧠 DSRN Core Observability") surprise_plot = gr.Plot(label="Surprise Bar") heatmap_plot = gr.Plot(label="State Heatmap") with gr.Accordion("📚 Reference: Intent Registry (60 Classes)", open=False): gr.Markdown(", ".join([f"`{i}`" for i in INTENTS])) # --- EVENT HANDLERS --- def update_examples(loc): return gr.update(choices=EXAMPLES[loc], value=EXAMPLES[loc][0]) def handle_reset(): return None, "", "en-US", EXAMPLES["en-US"][0], None, None, None, "**Telemetry:** 0.0 TPS" locale.change(update_examples, locale, examples_dropdown) examples_dropdown.change(lambda x: x, examples_dropdown, input_text) transcribe_btn.click(transcribe_audio, inputs=[audio_input], outputs=[input_text]) classify_btn.click( classify_intent, inputs=[input_text, locale], outputs=[output_label, heatmap_plot, surprise_plot, tps_stats], ) reset_btn.click( handle_reset, outputs=[ audio_input, input_text, locale, examples_dropdown, output_label, heatmap_plot, surprise_plot, tps_stats, ], ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)