""" Held-out evaluation CLI — spec 1.11. Runs a deterministic held-out evaluation over one of three tiers: Tier-1 : 250 Tsinghua test patients (in-distribution) Tier-2 : 17 Open-Full-Jaw patients (cohort shift) Tier-3 : 40 Bits2Bites patients (label shift, learned-occlusion only) Currently runs the SLERP baseline as the policy. Once spec 1.1 lands, the `--policy` flag will accept a checkpoint path and the CLI will load the trained model. Output: results/eval_{policy}_{tier}_{seed}.jsonl — one line per patient results/eval_summary.csv — per-tier means, append-only Examples: python eval.py --policy slerp --tier 1 --seeds 1 python eval.py --policy slerp --tier 1 --max-patients 5 python eval.py --policy slerp --tier 2 Self-contained: imports only the env + numpy + stdlib. """ from __future__ import annotations import argparse import csv import json import os import sys import time from typing import Any, Dict, List, Optional import numpy as np def _ensure_results_dir() -> str: here = os.path.dirname(os.path.abspath(__file__)) out = os.path.join(here, "results") os.makedirs(out, exist_ok=True) return out def _bootstrap_ci( values: List[float], iters: int = 1000, alpha: float = 0.05, rng: Optional[np.random.Generator] = None, ) -> Dict[str, float]: if not values: return {"mean": 0.0, "lo": 0.0, "hi": 0.0, "n": 0} rng = rng or np.random.default_rng(0) arr = np.asarray(values, dtype=np.float64) means = [] n = len(arr) for _ in range(iters): idx = rng.integers(0, n, size=n) means.append(arr[idx].mean()) means = np.sort(means) lo = float(np.quantile(means, alpha / 2)) hi = float(np.quantile(means, 1 - alpha / 2)) return {"mean": float(arr.mean()), "lo": lo, "hi": hi, "n": n} def _slerp_rollout(env, eid: str, initial: np.ndarray, target: np.ndarray) -> Dict[str, Any]: """Run a 24-stage SLERP rollout and return the final observation.""" from server.quaternion_utils import quaternion_slerp, quaternion_normalize final = None for stage in range(1, 25): alpha = stage / 25.0 poses = np.zeros((28, 7)) for i in range(28): poses[i, :4] = quaternion_normalize( quaternion_slerp(initial[i, :4], target[i, :4], alpha) ) poses[i, 4:] = (1 - alpha) * initial[i, 4:] + alpha * target[i, 4:] final = env.step(eid, poses.tolist()) return final or {} def evaluate( policy: str = "slerp", tier: int = 1, seeds: int = 1, max_patients: Optional[int] = None, out_dir: Optional[str] = None, ) -> Dict[str, Any]: """Run an eval pass and write per-patient JSONL + an updated summary CSV. Returns the bootstrap summary dict for the run. """ from server.dental_environment import StepwiseDentalEnvironment from server.eval_split import EvalRegistry out_dir = out_dir or _ensure_results_dir() env = StepwiseDentalEnvironment() reg = EvalRegistry() ids = reg.list_tier(tier) if max_patients is not None: ids = ids[:max_patients] if policy != "slerp": raise NotImplementedError( f"policy={policy!r} not yet supported. Spec 1.1 adds checkpoint loading." ) timestamp = time.strftime("%Y%m%dT%H%M%S") jsonl_path = os.path.join(out_dir, f"eval_{policy}_tier{tier}_{timestamp}.jsonl") rewards: List[float] = [] occlusion: List[float] = [] print(f"[eval] policy={policy} tier={tier} n_patients={len(ids)} seeds={seeds}") print(f"[eval] writing per-patient lines to {jsonl_path}") with open(jsonl_path, "w") as f: for pid in ids: for s in range(seeds): # Each (patient, seed) is a separate eval episode. obs = env.reset( task_id=f"tsinghua/{pid}" if tier == 1 else f"ofj/{pid}" if tier == 2 else f"bits2bites/{pid}", mode="eval", tier=tier, eval_idx=ids.index(pid), seed=hash(f"{pid}/{s}") & 0x7FFFFFFF, ) if tier != 1: # Tiers 2 + 3 require dataset loaders that ship with # spec 1.7 (OFJ landmarks → SE(3)) and spec 3.8 # (Bits2Bites occlusion classifier). Skip gracefully # for now and log per-patient rather than fail the run. f.write( json.dumps( { "tier": tier, "patient_id": pid, "seed": s, "status": "pending_loader", "note": "Tier 2/3 awaiting spec 1.7 / 3.8", } ) + "\n" ) continue eid = obs["episode_id"] initial = np.array(obs["current_config"]) target = np.array(obs["target_config"]) final = _slerp_rollout(env, eid, initial, target) rew = float(final.get("terminal_reward") or 0.0) bd = final.get("reward_breakdown") or {} occ = float(bd.get("occlusion_composite") or 0.0) rewards.append(rew) occlusion.append(occ) f.write( json.dumps( { "tier": tier, "patient_id": pid, "seed": s, "terminal_reward": rew, "occlusion_composite": occ, "mode": obs.get("mode"), "coordinate_frame": obs.get("coordinate_frame"), } ) + "\n" ) summary_terminal = _bootstrap_ci(rewards) summary_occ = _bootstrap_ci(occlusion) summary = { "timestamp": timestamp, "policy": policy, "tier": tier, "n_patients": len(ids), "seeds": seeds, "terminal_reward_mean": summary_terminal["mean"], "terminal_reward_lo": summary_terminal["lo"], "terminal_reward_hi": summary_terminal["hi"], "occlusion_mean": summary_occ["mean"], "occlusion_lo": summary_occ["lo"], "occlusion_hi": summary_occ["hi"], "jsonl": os.path.relpath(jsonl_path), } print( f"[eval] tier-{tier}: terminal_reward = {summary_terminal['mean']:.4f} " f"95% CI = [{summary_terminal['lo']:.4f}, {summary_terminal['hi']:.4f}] " f"(N={summary_terminal['n']})" ) # Append to eval_summary.csv (append-only autolog) csv_path = os.path.join(out_dir, "eval_summary.csv") new_file = not os.path.exists(csv_path) with open(csv_path, "a", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(summary.keys())) if new_file: writer.writeheader() writer.writerow(summary) print(f"[eval] appended to {csv_path}") return summary def main(): parser = argparse.ArgumentParser(description="Held-out evaluation CLI (spec 1.11)") parser.add_argument( "--policy", default="slerp", help="Policy to evaluate (slerp; future: checkpoint paths)" ) parser.add_argument( "--tier", type=int, choices=[1, 2, 3], default=1, help="1=Tsinghua test, 2=OFJ, 3=Bits2Bites", ) parser.add_argument("--seeds", type=int, default=1, help="Eval seeds per patient") parser.add_argument( "--max-patients", type=int, default=None, help="Cap patients per tier (for quick smoke runs)", ) args = parser.parse_args() summary = evaluate( policy=args.policy, tier=args.tier, seeds=args.seeds, max_patients=args.max_patients, ) return summary if __name__ == "__main__": main()