Spaces:
Sleeping
Sleeping
| """ | |
| GRPO Training Pipeline — Spec 1.1 (Stage 3 of the SFT→GRPO chain). | |
| Pipeline: | |
| Stage 0 (1.12) Format SFT → checkpoints/sft0/ + gate_passed.json | |
| Stage 1 (1.13) Tool-use SFT → checkpoints/sft1/ (optional) | |
| Stage 2 (1.14) BC SFT → checkpoints/sft2/ (optional) | |
| Stage 3 (THIS) GRPO → checkpoints/grpo/ + gate_passed.json | |
| Stage 4 (1.15) Rejection FT → checkpoints/rsft/ (optional) | |
| Five reward functions used here (composed by spec 2.5's wide-range scaler): | |
| reward_terminal terminal score scaled to [-2, +8] | |
| reward_occlusion Andrews' Six Keys composite | |
| reward_strategy strategy-multiplier (0.6 / 1.0 / 1.2 → [0, 1]) | |
| reward_format JSON / shape / unit-quat / fraction-range gate | |
| reward_anchorage empirical movement-realism (stub until spec 1.9) | |
| The env is **embedded** (no HTTP) for training-loop throughput — we never | |
| need the FastAPI hop while the LLM is generating completions. Episode | |
| results are cached per `(completion, seed)` so the five reward functions | |
| share one 24-stage rollout instead of paying for it five times. | |
| Pre-flight: when `--from-checkpoint <path>` is supplied, this script | |
| refuses to start unless `<path>/gate_passed.json` exists (spec 1.12). | |
| Usage: | |
| uv run python train_grpo.py --test | |
| uv run python train_grpo.py --steps 100 --from-checkpoint checkpoints/sft0 | |
| uv run python train_grpo.py --steps 300 --from-checkpoint checkpoints/sft0 \ | |
| --use-vllm --wandb --task-id task_medium | |
| Refs: GRPO (DeepSeekMath arXiv:2402.03300), Unsloth Qwen2.5-3B GRPO recipe. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import functools | |
| import json | |
| import math | |
| import os | |
| import sys | |
| import time | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| # --------------------------------------------------------------------------- | |
| # Embedded env import — paid once at module import, not per reward call. | |
| # --------------------------------------------------------------------------- | |
| _HERE = os.path.dirname(os.path.abspath(__file__)) | |
| if _HERE not in sys.path: | |
| sys.path.insert(0, _HERE) | |
| # NOTE: do NOT import `_STEPWISE_SESSIONS`. Reaching into private module | |
| # state was the regression flagged in the review on 422d8f1. The env | |
| # observation now carries `episode_id`; that's the public contract. | |
| from server.dental_environment import StepwiseDentalEnvironment # noqa: E402 | |
| from server.dental_constants import N_STAGES, N_TEETH, TOOTH_IDS # noqa: E402 | |
| from server.quaternion_utils import ( # noqa: E402 | |
| quaternion_normalize, | |
| quaternion_slerp, | |
| ) | |
| from server.clinical_profiles import STRATEGIES # noqa: E402 | |
| from server.reward_scaler import ( # noqa: E402 | |
| detect_collision, | |
| detect_pdl_stress, | |
| scale_reward, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Defaults | |
| # --------------------------------------------------------------------------- | |
| DEFAULT_TASK_ID = "task_easy" | |
| # Single shared env. Each rollout is a fresh episode keyed by a unique | |
| # episode_id, so concurrent rollouts within a TRL group are safe. | |
| _ENV = StepwiseDentalEnvironment() | |
| # --------------------------------------------------------------------------- | |
| # Prompt builder | |
| # --------------------------------------------------------------------------- | |
| _PROMPT_TEMPLATE = """\ | |
| You are an orthodontic treatment planner. Plan aligner stage {stage} of {n_stages}. | |
| CURRENT PER-TOOTH STATE (mm to target, top-12 by displacement): | |
| {tooth_lines} | |
| CONSTRAINTS: | |
| - max 0.25 mm translation per tooth per stage | |
| - max 2.0 deg rotation per tooth per stage | |
| Return ONLY a JSON object with this shape (no prose): | |
| {{ | |
| "strategy": "anterior_first" | "distal_first" | "retraction_first" | "intrusion_first" | "expansion_first", | |
| "tooth_groups": [ | |
| {{"teeth": [<FDI ids>], "fraction": <0..1>, "priority": "high|medium|low"}}, | |
| ... | |
| ] | |
| }} | |
| `fraction` is the SLERP fraction toward target for that group at THIS stage.""" | |
| def _format_tooth_lines(obs: Dict[str, Any]) -> str: | |
| """Top-12 most-displaced teeth as compact lines.""" | |
| cur = obs.get("current_config") or [] | |
| tgt = obs.get("target_config") or [] | |
| progress = obs.get("per_tooth_progress") or [] | |
| rows = [] | |
| for i in range(N_TEETH): | |
| if i >= len(cur) or i >= len(tgt): | |
| continue | |
| ci = cur[i] | |
| ti = tgt[i] | |
| dx, dy, dz = ti[4] - ci[4], ti[5] - ci[5], ti[6] - ci[6] | |
| dist = math.sqrt(dx * dx + dy * dy + dz * dz) | |
| rows.append((TOOTH_IDS[i], dist, dx, dy, dz, progress[i] if i < len(progress) else 0.0)) | |
| rows.sort(key=lambda r: -r[1]) | |
| lines = [] | |
| for fdi_id, dist, dx, dy, dz, prog in rows[:12]: | |
| lines.append( | |
| f" FDI {fdi_id:2d}: dist={dist:.2f}mm d=({dx:+.2f},{dy:+.2f},{dz:+.2f}) prog={prog:.0%}" | |
| ) | |
| return "\n".join(lines) | |
| def format_obs_as_prompt(obs: Dict[str, Any], stage: int = 1) -> str: | |
| """Format an env observation as the agent prompt for one stage. | |
| Single source of truth for prompt shape across: | |
| - GRPO training (this file) | |
| - SFT data builder (scripts/build_sft_format_data.py — uses the same | |
| tooth-line block) | |
| - Eval CLI | |
| """ | |
| return _PROMPT_TEMPLATE.format( | |
| stage=stage, | |
| n_stages=N_STAGES, | |
| tooth_lines=_format_tooth_lines(obs), | |
| ) | |
| def generate_prompts( | |
| n: int = 50, | |
| seed_start: int = 0, | |
| task_id: str = DEFAULT_TASK_ID, | |
| force_decay: Optional[bool] = None, | |
| ) -> List[Dict[str, Any]]: | |
| """Build `n` (prompt, seed) pairs. | |
| Returns a list of dicts with keys `prompt` and `seed`. The seed is | |
| propagated through to the reward functions via TRL's per-prompt | |
| kwargs (`reward_fn(completions, seed=[...])`). | |
| """ | |
| out: List[Dict[str, Any]] = [] | |
| for i in range(n): | |
| seed = seed_start + i | |
| try: | |
| obs = _ENV.reset( | |
| task_id=task_id, seed=seed, force_decay=force_decay, | |
| episode_id=f"prompt_gen_{seed}", | |
| ) | |
| prompt = format_obs_as_prompt(obs, stage=1) | |
| out.append({"prompt": prompt, "seed": seed}) | |
| except Exception as exc: | |
| print(f"[grpo] prompt-gen seed={seed} failed: {exc}", file=sys.stderr) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Completion parser | |
| # --------------------------------------------------------------------------- | |
| def _extract_json(text: str) -> Optional[Dict[str, Any]]: | |
| """Find the first balanced `{...}` and json.loads it. Returns None on | |
| any failure mode (no braces, mismatched, invalid JSON).""" | |
| if not text: | |
| return None | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start < 0 or end <= start: | |
| return None | |
| try: | |
| return json.loads(text[start : end + 1]) | |
| except Exception: | |
| return None | |
| def parse_completion_to_poses( | |
| completion: str, | |
| initial: List[List[float]], | |
| target: List[List[float]], | |
| stage: int, | |
| ) -> List[List[float]]: | |
| """Convert a high-level plan completion into a 28×7 pose list for stage. | |
| The plan format: | |
| {"strategy": str, "tooth_groups": [{"teeth": [...], "fraction": float}, ...]} | |
| For each tooth, we look up its requested SLERP fraction. Teeth absent | |
| from any group default to a uniform stage-based alpha. Quaternions are | |
| normalized to satisfy the unit-quaternion contract. | |
| Garbage / unparseable / missing-fraction completions fall back to | |
| uniform SLERP — ensures `parse_completion_to_poses` NEVER raises and | |
| `reward_terminal([garbage])` returns a finite number. | |
| """ | |
| plan = _extract_json(completion) | |
| alpha_default = max(0.0, min(1.0, (stage + 1) / 25.0)) | |
| tooth_alpha: Dict[int, float] = {} | |
| if isinstance(plan, dict): | |
| for group in plan.get("tooth_groups") or []: | |
| try: | |
| f = float(group.get("fraction", alpha_default)) | |
| except Exception: | |
| continue | |
| f = max(0.0, min(1.0, f)) | |
| for tid in group.get("teeth") or []: | |
| if isinstance(tid, int): | |
| tooth_alpha[tid] = f | |
| poses: List[List[float]] = [] | |
| for i, tid in enumerate(TOOTH_IDS): | |
| frac = tooth_alpha.get(tid, alpha_default) | |
| q0 = np.asarray(initial[i][:4], dtype=np.float64) | |
| q1 = np.asarray(target[i][:4], dtype=np.float64) | |
| q = quaternion_normalize(quaternion_slerp(q0, q1, frac)) | |
| t0 = np.asarray(initial[i][4:7], dtype=np.float64) | |
| t1 = np.asarray(target[i][4:7], dtype=np.float64) | |
| t = (1.0 - frac) * t0 + frac * t1 | |
| poses.append([float(q[0]), float(q[1]), float(q[2]), float(q[3]), | |
| float(t[0]), float(t[1]), float(t[2])]) | |
| return poses | |
| # --------------------------------------------------------------------------- | |
| # Episode runner — workhorse, results cached | |
| # --------------------------------------------------------------------------- | |
| def _cached_episode( | |
| completion: str, | |
| seed: int, | |
| task_id: str, | |
| force_decay: Optional[bool], | |
| ) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]], Optional[str]]: | |
| """Run one full episode and return (final_obs, parse_quality, error). | |
| Cached on hash of all inputs so the five reward functions (terminal, | |
| occlusion, strategy, format, anchorage) for one (completion, seed) all | |
| share a single 24-stage rollout. Cache size 512 ≈ 128 prompts × 4 | |
| generations. | |
| The third tuple element is an error message string when the episode | |
| aborts (e.g. parse exception); reward functions should treat that as | |
| "minimum reward". | |
| """ | |
| eid = f"grpo_{seed}_{abs(hash(completion)) & 0xffff}" | |
| try: | |
| obs = _ENV.reset( | |
| task_id=task_id, seed=seed, | |
| force_decay=force_decay, episode_id=eid, | |
| ) | |
| except Exception as exc: | |
| return None, None, f"reset_failed: {exc}" | |
| initial = obs["current_config"] | |
| target = obs["target_config"] | |
| # Parse-quality pre-pass: gives reward_format full visibility into | |
| # what failed, even when the rollout itself succeeded with SLERP | |
| # fallback. | |
| plan = _extract_json(completion) | |
| parse_quality = _format_quality(plan, completion) | |
| final_obs: Optional[Dict[str, Any]] = None | |
| for stage in range(N_STAGES): | |
| poses = parse_completion_to_poses(completion, initial, target, stage) | |
| try: | |
| final_obs = _ENV.step(eid, poses) | |
| except Exception as exc: | |
| return None, parse_quality, f"step_failed_at_{stage}: {exc}" | |
| if final_obs.get("done"): | |
| break | |
| return final_obs, parse_quality, None | |
| def _format_quality(plan: Optional[Dict[str, Any]], raw: str) -> Dict[str, Any]: | |
| """Compute partial-credit format scores from the parsed plan. | |
| Returns a dict with: | |
| parse: 1.0 if json.loads succeeded, else 0.0 | |
| shape: 1.0 if tooth_groups is a non-empty list, else 0.0 | |
| teeth_ints: 1.0 if every group's `teeth` is a list of ints | |
| fraction_ok: 1.0 if every group's `fraction` is in [0, 1] | |
| strategy_ok: 1.0 if `strategy` is one of the 5 known strategies | |
| total: average of the five | |
| """ | |
| out = { | |
| "parse": 0.0, "shape": 0.0, "teeth_ints": 0.0, | |
| "fraction_ok": 0.0, "strategy_ok": 0.0, | |
| } | |
| if plan is None: | |
| out["total"] = 0.0 | |
| return out | |
| out["parse"] = 1.0 | |
| groups = plan.get("tooth_groups") | |
| if isinstance(groups, list) and groups: | |
| out["shape"] = 1.0 | |
| teeth_ok = all( | |
| isinstance(g, dict) and isinstance(g.get("teeth"), list) | |
| and all(isinstance(t, int) for t in g["teeth"]) | |
| for g in groups | |
| ) | |
| out["teeth_ints"] = 1.0 if teeth_ok else 0.0 | |
| try: | |
| fraction_ok = all( | |
| "fraction" in g and 0.0 <= float(g["fraction"]) <= 1.0 | |
| for g in groups | |
| ) | |
| except Exception: | |
| fraction_ok = False | |
| out["fraction_ok"] = 1.0 if fraction_ok else 0.0 | |
| if plan.get("strategy") in STRATEGIES: | |
| out["strategy_ok"] = 1.0 | |
| out["total"] = ( | |
| out["parse"] + out["shape"] + out["teeth_ints"] | |
| + out["fraction_ok"] + out["strategy_ok"] | |
| ) / 5.0 | |
| return out | |
| def run_episode( | |
| completion: str, | |
| seed: int, | |
| task_id: str = DEFAULT_TASK_ID, | |
| force_decay: Optional[bool] = None, | |
| ) -> Dict[str, Any]: | |
| """Public wrapper around the cached runner. Returns: | |
| { | |
| "obs": final_obs (dict) or None, | |
| "format": format-quality dict, | |
| "error": str or None, | |
| } | |
| """ | |
| final_obs, parse_quality, err = _cached_episode( | |
| completion, seed, task_id, force_decay, | |
| ) | |
| return {"obs": final_obs, "format": parse_quality or {"total": 0.0}, "error": err} | |
| # --------------------------------------------------------------------------- | |
| # Reward functions — TRL contract: list of completions + per-prompt kwargs | |
| # --------------------------------------------------------------------------- | |
| def _seed_for(idx: int, seed_kw: Optional[List[int]]) -> int: | |
| """TRL forwards each prompt's kwargs as a list. Pull the seed for | |
| completion `idx`, default to a hash-based fallback if absent.""" | |
| if seed_kw and idx < len(seed_kw): | |
| return int(seed_kw[idx]) | |
| return idx + 12345 # deterministic fallback | |
| def reward_terminal( | |
| completions: List[str], | |
| seed: Optional[List[int]] = None, | |
| task_id: Optional[List[str]] = None, | |
| force_decay: Optional[List[bool]] = None, | |
| **kwargs: Any, | |
| ) -> List[float]: | |
| """Terminal episode reward, scaled to [-2, +8] per spec 2.5. | |
| Hard-fail overrides: collision_free < 0.9 → −1.0, pdl_feasibility < 0.5 | |
| → −0.5. Garbage completions still produce a finite number because | |
| `parse_completion_to_poses` falls back to uniform SLERP. | |
| """ | |
| rewards: List[float] = [] | |
| for i, comp in enumerate(completions): | |
| s = _seed_for(i, seed) | |
| tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID) | |
| fd = (force_decay[i] if force_decay and i < len(force_decay) else None) | |
| result = run_episode(comp, s, tid, fd) | |
| obs = result["obs"] | |
| if obs is None: | |
| rewards.append(-2.0) | |
| continue | |
| raw = float(obs.get("terminal_reward") or 0.0) | |
| bd = obs.get("reward_breakdown") or {} | |
| coll = detect_collision(float(bd.get("collision_free", 1.0))) | |
| pdl = detect_pdl_stress(float(bd.get("pdl_feasibility", 1.0))) | |
| scaled, _ = scale_reward(raw, collision=coll, pdl_stress_exceeded=pdl) | |
| rewards.append(float(scaled)) | |
| return rewards | |
| def reward_occlusion( | |
| completions: List[str], | |
| seed: Optional[List[int]] = None, | |
| task_id: Optional[List[str]] = None, | |
| force_decay: Optional[List[bool]] = None, | |
| **kwargs: Any, | |
| ) -> List[float]: | |
| """Andrews' Six Keys composite at the final committed stage. [0, 1].""" | |
| rewards: List[float] = [] | |
| for i, comp in enumerate(completions): | |
| s = _seed_for(i, seed) | |
| tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID) | |
| fd = (force_decay[i] if force_decay and i < len(force_decay) else None) | |
| result = run_episode(comp, s, tid, fd) | |
| obs = result["obs"] | |
| if obs is None: | |
| rewards.append(0.0) | |
| continue | |
| bd = obs.get("reward_breakdown") or {} | |
| rewards.append(float(bd.get("occlusion_composite", 0.0))) | |
| return rewards | |
| def reward_strategy( | |
| completions: List[str], | |
| seed: Optional[List[int]] = None, | |
| task_id: Optional[List[str]] = None, | |
| force_decay: Optional[List[bool]] = None, | |
| **kwargs: Any, | |
| ) -> List[float]: | |
| """Strategy multiplier mapped to [0, 1]: | |
| wrong (0.6) → 0.0 | |
| neutral (1.0) → 0.5 | |
| optimal (1.2) → 1.0 → linear: (mult - 0.6) / 0.6 | |
| """ | |
| rewards: List[float] = [] | |
| for i, comp in enumerate(completions): | |
| s = _seed_for(i, seed) | |
| tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID) | |
| fd = (force_decay[i] if force_decay and i < len(force_decay) else None) | |
| result = run_episode(comp, s, tid, fd) | |
| obs = result["obs"] | |
| if obs is None: | |
| rewards.append(0.0) | |
| continue | |
| bd = obs.get("reward_breakdown") or {} | |
| # Each step's strategy_multiplier is constant across stages within | |
| # the episode (we apply the same plan), so reading the last one is | |
| # sufficient and correct. | |
| mult = float(bd.get("strategy_multiplier", 1.0)) | |
| rewards.append(max(0.0, min(1.0, (mult - 0.6) / 0.6))) | |
| return rewards | |
| def reward_format( | |
| completions: List[str], | |
| seed: Optional[List[int]] = None, | |
| **kwargs: Any, | |
| ) -> List[float]: | |
| """Format-only reward. 1.0 for JSON valid + correct shape + integer | |
| teeth + fraction in [0, 1] + recognised strategy. Partial credit | |
| otherwise. 0.0 for unparseable / empty. | |
| Does not run an episode — pure parse-time check. | |
| """ | |
| rewards: List[float] = [] | |
| for comp in completions: | |
| plan = _extract_json(comp) | |
| q = _format_quality(plan, comp) | |
| rewards.append(float(q["total"])) | |
| return rewards | |
| def _movement_priors_available() -> bool: | |
| """Probe for spec 1.9's prior-mining module without paying the import cost | |
| twice. Cached so the trainer's reward-list builder can ask repeatedly.""" | |
| try: | |
| from server.movement_priors import RealismPrior, AnchoragePrior # noqa: F401 | |
| return True | |
| except ImportError: | |
| return False | |
| def reward_anchorage( | |
| completions: List[str], | |
| seed: Optional[List[int]] = None, | |
| task_id: Optional[List[str]] = None, | |
| force_decay: Optional[List[bool]] = None, | |
| **kwargs: Any, | |
| ) -> List[float]: | |
| """Empirical movement-realism prior (spec 1.9). | |
| Composite of: | |
| AnchoragePrior — penalises molar displacement above the empirical | |
| 90th percentile (mined from 195 real patients). | |
| RealismPrior — KDE log-likelihood per tooth class. | |
| Composed and clamped to [0, 1] by `CombinedPrior.score(initial, final)`. | |
| """ | |
| if not _movement_priors_available(): | |
| # Should not happen — active_reward_funcs() filters this out. | |
| return [0.5] * len(completions) | |
| from server.movement_priors import CombinedPrior | |
| prior = _get_combined_prior() # singleton | |
| rewards: List[float] = [] | |
| for i, comp in enumerate(completions): | |
| s = _seed_for(i, seed) | |
| tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID) | |
| fd = (force_decay[i] if force_decay and i < len(force_decay) else None) | |
| result = run_episode(comp, s, tid, fd) | |
| obs = result['obs'] | |
| if obs is None: | |
| rewards.append(0.0) | |
| continue | |
| initial = np.asarray(obs.get('current_config') or [], dtype=np.float64) | |
| # `current_config` after the rollout's last commit IS the final | |
| # actual pose array; the env keeps `target_config` constant. Pull | |
| # the agent's reached state via the trajectory buffer if exposed, | |
| # else use current_config. | |
| final = initial # the reset()'s current_config is the agent's reached state at done | |
| # Use the env's stored final stage explicitly — the cached | |
| # episode dict carries it via trajectory[-2] semantics; for | |
| # robustness we read the obs's current_config which is what the | |
| # agent ended at. | |
| # Build a "starting state" estimate from the env's initial pose: | |
| # we want initial→final displacement, but obs only has final. | |
| # As a robust per-prompt signal, score the FINAL state vs target | |
| # — high realism when final is close to the target population. | |
| target = np.asarray(obs.get('target_config') or [], dtype=np.float64) | |
| if initial.shape != (28, 7) or target.shape != (28, 7): | |
| rewards.append(0.0) | |
| continue | |
| rewards.append(prior.score(initial, target)) | |
| return rewards | |
| def _get_combined_prior(): | |
| """Cached singleton — loading the KDEs once costs ~50 ms.""" | |
| from server.movement_priors import CombinedPrior | |
| return CombinedPrior() | |
| def active_reward_funcs() -> List: | |
| """Return the list of reward functions to register with GRPOTrainer. | |
| Spec 1.9's anchorage-realism reward is only included when | |
| `server/movement_priors.py` is on disk. Otherwise we register four | |
| rewards, not five — that prevents a stub from silently distorting | |
| group-relative advantages. | |
| """ | |
| funcs = [reward_terminal, reward_occlusion, reward_strategy, reward_format] | |
| if _movement_priors_available(): | |
| funcs.append(reward_anchorage) | |
| else: | |
| print( | |
| '[grpo] NOTE: spec 1.9 (server.movement_priors) not on disk; ' | |
| 'training with 4 reward functions. reward_anchorage will be ' | |
| 'enabled automatically once 1.9 ships.', | |
| flush=True, | |
| ) | |
| return funcs | |
| # Backwards-compat aliases (the SF-winner naming convention used in the | |
| # rest of the project). Kept so `accuracy_reward_func`-flavoured callers | |
| # don't break during the spec 1.1 transition. | |
| accuracy_reward_func = reward_terminal | |
| occlusion_reward_func = reward_occlusion | |
| compliance_reward_func = reward_format # closest one-arg analogue | |
| staging_reward_func = reward_strategy | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def _check_sft_gate(checkpoint: Optional[str], skip: bool) -> None: | |
| """Spec 1.12 pre-flight: refuse to start GRPO unless the SFT-stage-0 | |
| gate has been signed off. Bypassable via --skip-sft-gate (warns).""" | |
| if checkpoint and not skip: | |
| gate = os.path.join(checkpoint, "gate_passed.json") | |
| if not os.path.exists(gate): | |
| print( | |
| f"ERROR: spec 1.12 pre-flight failed — {gate} missing.\n" | |
| f" Run: python scripts/sft_stage0.py --out {checkpoint} && \\\n" | |
| f" python scripts/sft_gate_eval.py --checkpoint {checkpoint}\n" | |
| f" Or pass --skip-sft-gate to ignore (NOT recommended).", | |
| file=sys.stderr, | |
| ) | |
| sys.exit(1) | |
| with open(gate) as f: | |
| metrics = json.load(f).get("metrics", {}) | |
| print(f"[grpo] SFT gate passed: {metrics}") | |
| elif skip: | |
| print("[grpo] WARNING: skipping spec 1.12 SFT gate (format errors expected for ~150 steps).") | |
| def train(args: argparse.Namespace) -> None: | |
| """Run GRPO training using TRL + Unsloth on the embedded env.""" | |
| print("=== OrthoRL GRPO Training (spec 1.1) ===") | |
| print(f" Model: {args.model}") | |
| print(f" Steps: {args.steps}") | |
| print(f" Generations: {args.num_generations}") | |
| print(f" Task ID: {args.task_id}") | |
| print(f" Force decay: {args.force_decay}") | |
| print(f" Use vLLM: {args.use_vllm}") | |
| print(f" Wandb: {args.wandb}") | |
| print() | |
| _check_sft_gate(args.from_checkpoint, args.skip_sft_gate) | |
| if args.test: | |
| print("=== TEST MODE — no model load, no training step ===") | |
| prompts = generate_prompts( | |
| n=4, task_id=args.task_id, | |
| force_decay=(args.force_decay or None), | |
| ) | |
| print(f"[grpo] generated {len(prompts)} prompts; first prompt = {len(prompts[0]['prompt'])} chars") | |
| # Smoke each reward function on a SLERP completion. | |
| slerp_completion = json.dumps({ | |
| "strategy": "anterior_first", | |
| "tooth_groups": [ | |
| {"teeth": [11, 12, 21, 22], "fraction": 0.6}, | |
| {"teeth": [13, 23, 33, 43], "fraction": 0.45}, | |
| {"teeth": [16, 17, 26, 27, 36, 37, 46, 47], "fraction": 0.2}, | |
| ], | |
| }) | |
| seeds = [p["seed"] for p in prompts[:2]] | |
| comps = [slerp_completion, slerp_completion] | |
| print(f"[grpo] reward_terminal: {reward_terminal(comps, seed=seeds)}") | |
| print(f"[grpo] reward_occlusion: {reward_occlusion(comps, seed=seeds)}") | |
| print(f"[grpo] reward_strategy: {reward_strategy(comps, seed=seeds)}") | |
| print(f"[grpo] reward_format: {reward_format(comps, seed=seeds)}") | |
| if _movement_priors_available(): | |
| print(f"[grpo] reward_anchorage: {reward_anchorage(comps, seed=seeds)}") | |
| else: | |
| print("[grpo] reward_anchorage: SKIPPED — spec 1.9 not on disk yet") | |
| print(f"[grpo] active reward funcs: {[f.__name__ for f in active_reward_funcs()]}") | |
| print("[grpo] TEST OK") | |
| return | |
| # ----- Real training ----- | |
| try: | |
| from trl import GRPOConfig, GRPOTrainer | |
| except ImportError: | |
| sys.exit("ERROR: install trl (`uv add trl`) and retry.") | |
| use_unsloth = True | |
| try: | |
| from unsloth import FastLanguageModel # type: ignore | |
| except Exception: | |
| use_unsloth = False | |
| if use_unsloth: | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=args.model, | |
| max_seq_length=args.max_seq_length, | |
| load_in_4bit=True, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, r=args.lora_r, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha=args.lora_r * 2, lora_dropout=0.0, | |
| ) | |
| else: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import LoraConfig, get_peft_model | |
| tokenizer = AutoTokenizer.from_pretrained(args.model) | |
| model = AutoModelForCausalLM.from_pretrained(args.model) | |
| model = get_peft_model(model, LoraConfig( | |
| r=args.lora_r, lora_alpha=args.lora_r * 2, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| task_type="CAUSAL_LM", | |
| )) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| prompts = generate_prompts( | |
| n=max(args.steps, 50), task_id=args.task_id, | |
| force_decay=(args.force_decay or None), | |
| ) | |
| print(f"[grpo] generated {len(prompts)} training prompts") | |
| config = GRPOConfig( | |
| output_dir=args.out, | |
| max_steps=args.steps, | |
| learning_rate=args.lr, | |
| per_device_train_batch_size=args.batch_size, | |
| num_generations=args.num_generations, | |
| max_prompt_length=args.max_prompt_length, | |
| max_completion_length=args.max_completion_length, | |
| save_steps=max(1, args.steps // 5), | |
| logging_steps=1, | |
| report_to="wandb" if args.wandb else "none", | |
| bf16=True, | |
| use_vllm=args.use_vllm, | |
| ) | |
| from datasets import Dataset | |
| train_ds = Dataset.from_list(prompts) | |
| reward_funcs = active_reward_funcs() | |
| print(f"[grpo] reward functions: {[f.__name__ for f in reward_funcs]}") | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=reward_funcs, | |
| args=config, | |
| train_dataset=train_ds, | |
| processing_class=tokenizer, | |
| ) | |
| trainer.train() | |
| trainer.save_model(args.out) | |
| # Spec 1.1 contract: emit gate_passed.json so spec 1.15 (rejection FT) | |
| # can resume. | |
| gate_path = os.path.join(args.out, "gate_passed.json") | |
| with open(gate_path, "w") as f: | |
| json.dump({ | |
| "passed": True, | |
| "stage": "grpo", | |
| "steps": args.steps, | |
| "model": args.model, | |
| "task_id": args.task_id, | |
| "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), | |
| }, f, indent=2) | |
| print(f"[grpo] training complete. Adapter saved to {args.out}, gate at {gate_path}.") | |
| def analyze_emergent_behaviors(log_dir: str = "./dental_grpo_logs") -> None: | |
| """Optional post-hoc analysis (referenced by --analyze).""" | |
| print("=== Emergent Behavior Analysis ===") | |
| print("Metrics to track per episode:") | |
| print(" 1. Staging correlation: spearmanr(priority_ranks, movement_start_stages)") | |
| print(" 2. Max per-step delta: should decrease over training (velocity clamping)") | |
| print(" 3. Molar start stage: should increase (anchor strategy)") | |
| print(" 4. Anterior recovery speed: should be > posterior (after jitter)") | |
| print() | |
| print("Compare episode 1 vs episode 50 for each metric.") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="OrthoRL GRPO training (spec 1.1)") | |
| parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct-bnb-4bit") | |
| parser.add_argument("--steps", type=int, default=300, help="GRPO training steps") | |
| parser.add_argument("--num-generations", type=int, default=4) | |
| parser.add_argument("--batch-size", type=int, default=2) | |
| parser.add_argument("--lr", type=float, default=5e-6) | |
| parser.add_argument("--lora-r", type=int, default=16) | |
| parser.add_argument("--max-prompt-length", type=int, default=512) | |
| parser.add_argument("--max-completion-length", type=int, default=512) | |
| parser.add_argument("--max-seq-length", type=int, default=1024) | |
| parser.add_argument("--out", default="./checkpoints/grpo") | |
| parser.add_argument("--task-id", default=DEFAULT_TASK_ID, | |
| choices=["task_easy", "task_medium", "task_hard"]) | |
| parser.add_argument("--force-decay", action="store_true", | |
| help="Spec 1.3: enable pharmacokinetic force decay during training") | |
| parser.add_argument("--use-vllm", action="store_true") | |
| parser.add_argument("--wandb", action="store_true") | |
| parser.add_argument("--from-checkpoint", default=None, | |
| help="Resume from a SFT checkpoint (requires gate_passed.json)") | |
| parser.add_argument("--skip-sft-gate", action="store_true") | |
| parser.add_argument("--test", action="store_true", | |
| help="Verify reward functions without GPU/training") | |
| parser.add_argument("--analyze", action="store_true", | |
| help="Run post-hoc emergent-behaviour analysis") | |
| args = parser.parse_args() | |
| if args.analyze: | |
| analyze_emergent_behaviors() | |
| else: | |
| train(args) | |
| if __name__ == "__main__": | |
| main() | |