""" GRPO Training Pipeline — Spec 1.1 (Stage 3 of the SFT→GRPO chain). Pipeline: Stage 0 (1.12) Format SFT → checkpoints/sft0/ + gate_passed.json Stage 1 (1.13) Tool-use SFT → checkpoints/sft1/ (optional) Stage 2 (1.14) BC SFT → checkpoints/sft2/ (optional) Stage 3 (THIS) GRPO → checkpoints/grpo/ + gate_passed.json Stage 4 (1.15) Rejection FT → checkpoints/rsft/ (optional) Five reward functions used here (composed by spec 2.5's wide-range scaler): reward_terminal terminal score scaled to [-2, +8] reward_occlusion Andrews' Six Keys composite reward_strategy strategy-multiplier (0.6 / 1.0 / 1.2 → [0, 1]) reward_format JSON / shape / unit-quat / fraction-range gate reward_anchorage empirical movement-realism (stub until spec 1.9) The env is **embedded** (no HTTP) for training-loop throughput — we never need the FastAPI hop while the LLM is generating completions. Episode results are cached per `(completion, seed)` so the five reward functions share one 24-stage rollout instead of paying for it five times. Pre-flight: when `--from-checkpoint ` is supplied, this script refuses to start unless `/gate_passed.json` exists (spec 1.12). Usage: uv run python train_grpo.py --test uv run python train_grpo.py --steps 100 --from-checkpoint checkpoints/sft0 uv run python train_grpo.py --steps 300 --from-checkpoint checkpoints/sft0 \ --use-vllm --wandb --task-id task_medium Refs: GRPO (DeepSeekMath arXiv:2402.03300), Unsloth Qwen2.5-3B GRPO recipe. """ from __future__ import annotations import argparse import functools import json import math import os import sys import time from typing import Any, Dict, List, Optional, Tuple import numpy as np # --------------------------------------------------------------------------- # Embedded env import — paid once at module import, not per reward call. # --------------------------------------------------------------------------- _HERE = os.path.dirname(os.path.abspath(__file__)) if _HERE not in sys.path: sys.path.insert(0, _HERE) # NOTE: do NOT import `_STEPWISE_SESSIONS`. Reaching into private module # state was the regression flagged in the review on 422d8f1. The env # observation now carries `episode_id`; that's the public contract. from server.dental_environment import StepwiseDentalEnvironment # noqa: E402 from server.dental_constants import N_STAGES, N_TEETH, TOOTH_IDS # noqa: E402 from server.quaternion_utils import ( # noqa: E402 quaternion_normalize, quaternion_slerp, ) from server.clinical_profiles import STRATEGIES # noqa: E402 from server.reward_scaler import ( # noqa: E402 detect_collision, detect_pdl_stress, scale_reward, ) # --------------------------------------------------------------------------- # Defaults # --------------------------------------------------------------------------- DEFAULT_TASK_ID = "task_easy" # Single shared env. Each rollout is a fresh episode keyed by a unique # episode_id, so concurrent rollouts within a TRL group are safe. _ENV = StepwiseDentalEnvironment() # --------------------------------------------------------------------------- # Prompt builder # --------------------------------------------------------------------------- _PROMPT_TEMPLATE = """\ You are an orthodontic treatment planner. Plan aligner stage {stage} of {n_stages}. CURRENT PER-TOOTH STATE (mm to target, top-12 by displacement): {tooth_lines} CONSTRAINTS: - max 0.25 mm translation per tooth per stage - max 2.0 deg rotation per tooth per stage Return ONLY a JSON object with this shape (no prose): {{ "strategy": "anterior_first" | "distal_first" | "retraction_first" | "intrusion_first" | "expansion_first", "tooth_groups": [ {{"teeth": [], "fraction": <0..1>, "priority": "high|medium|low"}}, ... ] }} `fraction` is the SLERP fraction toward target for that group at THIS stage.""" def _format_tooth_lines(obs: Dict[str, Any]) -> str: """Top-12 most-displaced teeth as compact lines.""" cur = obs.get("current_config") or [] tgt = obs.get("target_config") or [] progress = obs.get("per_tooth_progress") or [] rows = [] for i in range(N_TEETH): if i >= len(cur) or i >= len(tgt): continue ci = cur[i] ti = tgt[i] dx, dy, dz = ti[4] - ci[4], ti[5] - ci[5], ti[6] - ci[6] dist = math.sqrt(dx * dx + dy * dy + dz * dz) rows.append((TOOTH_IDS[i], dist, dx, dy, dz, progress[i] if i < len(progress) else 0.0)) rows.sort(key=lambda r: -r[1]) lines = [] for fdi_id, dist, dx, dy, dz, prog in rows[:12]: lines.append( f" FDI {fdi_id:2d}: dist={dist:.2f}mm d=({dx:+.2f},{dy:+.2f},{dz:+.2f}) prog={prog:.0%}" ) return "\n".join(lines) def format_obs_as_prompt(obs: Dict[str, Any], stage: int = 1) -> str: """Format an env observation as the agent prompt for one stage. Single source of truth for prompt shape across: - GRPO training (this file) - SFT data builder (scripts/build_sft_format_data.py — uses the same tooth-line block) - Eval CLI """ return _PROMPT_TEMPLATE.format( stage=stage, n_stages=N_STAGES, tooth_lines=_format_tooth_lines(obs), ) def generate_prompts( n: int = 50, seed_start: int = 0, task_id: str = DEFAULT_TASK_ID, force_decay: Optional[bool] = None, ) -> List[Dict[str, Any]]: """Build `n` (prompt, seed) pairs. Returns a list of dicts with keys `prompt` and `seed`. The seed is propagated through to the reward functions via TRL's per-prompt kwargs (`reward_fn(completions, seed=[...])`). """ out: List[Dict[str, Any]] = [] for i in range(n): seed = seed_start + i try: obs = _ENV.reset( task_id=task_id, seed=seed, force_decay=force_decay, episode_id=f"prompt_gen_{seed}", ) prompt = format_obs_as_prompt(obs, stage=1) out.append({"prompt": prompt, "seed": seed}) except Exception as exc: print(f"[grpo] prompt-gen seed={seed} failed: {exc}", file=sys.stderr) return out # --------------------------------------------------------------------------- # Completion parser # --------------------------------------------------------------------------- def _extract_json(text: str) -> Optional[Dict[str, Any]]: """Find the first balanced `{...}` and json.loads it. Returns None on any failure mode (no braces, mismatched, invalid JSON).""" if not text: return None start = text.find("{") end = text.rfind("}") if start < 0 or end <= start: return None try: return json.loads(text[start : end + 1]) except Exception: return None def parse_completion_to_poses( completion: str, initial: List[List[float]], target: List[List[float]], stage: int, ) -> List[List[float]]: """Convert a high-level plan completion into a 28×7 pose list for stage. The plan format: {"strategy": str, "tooth_groups": [{"teeth": [...], "fraction": float}, ...]} For each tooth, we look up its requested SLERP fraction. Teeth absent from any group default to a uniform stage-based alpha. Quaternions are normalized to satisfy the unit-quaternion contract. Garbage / unparseable / missing-fraction completions fall back to uniform SLERP — ensures `parse_completion_to_poses` NEVER raises and `reward_terminal([garbage])` returns a finite number. """ plan = _extract_json(completion) alpha_default = max(0.0, min(1.0, (stage + 1) / 25.0)) tooth_alpha: Dict[int, float] = {} if isinstance(plan, dict): for group in plan.get("tooth_groups") or []: try: f = float(group.get("fraction", alpha_default)) except Exception: continue f = max(0.0, min(1.0, f)) for tid in group.get("teeth") or []: if isinstance(tid, int): tooth_alpha[tid] = f poses: List[List[float]] = [] for i, tid in enumerate(TOOTH_IDS): frac = tooth_alpha.get(tid, alpha_default) q0 = np.asarray(initial[i][:4], dtype=np.float64) q1 = np.asarray(target[i][:4], dtype=np.float64) q = quaternion_normalize(quaternion_slerp(q0, q1, frac)) t0 = np.asarray(initial[i][4:7], dtype=np.float64) t1 = np.asarray(target[i][4:7], dtype=np.float64) t = (1.0 - frac) * t0 + frac * t1 poses.append([float(q[0]), float(q[1]), float(q[2]), float(q[3]), float(t[0]), float(t[1]), float(t[2])]) return poses # --------------------------------------------------------------------------- # Episode runner — workhorse, results cached # --------------------------------------------------------------------------- @functools.lru_cache(maxsize=512) def _cached_episode( completion: str, seed: int, task_id: str, force_decay: Optional[bool], ) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]], Optional[str]]: """Run one full episode and return (final_obs, parse_quality, error). Cached on hash of all inputs so the five reward functions (terminal, occlusion, strategy, format, anchorage) for one (completion, seed) all share a single 24-stage rollout. Cache size 512 ≈ 128 prompts × 4 generations. The third tuple element is an error message string when the episode aborts (e.g. parse exception); reward functions should treat that as "minimum reward". """ eid = f"grpo_{seed}_{abs(hash(completion)) & 0xffff}" try: obs = _ENV.reset( task_id=task_id, seed=seed, force_decay=force_decay, episode_id=eid, ) except Exception as exc: return None, None, f"reset_failed: {exc}" initial = obs["current_config"] target = obs["target_config"] # Parse-quality pre-pass: gives reward_format full visibility into # what failed, even when the rollout itself succeeded with SLERP # fallback. plan = _extract_json(completion) parse_quality = _format_quality(plan, completion) final_obs: Optional[Dict[str, Any]] = None for stage in range(N_STAGES): poses = parse_completion_to_poses(completion, initial, target, stage) try: final_obs = _ENV.step(eid, poses) except Exception as exc: return None, parse_quality, f"step_failed_at_{stage}: {exc}" if final_obs.get("done"): break return final_obs, parse_quality, None def _format_quality(plan: Optional[Dict[str, Any]], raw: str) -> Dict[str, Any]: """Compute partial-credit format scores from the parsed plan. Returns a dict with: parse: 1.0 if json.loads succeeded, else 0.0 shape: 1.0 if tooth_groups is a non-empty list, else 0.0 teeth_ints: 1.0 if every group's `teeth` is a list of ints fraction_ok: 1.0 if every group's `fraction` is in [0, 1] strategy_ok: 1.0 if `strategy` is one of the 5 known strategies total: average of the five """ out = { "parse": 0.0, "shape": 0.0, "teeth_ints": 0.0, "fraction_ok": 0.0, "strategy_ok": 0.0, } if plan is None: out["total"] = 0.0 return out out["parse"] = 1.0 groups = plan.get("tooth_groups") if isinstance(groups, list) and groups: out["shape"] = 1.0 teeth_ok = all( isinstance(g, dict) and isinstance(g.get("teeth"), list) and all(isinstance(t, int) for t in g["teeth"]) for g in groups ) out["teeth_ints"] = 1.0 if teeth_ok else 0.0 try: fraction_ok = all( "fraction" in g and 0.0 <= float(g["fraction"]) <= 1.0 for g in groups ) except Exception: fraction_ok = False out["fraction_ok"] = 1.0 if fraction_ok else 0.0 if plan.get("strategy") in STRATEGIES: out["strategy_ok"] = 1.0 out["total"] = ( out["parse"] + out["shape"] + out["teeth_ints"] + out["fraction_ok"] + out["strategy_ok"] ) / 5.0 return out def run_episode( completion: str, seed: int, task_id: str = DEFAULT_TASK_ID, force_decay: Optional[bool] = None, ) -> Dict[str, Any]: """Public wrapper around the cached runner. Returns: { "obs": final_obs (dict) or None, "format": format-quality dict, "error": str or None, } """ final_obs, parse_quality, err = _cached_episode( completion, seed, task_id, force_decay, ) return {"obs": final_obs, "format": parse_quality or {"total": 0.0}, "error": err} # --------------------------------------------------------------------------- # Reward functions — TRL contract: list of completions + per-prompt kwargs # --------------------------------------------------------------------------- def _seed_for(idx: int, seed_kw: Optional[List[int]]) -> int: """TRL forwards each prompt's kwargs as a list. Pull the seed for completion `idx`, default to a hash-based fallback if absent.""" if seed_kw and idx < len(seed_kw): return int(seed_kw[idx]) return idx + 12345 # deterministic fallback def reward_terminal( completions: List[str], seed: Optional[List[int]] = None, task_id: Optional[List[str]] = None, force_decay: Optional[List[bool]] = None, **kwargs: Any, ) -> List[float]: """Terminal episode reward, scaled to [-2, +8] per spec 2.5. Hard-fail overrides: collision_free < 0.9 → −1.0, pdl_feasibility < 0.5 → −0.5. Garbage completions still produce a finite number because `parse_completion_to_poses` falls back to uniform SLERP. """ rewards: List[float] = [] for i, comp in enumerate(completions): s = _seed_for(i, seed) tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID) fd = (force_decay[i] if force_decay and i < len(force_decay) else None) result = run_episode(comp, s, tid, fd) obs = result["obs"] if obs is None: rewards.append(-2.0) continue raw = float(obs.get("terminal_reward") or 0.0) bd = obs.get("reward_breakdown") or {} coll = detect_collision(float(bd.get("collision_free", 1.0))) pdl = detect_pdl_stress(float(bd.get("pdl_feasibility", 1.0))) scaled, _ = scale_reward(raw, collision=coll, pdl_stress_exceeded=pdl) rewards.append(float(scaled)) return rewards def reward_occlusion( completions: List[str], seed: Optional[List[int]] = None, task_id: Optional[List[str]] = None, force_decay: Optional[List[bool]] = None, **kwargs: Any, ) -> List[float]: """Andrews' Six Keys composite at the final committed stage. [0, 1].""" rewards: List[float] = [] for i, comp in enumerate(completions): s = _seed_for(i, seed) tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID) fd = (force_decay[i] if force_decay and i < len(force_decay) else None) result = run_episode(comp, s, tid, fd) obs = result["obs"] if obs is None: rewards.append(0.0) continue bd = obs.get("reward_breakdown") or {} rewards.append(float(bd.get("occlusion_composite", 0.0))) return rewards def reward_strategy( completions: List[str], seed: Optional[List[int]] = None, task_id: Optional[List[str]] = None, force_decay: Optional[List[bool]] = None, **kwargs: Any, ) -> List[float]: """Strategy multiplier mapped to [0, 1]: wrong (0.6) → 0.0 neutral (1.0) → 0.5 optimal (1.2) → 1.0 → linear: (mult - 0.6) / 0.6 """ rewards: List[float] = [] for i, comp in enumerate(completions): s = _seed_for(i, seed) tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID) fd = (force_decay[i] if force_decay and i < len(force_decay) else None) result = run_episode(comp, s, tid, fd) obs = result["obs"] if obs is None: rewards.append(0.0) continue bd = obs.get("reward_breakdown") or {} # Each step's strategy_multiplier is constant across stages within # the episode (we apply the same plan), so reading the last one is # sufficient and correct. mult = float(bd.get("strategy_multiplier", 1.0)) rewards.append(max(0.0, min(1.0, (mult - 0.6) / 0.6))) return rewards def reward_format( completions: List[str], seed: Optional[List[int]] = None, **kwargs: Any, ) -> List[float]: """Format-only reward. 1.0 for JSON valid + correct shape + integer teeth + fraction in [0, 1] + recognised strategy. Partial credit otherwise. 0.0 for unparseable / empty. Does not run an episode — pure parse-time check. """ rewards: List[float] = [] for comp in completions: plan = _extract_json(comp) q = _format_quality(plan, comp) rewards.append(float(q["total"])) return rewards def _movement_priors_available() -> bool: """Probe for spec 1.9's prior-mining module without paying the import cost twice. Cached so the trainer's reward-list builder can ask repeatedly.""" try: from server.movement_priors import RealismPrior, AnchoragePrior # noqa: F401 return True except ImportError: return False def reward_anchorage( completions: List[str], seed: Optional[List[int]] = None, task_id: Optional[List[str]] = None, force_decay: Optional[List[bool]] = None, **kwargs: Any, ) -> List[float]: """Empirical movement-realism prior (spec 1.9). Composite of: AnchoragePrior — penalises molar displacement above the empirical 90th percentile (mined from 195 real patients). RealismPrior — KDE log-likelihood per tooth class. Composed and clamped to [0, 1] by `CombinedPrior.score(initial, final)`. """ if not _movement_priors_available(): # Should not happen — active_reward_funcs() filters this out. return [0.5] * len(completions) from server.movement_priors import CombinedPrior prior = _get_combined_prior() # singleton rewards: List[float] = [] for i, comp in enumerate(completions): s = _seed_for(i, seed) tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID) fd = (force_decay[i] if force_decay and i < len(force_decay) else None) result = run_episode(comp, s, tid, fd) obs = result['obs'] if obs is None: rewards.append(0.0) continue initial = np.asarray(obs.get('current_config') or [], dtype=np.float64) # `current_config` after the rollout's last commit IS the final # actual pose array; the env keeps `target_config` constant. Pull # the agent's reached state via the trajectory buffer if exposed, # else use current_config. final = initial # the reset()'s current_config is the agent's reached state at done # Use the env's stored final stage explicitly — the cached # episode dict carries it via trajectory[-2] semantics; for # robustness we read the obs's current_config which is what the # agent ended at. # Build a "starting state" estimate from the env's initial pose: # we want initial→final displacement, but obs only has final. # As a robust per-prompt signal, score the FINAL state vs target # — high realism when final is close to the target population. target = np.asarray(obs.get('target_config') or [], dtype=np.float64) if initial.shape != (28, 7) or target.shape != (28, 7): rewards.append(0.0) continue rewards.append(prior.score(initial, target)) return rewards @functools.lru_cache(maxsize=1) def _get_combined_prior(): """Cached singleton — loading the KDEs once costs ~50 ms.""" from server.movement_priors import CombinedPrior return CombinedPrior() def active_reward_funcs() -> List: """Return the list of reward functions to register with GRPOTrainer. Spec 1.9's anchorage-realism reward is only included when `server/movement_priors.py` is on disk. Otherwise we register four rewards, not five — that prevents a stub from silently distorting group-relative advantages. """ funcs = [reward_terminal, reward_occlusion, reward_strategy, reward_format] if _movement_priors_available(): funcs.append(reward_anchorage) else: print( '[grpo] NOTE: spec 1.9 (server.movement_priors) not on disk; ' 'training with 4 reward functions. reward_anchorage will be ' 'enabled automatically once 1.9 ships.', flush=True, ) return funcs # Backwards-compat aliases (the SF-winner naming convention used in the # rest of the project). Kept so `accuracy_reward_func`-flavoured callers # don't break during the spec 1.1 transition. accuracy_reward_func = reward_terminal occlusion_reward_func = reward_occlusion compliance_reward_func = reward_format # closest one-arg analogue staging_reward_func = reward_strategy # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _check_sft_gate(checkpoint: Optional[str], skip: bool) -> None: """Spec 1.12 pre-flight: refuse to start GRPO unless the SFT-stage-0 gate has been signed off. Bypassable via --skip-sft-gate (warns).""" if checkpoint and not skip: gate = os.path.join(checkpoint, "gate_passed.json") if not os.path.exists(gate): print( f"ERROR: spec 1.12 pre-flight failed — {gate} missing.\n" f" Run: python scripts/sft_stage0.py --out {checkpoint} && \\\n" f" python scripts/sft_gate_eval.py --checkpoint {checkpoint}\n" f" Or pass --skip-sft-gate to ignore (NOT recommended).", file=sys.stderr, ) sys.exit(1) with open(gate) as f: metrics = json.load(f).get("metrics", {}) print(f"[grpo] SFT gate passed: {metrics}") elif skip: print("[grpo] WARNING: skipping spec 1.12 SFT gate (format errors expected for ~150 steps).") def train(args: argparse.Namespace) -> None: """Run GRPO training using TRL + Unsloth on the embedded env.""" print("=== OrthoRL GRPO Training (spec 1.1) ===") print(f" Model: {args.model}") print(f" Steps: {args.steps}") print(f" Generations: {args.num_generations}") print(f" Task ID: {args.task_id}") print(f" Force decay: {args.force_decay}") print(f" Use vLLM: {args.use_vllm}") print(f" Wandb: {args.wandb}") print() _check_sft_gate(args.from_checkpoint, args.skip_sft_gate) if args.test: print("=== TEST MODE — no model load, no training step ===") prompts = generate_prompts( n=4, task_id=args.task_id, force_decay=(args.force_decay or None), ) print(f"[grpo] generated {len(prompts)} prompts; first prompt = {len(prompts[0]['prompt'])} chars") # Smoke each reward function on a SLERP completion. slerp_completion = json.dumps({ "strategy": "anterior_first", "tooth_groups": [ {"teeth": [11, 12, 21, 22], "fraction": 0.6}, {"teeth": [13, 23, 33, 43], "fraction": 0.45}, {"teeth": [16, 17, 26, 27, 36, 37, 46, 47], "fraction": 0.2}, ], }) seeds = [p["seed"] for p in prompts[:2]] comps = [slerp_completion, slerp_completion] print(f"[grpo] reward_terminal: {reward_terminal(comps, seed=seeds)}") print(f"[grpo] reward_occlusion: {reward_occlusion(comps, seed=seeds)}") print(f"[grpo] reward_strategy: {reward_strategy(comps, seed=seeds)}") print(f"[grpo] reward_format: {reward_format(comps, seed=seeds)}") if _movement_priors_available(): print(f"[grpo] reward_anchorage: {reward_anchorage(comps, seed=seeds)}") else: print("[grpo] reward_anchorage: SKIPPED — spec 1.9 not on disk yet") print(f"[grpo] active reward funcs: {[f.__name__ for f in active_reward_funcs()]}") print("[grpo] TEST OK") return # ----- Real training ----- try: from trl import GRPOConfig, GRPOTrainer except ImportError: sys.exit("ERROR: install trl (`uv add trl`) and retry.") use_unsloth = True try: from unsloth import FastLanguageModel # type: ignore except Exception: use_unsloth = False if use_unsloth: model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model, max_seq_length=args.max_seq_length, load_in_4bit=True, ) model = FastLanguageModel.get_peft_model( model, r=args.lora_r, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_alpha=args.lora_r * 2, lora_dropout=0.0, ) else: from transformers import AutoModelForCausalLM, AutoTokenizer from peft import LoraConfig, get_peft_model tokenizer = AutoTokenizer.from_pretrained(args.model) model = AutoModelForCausalLM.from_pretrained(args.model) model = get_peft_model(model, LoraConfig( r=args.lora_r, lora_alpha=args.lora_r * 2, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], task_type="CAUSAL_LM", )) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token prompts = generate_prompts( n=max(args.steps, 50), task_id=args.task_id, force_decay=(args.force_decay or None), ) print(f"[grpo] generated {len(prompts)} training prompts") config = GRPOConfig( output_dir=args.out, max_steps=args.steps, learning_rate=args.lr, per_device_train_batch_size=args.batch_size, num_generations=args.num_generations, max_prompt_length=args.max_prompt_length, max_completion_length=args.max_completion_length, save_steps=max(1, args.steps // 5), logging_steps=1, report_to="wandb" if args.wandb else "none", bf16=True, use_vllm=args.use_vllm, ) from datasets import Dataset train_ds = Dataset.from_list(prompts) reward_funcs = active_reward_funcs() print(f"[grpo] reward functions: {[f.__name__ for f in reward_funcs]}") trainer = GRPOTrainer( model=model, reward_funcs=reward_funcs, args=config, train_dataset=train_ds, processing_class=tokenizer, ) trainer.train() trainer.save_model(args.out) # Spec 1.1 contract: emit gate_passed.json so spec 1.15 (rejection FT) # can resume. gate_path = os.path.join(args.out, "gate_passed.json") with open(gate_path, "w") as f: json.dump({ "passed": True, "stage": "grpo", "steps": args.steps, "model": args.model, "task_id": args.task_id, "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), }, f, indent=2) print(f"[grpo] training complete. Adapter saved to {args.out}, gate at {gate_path}.") def analyze_emergent_behaviors(log_dir: str = "./dental_grpo_logs") -> None: """Optional post-hoc analysis (referenced by --analyze).""" print("=== Emergent Behavior Analysis ===") print("Metrics to track per episode:") print(" 1. Staging correlation: spearmanr(priority_ranks, movement_start_stages)") print(" 2. Max per-step delta: should decrease over training (velocity clamping)") print(" 3. Molar start stage: should increase (anchor strategy)") print(" 4. Anterior recovery speed: should be > posterior (after jitter)") print() print("Compare episode 1 vs episode 50 for each metric.") def main() -> None: parser = argparse.ArgumentParser(description="OrthoRL GRPO training (spec 1.1)") parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct-bnb-4bit") parser.add_argument("--steps", type=int, default=300, help="GRPO training steps") parser.add_argument("--num-generations", type=int, default=4) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--lora-r", type=int, default=16) parser.add_argument("--max-prompt-length", type=int, default=512) parser.add_argument("--max-completion-length", type=int, default=512) parser.add_argument("--max-seq-length", type=int, default=1024) parser.add_argument("--out", default="./checkpoints/grpo") parser.add_argument("--task-id", default=DEFAULT_TASK_ID, choices=["task_easy", "task_medium", "task_hard"]) parser.add_argument("--force-decay", action="store_true", help="Spec 1.3: enable pharmacokinetic force decay during training") parser.add_argument("--use-vllm", action="store_true") parser.add_argument("--wandb", action="store_true") parser.add_argument("--from-checkpoint", default=None, help="Resume from a SFT checkpoint (requires gate_passed.json)") parser.add_argument("--skip-sft-gate", action="store_true") parser.add_argument("--test", action="store_true", help="Verify reward functions without GPU/training") parser.add_argument("--analyze", action="store_true", help="Run post-hoc emergent-behaviour analysis") args = parser.parse_args() if args.analyze: analyze_emergent_behaviors() else: train(args) if __name__ == "__main__": main()