adaptive-k-demo / app.py
Tuo Nome
Add Important Notes section to help users understand results
8075ff4
Raw
History Blame Contribute Delete
13.3 kB
"""
Adaptive-K Demo for HuggingFace Spaces
Interactive demonstration of entropy-guided dynamic expert selection.
This shows the CONCEPT without requiring the full SDK license.
"""
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
# Simulated expert names for visualization
EXPERT_NAMES = [
"Math & Logic",
"Language & Grammar",
"Code & Technical",
"Creative Writing",
"Facts & Knowledge",
"Reasoning & Analysis",
"Translation",
"Summarization"
]
def calculate_entropy(probs):
"""Calculate Shannon entropy of probability distribution."""
probs = np.array(probs)
probs = probs[probs > 0] # Avoid log(0)
return -np.sum(probs * np.log2(probs))
def simulate_router(text: str, model_type: str = "Mixtral 8x7B"):
"""
Simulate MoE router behavior based on input characteristics.
In real deployment, this would use actual model hidden states.
"""
text_lower = text.lower()
# Heuristic-based routing simulation
scores = np.ones(8) * 0.1
# Math/Logic patterns
if any(word in text_lower for word in ['calculate', 'compute', 'math', 'equation', '+', '-', '*', '/', '=', 'sum', 'average']):
scores[0] += 0.6
# Language/Grammar patterns
if any(word in text_lower for word in ['grammar', 'spelling', 'correct', 'rewrite', 'rephrase']):
scores[1] += 0.5
# Code/Technical patterns
if any(word in text_lower for word in ['code', 'python', 'function', 'debug', 'api', 'programming', 'def ', 'class ']):
scores[2] += 0.7
# Creative patterns
if any(word in text_lower for word in ['write', 'story', 'poem', 'creative', 'imagine', 'fiction']):
scores[3] += 0.5
# Knowledge patterns
if any(word in text_lower for word in ['what is', 'who is', 'when did', 'explain', 'define', 'history']):
scores[4] += 0.5
# Reasoning patterns
if any(word in text_lower for word in ['why', 'how does', 'analyze', 'compare', 'evaluate', 'reason']):
scores[5] += 0.6
# Translation patterns
if any(word in text_lower for word in ['translate', 'italian', 'french', 'spanish', 'german', 'chinese']):
scores[6] += 0.8
# Summarization patterns
if any(word in text_lower for word in ['summarize', 'summary', 'tldr', 'brief', 'shorten']):
scores[7] += 0.5
# Add complexity based on length
complexity_bonus = min(len(text) / 500, 0.3)
scores += np.random.uniform(0, complexity_bonus, 8)
# Normalize to probabilities
probs = scores / scores.sum()
return probs
def adaptive_k_select(probs: np.ndarray, thresholds: list = [0.6, 1.2], k_values: list = [1, 2, 4]):
"""
Select K based on entropy thresholds.
Low entropy β†’ confident router β†’ fewer experts needed
High entropy β†’ uncertain router β†’ more experts needed
"""
entropy = calculate_entropy(probs)
if entropy < thresholds[0]:
k = k_values[0]
elif entropy < thresholds[1]:
k = k_values[1]
else:
k = k_values[2]
return k, entropy
def create_visualization(probs, k, entropy, selected_indices):
"""Create visualization of expert selection."""
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
fig.patch.set_facecolor('#1a1a2e')
colors = ['#6366f1' if i in selected_indices else '#334155' for i in range(8)]
# Expert probabilities
ax1 = axes[0]
ax1.set_facecolor('#1a1a2e')
bars = ax1.barh(EXPERT_NAMES, probs, color=colors)
ax1.set_xlabel('Routing Probability', color='white')
ax1.set_title('Expert Routing Weights', color='white', fontsize=12, fontweight='bold')
ax1.tick_params(colors='white')
ax1.spines['bottom'].set_color('white')
ax1.spines['left'].set_color('white')
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.set_xlim(0, max(probs) * 1.2)
# Entropy gauge
ax2 = axes[1]
ax2.set_facecolor('#1a1a2e')
# Create entropy scale
entropy_levels = np.linspace(0, 3, 100)
colors_gradient = plt.cm.RdYlGn_r(entropy_levels / 3)
for i, (e, c) in enumerate(zip(entropy_levels[:-1], colors_gradient[:-1])):
ax2.barh(0, 0.03, left=e, color=c, height=0.3)
ax2.axvline(entropy, color='white', linewidth=3, label=f'Current: {entropy:.2f}')
ax2.axvline(0.6, color='#22c55e', linestyle='--', linewidth=2, alpha=0.7)
ax2.axvline(1.2, color='#f59e0b', linestyle='--', linewidth=2, alpha=0.7)
ax2.set_xlim(0, 3)
ax2.set_ylim(-0.5, 0.5)
ax2.set_xlabel('Entropy (bits)', color='white')
ax2.set_title('Router Entropy', color='white', fontsize=12, fontweight='bold')
ax2.tick_params(colors='white')
ax2.set_yticks([])
ax2.spines['bottom'].set_color('white')
ax2.spines['left'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
# Add threshold labels
ax2.text(0.3, 0.4, 'K=1', color='#22c55e', fontsize=10, ha='center')
ax2.text(0.9, 0.4, 'K=2', color='#f59e0b', fontsize=10, ha='center')
ax2.text(2.0, 0.4, 'K=4', color='#ef4444', fontsize=10, ha='center')
# Cost comparison
ax3 = axes[2]
ax3.set_facecolor('#1a1a2e')
fixed_k = 4 # Typical fixed K
savings = (1 - k / fixed_k) * 100
bars = ax3.bar(['Fixed K=4', f'Adaptive K={k}'], [100, 100 - savings],
color=['#ef4444', '#22c55e'])
ax3.set_ylabel('Relative Compute Cost (%)', color='white')
ax3.set_title(f'Compute Savings: {savings:.0f}%', color='white', fontsize=12, fontweight='bold')
ax3.tick_params(colors='white')
ax3.spines['bottom'].set_color('white')
ax3.spines['left'].set_color('white')
ax3.spines['top'].set_visible(False)
ax3.spines['right'].set_visible(False)
ax3.set_ylim(0, 120)
# Add value labels on bars
for bar, val in zip(bars, [100, 100 - savings]):
ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 3,
f'{val:.0f}%', ha='center', color='white', fontsize=11)
plt.tight_layout()
return fig
def process_input(text: str, threshold_low: float, threshold_high: float):
"""Main processing function."""
if not text.strip():
return None, "Please enter some text to analyze.", "", ""
# Simulate routing
probs = simulate_router(text)
# Adaptive K selection
k, entropy = adaptive_k_select(probs, [threshold_low, threshold_high])
# Get top-k experts
selected_indices = np.argsort(probs)[-k:][::-1]
selected_experts = [EXPERT_NAMES[i] for i in selected_indices]
# Create visualization
fig = create_visualization(probs, k, entropy, selected_indices)
# Format outputs
entropy_text = f"**Entropy:** {entropy:.3f} bits"
k_text = f"**Selected K:** {k} experts"
experts_text = f"**Active Experts:** {', '.join(selected_experts)}"
savings = (1 - k / 4) * 100
savings_text = f"**Compute Savings:** {savings:.0f}% vs fixed K=4"
summary = f"{entropy_text}\n\n{k_text}\n\n{experts_text}\n\n{savings_text}"
# Interpretation
if entropy < threshold_low:
interpretation = "🟒 **Low Complexity** - Router is confident. Only 1 expert needed for this straightforward task."
elif entropy < threshold_high:
interpretation = "🟑 **Medium Complexity** - Some uncertainty. 2 experts provide good coverage."
else:
interpretation = "πŸ”΄ **High Complexity** - Router is uncertain. 4 experts ensure comprehensive handling."
return fig, summary, interpretation, f"Router probabilities: {probs.round(3).tolist()}"
# Example inputs
EXAMPLES = [
["What is 2+2?", 0.6, 1.2],
["Translate 'Hello, how are you?' to Italian", 0.6, 1.2],
["Write a Python function to sort a list using quicksort algorithm", 0.6, 1.2],
["Explain the socioeconomic implications of artificial intelligence on global labor markets, considering both developed and developing nations", 0.6, 1.2],
["Summarize the main points", 0.6, 1.2],
["Why did the Roman Empire fall? Analyze the political, economic, and military factors.", 0.6, 1.2],
]
# Gradio Interface
with gr.Blocks(
title="Adaptive-K Demo",
theme=gr.themes.Base(
primary_hue="indigo",
secondary_hue="purple",
neutral_hue="slate",
),
css="""
.gradio-container { max-width: 1200px !important; }
.main-title { text-align: center; margin-bottom: 1rem; }
.subtitle { text-align: center; color: #9ca3af; margin-bottom: 2rem; }
"""
) as demo:
gr.Markdown(
"""
# 🧠 Adaptive-K: Dynamic Expert Selection
<p style="text-align: center; color: #9ca3af; font-size: 1.1rem;">
Entropy-guided routing for Mixture-of-Experts models. Reduce inference costs by 30-50%.
</p>
""",
elem_classes=["main-title"]
)
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Input Text",
placeholder="Enter a prompt to analyze routing behavior...",
lines=3,
)
with gr.Row():
threshold_low = gr.Slider(
minimum=0.1, maximum=1.5, value=0.6, step=0.1,
label="Low Entropy Threshold (K=1 β†’ K=2)"
)
threshold_high = gr.Slider(
minimum=0.5, maximum=2.5, value=1.2, step=0.1,
label="High Entropy Threshold (K=2 β†’ K=4)"
)
analyze_btn = gr.Button("πŸ” Analyze Routing", variant="primary", size="lg")
with gr.Column(scale=1):
summary_output = gr.Markdown(label="Results")
interpretation_output = gr.Markdown(label="Interpretation")
plot_output = gr.Plot(label="Visualization")
with gr.Accordion("πŸ”§ Technical Details", open=False):
debug_output = gr.Textbox(label="Router Probabilities", interactive=False)
gr.Markdown("""
### How Adaptive-K Works
1. **Router Forward Pass**: Input hidden states are passed through the MoE router
2. **Entropy Calculation**: Shannon entropy H = -Ξ£ pα΅’ logβ‚‚(pα΅’) measures uncertainty
3. **K Selection**:
- H < 0.6 β†’ K=1 (confident, simple task)
- 0.6 ≀ H < 1.2 β†’ K=2 (moderate complexity)
- H β‰₯ 1.2 β†’ K=4 (uncertain, complex task)
4. **Expert Execution**: Only top-K experts are computed
### Why It Works
Router entropy correlates with input complexity:
- **Simple inputs** (math, translation) β†’ Router "knows" which expert to use β†’ Low entropy
- **Complex inputs** (multi-domain reasoning) β†’ Router spreads probability β†’ High entropy
By using entropy as a complexity signal, Adaptive-K allocates compute where it's needed.
""")
with gr.Accordion("πŸ“ Important Notes", open=True):
gr.Markdown("""
### Understanding the Results
**πŸ’‘ About Compute Savings:**
- Savings are calculated vs **fixed K=4** (the baseline)
- When K=4 is selected β†’ **0% savings** (this is expected! Complex queries need all experts)
- When K=1 is selected β†’ **75% savings** (simple queries save the most)
- Try "What is 2+2?" to see maximum savings!
**🎯 About Expert Selection:**
- This demo uses **keyword heuristics** to simulate router behavior
- In production with real MoE models (Mixtral, DeepSeek), the neural router makes more accurate selections
- The concept remains the same: entropy guides K selection
**πŸ§ͺ Try These Examples:**
- Simple: "What is 2+2?" β†’ Low entropy β†’ K=1 β†’ **75% savings**
- Medium: "Translate hello to Italian" β†’ Medium entropy β†’ K=2 β†’ **50% savings**
- Complex: Multi-domain questions β†’ High entropy β†’ K=4 β†’ **0% savings** (but accuracy preserved!)
""")
gr.Examples(
examples=EXAMPLES,
inputs=[text_input, threshold_low, threshold_high],
outputs=[plot_output, summary_output, interpretation_output, debug_output],
fn=process_input,
cache_examples=True,
)
analyze_btn.click(
fn=process_input,
inputs=[text_input, threshold_low, threshold_high],
outputs=[plot_output, summary_output, interpretation_output, debug_output],
)
gr.Markdown("""
---
<p style="text-align: center; color: #6b7280;">
πŸ“„ <a href="https://github.com/Gabrobals/sbm-efficient/blob/master/Entropy_Guided_Dynamic_Expert_Selection_in_Mixture_of_Experts_Models.pdf">Paper</a> |
πŸ“– <a href="https://adaptive-k.vertexdata.it/whitepaper.html">Whitepaper</a> |
πŸ’» <a href="https://github.com/Gabrobals/sbm-efficient">GitHub</a> |
πŸ“¦ <a href="https://pypi.org/project/adaptive-k-routing/">PyPI</a>
<br><br>
Β© 2026 Vertex Data Β· MIT License
</p>
""")
if __name__ == "__main__":
demo.launch()