orthorl / demo_viz.py
sri-manikanta's picture
Initial deploy: spec 1.5
cc2303a verified
Raw
History Blame Contribute Delete
27.1 kB
"""
demo_viz.py — Spec 2.6: Demo Orchestration — Visualization Scripts
Generates ALL visual artifacts required for the hackathon pitch and README.
Works in two modes:
- Demo mode (--demo): uses synthetic/mock data; produces all plots immediately
- Training mode: reads from actual training logs / checkpoints
Usage:
# Generate all plots with synthetic demo data (no GPU required)
python demo_viz.py --demo
# Generate from training logs TSV
python demo_viz.py --results research/results.tsv
# Generate specific plots
python demo_viz.py --demo --plots reward_curves before_after safeguards
Output:
results/reward_curves.png
results/case_type_reward.png
results/exam_curve.png
results/difficulty_radar.png
results/tool_heatmap.png
results/reward_breakdown.png
results/safeguards.png
results/before_after.png
results/diff_view_classII.gif (delegates to diff_view.py)
"""
from __future__ import annotations
import argparse
import os
from typing import Any, Dict, List, Optional, Tuple
import matplotlib
matplotlib.use("Agg")
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
# ---------------------------------------------------------------------------
# Output directory
# ---------------------------------------------------------------------------
RESULTS_DIR = "results"
def _ensure_results() -> None:
os.makedirs(RESULTS_DIR, exist_ok=True)
# ---------------------------------------------------------------------------
# Plot 1: Reward Training Curves
# ---------------------------------------------------------------------------
def plot_reward_curves(
steps: List[int],
terminal: List[float],
occlusion: List[float],
strategy: List[float],
fmt: List[float],
slerp_baseline: float = 0.87,
output_path: str = "results/reward_curves.png",
) -> str:
"""
Plot 5-line reward training curves.
Parameters
----------
steps : list of int — training step indices
terminal : list of float — terminal reward per step (in [0,1] after normalisation)
occlusion : list of float — occlusion composite score per step
strategy : list of float — strategy match reward per step
fmt : list of float — format compliance per step
slerp_baseline : float — constant SLERP reference line
output_path : str
Returns
-------
str — absolute path of saved PNG
"""
_ensure_results()
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(steps, terminal, color="#2196F3", linewidth=2.5, label="Terminal Reward")
ax.plot(steps, occlusion, color="#4CAF50", linewidth=2.5, label="Occlusion Quality")
ax.plot(steps, strategy, color="#FF9800", linewidth=2.5, label="Strategy Match")
ax.plot(steps, fmt, color="#9C27B0", linewidth=2.5, label="Format Compliance")
ax.axhline(
y=slerp_baseline,
color="gray",
linestyle="--",
linewidth=1.5,
alpha=0.7,
label=f"SLERP Baseline ({slerp_baseline:.2f})",
)
ax.set_xlabel("Training Step", fontsize=14)
ax.set_ylabel("Reward Score", fontsize=14)
ax.set_title("OrthoRL: GRPO Training Progress", fontsize=16, fontweight="bold")
ax.legend(fontsize=12, loc="upper left")
ax.grid(True, alpha=0.3)
ax.set_ylim(-0.05, 1.05)
plt.tight_layout()
out = os.path.abspath(output_path)
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
return out
# ---------------------------------------------------------------------------
# Plot 2: Per-Case-Type Performance
# ---------------------------------------------------------------------------
def plot_case_type_comparison(
slerp_scores: Dict[str, float],
trained_scores: Dict[str, float],
output_path: str = "results/case_type_reward.png",
) -> str:
"""
Grouped bar chart: SLERP vs trained agent by malocclusion class.
Parameters
----------
slerp_scores : dict mapping class name → float
trained_scores : dict mapping class name → float
Returns
-------
str — absolute path of saved PNG
"""
_ensure_results()
classes = list(slerp_scores.keys())
x = np.arange(len(classes))
width = 0.35
fig, ax = plt.subplots(figsize=(10, 6))
bars_s = ax.bar(
x - width / 2,
[slerp_scores[c] for c in classes],
width,
label="SLERP Baseline",
color="#BDBDBD",
edgecolor="white",
)
bars_t = ax.bar(
x + width / 2,
[trained_scores[c] for c in classes],
width,
label="Trained Agent",
color="#2196F3",
edgecolor="white",
)
# Annotate bars
for bar in [*bars_s, *bars_t]:
ax.annotate(
f"{bar.get_height():.2f}",
xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
xytext=(0, 4),
textcoords="offset points",
ha="center",
fontsize=10,
)
ax.set_xlabel("Malocclusion Type", fontsize=14)
ax.set_ylabel("Terminal Reward", fontsize=14)
ax.set_title("Performance by Case Type: SLERP vs Trained Agent", fontsize=14, fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(classes, fontsize=13)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3, axis="y")
ax.set_ylim(0, 1.05)
plt.tight_layout()
out = os.path.abspath(output_path)
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
return out
# ---------------------------------------------------------------------------
# Plot 3: Clinical Exam Progression
# ---------------------------------------------------------------------------
def plot_exam_curve(
checkpoints: List[int],
scores: List[int],
output_path: str = "results/exam_curve.png",
) -> str:
"""
Line plot of exam score vs training checkpoint.
Parameters
----------
checkpoints : list of int — training steps at which exam was run
scores : list of int — exam scores (0-10)
Returns
-------
str — absolute path of saved PNG
"""
_ensure_results()
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(
checkpoints,
scores,
"o-",
color="#E91E63",
linewidth=2.5,
markersize=10,
markerfacecolor="white",
markeredgewidth=2.5,
)
ax.axhline(y=2.5, color="gray", linestyle="--", alpha=0.5, label="Random Baseline (2.5/10)")
# Annotate before/after
if len(scores) >= 2:
ax.annotate(
f"Before: {scores[0]}/10",
xy=(checkpoints[0], scores[0]),
xytext=(30, 15),
textcoords="offset points",
fontsize=13,
color="#c62828",
fontweight="bold",
arrowprops=dict(arrowstyle="->", color="#c62828", lw=1.5),
)
ax.annotate(
f"After: {scores[-1]}/10",
xy=(checkpoints[-1], scores[-1]),
xytext=(-90, -30),
textcoords="offset points",
fontsize=13,
color="#1b5e20",
fontweight="bold",
arrowprops=dict(arrowstyle="->", color="#1b5e20", lw=1.5),
)
ax.set_xlabel("Training Step", fontsize=14)
ax.set_ylabel("Clinical Exam Score (/ 10)", fontsize=14)
ax.set_title(
"Orthodontic Knowledge Acquisition During RL Training", fontsize=14, fontweight="bold"
)
ax.set_ylim(0, 11)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
out = os.path.abspath(output_path)
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
return out
# ---------------------------------------------------------------------------
# Plot 4: Difficulty Progression Radar Chart
# ---------------------------------------------------------------------------
def plot_difficulty_radar(
start_params: Dict[str, float],
end_params: Dict[str, float],
output_path: str = "results/difficulty_radar.png",
) -> str:
"""
Spider/radar chart showing difficulty axis progression.
Parameters
----------
start_params : dict — axis name → normalised value [0, 1] at training start
end_params : dict — axis name → normalised value [0, 1] at training end
Returns
-------
str — absolute path of saved PNG
"""
_ensure_results()
categories = list(start_params.keys())
N = len(categories)
values_start = [start_params[c] for c in categories]
values_end = [end_params[c] for c in categories]
angles = [n / float(N) * 2 * np.pi for n in range(N)]
angles += angles[:1]
values_start = values_start + values_start[:1]
values_end = values_end + values_end[:1]
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
ax.plot(angles, values_start, "o-", linewidth=2, color="#9E9E9E", label="Start")
ax.fill(angles, values_start, alpha=0.10, color="gray")
ax.plot(angles, values_end, "o-", linewidth=2, color="#2196F3", label="After Training")
ax.fill(angles, values_end, alpha=0.15, color="#2196F3")
ax.set_xticks(angles[:-1])
ax.set_xticklabels(
[c.replace("_", "\n") for c in categories],
fontsize=11,
)
ax.set_ylim(0, 1.05)
ax.set_yticks([0.25, 0.5, 0.75, 1.0])
ax.set_yticklabels(["25%", "50%", "75%", "100%"], fontsize=8)
ax.set_title(
"Adaptive Difficulty Progression\n(8 Axes)", fontsize=14, fontweight="bold", pad=20
)
ax.legend(loc="upper right", bbox_to_anchor=(1.25, 1.1), fontsize=12)
plt.tight_layout()
out = os.path.abspath(output_path)
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
return out
# ---------------------------------------------------------------------------
# Plot 5: Tool Usage Heatmap
# ---------------------------------------------------------------------------
def plot_tool_heatmap(
usage_data: Dict[str, Dict[str, float]],
output_path: str = "results/tool_heatmap.png",
) -> str:
"""
Heatmap of average tool calls per episode, by case type.
Parameters
----------
usage_data : dict[case_type][tool_name] = avg_calls_per_episode
Returns
-------
str — absolute path of saved PNG
"""
_ensure_results()
case_types = list(usage_data.keys())
# Use union of all tool names, ordered
all_tools: List[str] = []
for tools in usage_data.values():
for t in tools:
if t not in all_tools:
all_tools.append(t)
data = np.array([[usage_data[ct].get(t, 0.0) for t in all_tools] for ct in case_types])
fig, ax = plt.subplots(figsize=(12, 4))
im = ax.imshow(data, cmap="YlOrRd", aspect="auto", vmin=0)
ax.set_xticks(range(len(all_tools)))
ax.set_xticklabels(
[t.replace("_", "\n") for t in all_tools],
rotation=0,
ha="center",
fontsize=11,
)
ax.set_yticks(range(len(case_types)))
ax.set_yticklabels(case_types, fontsize=13)
ax.set_title(
"Tool Usage by Case Type — Emergent Diagnostic Behaviour", fontsize=13, fontweight="bold"
)
for i in range(len(case_types)):
for j in range(len(all_tools)):
ax.text(
j,
i,
f"{data[i, j]:.1f}",
ha="center",
va="center",
fontsize=11,
color="black" if data[i, j] < 3 else "white",
)
cbar = plt.colorbar(im, ax=ax, fraction=0.03, pad=0.02)
cbar.set_label("Avg calls / episode", fontsize=10)
plt.tight_layout()
out = os.path.abspath(output_path)
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
return out
# ---------------------------------------------------------------------------
# Plot 6: Per-Stage Reward Breakdown (Stacked Area)
# ---------------------------------------------------------------------------
def plot_reward_breakdown(
stage_rewards: Dict[str, List[float]],
output_path: str = "results/reward_breakdown.png",
) -> str:
"""
Stacked area chart of reward components over 24 aligner stages.
Parameters
----------
stage_rewards : dict mapping component name → list of 24 floats
Returns
-------
str — absolute path of saved PNG
"""
_ensure_results()
stages = list(range(1, 25))
components = list(stage_rewards.keys())
colors = [
"#2196F3",
"#4CAF50",
"#FF9800",
"#F44336",
"#9C27B0",
"#00BCD4",
"#795548",
"#607D8B",
]
fig, ax = plt.subplots(figsize=(14, 6))
data = np.array([stage_rewards[c] for c in components])
ax.stackplot(stages, data, labels=components, colors=colors[: len(components)], alpha=0.8)
ax.set_xlabel("Aligner Stage", fontsize=14)
ax.set_ylabel("Reward Component Value", fontsize=14)
ax.set_title("Per-Stage Reward Breakdown — Class II Episode", fontsize=14, fontweight="bold")
ax.set_xlim(1, 24)
ax.legend(loc="upper left", fontsize=10, ncol=2)
ax.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
out = os.path.abspath(output_path)
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
return out
# ---------------------------------------------------------------------------
# Plot 7: Anti-Hacking Safeguards Diagram
# ---------------------------------------------------------------------------
def plot_safeguards(
output_path: str = "results/safeguards.png",
) -> str:
"""
Visual diagram of the 6 anti-reward-hacking safeguards.
Returns
-------
str — absolute path of saved PNG
"""
_ensure_results()
safeguards = [
("Format Guard", "Garbage JSON → 0 reward\n(no SLERP fallback for bad output)"),
("Progress Guard", "<10% progress by stage 6\n→ progressive penalty"),
("Tool Budget", "Max 5 tool calls / episode\n(no infinite diagnostic spam)"),
("Timeout", "30 steps / 60s wall clock\n(episode terminates early)"),
("Balanced Sampling", "Equal Class I / II / III\n(no strategy memorisation)"),
("Generation Logging", "Sample outputs every 10 steps\n(human oversight loop)"),
]
fig, axes = plt.subplots(2, 3, figsize=(14, 8))
fig.patch.set_facecolor("#f5f5f5")
colors = ["#1976D2", "#388E3C", "#F57C00", "#D32F2F", "#7B1FA2", "#00796B"]
for idx, (ax, (title, desc), color) in enumerate(zip(axes.flatten(), safeguards, colors)):
ax.set_facecolor(color + "22") # light tint
ax.add_patch(
mpatches.FancyBboxPatch(
(0.05, 0.05),
0.90,
0.90,
boxstyle="round,pad=0.05",
facecolor=color + "33",
edgecolor=color,
linewidth=2,
transform=ax.transAxes,
)
)
ax.text(
0.5,
0.72,
f"[GUARD] {title}",
transform=ax.transAxes,
ha="center",
va="center",
fontsize=13,
fontweight="bold",
color=color,
)
ax.text(
0.5,
0.38,
desc,
transform=ax.transAxes,
ha="center",
va="center",
fontsize=10,
color="#333333",
multialignment="center",
)
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
fig.suptitle("OrthoRL Anti-Reward-Hacking Safeguards", fontsize=16, fontweight="bold", y=0.98)
plt.tight_layout(rect=[0, 0, 1, 0.96])
out = os.path.abspath(output_path)
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
return out
# ---------------------------------------------------------------------------
# Plot 9: Before/After Dental Arch
# ---------------------------------------------------------------------------
def plot_before_after(
initial: np.ndarray,
final: np.ndarray,
target: np.ndarray,
output_path: str = "results/before_after.png",
) -> str:
"""
Side-by-side top-down dental arch view: initial vs final stage.
Parameters
----------
initial : np.ndarray (28, 7) — initial tooth poses
final : np.ndarray (28, 7) — final tooth poses after 24 stages
target : np.ndarray (28, 7) — target (ideal) positions
Returns
-------
str — absolute path of saved PNG
"""
from server.dental_constants import TOOTH_IDS, TOOTH_TYPES
_ensure_results()
_RADIUS = {
"central_incisor": 3.5,
"lateral_incisor": 2.8,
"canine": 3.0,
"premolar_1": 3.8,
"premolar_2": 3.8,
"molar_1": 5.5,
"molar_2": 5.0,
}
fig, (ax_l, ax_r) = plt.subplots(1, 2, figsize=(14, 7))
fig.patch.set_facecolor("#f9f9f9")
for ax, poses, title in [
(ax_l, initial, "Initial Malocclusion"),
(ax_r, final, "After 24 GRPO Stages"),
]:
ax.set_facecolor("#f0f4f8")
ax.set_aspect("equal")
for i, tid in enumerate(TOOTH_IDS):
ttype = TOOTH_TYPES[tid]
r = _RADIUS[ttype] / 2.0
# Target (gray dotted outline)
ax.add_patch(
plt.Circle(
(target[i, 4], target[i, 5]),
r + 0.4,
facecolor="none",
edgecolor="#AAAAAA",
linestyle=":",
linewidth=1.5,
alpha=0.7,
zorder=1,
)
)
# Current position (colored fill)
color = "#f44336" if ax is ax_l else "#4CAF50"
ec = "#b71c1c" if ax is ax_l else "#1b5e20"
ax.add_patch(
plt.Circle(
(poses[i, 4], poses[i, 5]),
r,
facecolor=color,
edgecolor=ec,
linewidth=1.5,
alpha=0.7,
zorder=2,
)
)
ax.set_xlim(-30, 30)
ax.set_ylim(-25, 15)
ax.set_title(
title, fontsize=14, fontweight="bold", color="#c62828" if ax is ax_l else "#1b5e20"
)
ax.set_xlabel("Transverse (mm)", fontsize=10)
ax.set_ylabel("Sagittal (mm)", fontsize=10)
ax.grid(True, alpha=0.2)
# Legend on right panel
legend_elements = [
mpatches.Patch(
facecolor="#4CAF50", edgecolor="#1b5e20", alpha=0.7, label="Agent placement"
),
mpatches.Patch(
facecolor="none", edgecolor="#AAAAAA", linestyle=":", label="Target (ideal)"
),
]
ax_r.legend(handles=legend_elements, loc="upper right", fontsize=9)
fig.suptitle(
"OrthoRL: Tooth Alignment Before vs After Training", fontsize=15, fontweight="bold", y=1.01
)
plt.tight_layout()
out = os.path.abspath(output_path)
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
return out
# ---------------------------------------------------------------------------
# Demo data generators
# ---------------------------------------------------------------------------
def _demo_training_data(n_steps: int = 300, rng_seed: int = 42) -> Dict[str, Any]:
"""Generate realistic-looking synthetic training curves."""
rng = np.random.default_rng(rng_seed)
steps = list(range(0, n_steps + 1, 10))
n = len(steps)
def sigmoid_curve(start: float, end: float, noise: float = 0.03) -> List[float]:
xs = np.linspace(-4, 2, n)
curve = start + (end - start) / (1 + np.exp(-xs))
curve += rng.normal(0, noise, n)
return np.clip(curve, 0.0, 1.0).tolist()
return {
"steps": steps,
"terminal": sigmoid_curve(0.38, 0.76, noise=0.025),
"occlusion": sigmoid_curve(0.45, 0.82, noise=0.020),
"strategy": sigmoid_curve(0.33, 0.78, noise=0.030),
"format": sigmoid_curve(0.60, 0.97, noise=0.015),
"slerp_baseline": 0.87,
}
def _demo_case_scores() -> Tuple[Dict[str, float], Dict[str, float]]:
slerp = {"Class I": 0.88, "Class II": 0.55, "Class III": 0.60}
trained = {"Class I": 0.87, "Class II": 0.75, "Class III": 0.72}
return slerp, trained
def _demo_exam_data() -> Tuple[List[int], List[int]]:
return [0, 50, 100, 150, 200, 300], [3, 4, 5, 6, 7, 8]
def _demo_difficulty_params() -> Tuple[Dict[str, float], Dict[str, float]]:
axes = [
"n_perturbed\nteeth",
"translation\nmagnitude",
"rotation\nmagnitude",
"constraint\ntightness",
"compliance\nrate",
"scan\nnoise",
"missing\nteeth",
"crowding\nseverity",
]
start = {a: 0.20 for a in axes}
end = {
axes[0]: 0.70,
axes[1]: 0.65,
axes[2]: 0.55,
axes[3]: 0.60,
axes[4]: 0.45,
axes[5]: 0.40,
axes[6]: 0.25,
axes[7]: 0.75,
}
return start, end
def _demo_tool_usage() -> Dict[str, Dict[str, float]]:
tools = [
"diagnose\nangle_class",
"measure\ncrowding",
"measure\noverbite",
"inspect\ntooth",
"simulate\nstep",
"check\ncollisions",
]
return {
"Class I": {t: v for t, v in zip(tools, [0.8, 0.6, 0.5, 2.1, 1.9, 0.7])},
"Class II": {t: v for t, v in zip(tools, [2.9, 1.4, 1.2, 2.4, 2.3, 1.1])},
"Class III": {t: v for t, v in zip(tools, [2.7, 1.1, 1.5, 2.2, 2.1, 1.3])},
}
def _demo_stage_rewards(n_stages: int = 24) -> Dict[str, List[float]]:
rng = np.random.default_rng(0)
stages = np.linspace(0, 1, n_stages)
def wave(base: float, amp: float, phase: float) -> List[float]:
v = base + amp * np.sin(stages * np.pi + phase) + rng.normal(0, 0.01, n_stages)
return np.clip(v, 0, None).tolist()
return {
"Progress": wave(0.25, 0.08, 0.0),
"Compliance": wave(0.20, 0.05, 0.5),
"Smoothness": wave(0.15, 0.04, 1.0),
"Staging": wave(0.12, 0.03, 1.5),
"Occlusion": wave(0.10, 0.04, 2.0),
}
def _demo_poses() -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Generate synthetic initial/final/target dental arch poses."""
from server.dental_environment import StepwiseDentalEnvironment
env = StepwiseDentalEnvironment()
obs = env.reset(task_id="task_easy", seed=7)
initial = np.array(obs["current_config"], dtype=np.float64)
target = np.array(obs["target_config"], dtype=np.float64)
# Simulate a "final" by SLERP-ing 85% of the way to target
alpha = 0.85
final = initial.copy()
for i in range(len(initial)):
q0, q1 = initial[i, :4], target[i, :4]
q = q0 * (1 - alpha) + q1 * alpha
q /= np.linalg.norm(q) + 1e-10
final[i, :4] = q
final[i, 4:] = initial[i, 4:] * (1 - alpha) + target[i, 4:] * alpha
return initial, final, target
# ---------------------------------------------------------------------------
# Main demo runner
# ---------------------------------------------------------------------------
ALL_PLOTS = [
"reward_curves",
"case_type_reward",
"exam_curve",
"difficulty_radar",
"tool_heatmap",
"reward_breakdown",
"safeguards",
"before_after",
"diff_view_gif",
]
def run_demo(plots: Optional[List[str]] = None) -> Dict[str, str]:
"""
Generate all visual artifacts using synthetic demo data.
Returns
-------
dict mapping plot_name → saved_path
"""
if plots is None:
plots = ALL_PLOTS
saved: Dict[str, str] = {}
if "reward_curves" in plots:
d = _demo_training_data()
path = plot_reward_curves(
d["steps"],
d["terminal"],
d["occlusion"],
d["strategy"],
d["format"],
slerp_baseline=d["slerp_baseline"],
)
saved["reward_curves"] = path
print(f" [✓] {path}")
if "case_type_reward" in plots:
slerp, trained = _demo_case_scores()
path = plot_case_type_comparison(slerp, trained)
saved["case_type_reward"] = path
print(f" [✓] {path}")
if "exam_curve" in plots:
checkpoints, scores = _demo_exam_data()
path = plot_exam_curve(checkpoints, scores)
saved["exam_curve"] = path
print(f" [✓] {path}")
if "difficulty_radar" in plots:
start, end = _demo_difficulty_params()
path = plot_difficulty_radar(start, end)
saved["difficulty_radar"] = path
print(f" [✓] {path}")
if "tool_heatmap" in plots:
usage = _demo_tool_usage()
path = plot_tool_heatmap(usage)
saved["tool_heatmap"] = path
print(f" [✓] {path}")
if "reward_breakdown" in plots:
stage_data = _demo_stage_rewards()
path = plot_reward_breakdown(stage_data)
saved["reward_breakdown"] = path
print(f" [✓] {path}")
if "safeguards" in plots:
path = plot_safeguards()
saved["safeguards"] = path
print(f" [✓] {path}")
if "before_after" in plots:
initial, final, target = _demo_poses()
path = plot_before_after(initial, final, target)
saved["before_after"] = path
print(f" [✓] {path}")
if "diff_view_gif" in plots:
# Delegate to diff_view.py demo runner
from diff_view import _run_demo
gif_path = os.path.join(RESULTS_DIR, "diff_view_classII.gif")
_run_demo(gif_path)
saved["diff_view_gif"] = os.path.abspath(gif_path)
print(f" [✓] {saved['diff_view_gif']}")
return saved
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="OrthoRL Demo Visualizations (spec 2.6)")
parser.add_argument(
"--demo",
action="store_true",
help="Generate all plots with synthetic demo data (no GPU required)",
)
parser.add_argument(
"--plots",
nargs="+",
default=None,
choices=ALL_PLOTS,
help="Subset of plots to generate (default: all)",
)
parser.add_argument(
"--results", default=None, help="Path to training results TSV (if available)"
)
args = parser.parse_args()
if args.demo or args.results is None:
print("[demo_viz] Generating demo artifacts ...")
saved = run_demo(plots=args.plots)
print(f"\n[demo_viz] Generated {len(saved)} artifacts:")
for name, path in saved.items():
size_kb = os.path.getsize(path) / 1024
print(f" {name:25s}: {path} ({size_kb:.0f} KB)")
else:
print("Training-mode generation from results TSV is not yet implemented.")
print("Run with --demo for immediate output.")