""" diff_view.py — Spec 2.3: Diff-View GIF (Trained vs SLERP Baseline Overlay) Renders an animated top-down dental arch showing: - Green circles: trained agent tooth positions - Red circles: SLERP baseline positions - Blue arrows: deviation vectors (only when deviation > ARROW_THRESHOLD mm) Usage: # Verify without a trained model (uses SLERP baseline + synthetic perturbation) python diff_view.py --demo # With a trained model checkpoint python diff_view.py --case class_ii --model ./checkpoints/final # From two pre-saved trajectory JSON files python diff_view.py --trained trained.json --slerp slerp.json --output results/diff_view_classII.gif Output: results/diff_view_classII.gif (or --output path) """ from __future__ import annotations import argparse import io import json import math import os from typing import List, Optional, Tuple import matplotlib matplotlib.use("Agg") # headless rendering import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np from PIL import Image from server.dental_constants import N_STAGES, N_TEETH, TOOTH_IDS, TOOTH_TYPES # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- ARROW_THRESHOLD_MM = 0.05 # min deviation to draw an arrow (filter noise) FRAME_DURATION_MS = 500 # ms per frame in the GIF FIG_SIZE = (7, 7) # Tooth colours by type for arch outline _TYPE_COLOR = { "central_incisor": "#cccccc", "lateral_incisor": "#dddddd", "canine": "#bbbbbb", "premolar_1": "#aaaaaa", "premolar_2": "#aaaaaa", "molar_1": "#999999", "molar_2": "#888888", } # Per-tooth approximate width (mm) for circle radius _TYPE_RADIUS = { "central_incisor": 3.5, "lateral_incisor": 2.8, "canine": 3.0, "premolar_1": 3.8, "premolar_2": 3.8, "molar_1": 5.5, "molar_2": 5.0, } # --------------------------------------------------------------------------- # Core rendering # --------------------------------------------------------------------------- def render_diff_frame( stage: int, trained: np.ndarray, slerp: np.ndarray, ) -> plt.Figure: """ Render a single frame of the diff view. Parameters ---------- stage : int Stage index (0-based, shown as "Stage N" label). trained : np.ndarray Shape (28, 7): trained-agent tooth poses [qw, qx, qy, qz, tx, ty, tz]. slerp : np.ndarray Shape (28, 7): SLERP baseline tooth poses [qw, qx, qy, qz, tx, ty, tz]. Returns ------- matplotlib.figure.Figure Rendered figure (caller is responsible for closing it). """ assert trained.shape == (N_TEETH, 7), f"Expected (28,7), got {trained.shape}" assert slerp.shape == (N_TEETH, 7), f"Expected (28,7), got {slerp.shape}" fig, ax = plt.subplots(figsize=FIG_SIZE) ax.set_aspect("equal") ax.set_facecolor("#f8f8f8") fig.patch.set_facecolor("#f8f8f8") # Tooth translations: tx=X (left-right), ty=Y (anterior-posterior) for i, tid in enumerate(TOOTH_IDS): ttype = TOOTH_TYPES[tid] r = _TYPE_RADIUS[ttype] / 2.0 # convert width → radius tx_t, ty_t = trained[i, 4], trained[i, 5] tx_s, ty_s = slerp[i, 4], slerp[i, 5] # Arch outline circle (grey, behind everything) ax.add_patch( plt.Circle( (tx_s, ty_s), r + 0.3, color=_TYPE_COLOR[ttype], alpha=0.3, zorder=1, ) ) # SLERP baseline (red, semi-transparent) ax.add_patch( plt.Circle( (tx_s, ty_s), r, facecolor="none", edgecolor="#cc3333", linewidth=1.5, alpha=0.6, zorder=2, ) ) # Trained agent (green, semi-transparent fill) ax.add_patch( plt.Circle( (tx_t, ty_t), r, facecolor="#33cc33", edgecolor="#007700", linewidth=1.5, alpha=0.4, zorder=3, ) ) # Deviation arrow (blue) — only if deviation > threshold dx = tx_t - tx_s dy = ty_t - ty_s deviation_mm = math.sqrt(dx * dx + dy * dy) if deviation_mm > ARROW_THRESHOLD_MM: scale = min(deviation_mm * 2.0, 4.0) # cap arrow length for readability ax.annotate( "", xy=(tx_s + dx * scale / deviation_mm, ty_s + dy * scale / deviation_mm), xytext=(tx_s, ty_s), arrowprops=dict( arrowstyle="-|>", color="#0066cc", lw=max(0.8, deviation_mm * 3.0), mutation_scale=8, ), zorder=4, ) # Labels and legend ax.set_title( f"OrthoRL Diff View — Stage {stage + 1}/{N_STAGES}", fontsize=12, fontweight="bold", ) ax.set_xlabel("Transverse (mm)", fontsize=9) ax.set_ylabel("Sagittal (mm)", fontsize=9) legend_handles = [ mpatches.Patch(facecolor="#33cc33", edgecolor="#007700", alpha=0.7, label="Trained agent"), mpatches.Patch(facecolor="none", edgecolor="#cc3333", label="SLERP baseline"), mpatches.Patch( facecolor="#0066cc", edgecolor="#0066cc", label=f"Deviation >{ARROW_THRESHOLD_MM:.2f}mm" ), ] ax.legend(handles=legend_handles, loc="upper right", fontsize=8) ax.set_xlim(-30, 30) ax.set_ylim(-25, 15) ax.grid(True, alpha=0.2) plt.tight_layout() return fig def render_diff_gif( trained_traj: List[np.ndarray], slerp_traj: List[np.ndarray], output_path: str, n_frames: int = N_STAGES, ) -> str: """ Render n_frames and save as an animated GIF. Parameters ---------- trained_traj : list of (28, 7) arrays Per-stage tooth poses for the trained agent (length >= n_frames). slerp_traj : list of (28, 7) arrays Per-stage tooth poses for SLERP baseline (length >= n_frames). output_path : str Where to save the GIF. n_frames : int Number of frames to render (default: N_STAGES = 24). Returns ------- str Absolute path to the saved GIF. """ os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) pil_frames: List[Image.Image] = [] for stage in range(n_frames): fig = render_diff_frame(stage, trained_traj[stage], slerp_traj[stage]) buf = io.BytesIO() fig.savefig(buf, format="png", dpi=80, bbox_inches="tight") plt.close(fig) buf.seek(0) pil_frames.append(Image.open(buf).copy()) buf.close() # Save as GIF pil_frames[0].save( output_path, save_all=True, append_images=pil_frames[1:], duration=FRAME_DURATION_MS, loop=0, # loop forever optimize=True, ) return os.path.abspath(output_path) # --------------------------------------------------------------------------- # Trajectory generation helpers # --------------------------------------------------------------------------- def slerp_trajectory( initial: np.ndarray, target: np.ndarray, n_stages: int = N_STAGES, ) -> List[np.ndarray]: """Generate a linear-SLERP trajectory from initial to target (n_stages frames).""" traj: List[np.ndarray] = [] for s in range(n_stages): alpha = (s + 1) / n_stages frame = np.empty_like(initial) for i in range(N_TEETH): q0, q1 = initial[i, :4], target[i, :4] t0, t1 = initial[i, 4:], target[i, 4:] # Normalised LERP quaternion q = q0 * (1 - alpha) + q1 * alpha q /= np.linalg.norm(q) + 1e-10 t = t0 * (1 - alpha) + t1 * alpha frame[i, :4] = q frame[i, 4:] = t traj.append(frame) return traj def perturbed_trajectory( slerp_traj: List[np.ndarray], rng: np.random.Generator, perturbation_mm: float = 0.3, ) -> List[np.ndarray]: """ Simulate a 'trained' trajectory by adding stage-progressive perturbations. Used for demo mode (no actual trained model). Perturbation increases linearly over stages, mimicking strategy-specific front-loading behaviour. """ traj: List[np.ndarray] = [] n = len(slerp_traj) for s, frame in enumerate(slerp_traj): f = frame.copy() # Front-load perturbation: anterior teeth move more in early stages alpha_stage = (s + 1) / n for i in range(N_TEETH): tid = TOOTH_IDS[i] ttype = TOOTH_TYPES[tid] # Incisors get larger perturbation in early stages (simulate anterior_first) if "incisor" in ttype or "canine" in ttype: scale = perturbation_mm * alpha_stage else: scale = perturbation_mm * 0.3 * alpha_stage f[i, 4] += rng.uniform(-scale, scale) # tx f[i, 5] += rng.uniform(-scale * 0.5, scale * 0.5) # ty traj.append(f) return traj # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _run_demo(output_path: str) -> None: """Demo mode: generate SLERP + perturbed trajectories, save GIF.""" from server.dental_environment import StepwiseDentalEnvironment print("[diff_view] Demo mode — generating synthetic trajectories ...") env = StepwiseDentalEnvironment() obs = env.reset(task_id="task_easy", seed=42) initial = np.array(obs["current_config"], dtype=np.float64) target = np.array(obs["target_config"], dtype=np.float64) slerp_traj = slerp_trajectory(initial, target) rng = np.random.default_rng(seed=42) trained_traj = perturbed_trajectory(slerp_traj, rng, perturbation_mm=0.4) saved = render_diff_gif(trained_traj, slerp_traj, output_path) size_mb = os.path.getsize(saved) / 1024 / 1024 print(f"[diff_view] Saved: {saved} ({size_mb:.1f} MB, {N_STAGES} frames)") def _load_trajectory(path: str) -> List[np.ndarray]: with open(path) as fh: raw = json.load(fh) return [np.array(frame, dtype=np.float64) for frame in raw] if __name__ == "__main__": parser = argparse.ArgumentParser(description="OrthoRL Diff-View GIF (spec 2.3)") parser.add_argument( "--demo", action="store_true", help="Demo mode: use SLERP + synthetic perturbation (no trained model)", ) parser.add_argument( "--case", default="class_ii", choices=["class_i", "class_ii", "class_iii"], help="Malocclusion type to visualise", ) parser.add_argument("--model", default=None, help="Path to trained GRPO checkpoint") parser.add_argument( "--trained", default=None, help="Path to pre-saved trained trajectory JSON (28×7 list per stage)", ) parser.add_argument("--slerp", default=None, help="Path to pre-saved SLERP trajectory JSON") parser.add_argument("--output", default="results/diff_view_classII.gif", help="Output GIF path") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() if args.demo or (args.trained is None and args.model is None): _run_demo(args.output) elif args.trained and args.slerp: trained_traj = _load_trajectory(args.trained) slerp_traj = _load_trajectory(args.slerp) saved = render_diff_gif(trained_traj, slerp_traj, args.output) print(f"[diff_view] Saved: {saved}") else: print("Use --demo for a quick test, or provide --trained and --slerp trajectory files.") print("Model-based generation (--model) requires a trained checkpoint.")