Spaces:
Sleeping
Sleeping
File size: 30,383 Bytes
cc2303a 6aab25b cc2303a 6aab25b cc2303a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 | """
GRPO Training Pipeline — Spec 1.1 (Stage 3 of the SFT→GRPO chain).
Pipeline:
Stage 0 (1.12) Format SFT → checkpoints/sft0/ + gate_passed.json
Stage 1 (1.13) Tool-use SFT → checkpoints/sft1/ (optional)
Stage 2 (1.14) BC SFT → checkpoints/sft2/ (optional)
Stage 3 (THIS) GRPO → checkpoints/grpo/ + gate_passed.json
Stage 4 (1.15) Rejection FT → checkpoints/rsft/ (optional)
Five reward functions used here (composed by spec 2.5's wide-range scaler):
reward_terminal terminal score scaled to [-2, +8]
reward_occlusion Andrews' Six Keys composite
reward_strategy strategy-multiplier (0.6 / 1.0 / 1.2 → [0, 1])
reward_format JSON / shape / unit-quat / fraction-range gate
reward_anchorage empirical movement-realism (stub until spec 1.9)
The env is **embedded** (no HTTP) for training-loop throughput — we never
need the FastAPI hop while the LLM is generating completions. Episode
results are cached per `(completion, seed)` so the five reward functions
share one 24-stage rollout instead of paying for it five times.
Pre-flight: when `--from-checkpoint <path>` is supplied, this script
refuses to start unless `<path>/gate_passed.json` exists (spec 1.12).
Usage:
uv run python train_grpo.py --test
uv run python train_grpo.py --steps 100 --from-checkpoint checkpoints/sft0
uv run python train_grpo.py --steps 300 --from-checkpoint checkpoints/sft0 \
--use-vllm --wandb --task-id task_medium
Refs: GRPO (DeepSeekMath arXiv:2402.03300), Unsloth Qwen2.5-3B GRPO recipe.
"""
from __future__ import annotations
import argparse
import functools
import json
import math
import os
import sys
import time
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
# ---------------------------------------------------------------------------
# Embedded env import — paid once at module import, not per reward call.
# ---------------------------------------------------------------------------
_HERE = os.path.dirname(os.path.abspath(__file__))
if _HERE not in sys.path:
sys.path.insert(0, _HERE)
# NOTE: do NOT import `_STEPWISE_SESSIONS`. Reaching into private module
# state was the regression flagged in the review on 422d8f1. The env
# observation now carries `episode_id`; that's the public contract.
from server.dental_environment import StepwiseDentalEnvironment # noqa: E402
from server.dental_constants import N_STAGES, N_TEETH, TOOTH_IDS # noqa: E402
from server.quaternion_utils import ( # noqa: E402
quaternion_normalize,
quaternion_slerp,
)
from server.clinical_profiles import STRATEGIES # noqa: E402
from server.reward_scaler import ( # noqa: E402
detect_collision,
detect_pdl_stress,
scale_reward,
)
# ---------------------------------------------------------------------------
# Defaults
# ---------------------------------------------------------------------------
DEFAULT_TASK_ID = "task_easy"
# Single shared env. Each rollout is a fresh episode keyed by a unique
# episode_id, so concurrent rollouts within a TRL group are safe.
_ENV = StepwiseDentalEnvironment()
# ---------------------------------------------------------------------------
# Prompt builder
# ---------------------------------------------------------------------------
_PROMPT_TEMPLATE = """\
You are an orthodontic treatment planner. Plan aligner stage {stage} of {n_stages}.
CURRENT PER-TOOTH STATE (mm to target, top-12 by displacement):
{tooth_lines}
CONSTRAINTS:
- max 0.25 mm translation per tooth per stage
- max 2.0 deg rotation per tooth per stage
Return ONLY a JSON object with this shape (no prose):
{{
"strategy": "anterior_first" | "distal_first" | "retraction_first" | "intrusion_first" | "expansion_first",
"tooth_groups": [
{{"teeth": [<FDI ids>], "fraction": <0..1>, "priority": "high|medium|low"}},
...
]
}}
`fraction` is the SLERP fraction toward target for that group at THIS stage."""
def _format_tooth_lines(obs: Dict[str, Any]) -> str:
"""Top-12 most-displaced teeth as compact lines."""
cur = obs.get("current_config") or []
tgt = obs.get("target_config") or []
progress = obs.get("per_tooth_progress") or []
rows = []
for i in range(N_TEETH):
if i >= len(cur) or i >= len(tgt):
continue
ci = cur[i]
ti = tgt[i]
dx, dy, dz = ti[4] - ci[4], ti[5] - ci[5], ti[6] - ci[6]
dist = math.sqrt(dx * dx + dy * dy + dz * dz)
rows.append((TOOTH_IDS[i], dist, dx, dy, dz, progress[i] if i < len(progress) else 0.0))
rows.sort(key=lambda r: -r[1])
lines = []
for fdi_id, dist, dx, dy, dz, prog in rows[:12]:
lines.append(
f" FDI {fdi_id:2d}: dist={dist:.2f}mm d=({dx:+.2f},{dy:+.2f},{dz:+.2f}) prog={prog:.0%}"
)
return "\n".join(lines)
def format_obs_as_prompt(obs: Dict[str, Any], stage: int = 1) -> str:
"""Format an env observation as the agent prompt for one stage.
Single source of truth for prompt shape across:
- GRPO training (this file)
- SFT data builder (scripts/build_sft_format_data.py — uses the same
tooth-line block)
- Eval CLI
"""
return _PROMPT_TEMPLATE.format(
stage=stage,
n_stages=N_STAGES,
tooth_lines=_format_tooth_lines(obs),
)
def generate_prompts(
n: int = 50,
seed_start: int = 0,
task_id: str = DEFAULT_TASK_ID,
force_decay: Optional[bool] = None,
) -> List[Dict[str, Any]]:
"""Build `n` (prompt, seed) pairs.
Returns a list of dicts with keys `prompt` and `seed`. The seed is
propagated through to the reward functions via TRL's per-prompt
kwargs (`reward_fn(completions, seed=[...])`).
"""
out: List[Dict[str, Any]] = []
for i in range(n):
seed = seed_start + i
try:
obs = _ENV.reset(
task_id=task_id, seed=seed, force_decay=force_decay,
episode_id=f"prompt_gen_{seed}",
)
prompt = format_obs_as_prompt(obs, stage=1)
out.append({"prompt": prompt, "seed": seed})
except Exception as exc:
print(f"[grpo] prompt-gen seed={seed} failed: {exc}", file=sys.stderr)
return out
# ---------------------------------------------------------------------------
# Completion parser
# ---------------------------------------------------------------------------
def _extract_json(text: str) -> Optional[Dict[str, Any]]:
"""Find the first balanced `{...}` and json.loads it. Returns None on
any failure mode (no braces, mismatched, invalid JSON)."""
if not text:
return None
start = text.find("{")
end = text.rfind("}")
if start < 0 or end <= start:
return None
try:
return json.loads(text[start : end + 1])
except Exception:
return None
def parse_completion_to_poses(
completion: str,
initial: List[List[float]],
target: List[List[float]],
stage: int,
) -> List[List[float]]:
"""Convert a high-level plan completion into a 28×7 pose list for stage.
The plan format:
{"strategy": str, "tooth_groups": [{"teeth": [...], "fraction": float}, ...]}
For each tooth, we look up its requested SLERP fraction. Teeth absent
from any group default to a uniform stage-based alpha. Quaternions are
normalized to satisfy the unit-quaternion contract.
Garbage / unparseable / missing-fraction completions fall back to
uniform SLERP — ensures `parse_completion_to_poses` NEVER raises and
`reward_terminal([garbage])` returns a finite number.
"""
plan = _extract_json(completion)
alpha_default = max(0.0, min(1.0, (stage + 1) / 25.0))
tooth_alpha: Dict[int, float] = {}
if isinstance(plan, dict):
for group in plan.get("tooth_groups") or []:
try:
f = float(group.get("fraction", alpha_default))
except Exception:
continue
f = max(0.0, min(1.0, f))
for tid in group.get("teeth") or []:
if isinstance(tid, int):
tooth_alpha[tid] = f
poses: List[List[float]] = []
for i, tid in enumerate(TOOTH_IDS):
frac = tooth_alpha.get(tid, alpha_default)
q0 = np.asarray(initial[i][:4], dtype=np.float64)
q1 = np.asarray(target[i][:4], dtype=np.float64)
q = quaternion_normalize(quaternion_slerp(q0, q1, frac))
t0 = np.asarray(initial[i][4:7], dtype=np.float64)
t1 = np.asarray(target[i][4:7], dtype=np.float64)
t = (1.0 - frac) * t0 + frac * t1
poses.append([float(q[0]), float(q[1]), float(q[2]), float(q[3]),
float(t[0]), float(t[1]), float(t[2])])
return poses
# ---------------------------------------------------------------------------
# Episode runner — workhorse, results cached
# ---------------------------------------------------------------------------
@functools.lru_cache(maxsize=512)
def _cached_episode(
completion: str,
seed: int,
task_id: str,
force_decay: Optional[bool],
) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]], Optional[str]]:
"""Run one full episode and return (final_obs, parse_quality, error).
Cached on hash of all inputs so the five reward functions (terminal,
occlusion, strategy, format, anchorage) for one (completion, seed) all
share a single 24-stage rollout. Cache size 512 ≈ 128 prompts × 4
generations.
The third tuple element is an error message string when the episode
aborts (e.g. parse exception); reward functions should treat that as
"minimum reward".
"""
eid = f"grpo_{seed}_{abs(hash(completion)) & 0xffff}"
try:
obs = _ENV.reset(
task_id=task_id, seed=seed,
force_decay=force_decay, episode_id=eid,
)
except Exception as exc:
return None, None, f"reset_failed: {exc}"
initial = obs["current_config"]
target = obs["target_config"]
# Parse-quality pre-pass: gives reward_format full visibility into
# what failed, even when the rollout itself succeeded with SLERP
# fallback.
plan = _extract_json(completion)
parse_quality = _format_quality(plan, completion)
final_obs: Optional[Dict[str, Any]] = None
for stage in range(N_STAGES):
poses = parse_completion_to_poses(completion, initial, target, stage)
try:
final_obs = _ENV.step(eid, poses)
except Exception as exc:
return None, parse_quality, f"step_failed_at_{stage}: {exc}"
if final_obs.get("done"):
break
return final_obs, parse_quality, None
def _format_quality(plan: Optional[Dict[str, Any]], raw: str) -> Dict[str, Any]:
"""Compute partial-credit format scores from the parsed plan.
Returns a dict with:
parse: 1.0 if json.loads succeeded, else 0.0
shape: 1.0 if tooth_groups is a non-empty list, else 0.0
teeth_ints: 1.0 if every group's `teeth` is a list of ints
fraction_ok: 1.0 if every group's `fraction` is in [0, 1]
strategy_ok: 1.0 if `strategy` is one of the 5 known strategies
total: average of the five
"""
out = {
"parse": 0.0, "shape": 0.0, "teeth_ints": 0.0,
"fraction_ok": 0.0, "strategy_ok": 0.0,
}
if plan is None:
out["total"] = 0.0
return out
out["parse"] = 1.0
groups = plan.get("tooth_groups")
if isinstance(groups, list) and groups:
out["shape"] = 1.0
teeth_ok = all(
isinstance(g, dict) and isinstance(g.get("teeth"), list)
and all(isinstance(t, int) for t in g["teeth"])
for g in groups
)
out["teeth_ints"] = 1.0 if teeth_ok else 0.0
try:
fraction_ok = all(
"fraction" in g and 0.0 <= float(g["fraction"]) <= 1.0
for g in groups
)
except Exception:
fraction_ok = False
out["fraction_ok"] = 1.0 if fraction_ok else 0.0
if plan.get("strategy") in STRATEGIES:
out["strategy_ok"] = 1.0
out["total"] = (
out["parse"] + out["shape"] + out["teeth_ints"]
+ out["fraction_ok"] + out["strategy_ok"]
) / 5.0
return out
def run_episode(
completion: str,
seed: int,
task_id: str = DEFAULT_TASK_ID,
force_decay: Optional[bool] = None,
) -> Dict[str, Any]:
"""Public wrapper around the cached runner. Returns:
{
"obs": final_obs (dict) or None,
"format": format-quality dict,
"error": str or None,
}
"""
final_obs, parse_quality, err = _cached_episode(
completion, seed, task_id, force_decay,
)
return {"obs": final_obs, "format": parse_quality or {"total": 0.0}, "error": err}
# ---------------------------------------------------------------------------
# Reward functions — TRL contract: list of completions + per-prompt kwargs
# ---------------------------------------------------------------------------
def _seed_for(idx: int, seed_kw: Optional[List[int]]) -> int:
"""TRL forwards each prompt's kwargs as a list. Pull the seed for
completion `idx`, default to a hash-based fallback if absent."""
if seed_kw and idx < len(seed_kw):
return int(seed_kw[idx])
return idx + 12345 # deterministic fallback
def reward_terminal(
completions: List[str],
seed: Optional[List[int]] = None,
task_id: Optional[List[str]] = None,
force_decay: Optional[List[bool]] = None,
**kwargs: Any,
) -> List[float]:
"""Terminal episode reward, scaled to [-2, +8] per spec 2.5.
Hard-fail overrides: collision_free < 0.9 → −1.0, pdl_feasibility < 0.5
→ −0.5. Garbage completions still produce a finite number because
`parse_completion_to_poses` falls back to uniform SLERP.
"""
rewards: List[float] = []
for i, comp in enumerate(completions):
s = _seed_for(i, seed)
tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID)
fd = (force_decay[i] if force_decay and i < len(force_decay) else None)
result = run_episode(comp, s, tid, fd)
obs = result["obs"]
if obs is None:
rewards.append(-2.0)
continue
raw = float(obs.get("terminal_reward") or 0.0)
bd = obs.get("reward_breakdown") or {}
coll = detect_collision(float(bd.get("collision_free", 1.0)))
pdl = detect_pdl_stress(float(bd.get("pdl_feasibility", 1.0)))
scaled, _ = scale_reward(raw, collision=coll, pdl_stress_exceeded=pdl)
rewards.append(float(scaled))
return rewards
def reward_occlusion(
completions: List[str],
seed: Optional[List[int]] = None,
task_id: Optional[List[str]] = None,
force_decay: Optional[List[bool]] = None,
**kwargs: Any,
) -> List[float]:
"""Andrews' Six Keys composite at the final committed stage. [0, 1]."""
rewards: List[float] = []
for i, comp in enumerate(completions):
s = _seed_for(i, seed)
tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID)
fd = (force_decay[i] if force_decay and i < len(force_decay) else None)
result = run_episode(comp, s, tid, fd)
obs = result["obs"]
if obs is None:
rewards.append(0.0)
continue
bd = obs.get("reward_breakdown") or {}
rewards.append(float(bd.get("occlusion_composite", 0.0)))
return rewards
def reward_strategy(
completions: List[str],
seed: Optional[List[int]] = None,
task_id: Optional[List[str]] = None,
force_decay: Optional[List[bool]] = None,
**kwargs: Any,
) -> List[float]:
"""Strategy multiplier mapped to [0, 1]:
wrong (0.6) → 0.0
neutral (1.0) → 0.5
optimal (1.2) → 1.0 → linear: (mult - 0.6) / 0.6
"""
rewards: List[float] = []
for i, comp in enumerate(completions):
s = _seed_for(i, seed)
tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID)
fd = (force_decay[i] if force_decay and i < len(force_decay) else None)
result = run_episode(comp, s, tid, fd)
obs = result["obs"]
if obs is None:
rewards.append(0.0)
continue
bd = obs.get("reward_breakdown") or {}
# Each step's strategy_multiplier is constant across stages within
# the episode (we apply the same plan), so reading the last one is
# sufficient and correct.
mult = float(bd.get("strategy_multiplier", 1.0))
rewards.append(max(0.0, min(1.0, (mult - 0.6) / 0.6)))
return rewards
def reward_format(
completions: List[str],
seed: Optional[List[int]] = None,
**kwargs: Any,
) -> List[float]:
"""Format-only reward. 1.0 for JSON valid + correct shape + integer
teeth + fraction in [0, 1] + recognised strategy. Partial credit
otherwise. 0.0 for unparseable / empty.
Does not run an episode — pure parse-time check.
"""
rewards: List[float] = []
for comp in completions:
plan = _extract_json(comp)
q = _format_quality(plan, comp)
rewards.append(float(q["total"]))
return rewards
def _movement_priors_available() -> bool:
"""Probe for spec 1.9's prior-mining module without paying the import cost
twice. Cached so the trainer's reward-list builder can ask repeatedly."""
try:
from server.movement_priors import RealismPrior, AnchoragePrior # noqa: F401
return True
except ImportError:
return False
def reward_anchorage(
completions: List[str],
seed: Optional[List[int]] = None,
task_id: Optional[List[str]] = None,
force_decay: Optional[List[bool]] = None,
**kwargs: Any,
) -> List[float]:
"""Empirical movement-realism prior (spec 1.9).
Composite of:
AnchoragePrior — penalises molar displacement above the empirical
90th percentile (mined from 195 real patients).
RealismPrior — KDE log-likelihood per tooth class.
Composed and clamped to [0, 1] by `CombinedPrior.score(initial, final)`.
"""
if not _movement_priors_available():
# Should not happen — active_reward_funcs() filters this out.
return [0.5] * len(completions)
from server.movement_priors import CombinedPrior
prior = _get_combined_prior() # singleton
rewards: List[float] = []
for i, comp in enumerate(completions):
s = _seed_for(i, seed)
tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID)
fd = (force_decay[i] if force_decay and i < len(force_decay) else None)
result = run_episode(comp, s, tid, fd)
obs = result['obs']
if obs is None:
rewards.append(0.0)
continue
initial = np.asarray(obs.get('current_config') or [], dtype=np.float64)
# `current_config` after the rollout's last commit IS the final
# actual pose array; the env keeps `target_config` constant. Pull
# the agent's reached state via the trajectory buffer if exposed,
# else use current_config.
final = initial # the reset()'s current_config is the agent's reached state at done
# Use the env's stored final stage explicitly — the cached
# episode dict carries it via trajectory[-2] semantics; for
# robustness we read the obs's current_config which is what the
# agent ended at.
# Build a "starting state" estimate from the env's initial pose:
# we want initial→final displacement, but obs only has final.
# As a robust per-prompt signal, score the FINAL state vs target
# — high realism when final is close to the target population.
target = np.asarray(obs.get('target_config') or [], dtype=np.float64)
if initial.shape != (28, 7) or target.shape != (28, 7):
rewards.append(0.0)
continue
rewards.append(prior.score(initial, target))
return rewards
@functools.lru_cache(maxsize=1)
def _get_combined_prior():
"""Cached singleton — loading the KDEs once costs ~50 ms."""
from server.movement_priors import CombinedPrior
return CombinedPrior()
def active_reward_funcs() -> List:
"""Return the list of reward functions to register with GRPOTrainer.
Spec 1.9's anchorage-realism reward is only included when
`server/movement_priors.py` is on disk. Otherwise we register four
rewards, not five — that prevents a stub from silently distorting
group-relative advantages.
"""
funcs = [reward_terminal, reward_occlusion, reward_strategy, reward_format]
if _movement_priors_available():
funcs.append(reward_anchorage)
else:
print(
'[grpo] NOTE: spec 1.9 (server.movement_priors) not on disk; '
'training with 4 reward functions. reward_anchorage will be '
'enabled automatically once 1.9 ships.',
flush=True,
)
return funcs
# Backwards-compat aliases (the SF-winner naming convention used in the
# rest of the project). Kept so `accuracy_reward_func`-flavoured callers
# don't break during the spec 1.1 transition.
accuracy_reward_func = reward_terminal
occlusion_reward_func = reward_occlusion
compliance_reward_func = reward_format # closest one-arg analogue
staging_reward_func = reward_strategy
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _check_sft_gate(checkpoint: Optional[str], skip: bool) -> None:
"""Spec 1.12 pre-flight: refuse to start GRPO unless the SFT-stage-0
gate has been signed off. Bypassable via --skip-sft-gate (warns)."""
if checkpoint and not skip:
gate = os.path.join(checkpoint, "gate_passed.json")
if not os.path.exists(gate):
print(
f"ERROR: spec 1.12 pre-flight failed — {gate} missing.\n"
f" Run: python scripts/sft_stage0.py --out {checkpoint} && \\\n"
f" python scripts/sft_gate_eval.py --checkpoint {checkpoint}\n"
f" Or pass --skip-sft-gate to ignore (NOT recommended).",
file=sys.stderr,
)
sys.exit(1)
with open(gate) as f:
metrics = json.load(f).get("metrics", {})
print(f"[grpo] SFT gate passed: {metrics}")
elif skip:
print("[grpo] WARNING: skipping spec 1.12 SFT gate (format errors expected for ~150 steps).")
def train(args: argparse.Namespace) -> None:
"""Run GRPO training using TRL + Unsloth on the embedded env."""
print("=== OrthoRL GRPO Training (spec 1.1) ===")
print(f" Model: {args.model}")
print(f" Steps: {args.steps}")
print(f" Generations: {args.num_generations}")
print(f" Task ID: {args.task_id}")
print(f" Force decay: {args.force_decay}")
print(f" Use vLLM: {args.use_vllm}")
print(f" Wandb: {args.wandb}")
print()
_check_sft_gate(args.from_checkpoint, args.skip_sft_gate)
if args.test:
print("=== TEST MODE — no model load, no training step ===")
prompts = generate_prompts(
n=4, task_id=args.task_id,
force_decay=(args.force_decay or None),
)
print(f"[grpo] generated {len(prompts)} prompts; first prompt = {len(prompts[0]['prompt'])} chars")
# Smoke each reward function on a SLERP completion.
slerp_completion = json.dumps({
"strategy": "anterior_first",
"tooth_groups": [
{"teeth": [11, 12, 21, 22], "fraction": 0.6},
{"teeth": [13, 23, 33, 43], "fraction": 0.45},
{"teeth": [16, 17, 26, 27, 36, 37, 46, 47], "fraction": 0.2},
],
})
seeds = [p["seed"] for p in prompts[:2]]
comps = [slerp_completion, slerp_completion]
print(f"[grpo] reward_terminal: {reward_terminal(comps, seed=seeds)}")
print(f"[grpo] reward_occlusion: {reward_occlusion(comps, seed=seeds)}")
print(f"[grpo] reward_strategy: {reward_strategy(comps, seed=seeds)}")
print(f"[grpo] reward_format: {reward_format(comps, seed=seeds)}")
if _movement_priors_available():
print(f"[grpo] reward_anchorage: {reward_anchorage(comps, seed=seeds)}")
else:
print("[grpo] reward_anchorage: SKIPPED — spec 1.9 not on disk yet")
print(f"[grpo] active reward funcs: {[f.__name__ for f in active_reward_funcs()]}")
print("[grpo] TEST OK")
return
# ----- Real training -----
try:
from trl import GRPOConfig, GRPOTrainer
except ImportError:
sys.exit("ERROR: install trl (`uv add trl`) and retry.")
use_unsloth = True
try:
from unsloth import FastLanguageModel # type: ignore
except Exception:
use_unsloth = False
if use_unsloth:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_length,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model, r=args.lora_r,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=args.lora_r * 2, lora_dropout=0.0,
)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModelForCausalLM.from_pretrained(args.model)
model = get_peft_model(model, LoraConfig(
r=args.lora_r, lora_alpha=args.lora_r * 2,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
))
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
prompts = generate_prompts(
n=max(args.steps, 50), task_id=args.task_id,
force_decay=(args.force_decay or None),
)
print(f"[grpo] generated {len(prompts)} training prompts")
config = GRPOConfig(
output_dir=args.out,
max_steps=args.steps,
learning_rate=args.lr,
per_device_train_batch_size=args.batch_size,
num_generations=args.num_generations,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
save_steps=max(1, args.steps // 5),
logging_steps=1,
report_to="wandb" if args.wandb else "none",
bf16=True,
use_vllm=args.use_vllm,
)
from datasets import Dataset
train_ds = Dataset.from_list(prompts)
reward_funcs = active_reward_funcs()
print(f"[grpo] reward functions: {[f.__name__ for f in reward_funcs]}")
trainer = GRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=config,
train_dataset=train_ds,
processing_class=tokenizer,
)
trainer.train()
trainer.save_model(args.out)
# Spec 1.1 contract: emit gate_passed.json so spec 1.15 (rejection FT)
# can resume.
gate_path = os.path.join(args.out, "gate_passed.json")
with open(gate_path, "w") as f:
json.dump({
"passed": True,
"stage": "grpo",
"steps": args.steps,
"model": args.model,
"task_id": args.task_id,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
}, f, indent=2)
print(f"[grpo] training complete. Adapter saved to {args.out}, gate at {gate_path}.")
def analyze_emergent_behaviors(log_dir: str = "./dental_grpo_logs") -> None:
"""Optional post-hoc analysis (referenced by --analyze)."""
print("=== Emergent Behavior Analysis ===")
print("Metrics to track per episode:")
print(" 1. Staging correlation: spearmanr(priority_ranks, movement_start_stages)")
print(" 2. Max per-step delta: should decrease over training (velocity clamping)")
print(" 3. Molar start stage: should increase (anchor strategy)")
print(" 4. Anterior recovery speed: should be > posterior (after jitter)")
print()
print("Compare episode 1 vs episode 50 for each metric.")
def main() -> None:
parser = argparse.ArgumentParser(description="OrthoRL GRPO training (spec 1.1)")
parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct-bnb-4bit")
parser.add_argument("--steps", type=int, default=300, help="GRPO training steps")
parser.add_argument("--num-generations", type=int, default=4)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--max-prompt-length", type=int, default=512)
parser.add_argument("--max-completion-length", type=int, default=512)
parser.add_argument("--max-seq-length", type=int, default=1024)
parser.add_argument("--out", default="./checkpoints/grpo")
parser.add_argument("--task-id", default=DEFAULT_TASK_ID,
choices=["task_easy", "task_medium", "task_hard"])
parser.add_argument("--force-decay", action="store_true",
help="Spec 1.3: enable pharmacokinetic force decay during training")
parser.add_argument("--use-vllm", action="store_true")
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--from-checkpoint", default=None,
help="Resume from a SFT checkpoint (requires gate_passed.json)")
parser.add_argument("--skip-sft-gate", action="store_true")
parser.add_argument("--test", action="store_true",
help="Verify reward functions without GPU/training")
parser.add_argument("--analyze", action="store_true",
help="Run post-hoc emergent-behaviour analysis")
args = parser.parse_args()
if args.analyze:
analyze_emergent_behaviors()
else:
train(args)
if __name__ == "__main__":
main()
|