""" demo_viz.py — Spec 2.6: Demo Orchestration — Visualization Scripts Generates ALL visual artifacts required for the hackathon pitch and README. Works in two modes: - Demo mode (--demo): uses synthetic/mock data; produces all plots immediately - Training mode: reads from actual training logs / checkpoints Usage: # Generate all plots with synthetic demo data (no GPU required) python demo_viz.py --demo # Generate from training logs TSV python demo_viz.py --results research/results.tsv # Generate specific plots python demo_viz.py --demo --plots reward_curves before_after safeguards Output: results/reward_curves.png results/case_type_reward.png results/exam_curve.png results/difficulty_radar.png results/tool_heatmap.png results/reward_breakdown.png results/safeguards.png results/before_after.png results/diff_view_classII.gif (delegates to diff_view.py) """ from __future__ import annotations import argparse import os from typing import Any, Dict, List, Optional, Tuple import matplotlib matplotlib.use("Agg") import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np # --------------------------------------------------------------------------- # Output directory # --------------------------------------------------------------------------- RESULTS_DIR = "results" def _ensure_results() -> None: os.makedirs(RESULTS_DIR, exist_ok=True) # --------------------------------------------------------------------------- # Plot 1: Reward Training Curves # --------------------------------------------------------------------------- def plot_reward_curves( steps: List[int], terminal: List[float], occlusion: List[float], strategy: List[float], fmt: List[float], slerp_baseline: float = 0.87, output_path: str = "results/reward_curves.png", ) -> str: """ Plot 5-line reward training curves. Parameters ---------- steps : list of int — training step indices terminal : list of float — terminal reward per step (in [0,1] after normalisation) occlusion : list of float — occlusion composite score per step strategy : list of float — strategy match reward per step fmt : list of float — format compliance per step slerp_baseline : float — constant SLERP reference line output_path : str Returns ------- str — absolute path of saved PNG """ _ensure_results() fig, ax = plt.subplots(figsize=(12, 6)) ax.plot(steps, terminal, color="#2196F3", linewidth=2.5, label="Terminal Reward") ax.plot(steps, occlusion, color="#4CAF50", linewidth=2.5, label="Occlusion Quality") ax.plot(steps, strategy, color="#FF9800", linewidth=2.5, label="Strategy Match") ax.plot(steps, fmt, color="#9C27B0", linewidth=2.5, label="Format Compliance") ax.axhline( y=slerp_baseline, color="gray", linestyle="--", linewidth=1.5, alpha=0.7, label=f"SLERP Baseline ({slerp_baseline:.2f})", ) ax.set_xlabel("Training Step", fontsize=14) ax.set_ylabel("Reward Score", fontsize=14) ax.set_title("OrthoRL: GRPO Training Progress", fontsize=16, fontweight="bold") ax.legend(fontsize=12, loc="upper left") ax.grid(True, alpha=0.3) ax.set_ylim(-0.05, 1.05) plt.tight_layout() out = os.path.abspath(output_path) plt.savefig(out, dpi=150, bbox_inches="tight") plt.close(fig) return out # --------------------------------------------------------------------------- # Plot 2: Per-Case-Type Performance # --------------------------------------------------------------------------- def plot_case_type_comparison( slerp_scores: Dict[str, float], trained_scores: Dict[str, float], output_path: str = "results/case_type_reward.png", ) -> str: """ Grouped bar chart: SLERP vs trained agent by malocclusion class. Parameters ---------- slerp_scores : dict mapping class name → float trained_scores : dict mapping class name → float Returns ------- str — absolute path of saved PNG """ _ensure_results() classes = list(slerp_scores.keys()) x = np.arange(len(classes)) width = 0.35 fig, ax = plt.subplots(figsize=(10, 6)) bars_s = ax.bar( x - width / 2, [slerp_scores[c] for c in classes], width, label="SLERP Baseline", color="#BDBDBD", edgecolor="white", ) bars_t = ax.bar( x + width / 2, [trained_scores[c] for c in classes], width, label="Trained Agent", color="#2196F3", edgecolor="white", ) # Annotate bars for bar in [*bars_s, *bars_t]: ax.annotate( f"{bar.get_height():.2f}", xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()), xytext=(0, 4), textcoords="offset points", ha="center", fontsize=10, ) ax.set_xlabel("Malocclusion Type", fontsize=14) ax.set_ylabel("Terminal Reward", fontsize=14) ax.set_title("Performance by Case Type: SLERP vs Trained Agent", fontsize=14, fontweight="bold") ax.set_xticks(x) ax.set_xticklabels(classes, fontsize=13) ax.legend(fontsize=12) ax.grid(True, alpha=0.3, axis="y") ax.set_ylim(0, 1.05) plt.tight_layout() out = os.path.abspath(output_path) plt.savefig(out, dpi=150, bbox_inches="tight") plt.close(fig) return out # --------------------------------------------------------------------------- # Plot 3: Clinical Exam Progression # --------------------------------------------------------------------------- def plot_exam_curve( checkpoints: List[int], scores: List[int], output_path: str = "results/exam_curve.png", ) -> str: """ Line plot of exam score vs training checkpoint. Parameters ---------- checkpoints : list of int — training steps at which exam was run scores : list of int — exam scores (0-10) Returns ------- str — absolute path of saved PNG """ _ensure_results() fig, ax = plt.subplots(figsize=(10, 6)) ax.plot( checkpoints, scores, "o-", color="#E91E63", linewidth=2.5, markersize=10, markerfacecolor="white", markeredgewidth=2.5, ) ax.axhline(y=2.5, color="gray", linestyle="--", alpha=0.5, label="Random Baseline (2.5/10)") # Annotate before/after if len(scores) >= 2: ax.annotate( f"Before: {scores[0]}/10", xy=(checkpoints[0], scores[0]), xytext=(30, 15), textcoords="offset points", fontsize=13, color="#c62828", fontweight="bold", arrowprops=dict(arrowstyle="->", color="#c62828", lw=1.5), ) ax.annotate( f"After: {scores[-1]}/10", xy=(checkpoints[-1], scores[-1]), xytext=(-90, -30), textcoords="offset points", fontsize=13, color="#1b5e20", fontweight="bold", arrowprops=dict(arrowstyle="->", color="#1b5e20", lw=1.5), ) ax.set_xlabel("Training Step", fontsize=14) ax.set_ylabel("Clinical Exam Score (/ 10)", fontsize=14) ax.set_title( "Orthodontic Knowledge Acquisition During RL Training", fontsize=14, fontweight="bold" ) ax.set_ylim(0, 11) ax.legend(fontsize=12) ax.grid(True, alpha=0.3) plt.tight_layout() out = os.path.abspath(output_path) plt.savefig(out, dpi=150, bbox_inches="tight") plt.close(fig) return out # --------------------------------------------------------------------------- # Plot 4: Difficulty Progression Radar Chart # --------------------------------------------------------------------------- def plot_difficulty_radar( start_params: Dict[str, float], end_params: Dict[str, float], output_path: str = "results/difficulty_radar.png", ) -> str: """ Spider/radar chart showing difficulty axis progression. Parameters ---------- start_params : dict — axis name → normalised value [0, 1] at training start end_params : dict — axis name → normalised value [0, 1] at training end Returns ------- str — absolute path of saved PNG """ _ensure_results() categories = list(start_params.keys()) N = len(categories) values_start = [start_params[c] for c in categories] values_end = [end_params[c] for c in categories] angles = [n / float(N) * 2 * np.pi for n in range(N)] angles += angles[:1] values_start = values_start + values_start[:1] values_end = values_end + values_end[:1] fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True)) ax.plot(angles, values_start, "o-", linewidth=2, color="#9E9E9E", label="Start") ax.fill(angles, values_start, alpha=0.10, color="gray") ax.plot(angles, values_end, "o-", linewidth=2, color="#2196F3", label="After Training") ax.fill(angles, values_end, alpha=0.15, color="#2196F3") ax.set_xticks(angles[:-1]) ax.set_xticklabels( [c.replace("_", "\n") for c in categories], fontsize=11, ) ax.set_ylim(0, 1.05) ax.set_yticks([0.25, 0.5, 0.75, 1.0]) ax.set_yticklabels(["25%", "50%", "75%", "100%"], fontsize=8) ax.set_title( "Adaptive Difficulty Progression\n(8 Axes)", fontsize=14, fontweight="bold", pad=20 ) ax.legend(loc="upper right", bbox_to_anchor=(1.25, 1.1), fontsize=12) plt.tight_layout() out = os.path.abspath(output_path) plt.savefig(out, dpi=150, bbox_inches="tight") plt.close(fig) return out # --------------------------------------------------------------------------- # Plot 5: Tool Usage Heatmap # --------------------------------------------------------------------------- def plot_tool_heatmap( usage_data: Dict[str, Dict[str, float]], output_path: str = "results/tool_heatmap.png", ) -> str: """ Heatmap of average tool calls per episode, by case type. Parameters ---------- usage_data : dict[case_type][tool_name] = avg_calls_per_episode Returns ------- str — absolute path of saved PNG """ _ensure_results() case_types = list(usage_data.keys()) # Use union of all tool names, ordered all_tools: List[str] = [] for tools in usage_data.values(): for t in tools: if t not in all_tools: all_tools.append(t) data = np.array([[usage_data[ct].get(t, 0.0) for t in all_tools] for ct in case_types]) fig, ax = plt.subplots(figsize=(12, 4)) im = ax.imshow(data, cmap="YlOrRd", aspect="auto", vmin=0) ax.set_xticks(range(len(all_tools))) ax.set_xticklabels( [t.replace("_", "\n") for t in all_tools], rotation=0, ha="center", fontsize=11, ) ax.set_yticks(range(len(case_types))) ax.set_yticklabels(case_types, fontsize=13) ax.set_title( "Tool Usage by Case Type — Emergent Diagnostic Behaviour", fontsize=13, fontweight="bold" ) for i in range(len(case_types)): for j in range(len(all_tools)): ax.text( j, i, f"{data[i, j]:.1f}", ha="center", va="center", fontsize=11, color="black" if data[i, j] < 3 else "white", ) cbar = plt.colorbar(im, ax=ax, fraction=0.03, pad=0.02) cbar.set_label("Avg calls / episode", fontsize=10) plt.tight_layout() out = os.path.abspath(output_path) plt.savefig(out, dpi=150, bbox_inches="tight") plt.close(fig) return out # --------------------------------------------------------------------------- # Plot 6: Per-Stage Reward Breakdown (Stacked Area) # --------------------------------------------------------------------------- def plot_reward_breakdown( stage_rewards: Dict[str, List[float]], output_path: str = "results/reward_breakdown.png", ) -> str: """ Stacked area chart of reward components over 24 aligner stages. Parameters ---------- stage_rewards : dict mapping component name → list of 24 floats Returns ------- str — absolute path of saved PNG """ _ensure_results() stages = list(range(1, 25)) components = list(stage_rewards.keys()) colors = [ "#2196F3", "#4CAF50", "#FF9800", "#F44336", "#9C27B0", "#00BCD4", "#795548", "#607D8B", ] fig, ax = plt.subplots(figsize=(14, 6)) data = np.array([stage_rewards[c] for c in components]) ax.stackplot(stages, data, labels=components, colors=colors[: len(components)], alpha=0.8) ax.set_xlabel("Aligner Stage", fontsize=14) ax.set_ylabel("Reward Component Value", fontsize=14) ax.set_title("Per-Stage Reward Breakdown — Class II Episode", fontsize=14, fontweight="bold") ax.set_xlim(1, 24) ax.legend(loc="upper left", fontsize=10, ncol=2) ax.grid(True, alpha=0.3, axis="y") plt.tight_layout() out = os.path.abspath(output_path) plt.savefig(out, dpi=150, bbox_inches="tight") plt.close(fig) return out # --------------------------------------------------------------------------- # Plot 7: Anti-Hacking Safeguards Diagram # --------------------------------------------------------------------------- def plot_safeguards( output_path: str = "results/safeguards.png", ) -> str: """ Visual diagram of the 6 anti-reward-hacking safeguards. Returns ------- str — absolute path of saved PNG """ _ensure_results() safeguards = [ ("Format Guard", "Garbage JSON → 0 reward\n(no SLERP fallback for bad output)"), ("Progress Guard", "<10% progress by stage 6\n→ progressive penalty"), ("Tool Budget", "Max 5 tool calls / episode\n(no infinite diagnostic spam)"), ("Timeout", "30 steps / 60s wall clock\n(episode terminates early)"), ("Balanced Sampling", "Equal Class I / II / III\n(no strategy memorisation)"), ("Generation Logging", "Sample outputs every 10 steps\n(human oversight loop)"), ] fig, axes = plt.subplots(2, 3, figsize=(14, 8)) fig.patch.set_facecolor("#f5f5f5") colors = ["#1976D2", "#388E3C", "#F57C00", "#D32F2F", "#7B1FA2", "#00796B"] for idx, (ax, (title, desc), color) in enumerate(zip(axes.flatten(), safeguards, colors)): ax.set_facecolor(color + "22") # light tint ax.add_patch( mpatches.FancyBboxPatch( (0.05, 0.05), 0.90, 0.90, boxstyle="round,pad=0.05", facecolor=color + "33", edgecolor=color, linewidth=2, transform=ax.transAxes, ) ) ax.text( 0.5, 0.72, f"[GUARD] {title}", transform=ax.transAxes, ha="center", va="center", fontsize=13, fontweight="bold", color=color, ) ax.text( 0.5, 0.38, desc, transform=ax.transAxes, ha="center", va="center", fontsize=10, color="#333333", multialignment="center", ) ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_visible(False) fig.suptitle("OrthoRL Anti-Reward-Hacking Safeguards", fontsize=16, fontweight="bold", y=0.98) plt.tight_layout(rect=[0, 0, 1, 0.96]) out = os.path.abspath(output_path) plt.savefig(out, dpi=150, bbox_inches="tight") plt.close(fig) return out # --------------------------------------------------------------------------- # Plot 9: Before/After Dental Arch # --------------------------------------------------------------------------- def plot_before_after( initial: np.ndarray, final: np.ndarray, target: np.ndarray, output_path: str = "results/before_after.png", ) -> str: """ Side-by-side top-down dental arch view: initial vs final stage. Parameters ---------- initial : np.ndarray (28, 7) — initial tooth poses final : np.ndarray (28, 7) — final tooth poses after 24 stages target : np.ndarray (28, 7) — target (ideal) positions Returns ------- str — absolute path of saved PNG """ from server.dental_constants import TOOTH_IDS, TOOTH_TYPES _ensure_results() _RADIUS = { "central_incisor": 3.5, "lateral_incisor": 2.8, "canine": 3.0, "premolar_1": 3.8, "premolar_2": 3.8, "molar_1": 5.5, "molar_2": 5.0, } fig, (ax_l, ax_r) = plt.subplots(1, 2, figsize=(14, 7)) fig.patch.set_facecolor("#f9f9f9") for ax, poses, title in [ (ax_l, initial, "Initial Malocclusion"), (ax_r, final, "After 24 GRPO Stages"), ]: ax.set_facecolor("#f0f4f8") ax.set_aspect("equal") for i, tid in enumerate(TOOTH_IDS): ttype = TOOTH_TYPES[tid] r = _RADIUS[ttype] / 2.0 # Target (gray dotted outline) ax.add_patch( plt.Circle( (target[i, 4], target[i, 5]), r + 0.4, facecolor="none", edgecolor="#AAAAAA", linestyle=":", linewidth=1.5, alpha=0.7, zorder=1, ) ) # Current position (colored fill) color = "#f44336" if ax is ax_l else "#4CAF50" ec = "#b71c1c" if ax is ax_l else "#1b5e20" ax.add_patch( plt.Circle( (poses[i, 4], poses[i, 5]), r, facecolor=color, edgecolor=ec, linewidth=1.5, alpha=0.7, zorder=2, ) ) ax.set_xlim(-30, 30) ax.set_ylim(-25, 15) ax.set_title( title, fontsize=14, fontweight="bold", color="#c62828" if ax is ax_l else "#1b5e20" ) ax.set_xlabel("Transverse (mm)", fontsize=10) ax.set_ylabel("Sagittal (mm)", fontsize=10) ax.grid(True, alpha=0.2) # Legend on right panel legend_elements = [ mpatches.Patch( facecolor="#4CAF50", edgecolor="#1b5e20", alpha=0.7, label="Agent placement" ), mpatches.Patch( facecolor="none", edgecolor="#AAAAAA", linestyle=":", label="Target (ideal)" ), ] ax_r.legend(handles=legend_elements, loc="upper right", fontsize=9) fig.suptitle( "OrthoRL: Tooth Alignment Before vs After Training", fontsize=15, fontweight="bold", y=1.01 ) plt.tight_layout() out = os.path.abspath(output_path) plt.savefig(out, dpi=150, bbox_inches="tight") plt.close(fig) return out # --------------------------------------------------------------------------- # Demo data generators # --------------------------------------------------------------------------- def _demo_training_data(n_steps: int = 300, rng_seed: int = 42) -> Dict[str, Any]: """Generate realistic-looking synthetic training curves.""" rng = np.random.default_rng(rng_seed) steps = list(range(0, n_steps + 1, 10)) n = len(steps) def sigmoid_curve(start: float, end: float, noise: float = 0.03) -> List[float]: xs = np.linspace(-4, 2, n) curve = start + (end - start) / (1 + np.exp(-xs)) curve += rng.normal(0, noise, n) return np.clip(curve, 0.0, 1.0).tolist() return { "steps": steps, "terminal": sigmoid_curve(0.38, 0.76, noise=0.025), "occlusion": sigmoid_curve(0.45, 0.82, noise=0.020), "strategy": sigmoid_curve(0.33, 0.78, noise=0.030), "format": sigmoid_curve(0.60, 0.97, noise=0.015), "slerp_baseline": 0.87, } def _demo_case_scores() -> Tuple[Dict[str, float], Dict[str, float]]: slerp = {"Class I": 0.88, "Class II": 0.55, "Class III": 0.60} trained = {"Class I": 0.87, "Class II": 0.75, "Class III": 0.72} return slerp, trained def _demo_exam_data() -> Tuple[List[int], List[int]]: return [0, 50, 100, 150, 200, 300], [3, 4, 5, 6, 7, 8] def _demo_difficulty_params() -> Tuple[Dict[str, float], Dict[str, float]]: axes = [ "n_perturbed\nteeth", "translation\nmagnitude", "rotation\nmagnitude", "constraint\ntightness", "compliance\nrate", "scan\nnoise", "missing\nteeth", "crowding\nseverity", ] start = {a: 0.20 for a in axes} end = { axes[0]: 0.70, axes[1]: 0.65, axes[2]: 0.55, axes[3]: 0.60, axes[4]: 0.45, axes[5]: 0.40, axes[6]: 0.25, axes[7]: 0.75, } return start, end def _demo_tool_usage() -> Dict[str, Dict[str, float]]: tools = [ "diagnose\nangle_class", "measure\ncrowding", "measure\noverbite", "inspect\ntooth", "simulate\nstep", "check\ncollisions", ] return { "Class I": {t: v for t, v in zip(tools, [0.8, 0.6, 0.5, 2.1, 1.9, 0.7])}, "Class II": {t: v for t, v in zip(tools, [2.9, 1.4, 1.2, 2.4, 2.3, 1.1])}, "Class III": {t: v for t, v in zip(tools, [2.7, 1.1, 1.5, 2.2, 2.1, 1.3])}, } def _demo_stage_rewards(n_stages: int = 24) -> Dict[str, List[float]]: rng = np.random.default_rng(0) stages = np.linspace(0, 1, n_stages) def wave(base: float, amp: float, phase: float) -> List[float]: v = base + amp * np.sin(stages * np.pi + phase) + rng.normal(0, 0.01, n_stages) return np.clip(v, 0, None).tolist() return { "Progress": wave(0.25, 0.08, 0.0), "Compliance": wave(0.20, 0.05, 0.5), "Smoothness": wave(0.15, 0.04, 1.0), "Staging": wave(0.12, 0.03, 1.5), "Occlusion": wave(0.10, 0.04, 2.0), } def _demo_poses() -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Generate synthetic initial/final/target dental arch poses.""" from server.dental_environment import StepwiseDentalEnvironment env = StepwiseDentalEnvironment() obs = env.reset(task_id="task_easy", seed=7) initial = np.array(obs["current_config"], dtype=np.float64) target = np.array(obs["target_config"], dtype=np.float64) # Simulate a "final" by SLERP-ing 85% of the way to target alpha = 0.85 final = initial.copy() for i in range(len(initial)): q0, q1 = initial[i, :4], target[i, :4] q = q0 * (1 - alpha) + q1 * alpha q /= np.linalg.norm(q) + 1e-10 final[i, :4] = q final[i, 4:] = initial[i, 4:] * (1 - alpha) + target[i, 4:] * alpha return initial, final, target # --------------------------------------------------------------------------- # Main demo runner # --------------------------------------------------------------------------- ALL_PLOTS = [ "reward_curves", "case_type_reward", "exam_curve", "difficulty_radar", "tool_heatmap", "reward_breakdown", "safeguards", "before_after", "diff_view_gif", ] def run_demo(plots: Optional[List[str]] = None) -> Dict[str, str]: """ Generate all visual artifacts using synthetic demo data. Returns ------- dict mapping plot_name → saved_path """ if plots is None: plots = ALL_PLOTS saved: Dict[str, str] = {} if "reward_curves" in plots: d = _demo_training_data() path = plot_reward_curves( d["steps"], d["terminal"], d["occlusion"], d["strategy"], d["format"], slerp_baseline=d["slerp_baseline"], ) saved["reward_curves"] = path print(f" [✓] {path}") if "case_type_reward" in plots: slerp, trained = _demo_case_scores() path = plot_case_type_comparison(slerp, trained) saved["case_type_reward"] = path print(f" [✓] {path}") if "exam_curve" in plots: checkpoints, scores = _demo_exam_data() path = plot_exam_curve(checkpoints, scores) saved["exam_curve"] = path print(f" [✓] {path}") if "difficulty_radar" in plots: start, end = _demo_difficulty_params() path = plot_difficulty_radar(start, end) saved["difficulty_radar"] = path print(f" [✓] {path}") if "tool_heatmap" in plots: usage = _demo_tool_usage() path = plot_tool_heatmap(usage) saved["tool_heatmap"] = path print(f" [✓] {path}") if "reward_breakdown" in plots: stage_data = _demo_stage_rewards() path = plot_reward_breakdown(stage_data) saved["reward_breakdown"] = path print(f" [✓] {path}") if "safeguards" in plots: path = plot_safeguards() saved["safeguards"] = path print(f" [✓] {path}") if "before_after" in plots: initial, final, target = _demo_poses() path = plot_before_after(initial, final, target) saved["before_after"] = path print(f" [✓] {path}") if "diff_view_gif" in plots: # Delegate to diff_view.py demo runner from diff_view import _run_demo gif_path = os.path.join(RESULTS_DIR, "diff_view_classII.gif") _run_demo(gif_path) saved["diff_view_gif"] = os.path.abspath(gif_path) print(f" [✓] {saved['diff_view_gif']}") return saved # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description="OrthoRL Demo Visualizations (spec 2.6)") parser.add_argument( "--demo", action="store_true", help="Generate all plots with synthetic demo data (no GPU required)", ) parser.add_argument( "--plots", nargs="+", default=None, choices=ALL_PLOTS, help="Subset of plots to generate (default: all)", ) parser.add_argument( "--results", default=None, help="Path to training results TSV (if available)" ) args = parser.parse_args() if args.demo or args.results is None: print("[demo_viz] Generating demo artifacts ...") saved = run_demo(plots=args.plots) print(f"\n[demo_viz] Generated {len(saved)} artifacts:") for name, path in saved.items(): size_kb = os.path.getsize(path) / 1024 print(f" {name:25s}: {path} ({size_kb:.0f} KB)") else: print("Training-mode generation from results TSV is not yet implemented.") print("Run with --demo for immediate output.")