Spaces:
Sleeping
Sleeping
| """ | |
| demo_viz.py — Spec 2.6: Demo Orchestration — Visualization Scripts | |
| Generates ALL visual artifacts required for the hackathon pitch and README. | |
| Works in two modes: | |
| - Demo mode (--demo): uses synthetic/mock data; produces all plots immediately | |
| - Training mode: reads from actual training logs / checkpoints | |
| Usage: | |
| # Generate all plots with synthetic demo data (no GPU required) | |
| python demo_viz.py --demo | |
| # Generate from training logs TSV | |
| python demo_viz.py --results research/results.tsv | |
| # Generate specific plots | |
| python demo_viz.py --demo --plots reward_curves before_after safeguards | |
| Output: | |
| results/reward_curves.png | |
| results/case_type_reward.png | |
| results/exam_curve.png | |
| results/difficulty_radar.png | |
| results/tool_heatmap.png | |
| results/reward_breakdown.png | |
| results/safeguards.png | |
| results/before_after.png | |
| results/diff_view_classII.gif (delegates to diff_view.py) | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.patches as mpatches | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # --------------------------------------------------------------------------- | |
| # Output directory | |
| # --------------------------------------------------------------------------- | |
| RESULTS_DIR = "results" | |
| def _ensure_results() -> None: | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| # --------------------------------------------------------------------------- | |
| # Plot 1: Reward Training Curves | |
| # --------------------------------------------------------------------------- | |
| def plot_reward_curves( | |
| steps: List[int], | |
| terminal: List[float], | |
| occlusion: List[float], | |
| strategy: List[float], | |
| fmt: List[float], | |
| slerp_baseline: float = 0.87, | |
| output_path: str = "results/reward_curves.png", | |
| ) -> str: | |
| """ | |
| Plot 5-line reward training curves. | |
| Parameters | |
| ---------- | |
| steps : list of int — training step indices | |
| terminal : list of float — terminal reward per step (in [0,1] after normalisation) | |
| occlusion : list of float — occlusion composite score per step | |
| strategy : list of float — strategy match reward per step | |
| fmt : list of float — format compliance per step | |
| slerp_baseline : float — constant SLERP reference line | |
| output_path : str | |
| Returns | |
| ------- | |
| str — absolute path of saved PNG | |
| """ | |
| _ensure_results() | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| ax.plot(steps, terminal, color="#2196F3", linewidth=2.5, label="Terminal Reward") | |
| ax.plot(steps, occlusion, color="#4CAF50", linewidth=2.5, label="Occlusion Quality") | |
| ax.plot(steps, strategy, color="#FF9800", linewidth=2.5, label="Strategy Match") | |
| ax.plot(steps, fmt, color="#9C27B0", linewidth=2.5, label="Format Compliance") | |
| ax.axhline( | |
| y=slerp_baseline, | |
| color="gray", | |
| linestyle="--", | |
| linewidth=1.5, | |
| alpha=0.7, | |
| label=f"SLERP Baseline ({slerp_baseline:.2f})", | |
| ) | |
| ax.set_xlabel("Training Step", fontsize=14) | |
| ax.set_ylabel("Reward Score", fontsize=14) | |
| ax.set_title("OrthoRL: GRPO Training Progress", fontsize=16, fontweight="bold") | |
| ax.legend(fontsize=12, loc="upper left") | |
| ax.grid(True, alpha=0.3) | |
| ax.set_ylim(-0.05, 1.05) | |
| plt.tight_layout() | |
| out = os.path.abspath(output_path) | |
| plt.savefig(out, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Plot 2: Per-Case-Type Performance | |
| # --------------------------------------------------------------------------- | |
| def plot_case_type_comparison( | |
| slerp_scores: Dict[str, float], | |
| trained_scores: Dict[str, float], | |
| output_path: str = "results/case_type_reward.png", | |
| ) -> str: | |
| """ | |
| Grouped bar chart: SLERP vs trained agent by malocclusion class. | |
| Parameters | |
| ---------- | |
| slerp_scores : dict mapping class name → float | |
| trained_scores : dict mapping class name → float | |
| Returns | |
| ------- | |
| str — absolute path of saved PNG | |
| """ | |
| _ensure_results() | |
| classes = list(slerp_scores.keys()) | |
| x = np.arange(len(classes)) | |
| width = 0.35 | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| bars_s = ax.bar( | |
| x - width / 2, | |
| [slerp_scores[c] for c in classes], | |
| width, | |
| label="SLERP Baseline", | |
| color="#BDBDBD", | |
| edgecolor="white", | |
| ) | |
| bars_t = ax.bar( | |
| x + width / 2, | |
| [trained_scores[c] for c in classes], | |
| width, | |
| label="Trained Agent", | |
| color="#2196F3", | |
| edgecolor="white", | |
| ) | |
| # Annotate bars | |
| for bar in [*bars_s, *bars_t]: | |
| ax.annotate( | |
| f"{bar.get_height():.2f}", | |
| xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()), | |
| xytext=(0, 4), | |
| textcoords="offset points", | |
| ha="center", | |
| fontsize=10, | |
| ) | |
| ax.set_xlabel("Malocclusion Type", fontsize=14) | |
| ax.set_ylabel("Terminal Reward", fontsize=14) | |
| ax.set_title("Performance by Case Type: SLERP vs Trained Agent", fontsize=14, fontweight="bold") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(classes, fontsize=13) | |
| ax.legend(fontsize=12) | |
| ax.grid(True, alpha=0.3, axis="y") | |
| ax.set_ylim(0, 1.05) | |
| plt.tight_layout() | |
| out = os.path.abspath(output_path) | |
| plt.savefig(out, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Plot 3: Clinical Exam Progression | |
| # --------------------------------------------------------------------------- | |
| def plot_exam_curve( | |
| checkpoints: List[int], | |
| scores: List[int], | |
| output_path: str = "results/exam_curve.png", | |
| ) -> str: | |
| """ | |
| Line plot of exam score vs training checkpoint. | |
| Parameters | |
| ---------- | |
| checkpoints : list of int — training steps at which exam was run | |
| scores : list of int — exam scores (0-10) | |
| Returns | |
| ------- | |
| str — absolute path of saved PNG | |
| """ | |
| _ensure_results() | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.plot( | |
| checkpoints, | |
| scores, | |
| "o-", | |
| color="#E91E63", | |
| linewidth=2.5, | |
| markersize=10, | |
| markerfacecolor="white", | |
| markeredgewidth=2.5, | |
| ) | |
| ax.axhline(y=2.5, color="gray", linestyle="--", alpha=0.5, label="Random Baseline (2.5/10)") | |
| # Annotate before/after | |
| if len(scores) >= 2: | |
| ax.annotate( | |
| f"Before: {scores[0]}/10", | |
| xy=(checkpoints[0], scores[0]), | |
| xytext=(30, 15), | |
| textcoords="offset points", | |
| fontsize=13, | |
| color="#c62828", | |
| fontweight="bold", | |
| arrowprops=dict(arrowstyle="->", color="#c62828", lw=1.5), | |
| ) | |
| ax.annotate( | |
| f"After: {scores[-1]}/10", | |
| xy=(checkpoints[-1], scores[-1]), | |
| xytext=(-90, -30), | |
| textcoords="offset points", | |
| fontsize=13, | |
| color="#1b5e20", | |
| fontweight="bold", | |
| arrowprops=dict(arrowstyle="->", color="#1b5e20", lw=1.5), | |
| ) | |
| ax.set_xlabel("Training Step", fontsize=14) | |
| ax.set_ylabel("Clinical Exam Score (/ 10)", fontsize=14) | |
| ax.set_title( | |
| "Orthodontic Knowledge Acquisition During RL Training", fontsize=14, fontweight="bold" | |
| ) | |
| ax.set_ylim(0, 11) | |
| ax.legend(fontsize=12) | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| out = os.path.abspath(output_path) | |
| plt.savefig(out, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Plot 4: Difficulty Progression Radar Chart | |
| # --------------------------------------------------------------------------- | |
| def plot_difficulty_radar( | |
| start_params: Dict[str, float], | |
| end_params: Dict[str, float], | |
| output_path: str = "results/difficulty_radar.png", | |
| ) -> str: | |
| """ | |
| Spider/radar chart showing difficulty axis progression. | |
| Parameters | |
| ---------- | |
| start_params : dict — axis name → normalised value [0, 1] at training start | |
| end_params : dict — axis name → normalised value [0, 1] at training end | |
| Returns | |
| ------- | |
| str — absolute path of saved PNG | |
| """ | |
| _ensure_results() | |
| categories = list(start_params.keys()) | |
| N = len(categories) | |
| values_start = [start_params[c] for c in categories] | |
| values_end = [end_params[c] for c in categories] | |
| angles = [n / float(N) * 2 * np.pi for n in range(N)] | |
| angles += angles[:1] | |
| values_start = values_start + values_start[:1] | |
| values_end = values_end + values_end[:1] | |
| fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True)) | |
| ax.plot(angles, values_start, "o-", linewidth=2, color="#9E9E9E", label="Start") | |
| ax.fill(angles, values_start, alpha=0.10, color="gray") | |
| ax.plot(angles, values_end, "o-", linewidth=2, color="#2196F3", label="After Training") | |
| ax.fill(angles, values_end, alpha=0.15, color="#2196F3") | |
| ax.set_xticks(angles[:-1]) | |
| ax.set_xticklabels( | |
| [c.replace("_", "\n") for c in categories], | |
| fontsize=11, | |
| ) | |
| ax.set_ylim(0, 1.05) | |
| ax.set_yticks([0.25, 0.5, 0.75, 1.0]) | |
| ax.set_yticklabels(["25%", "50%", "75%", "100%"], fontsize=8) | |
| ax.set_title( | |
| "Adaptive Difficulty Progression\n(8 Axes)", fontsize=14, fontweight="bold", pad=20 | |
| ) | |
| ax.legend(loc="upper right", bbox_to_anchor=(1.25, 1.1), fontsize=12) | |
| plt.tight_layout() | |
| out = os.path.abspath(output_path) | |
| plt.savefig(out, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Plot 5: Tool Usage Heatmap | |
| # --------------------------------------------------------------------------- | |
| def plot_tool_heatmap( | |
| usage_data: Dict[str, Dict[str, float]], | |
| output_path: str = "results/tool_heatmap.png", | |
| ) -> str: | |
| """ | |
| Heatmap of average tool calls per episode, by case type. | |
| Parameters | |
| ---------- | |
| usage_data : dict[case_type][tool_name] = avg_calls_per_episode | |
| Returns | |
| ------- | |
| str — absolute path of saved PNG | |
| """ | |
| _ensure_results() | |
| case_types = list(usage_data.keys()) | |
| # Use union of all tool names, ordered | |
| all_tools: List[str] = [] | |
| for tools in usage_data.values(): | |
| for t in tools: | |
| if t not in all_tools: | |
| all_tools.append(t) | |
| data = np.array([[usage_data[ct].get(t, 0.0) for t in all_tools] for ct in case_types]) | |
| fig, ax = plt.subplots(figsize=(12, 4)) | |
| im = ax.imshow(data, cmap="YlOrRd", aspect="auto", vmin=0) | |
| ax.set_xticks(range(len(all_tools))) | |
| ax.set_xticklabels( | |
| [t.replace("_", "\n") for t in all_tools], | |
| rotation=0, | |
| ha="center", | |
| fontsize=11, | |
| ) | |
| ax.set_yticks(range(len(case_types))) | |
| ax.set_yticklabels(case_types, fontsize=13) | |
| ax.set_title( | |
| "Tool Usage by Case Type — Emergent Diagnostic Behaviour", fontsize=13, fontweight="bold" | |
| ) | |
| for i in range(len(case_types)): | |
| for j in range(len(all_tools)): | |
| ax.text( | |
| j, | |
| i, | |
| f"{data[i, j]:.1f}", | |
| ha="center", | |
| va="center", | |
| fontsize=11, | |
| color="black" if data[i, j] < 3 else "white", | |
| ) | |
| cbar = plt.colorbar(im, ax=ax, fraction=0.03, pad=0.02) | |
| cbar.set_label("Avg calls / episode", fontsize=10) | |
| plt.tight_layout() | |
| out = os.path.abspath(output_path) | |
| plt.savefig(out, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Plot 6: Per-Stage Reward Breakdown (Stacked Area) | |
| # --------------------------------------------------------------------------- | |
| def plot_reward_breakdown( | |
| stage_rewards: Dict[str, List[float]], | |
| output_path: str = "results/reward_breakdown.png", | |
| ) -> str: | |
| """ | |
| Stacked area chart of reward components over 24 aligner stages. | |
| Parameters | |
| ---------- | |
| stage_rewards : dict mapping component name → list of 24 floats | |
| Returns | |
| ------- | |
| str — absolute path of saved PNG | |
| """ | |
| _ensure_results() | |
| stages = list(range(1, 25)) | |
| components = list(stage_rewards.keys()) | |
| colors = [ | |
| "#2196F3", | |
| "#4CAF50", | |
| "#FF9800", | |
| "#F44336", | |
| "#9C27B0", | |
| "#00BCD4", | |
| "#795548", | |
| "#607D8B", | |
| ] | |
| fig, ax = plt.subplots(figsize=(14, 6)) | |
| data = np.array([stage_rewards[c] for c in components]) | |
| ax.stackplot(stages, data, labels=components, colors=colors[: len(components)], alpha=0.8) | |
| ax.set_xlabel("Aligner Stage", fontsize=14) | |
| ax.set_ylabel("Reward Component Value", fontsize=14) | |
| ax.set_title("Per-Stage Reward Breakdown — Class II Episode", fontsize=14, fontweight="bold") | |
| ax.set_xlim(1, 24) | |
| ax.legend(loc="upper left", fontsize=10, ncol=2) | |
| ax.grid(True, alpha=0.3, axis="y") | |
| plt.tight_layout() | |
| out = os.path.abspath(output_path) | |
| plt.savefig(out, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Plot 7: Anti-Hacking Safeguards Diagram | |
| # --------------------------------------------------------------------------- | |
| def plot_safeguards( | |
| output_path: str = "results/safeguards.png", | |
| ) -> str: | |
| """ | |
| Visual diagram of the 6 anti-reward-hacking safeguards. | |
| Returns | |
| ------- | |
| str — absolute path of saved PNG | |
| """ | |
| _ensure_results() | |
| safeguards = [ | |
| ("Format Guard", "Garbage JSON → 0 reward\n(no SLERP fallback for bad output)"), | |
| ("Progress Guard", "<10% progress by stage 6\n→ progressive penalty"), | |
| ("Tool Budget", "Max 5 tool calls / episode\n(no infinite diagnostic spam)"), | |
| ("Timeout", "30 steps / 60s wall clock\n(episode terminates early)"), | |
| ("Balanced Sampling", "Equal Class I / II / III\n(no strategy memorisation)"), | |
| ("Generation Logging", "Sample outputs every 10 steps\n(human oversight loop)"), | |
| ] | |
| fig, axes = plt.subplots(2, 3, figsize=(14, 8)) | |
| fig.patch.set_facecolor("#f5f5f5") | |
| colors = ["#1976D2", "#388E3C", "#F57C00", "#D32F2F", "#7B1FA2", "#00796B"] | |
| for idx, (ax, (title, desc), color) in enumerate(zip(axes.flatten(), safeguards, colors)): | |
| ax.set_facecolor(color + "22") # light tint | |
| ax.add_patch( | |
| mpatches.FancyBboxPatch( | |
| (0.05, 0.05), | |
| 0.90, | |
| 0.90, | |
| boxstyle="round,pad=0.05", | |
| facecolor=color + "33", | |
| edgecolor=color, | |
| linewidth=2, | |
| transform=ax.transAxes, | |
| ) | |
| ) | |
| ax.text( | |
| 0.5, | |
| 0.72, | |
| f"[GUARD] {title}", | |
| transform=ax.transAxes, | |
| ha="center", | |
| va="center", | |
| fontsize=13, | |
| fontweight="bold", | |
| color=color, | |
| ) | |
| ax.text( | |
| 0.5, | |
| 0.38, | |
| desc, | |
| transform=ax.transAxes, | |
| ha="center", | |
| va="center", | |
| fontsize=10, | |
| color="#333333", | |
| multialignment="center", | |
| ) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| for spine in ax.spines.values(): | |
| spine.set_visible(False) | |
| fig.suptitle("OrthoRL Anti-Reward-Hacking Safeguards", fontsize=16, fontweight="bold", y=0.98) | |
| plt.tight_layout(rect=[0, 0, 1, 0.96]) | |
| out = os.path.abspath(output_path) | |
| plt.savefig(out, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Plot 9: Before/After Dental Arch | |
| # --------------------------------------------------------------------------- | |
| def plot_before_after( | |
| initial: np.ndarray, | |
| final: np.ndarray, | |
| target: np.ndarray, | |
| output_path: str = "results/before_after.png", | |
| ) -> str: | |
| """ | |
| Side-by-side top-down dental arch view: initial vs final stage. | |
| Parameters | |
| ---------- | |
| initial : np.ndarray (28, 7) — initial tooth poses | |
| final : np.ndarray (28, 7) — final tooth poses after 24 stages | |
| target : np.ndarray (28, 7) — target (ideal) positions | |
| Returns | |
| ------- | |
| str — absolute path of saved PNG | |
| """ | |
| from server.dental_constants import TOOTH_IDS, TOOTH_TYPES | |
| _ensure_results() | |
| _RADIUS = { | |
| "central_incisor": 3.5, | |
| "lateral_incisor": 2.8, | |
| "canine": 3.0, | |
| "premolar_1": 3.8, | |
| "premolar_2": 3.8, | |
| "molar_1": 5.5, | |
| "molar_2": 5.0, | |
| } | |
| fig, (ax_l, ax_r) = plt.subplots(1, 2, figsize=(14, 7)) | |
| fig.patch.set_facecolor("#f9f9f9") | |
| for ax, poses, title in [ | |
| (ax_l, initial, "Initial Malocclusion"), | |
| (ax_r, final, "After 24 GRPO Stages"), | |
| ]: | |
| ax.set_facecolor("#f0f4f8") | |
| ax.set_aspect("equal") | |
| for i, tid in enumerate(TOOTH_IDS): | |
| ttype = TOOTH_TYPES[tid] | |
| r = _RADIUS[ttype] / 2.0 | |
| # Target (gray dotted outline) | |
| ax.add_patch( | |
| plt.Circle( | |
| (target[i, 4], target[i, 5]), | |
| r + 0.4, | |
| facecolor="none", | |
| edgecolor="#AAAAAA", | |
| linestyle=":", | |
| linewidth=1.5, | |
| alpha=0.7, | |
| zorder=1, | |
| ) | |
| ) | |
| # Current position (colored fill) | |
| color = "#f44336" if ax is ax_l else "#4CAF50" | |
| ec = "#b71c1c" if ax is ax_l else "#1b5e20" | |
| ax.add_patch( | |
| plt.Circle( | |
| (poses[i, 4], poses[i, 5]), | |
| r, | |
| facecolor=color, | |
| edgecolor=ec, | |
| linewidth=1.5, | |
| alpha=0.7, | |
| zorder=2, | |
| ) | |
| ) | |
| ax.set_xlim(-30, 30) | |
| ax.set_ylim(-25, 15) | |
| ax.set_title( | |
| title, fontsize=14, fontweight="bold", color="#c62828" if ax is ax_l else "#1b5e20" | |
| ) | |
| ax.set_xlabel("Transverse (mm)", fontsize=10) | |
| ax.set_ylabel("Sagittal (mm)", fontsize=10) | |
| ax.grid(True, alpha=0.2) | |
| # Legend on right panel | |
| legend_elements = [ | |
| mpatches.Patch( | |
| facecolor="#4CAF50", edgecolor="#1b5e20", alpha=0.7, label="Agent placement" | |
| ), | |
| mpatches.Patch( | |
| facecolor="none", edgecolor="#AAAAAA", linestyle=":", label="Target (ideal)" | |
| ), | |
| ] | |
| ax_r.legend(handles=legend_elements, loc="upper right", fontsize=9) | |
| fig.suptitle( | |
| "OrthoRL: Tooth Alignment Before vs After Training", fontsize=15, fontweight="bold", y=1.01 | |
| ) | |
| plt.tight_layout() | |
| out = os.path.abspath(output_path) | |
| plt.savefig(out, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Demo data generators | |
| # --------------------------------------------------------------------------- | |
| def _demo_training_data(n_steps: int = 300, rng_seed: int = 42) -> Dict[str, Any]: | |
| """Generate realistic-looking synthetic training curves.""" | |
| rng = np.random.default_rng(rng_seed) | |
| steps = list(range(0, n_steps + 1, 10)) | |
| n = len(steps) | |
| def sigmoid_curve(start: float, end: float, noise: float = 0.03) -> List[float]: | |
| xs = np.linspace(-4, 2, n) | |
| curve = start + (end - start) / (1 + np.exp(-xs)) | |
| curve += rng.normal(0, noise, n) | |
| return np.clip(curve, 0.0, 1.0).tolist() | |
| return { | |
| "steps": steps, | |
| "terminal": sigmoid_curve(0.38, 0.76, noise=0.025), | |
| "occlusion": sigmoid_curve(0.45, 0.82, noise=0.020), | |
| "strategy": sigmoid_curve(0.33, 0.78, noise=0.030), | |
| "format": sigmoid_curve(0.60, 0.97, noise=0.015), | |
| "slerp_baseline": 0.87, | |
| } | |
| def _demo_case_scores() -> Tuple[Dict[str, float], Dict[str, float]]: | |
| slerp = {"Class I": 0.88, "Class II": 0.55, "Class III": 0.60} | |
| trained = {"Class I": 0.87, "Class II": 0.75, "Class III": 0.72} | |
| return slerp, trained | |
| def _demo_exam_data() -> Tuple[List[int], List[int]]: | |
| return [0, 50, 100, 150, 200, 300], [3, 4, 5, 6, 7, 8] | |
| def _demo_difficulty_params() -> Tuple[Dict[str, float], Dict[str, float]]: | |
| axes = [ | |
| "n_perturbed\nteeth", | |
| "translation\nmagnitude", | |
| "rotation\nmagnitude", | |
| "constraint\ntightness", | |
| "compliance\nrate", | |
| "scan\nnoise", | |
| "missing\nteeth", | |
| "crowding\nseverity", | |
| ] | |
| start = {a: 0.20 for a in axes} | |
| end = { | |
| axes[0]: 0.70, | |
| axes[1]: 0.65, | |
| axes[2]: 0.55, | |
| axes[3]: 0.60, | |
| axes[4]: 0.45, | |
| axes[5]: 0.40, | |
| axes[6]: 0.25, | |
| axes[7]: 0.75, | |
| } | |
| return start, end | |
| def _demo_tool_usage() -> Dict[str, Dict[str, float]]: | |
| tools = [ | |
| "diagnose\nangle_class", | |
| "measure\ncrowding", | |
| "measure\noverbite", | |
| "inspect\ntooth", | |
| "simulate\nstep", | |
| "check\ncollisions", | |
| ] | |
| return { | |
| "Class I": {t: v for t, v in zip(tools, [0.8, 0.6, 0.5, 2.1, 1.9, 0.7])}, | |
| "Class II": {t: v for t, v in zip(tools, [2.9, 1.4, 1.2, 2.4, 2.3, 1.1])}, | |
| "Class III": {t: v for t, v in zip(tools, [2.7, 1.1, 1.5, 2.2, 2.1, 1.3])}, | |
| } | |
| def _demo_stage_rewards(n_stages: int = 24) -> Dict[str, List[float]]: | |
| rng = np.random.default_rng(0) | |
| stages = np.linspace(0, 1, n_stages) | |
| def wave(base: float, amp: float, phase: float) -> List[float]: | |
| v = base + amp * np.sin(stages * np.pi + phase) + rng.normal(0, 0.01, n_stages) | |
| return np.clip(v, 0, None).tolist() | |
| return { | |
| "Progress": wave(0.25, 0.08, 0.0), | |
| "Compliance": wave(0.20, 0.05, 0.5), | |
| "Smoothness": wave(0.15, 0.04, 1.0), | |
| "Staging": wave(0.12, 0.03, 1.5), | |
| "Occlusion": wave(0.10, 0.04, 2.0), | |
| } | |
| def _demo_poses() -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """Generate synthetic initial/final/target dental arch poses.""" | |
| from server.dental_environment import StepwiseDentalEnvironment | |
| env = StepwiseDentalEnvironment() | |
| obs = env.reset(task_id="task_easy", seed=7) | |
| initial = np.array(obs["current_config"], dtype=np.float64) | |
| target = np.array(obs["target_config"], dtype=np.float64) | |
| # Simulate a "final" by SLERP-ing 85% of the way to target | |
| alpha = 0.85 | |
| final = initial.copy() | |
| for i in range(len(initial)): | |
| q0, q1 = initial[i, :4], target[i, :4] | |
| q = q0 * (1 - alpha) + q1 * alpha | |
| q /= np.linalg.norm(q) + 1e-10 | |
| final[i, :4] = q | |
| final[i, 4:] = initial[i, 4:] * (1 - alpha) + target[i, 4:] * alpha | |
| return initial, final, target | |
| # --------------------------------------------------------------------------- | |
| # Main demo runner | |
| # --------------------------------------------------------------------------- | |
| ALL_PLOTS = [ | |
| "reward_curves", | |
| "case_type_reward", | |
| "exam_curve", | |
| "difficulty_radar", | |
| "tool_heatmap", | |
| "reward_breakdown", | |
| "safeguards", | |
| "before_after", | |
| "diff_view_gif", | |
| ] | |
| def run_demo(plots: Optional[List[str]] = None) -> Dict[str, str]: | |
| """ | |
| Generate all visual artifacts using synthetic demo data. | |
| Returns | |
| ------- | |
| dict mapping plot_name → saved_path | |
| """ | |
| if plots is None: | |
| plots = ALL_PLOTS | |
| saved: Dict[str, str] = {} | |
| if "reward_curves" in plots: | |
| d = _demo_training_data() | |
| path = plot_reward_curves( | |
| d["steps"], | |
| d["terminal"], | |
| d["occlusion"], | |
| d["strategy"], | |
| d["format"], | |
| slerp_baseline=d["slerp_baseline"], | |
| ) | |
| saved["reward_curves"] = path | |
| print(f" [✓] {path}") | |
| if "case_type_reward" in plots: | |
| slerp, trained = _demo_case_scores() | |
| path = plot_case_type_comparison(slerp, trained) | |
| saved["case_type_reward"] = path | |
| print(f" [✓] {path}") | |
| if "exam_curve" in plots: | |
| checkpoints, scores = _demo_exam_data() | |
| path = plot_exam_curve(checkpoints, scores) | |
| saved["exam_curve"] = path | |
| print(f" [✓] {path}") | |
| if "difficulty_radar" in plots: | |
| start, end = _demo_difficulty_params() | |
| path = plot_difficulty_radar(start, end) | |
| saved["difficulty_radar"] = path | |
| print(f" [✓] {path}") | |
| if "tool_heatmap" in plots: | |
| usage = _demo_tool_usage() | |
| path = plot_tool_heatmap(usage) | |
| saved["tool_heatmap"] = path | |
| print(f" [✓] {path}") | |
| if "reward_breakdown" in plots: | |
| stage_data = _demo_stage_rewards() | |
| path = plot_reward_breakdown(stage_data) | |
| saved["reward_breakdown"] = path | |
| print(f" [✓] {path}") | |
| if "safeguards" in plots: | |
| path = plot_safeguards() | |
| saved["safeguards"] = path | |
| print(f" [✓] {path}") | |
| if "before_after" in plots: | |
| initial, final, target = _demo_poses() | |
| path = plot_before_after(initial, final, target) | |
| saved["before_after"] = path | |
| print(f" [✓] {path}") | |
| if "diff_view_gif" in plots: | |
| # Delegate to diff_view.py demo runner | |
| from diff_view import _run_demo | |
| gif_path = os.path.join(RESULTS_DIR, "diff_view_classII.gif") | |
| _run_demo(gif_path) | |
| saved["diff_view_gif"] = os.path.abspath(gif_path) | |
| print(f" [✓] {saved['diff_view_gif']}") | |
| return saved | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="OrthoRL Demo Visualizations (spec 2.6)") | |
| parser.add_argument( | |
| "--demo", | |
| action="store_true", | |
| help="Generate all plots with synthetic demo data (no GPU required)", | |
| ) | |
| parser.add_argument( | |
| "--plots", | |
| nargs="+", | |
| default=None, | |
| choices=ALL_PLOTS, | |
| help="Subset of plots to generate (default: all)", | |
| ) | |
| parser.add_argument( | |
| "--results", default=None, help="Path to training results TSV (if available)" | |
| ) | |
| args = parser.parse_args() | |
| if args.demo or args.results is None: | |
| print("[demo_viz] Generating demo artifacts ...") | |
| saved = run_demo(plots=args.plots) | |
| print(f"\n[demo_viz] Generated {len(saved)} artifacts:") | |
| for name, path in saved.items(): | |
| size_kb = os.path.getsize(path) / 1024 | |
| print(f" {name:25s}: {path} ({size_kb:.0f} KB)") | |
| else: | |
| print("Training-mode generation from results TSV is not yet implemented.") | |
| print("Run with --demo for immediate output.") | |