Spaces:
Sleeping
Sleeping
| """ | |
| Adaptive-K Demo for HuggingFace Spaces | |
| Interactive demonstration of entropy-guided dynamic expert selection. | |
| This shows the CONCEPT without requiring the full SDK license. | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| # Simulated expert names for visualization | |
| EXPERT_NAMES = [ | |
| "Math & Logic", | |
| "Language & Grammar", | |
| "Code & Technical", | |
| "Creative Writing", | |
| "Facts & Knowledge", | |
| "Reasoning & Analysis", | |
| "Translation", | |
| "Summarization" | |
| ] | |
| def calculate_entropy(probs): | |
| """Calculate Shannon entropy of probability distribution.""" | |
| probs = np.array(probs) | |
| probs = probs[probs > 0] # Avoid log(0) | |
| return -np.sum(probs * np.log2(probs)) | |
| def simulate_router(text: str, model_type: str = "Mixtral 8x7B"): | |
| """ | |
| Simulate MoE router behavior based on input characteristics. | |
| In real deployment, this would use actual model hidden states. | |
| """ | |
| text_lower = text.lower() | |
| # Heuristic-based routing simulation | |
| scores = np.ones(8) * 0.1 | |
| # Math/Logic patterns | |
| if any(word in text_lower for word in ['calculate', 'compute', 'math', 'equation', '+', '-', '*', '/', '=', 'sum', 'average']): | |
| scores[0] += 0.6 | |
| # Language/Grammar patterns | |
| if any(word in text_lower for word in ['grammar', 'spelling', 'correct', 'rewrite', 'rephrase']): | |
| scores[1] += 0.5 | |
| # Code/Technical patterns | |
| if any(word in text_lower for word in ['code', 'python', 'function', 'debug', 'api', 'programming', 'def ', 'class ']): | |
| scores[2] += 0.7 | |
| # Creative patterns | |
| if any(word in text_lower for word in ['write', 'story', 'poem', 'creative', 'imagine', 'fiction']): | |
| scores[3] += 0.5 | |
| # Knowledge patterns | |
| if any(word in text_lower for word in ['what is', 'who is', 'when did', 'explain', 'define', 'history']): | |
| scores[4] += 0.5 | |
| # Reasoning patterns | |
| if any(word in text_lower for word in ['why', 'how does', 'analyze', 'compare', 'evaluate', 'reason']): | |
| scores[5] += 0.6 | |
| # Translation patterns | |
| if any(word in text_lower for word in ['translate', 'italian', 'french', 'spanish', 'german', 'chinese']): | |
| scores[6] += 0.8 | |
| # Summarization patterns | |
| if any(word in text_lower for word in ['summarize', 'summary', 'tldr', 'brief', 'shorten']): | |
| scores[7] += 0.5 | |
| # Add complexity based on length | |
| complexity_bonus = min(len(text) / 500, 0.3) | |
| scores += np.random.uniform(0, complexity_bonus, 8) | |
| # Normalize to probabilities | |
| probs = scores / scores.sum() | |
| return probs | |
| def adaptive_k_select(probs: np.ndarray, thresholds: list = [0.6, 1.2], k_values: list = [1, 2, 4]): | |
| """ | |
| Select K based on entropy thresholds. | |
| Low entropy β confident router β fewer experts needed | |
| High entropy β uncertain router β more experts needed | |
| """ | |
| entropy = calculate_entropy(probs) | |
| if entropy < thresholds[0]: | |
| k = k_values[0] | |
| elif entropy < thresholds[1]: | |
| k = k_values[1] | |
| else: | |
| k = k_values[2] | |
| return k, entropy | |
| def create_visualization(probs, k, entropy, selected_indices): | |
| """Create visualization of expert selection.""" | |
| fig, axes = plt.subplots(1, 3, figsize=(14, 4)) | |
| fig.patch.set_facecolor('#1a1a2e') | |
| colors = ['#6366f1' if i in selected_indices else '#334155' for i in range(8)] | |
| # Expert probabilities | |
| ax1 = axes[0] | |
| ax1.set_facecolor('#1a1a2e') | |
| bars = ax1.barh(EXPERT_NAMES, probs, color=colors) | |
| ax1.set_xlabel('Routing Probability', color='white') | |
| ax1.set_title('Expert Routing Weights', color='white', fontsize=12, fontweight='bold') | |
| ax1.tick_params(colors='white') | |
| ax1.spines['bottom'].set_color('white') | |
| ax1.spines['left'].set_color('white') | |
| ax1.spines['top'].set_visible(False) | |
| ax1.spines['right'].set_visible(False) | |
| ax1.set_xlim(0, max(probs) * 1.2) | |
| # Entropy gauge | |
| ax2 = axes[1] | |
| ax2.set_facecolor('#1a1a2e') | |
| # Create entropy scale | |
| entropy_levels = np.linspace(0, 3, 100) | |
| colors_gradient = plt.cm.RdYlGn_r(entropy_levels / 3) | |
| for i, (e, c) in enumerate(zip(entropy_levels[:-1], colors_gradient[:-1])): | |
| ax2.barh(0, 0.03, left=e, color=c, height=0.3) | |
| ax2.axvline(entropy, color='white', linewidth=3, label=f'Current: {entropy:.2f}') | |
| ax2.axvline(0.6, color='#22c55e', linestyle='--', linewidth=2, alpha=0.7) | |
| ax2.axvline(1.2, color='#f59e0b', linestyle='--', linewidth=2, alpha=0.7) | |
| ax2.set_xlim(0, 3) | |
| ax2.set_ylim(-0.5, 0.5) | |
| ax2.set_xlabel('Entropy (bits)', color='white') | |
| ax2.set_title('Router Entropy', color='white', fontsize=12, fontweight='bold') | |
| ax2.tick_params(colors='white') | |
| ax2.set_yticks([]) | |
| ax2.spines['bottom'].set_color('white') | |
| ax2.spines['left'].set_visible(False) | |
| ax2.spines['top'].set_visible(False) | |
| ax2.spines['right'].set_visible(False) | |
| # Add threshold labels | |
| ax2.text(0.3, 0.4, 'K=1', color='#22c55e', fontsize=10, ha='center') | |
| ax2.text(0.9, 0.4, 'K=2', color='#f59e0b', fontsize=10, ha='center') | |
| ax2.text(2.0, 0.4, 'K=4', color='#ef4444', fontsize=10, ha='center') | |
| # Cost comparison | |
| ax3 = axes[2] | |
| ax3.set_facecolor('#1a1a2e') | |
| fixed_k = 4 # Typical fixed K | |
| savings = (1 - k / fixed_k) * 100 | |
| bars = ax3.bar(['Fixed K=4', f'Adaptive K={k}'], [100, 100 - savings], | |
| color=['#ef4444', '#22c55e']) | |
| ax3.set_ylabel('Relative Compute Cost (%)', color='white') | |
| ax3.set_title(f'Compute Savings: {savings:.0f}%', color='white', fontsize=12, fontweight='bold') | |
| ax3.tick_params(colors='white') | |
| ax3.spines['bottom'].set_color('white') | |
| ax3.spines['left'].set_color('white') | |
| ax3.spines['top'].set_visible(False) | |
| ax3.spines['right'].set_visible(False) | |
| ax3.set_ylim(0, 120) | |
| # Add value labels on bars | |
| for bar, val in zip(bars, [100, 100 - savings]): | |
| ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 3, | |
| f'{val:.0f}%', ha='center', color='white', fontsize=11) | |
| plt.tight_layout() | |
| return fig | |
| def process_input(text: str, threshold_low: float, threshold_high: float): | |
| """Main processing function.""" | |
| if not text.strip(): | |
| return None, "Please enter some text to analyze.", "", "" | |
| # Simulate routing | |
| probs = simulate_router(text) | |
| # Adaptive K selection | |
| k, entropy = adaptive_k_select(probs, [threshold_low, threshold_high]) | |
| # Get top-k experts | |
| selected_indices = np.argsort(probs)[-k:][::-1] | |
| selected_experts = [EXPERT_NAMES[i] for i in selected_indices] | |
| # Create visualization | |
| fig = create_visualization(probs, k, entropy, selected_indices) | |
| # Format outputs | |
| entropy_text = f"**Entropy:** {entropy:.3f} bits" | |
| k_text = f"**Selected K:** {k} experts" | |
| experts_text = f"**Active Experts:** {', '.join(selected_experts)}" | |
| savings = (1 - k / 4) * 100 | |
| savings_text = f"**Compute Savings:** {savings:.0f}% vs fixed K=4" | |
| summary = f"{entropy_text}\n\n{k_text}\n\n{experts_text}\n\n{savings_text}" | |
| # Interpretation | |
| if entropy < threshold_low: | |
| interpretation = "π’ **Low Complexity** - Router is confident. Only 1 expert needed for this straightforward task." | |
| elif entropy < threshold_high: | |
| interpretation = "π‘ **Medium Complexity** - Some uncertainty. 2 experts provide good coverage." | |
| else: | |
| interpretation = "π΄ **High Complexity** - Router is uncertain. 4 experts ensure comprehensive handling." | |
| return fig, summary, interpretation, f"Router probabilities: {probs.round(3).tolist()}" | |
| # Example inputs | |
| EXAMPLES = [ | |
| ["What is 2+2?", 0.6, 1.2], | |
| ["Translate 'Hello, how are you?' to Italian", 0.6, 1.2], | |
| ["Write a Python function to sort a list using quicksort algorithm", 0.6, 1.2], | |
| ["Explain the socioeconomic implications of artificial intelligence on global labor markets, considering both developed and developing nations", 0.6, 1.2], | |
| ["Summarize the main points", 0.6, 1.2], | |
| ["Why did the Roman Empire fall? Analyze the political, economic, and military factors.", 0.6, 1.2], | |
| ] | |
| # Gradio Interface | |
| with gr.Blocks( | |
| title="Adaptive-K Demo", | |
| theme=gr.themes.Base( | |
| primary_hue="indigo", | |
| secondary_hue="purple", | |
| neutral_hue="slate", | |
| ), | |
| css=""" | |
| .gradio-container { max-width: 1200px !important; } | |
| .main-title { text-align: center; margin-bottom: 1rem; } | |
| .subtitle { text-align: center; color: #9ca3af; margin-bottom: 2rem; } | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # π§ Adaptive-K: Dynamic Expert Selection | |
| <p style="text-align: center; color: #9ca3af; font-size: 1.1rem;"> | |
| Entropy-guided routing for Mixture-of-Experts models. Reduce inference costs by 30-50%. | |
| </p> | |
| """, | |
| elem_classes=["main-title"] | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter a prompt to analyze routing behavior...", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| threshold_low = gr.Slider( | |
| minimum=0.1, maximum=1.5, value=0.6, step=0.1, | |
| label="Low Entropy Threshold (K=1 β K=2)" | |
| ) | |
| threshold_high = gr.Slider( | |
| minimum=0.5, maximum=2.5, value=1.2, step=0.1, | |
| label="High Entropy Threshold (K=2 β K=4)" | |
| ) | |
| analyze_btn = gr.Button("π Analyze Routing", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| summary_output = gr.Markdown(label="Results") | |
| interpretation_output = gr.Markdown(label="Interpretation") | |
| plot_output = gr.Plot(label="Visualization") | |
| with gr.Accordion("π§ Technical Details", open=False): | |
| debug_output = gr.Textbox(label="Router Probabilities", interactive=False) | |
| gr.Markdown(""" | |
| ### How Adaptive-K Works | |
| 1. **Router Forward Pass**: Input hidden states are passed through the MoE router | |
| 2. **Entropy Calculation**: Shannon entropy H = -Ξ£ pα΅’ logβ(pα΅’) measures uncertainty | |
| 3. **K Selection**: | |
| - H < 0.6 β K=1 (confident, simple task) | |
| - 0.6 β€ H < 1.2 β K=2 (moderate complexity) | |
| - H β₯ 1.2 β K=4 (uncertain, complex task) | |
| 4. **Expert Execution**: Only top-K experts are computed | |
| ### Why It Works | |
| Router entropy correlates with input complexity: | |
| - **Simple inputs** (math, translation) β Router "knows" which expert to use β Low entropy | |
| - **Complex inputs** (multi-domain reasoning) β Router spreads probability β High entropy | |
| By using entropy as a complexity signal, Adaptive-K allocates compute where it's needed. | |
| """) | |
| with gr.Accordion("π Important Notes", open=True): | |
| gr.Markdown(""" | |
| ### Understanding the Results | |
| **π‘ About Compute Savings:** | |
| - Savings are calculated vs **fixed K=4** (the baseline) | |
| - When K=4 is selected β **0% savings** (this is expected! Complex queries need all experts) | |
| - When K=1 is selected β **75% savings** (simple queries save the most) | |
| - Try "What is 2+2?" to see maximum savings! | |
| **π― About Expert Selection:** | |
| - This demo uses **keyword heuristics** to simulate router behavior | |
| - In production with real MoE models (Mixtral, DeepSeek), the neural router makes more accurate selections | |
| - The concept remains the same: entropy guides K selection | |
| **π§ͺ Try These Examples:** | |
| - Simple: "What is 2+2?" β Low entropy β K=1 β **75% savings** | |
| - Medium: "Translate hello to Italian" β Medium entropy β K=2 β **50% savings** | |
| - Complex: Multi-domain questions β High entropy β K=4 β **0% savings** (but accuracy preserved!) | |
| """) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[text_input, threshold_low, threshold_high], | |
| outputs=[plot_output, summary_output, interpretation_output, debug_output], | |
| fn=process_input, | |
| cache_examples=True, | |
| ) | |
| analyze_btn.click( | |
| fn=process_input, | |
| inputs=[text_input, threshold_low, threshold_high], | |
| outputs=[plot_output, summary_output, interpretation_output, debug_output], | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| <p style="text-align: center; color: #6b7280;"> | |
| π <a href="https://github.com/Gabrobals/sbm-efficient/blob/master/Entropy_Guided_Dynamic_Expert_Selection_in_Mixture_of_Experts_Models.pdf">Paper</a> | | |
| π <a href="https://adaptive-k.vertexdata.it/whitepaper.html">Whitepaper</a> | | |
| π» <a href="https://github.com/Gabrobals/sbm-efficient">GitHub</a> | | |
| π¦ <a href="https://pypi.org/project/adaptive-k-routing/">PyPI</a> | |
| <br><br> | |
| Β© 2026 Vertex Data Β· MIT License | |
| </p> | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |