orthorl / train_grpo.py
sri-manikanta's picture
Spec 1.9 + 1.10: anchorage priors and mesh collision
6aab25b verified
Raw
History Blame Contribute Delete
30.4 kB
"""
GRPO Training Pipeline — Spec 1.1 (Stage 3 of the SFT→GRPO chain).
Pipeline:
Stage 0 (1.12) Format SFT → checkpoints/sft0/ + gate_passed.json
Stage 1 (1.13) Tool-use SFT → checkpoints/sft1/ (optional)
Stage 2 (1.14) BC SFT → checkpoints/sft2/ (optional)
Stage 3 (THIS) GRPO → checkpoints/grpo/ + gate_passed.json
Stage 4 (1.15) Rejection FT → checkpoints/rsft/ (optional)
Five reward functions used here (composed by spec 2.5's wide-range scaler):
reward_terminal terminal score scaled to [-2, +8]
reward_occlusion Andrews' Six Keys composite
reward_strategy strategy-multiplier (0.6 / 1.0 / 1.2 → [0, 1])
reward_format JSON / shape / unit-quat / fraction-range gate
reward_anchorage empirical movement-realism (stub until spec 1.9)
The env is **embedded** (no HTTP) for training-loop throughput — we never
need the FastAPI hop while the LLM is generating completions. Episode
results are cached per `(completion, seed)` so the five reward functions
share one 24-stage rollout instead of paying for it five times.
Pre-flight: when `--from-checkpoint <path>` is supplied, this script
refuses to start unless `<path>/gate_passed.json` exists (spec 1.12).
Usage:
uv run python train_grpo.py --test
uv run python train_grpo.py --steps 100 --from-checkpoint checkpoints/sft0
uv run python train_grpo.py --steps 300 --from-checkpoint checkpoints/sft0 \
--use-vllm --wandb --task-id task_medium
Refs: GRPO (DeepSeekMath arXiv:2402.03300), Unsloth Qwen2.5-3B GRPO recipe.
"""
from __future__ import annotations
import argparse
import functools
import json
import math
import os
import sys
import time
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
# ---------------------------------------------------------------------------
# Embedded env import — paid once at module import, not per reward call.
# ---------------------------------------------------------------------------
_HERE = os.path.dirname(os.path.abspath(__file__))
if _HERE not in sys.path:
sys.path.insert(0, _HERE)
# NOTE: do NOT import `_STEPWISE_SESSIONS`. Reaching into private module
# state was the regression flagged in the review on 422d8f1. The env
# observation now carries `episode_id`; that's the public contract.
from server.dental_environment import StepwiseDentalEnvironment # noqa: E402
from server.dental_constants import N_STAGES, N_TEETH, TOOTH_IDS # noqa: E402
from server.quaternion_utils import ( # noqa: E402
quaternion_normalize,
quaternion_slerp,
)
from server.clinical_profiles import STRATEGIES # noqa: E402
from server.reward_scaler import ( # noqa: E402
detect_collision,
detect_pdl_stress,
scale_reward,
)
# ---------------------------------------------------------------------------
# Defaults
# ---------------------------------------------------------------------------
DEFAULT_TASK_ID = "task_easy"
# Single shared env. Each rollout is a fresh episode keyed by a unique
# episode_id, so concurrent rollouts within a TRL group are safe.
_ENV = StepwiseDentalEnvironment()
# ---------------------------------------------------------------------------
# Prompt builder
# ---------------------------------------------------------------------------
_PROMPT_TEMPLATE = """\
You are an orthodontic treatment planner. Plan aligner stage {stage} of {n_stages}.
CURRENT PER-TOOTH STATE (mm to target, top-12 by displacement):
{tooth_lines}
CONSTRAINTS:
- max 0.25 mm translation per tooth per stage
- max 2.0 deg rotation per tooth per stage
Return ONLY a JSON object with this shape (no prose):
{{
"strategy": "anterior_first" | "distal_first" | "retraction_first" | "intrusion_first" | "expansion_first",
"tooth_groups": [
{{"teeth": [<FDI ids>], "fraction": <0..1>, "priority": "high|medium|low"}},
...
]
}}
`fraction` is the SLERP fraction toward target for that group at THIS stage."""
def _format_tooth_lines(obs: Dict[str, Any]) -> str:
"""Top-12 most-displaced teeth as compact lines."""
cur = obs.get("current_config") or []
tgt = obs.get("target_config") or []
progress = obs.get("per_tooth_progress") or []
rows = []
for i in range(N_TEETH):
if i >= len(cur) or i >= len(tgt):
continue
ci = cur[i]
ti = tgt[i]
dx, dy, dz = ti[4] - ci[4], ti[5] - ci[5], ti[6] - ci[6]
dist = math.sqrt(dx * dx + dy * dy + dz * dz)
rows.append((TOOTH_IDS[i], dist, dx, dy, dz, progress[i] if i < len(progress) else 0.0))
rows.sort(key=lambda r: -r[1])
lines = []
for fdi_id, dist, dx, dy, dz, prog in rows[:12]:
lines.append(
f" FDI {fdi_id:2d}: dist={dist:.2f}mm d=({dx:+.2f},{dy:+.2f},{dz:+.2f}) prog={prog:.0%}"
)
return "\n".join(lines)
def format_obs_as_prompt(obs: Dict[str, Any], stage: int = 1) -> str:
"""Format an env observation as the agent prompt for one stage.
Single source of truth for prompt shape across:
- GRPO training (this file)
- SFT data builder (scripts/build_sft_format_data.py — uses the same
tooth-line block)
- Eval CLI
"""
return _PROMPT_TEMPLATE.format(
stage=stage,
n_stages=N_STAGES,
tooth_lines=_format_tooth_lines(obs),
)
def generate_prompts(
n: int = 50,
seed_start: int = 0,
task_id: str = DEFAULT_TASK_ID,
force_decay: Optional[bool] = None,
) -> List[Dict[str, Any]]:
"""Build `n` (prompt, seed) pairs.
Returns a list of dicts with keys `prompt` and `seed`. The seed is
propagated through to the reward functions via TRL's per-prompt
kwargs (`reward_fn(completions, seed=[...])`).
"""
out: List[Dict[str, Any]] = []
for i in range(n):
seed = seed_start + i
try:
obs = _ENV.reset(
task_id=task_id, seed=seed, force_decay=force_decay,
episode_id=f"prompt_gen_{seed}",
)
prompt = format_obs_as_prompt(obs, stage=1)
out.append({"prompt": prompt, "seed": seed})
except Exception as exc:
print(f"[grpo] prompt-gen seed={seed} failed: {exc}", file=sys.stderr)
return out
# ---------------------------------------------------------------------------
# Completion parser
# ---------------------------------------------------------------------------
def _extract_json(text: str) -> Optional[Dict[str, Any]]:
"""Find the first balanced `{...}` and json.loads it. Returns None on
any failure mode (no braces, mismatched, invalid JSON)."""
if not text:
return None
start = text.find("{")
end = text.rfind("}")
if start < 0 or end <= start:
return None
try:
return json.loads(text[start : end + 1])
except Exception:
return None
def parse_completion_to_poses(
completion: str,
initial: List[List[float]],
target: List[List[float]],
stage: int,
) -> List[List[float]]:
"""Convert a high-level plan completion into a 28×7 pose list for stage.
The plan format:
{"strategy": str, "tooth_groups": [{"teeth": [...], "fraction": float}, ...]}
For each tooth, we look up its requested SLERP fraction. Teeth absent
from any group default to a uniform stage-based alpha. Quaternions are
normalized to satisfy the unit-quaternion contract.
Garbage / unparseable / missing-fraction completions fall back to
uniform SLERP — ensures `parse_completion_to_poses` NEVER raises and
`reward_terminal([garbage])` returns a finite number.
"""
plan = _extract_json(completion)
alpha_default = max(0.0, min(1.0, (stage + 1) / 25.0))
tooth_alpha: Dict[int, float] = {}
if isinstance(plan, dict):
for group in plan.get("tooth_groups") or []:
try:
f = float(group.get("fraction", alpha_default))
except Exception:
continue
f = max(0.0, min(1.0, f))
for tid in group.get("teeth") or []:
if isinstance(tid, int):
tooth_alpha[tid] = f
poses: List[List[float]] = []
for i, tid in enumerate(TOOTH_IDS):
frac = tooth_alpha.get(tid, alpha_default)
q0 = np.asarray(initial[i][:4], dtype=np.float64)
q1 = np.asarray(target[i][:4], dtype=np.float64)
q = quaternion_normalize(quaternion_slerp(q0, q1, frac))
t0 = np.asarray(initial[i][4:7], dtype=np.float64)
t1 = np.asarray(target[i][4:7], dtype=np.float64)
t = (1.0 - frac) * t0 + frac * t1
poses.append([float(q[0]), float(q[1]), float(q[2]), float(q[3]),
float(t[0]), float(t[1]), float(t[2])])
return poses
# ---------------------------------------------------------------------------
# Episode runner — workhorse, results cached
# ---------------------------------------------------------------------------
@functools.lru_cache(maxsize=512)
def _cached_episode(
completion: str,
seed: int,
task_id: str,
force_decay: Optional[bool],
) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]], Optional[str]]:
"""Run one full episode and return (final_obs, parse_quality, error).
Cached on hash of all inputs so the five reward functions (terminal,
occlusion, strategy, format, anchorage) for one (completion, seed) all
share a single 24-stage rollout. Cache size 512 ≈ 128 prompts × 4
generations.
The third tuple element is an error message string when the episode
aborts (e.g. parse exception); reward functions should treat that as
"minimum reward".
"""
eid = f"grpo_{seed}_{abs(hash(completion)) & 0xffff}"
try:
obs = _ENV.reset(
task_id=task_id, seed=seed,
force_decay=force_decay, episode_id=eid,
)
except Exception as exc:
return None, None, f"reset_failed: {exc}"
initial = obs["current_config"]
target = obs["target_config"]
# Parse-quality pre-pass: gives reward_format full visibility into
# what failed, even when the rollout itself succeeded with SLERP
# fallback.
plan = _extract_json(completion)
parse_quality = _format_quality(plan, completion)
final_obs: Optional[Dict[str, Any]] = None
for stage in range(N_STAGES):
poses = parse_completion_to_poses(completion, initial, target, stage)
try:
final_obs = _ENV.step(eid, poses)
except Exception as exc:
return None, parse_quality, f"step_failed_at_{stage}: {exc}"
if final_obs.get("done"):
break
return final_obs, parse_quality, None
def _format_quality(plan: Optional[Dict[str, Any]], raw: str) -> Dict[str, Any]:
"""Compute partial-credit format scores from the parsed plan.
Returns a dict with:
parse: 1.0 if json.loads succeeded, else 0.0
shape: 1.0 if tooth_groups is a non-empty list, else 0.0
teeth_ints: 1.0 if every group's `teeth` is a list of ints
fraction_ok: 1.0 if every group's `fraction` is in [0, 1]
strategy_ok: 1.0 if `strategy` is one of the 5 known strategies
total: average of the five
"""
out = {
"parse": 0.0, "shape": 0.0, "teeth_ints": 0.0,
"fraction_ok": 0.0, "strategy_ok": 0.0,
}
if plan is None:
out["total"] = 0.0
return out
out["parse"] = 1.0
groups = plan.get("tooth_groups")
if isinstance(groups, list) and groups:
out["shape"] = 1.0
teeth_ok = all(
isinstance(g, dict) and isinstance(g.get("teeth"), list)
and all(isinstance(t, int) for t in g["teeth"])
for g in groups
)
out["teeth_ints"] = 1.0 if teeth_ok else 0.0
try:
fraction_ok = all(
"fraction" in g and 0.0 <= float(g["fraction"]) <= 1.0
for g in groups
)
except Exception:
fraction_ok = False
out["fraction_ok"] = 1.0 if fraction_ok else 0.0
if plan.get("strategy") in STRATEGIES:
out["strategy_ok"] = 1.0
out["total"] = (
out["parse"] + out["shape"] + out["teeth_ints"]
+ out["fraction_ok"] + out["strategy_ok"]
) / 5.0
return out
def run_episode(
completion: str,
seed: int,
task_id: str = DEFAULT_TASK_ID,
force_decay: Optional[bool] = None,
) -> Dict[str, Any]:
"""Public wrapper around the cached runner. Returns:
{
"obs": final_obs (dict) or None,
"format": format-quality dict,
"error": str or None,
}
"""
final_obs, parse_quality, err = _cached_episode(
completion, seed, task_id, force_decay,
)
return {"obs": final_obs, "format": parse_quality or {"total": 0.0}, "error": err}
# ---------------------------------------------------------------------------
# Reward functions — TRL contract: list of completions + per-prompt kwargs
# ---------------------------------------------------------------------------
def _seed_for(idx: int, seed_kw: Optional[List[int]]) -> int:
"""TRL forwards each prompt's kwargs as a list. Pull the seed for
completion `idx`, default to a hash-based fallback if absent."""
if seed_kw and idx < len(seed_kw):
return int(seed_kw[idx])
return idx + 12345 # deterministic fallback
def reward_terminal(
completions: List[str],
seed: Optional[List[int]] = None,
task_id: Optional[List[str]] = None,
force_decay: Optional[List[bool]] = None,
**kwargs: Any,
) -> List[float]:
"""Terminal episode reward, scaled to [-2, +8] per spec 2.5.
Hard-fail overrides: collision_free < 0.9 → −1.0, pdl_feasibility < 0.5
→ −0.5. Garbage completions still produce a finite number because
`parse_completion_to_poses` falls back to uniform SLERP.
"""
rewards: List[float] = []
for i, comp in enumerate(completions):
s = _seed_for(i, seed)
tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID)
fd = (force_decay[i] if force_decay and i < len(force_decay) else None)
result = run_episode(comp, s, tid, fd)
obs = result["obs"]
if obs is None:
rewards.append(-2.0)
continue
raw = float(obs.get("terminal_reward") or 0.0)
bd = obs.get("reward_breakdown") or {}
coll = detect_collision(float(bd.get("collision_free", 1.0)))
pdl = detect_pdl_stress(float(bd.get("pdl_feasibility", 1.0)))
scaled, _ = scale_reward(raw, collision=coll, pdl_stress_exceeded=pdl)
rewards.append(float(scaled))
return rewards
def reward_occlusion(
completions: List[str],
seed: Optional[List[int]] = None,
task_id: Optional[List[str]] = None,
force_decay: Optional[List[bool]] = None,
**kwargs: Any,
) -> List[float]:
"""Andrews' Six Keys composite at the final committed stage. [0, 1]."""
rewards: List[float] = []
for i, comp in enumerate(completions):
s = _seed_for(i, seed)
tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID)
fd = (force_decay[i] if force_decay and i < len(force_decay) else None)
result = run_episode(comp, s, tid, fd)
obs = result["obs"]
if obs is None:
rewards.append(0.0)
continue
bd = obs.get("reward_breakdown") or {}
rewards.append(float(bd.get("occlusion_composite", 0.0)))
return rewards
def reward_strategy(
completions: List[str],
seed: Optional[List[int]] = None,
task_id: Optional[List[str]] = None,
force_decay: Optional[List[bool]] = None,
**kwargs: Any,
) -> List[float]:
"""Strategy multiplier mapped to [0, 1]:
wrong (0.6) → 0.0
neutral (1.0) → 0.5
optimal (1.2) → 1.0 → linear: (mult - 0.6) / 0.6
"""
rewards: List[float] = []
for i, comp in enumerate(completions):
s = _seed_for(i, seed)
tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID)
fd = (force_decay[i] if force_decay and i < len(force_decay) else None)
result = run_episode(comp, s, tid, fd)
obs = result["obs"]
if obs is None:
rewards.append(0.0)
continue
bd = obs.get("reward_breakdown") or {}
# Each step's strategy_multiplier is constant across stages within
# the episode (we apply the same plan), so reading the last one is
# sufficient and correct.
mult = float(bd.get("strategy_multiplier", 1.0))
rewards.append(max(0.0, min(1.0, (mult - 0.6) / 0.6)))
return rewards
def reward_format(
completions: List[str],
seed: Optional[List[int]] = None,
**kwargs: Any,
) -> List[float]:
"""Format-only reward. 1.0 for JSON valid + correct shape + integer
teeth + fraction in [0, 1] + recognised strategy. Partial credit
otherwise. 0.0 for unparseable / empty.
Does not run an episode — pure parse-time check.
"""
rewards: List[float] = []
for comp in completions:
plan = _extract_json(comp)
q = _format_quality(plan, comp)
rewards.append(float(q["total"]))
return rewards
def _movement_priors_available() -> bool:
"""Probe for spec 1.9's prior-mining module without paying the import cost
twice. Cached so the trainer's reward-list builder can ask repeatedly."""
try:
from server.movement_priors import RealismPrior, AnchoragePrior # noqa: F401
return True
except ImportError:
return False
def reward_anchorage(
completions: List[str],
seed: Optional[List[int]] = None,
task_id: Optional[List[str]] = None,
force_decay: Optional[List[bool]] = None,
**kwargs: Any,
) -> List[float]:
"""Empirical movement-realism prior (spec 1.9).
Composite of:
AnchoragePrior — penalises molar displacement above the empirical
90th percentile (mined from 195 real patients).
RealismPrior — KDE log-likelihood per tooth class.
Composed and clamped to [0, 1] by `CombinedPrior.score(initial, final)`.
"""
if not _movement_priors_available():
# Should not happen — active_reward_funcs() filters this out.
return [0.5] * len(completions)
from server.movement_priors import CombinedPrior
prior = _get_combined_prior() # singleton
rewards: List[float] = []
for i, comp in enumerate(completions):
s = _seed_for(i, seed)
tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID)
fd = (force_decay[i] if force_decay and i < len(force_decay) else None)
result = run_episode(comp, s, tid, fd)
obs = result['obs']
if obs is None:
rewards.append(0.0)
continue
initial = np.asarray(obs.get('current_config') or [], dtype=np.float64)
# `current_config` after the rollout's last commit IS the final
# actual pose array; the env keeps `target_config` constant. Pull
# the agent's reached state via the trajectory buffer if exposed,
# else use current_config.
final = initial # the reset()'s current_config is the agent's reached state at done
# Use the env's stored final stage explicitly — the cached
# episode dict carries it via trajectory[-2] semantics; for
# robustness we read the obs's current_config which is what the
# agent ended at.
# Build a "starting state" estimate from the env's initial pose:
# we want initial→final displacement, but obs only has final.
# As a robust per-prompt signal, score the FINAL state vs target
# — high realism when final is close to the target population.
target = np.asarray(obs.get('target_config') or [], dtype=np.float64)
if initial.shape != (28, 7) or target.shape != (28, 7):
rewards.append(0.0)
continue
rewards.append(prior.score(initial, target))
return rewards
@functools.lru_cache(maxsize=1)
def _get_combined_prior():
"""Cached singleton — loading the KDEs once costs ~50 ms."""
from server.movement_priors import CombinedPrior
return CombinedPrior()
def active_reward_funcs() -> List:
"""Return the list of reward functions to register with GRPOTrainer.
Spec 1.9's anchorage-realism reward is only included when
`server/movement_priors.py` is on disk. Otherwise we register four
rewards, not five — that prevents a stub from silently distorting
group-relative advantages.
"""
funcs = [reward_terminal, reward_occlusion, reward_strategy, reward_format]
if _movement_priors_available():
funcs.append(reward_anchorage)
else:
print(
'[grpo] NOTE: spec 1.9 (server.movement_priors) not on disk; '
'training with 4 reward functions. reward_anchorage will be '
'enabled automatically once 1.9 ships.',
flush=True,
)
return funcs
# Backwards-compat aliases (the SF-winner naming convention used in the
# rest of the project). Kept so `accuracy_reward_func`-flavoured callers
# don't break during the spec 1.1 transition.
accuracy_reward_func = reward_terminal
occlusion_reward_func = reward_occlusion
compliance_reward_func = reward_format # closest one-arg analogue
staging_reward_func = reward_strategy
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _check_sft_gate(checkpoint: Optional[str], skip: bool) -> None:
"""Spec 1.12 pre-flight: refuse to start GRPO unless the SFT-stage-0
gate has been signed off. Bypassable via --skip-sft-gate (warns)."""
if checkpoint and not skip:
gate = os.path.join(checkpoint, "gate_passed.json")
if not os.path.exists(gate):
print(
f"ERROR: spec 1.12 pre-flight failed — {gate} missing.\n"
f" Run: python scripts/sft_stage0.py --out {checkpoint} && \\\n"
f" python scripts/sft_gate_eval.py --checkpoint {checkpoint}\n"
f" Or pass --skip-sft-gate to ignore (NOT recommended).",
file=sys.stderr,
)
sys.exit(1)
with open(gate) as f:
metrics = json.load(f).get("metrics", {})
print(f"[grpo] SFT gate passed: {metrics}")
elif skip:
print("[grpo] WARNING: skipping spec 1.12 SFT gate (format errors expected for ~150 steps).")
def train(args: argparse.Namespace) -> None:
"""Run GRPO training using TRL + Unsloth on the embedded env."""
print("=== OrthoRL GRPO Training (spec 1.1) ===")
print(f" Model: {args.model}")
print(f" Steps: {args.steps}")
print(f" Generations: {args.num_generations}")
print(f" Task ID: {args.task_id}")
print(f" Force decay: {args.force_decay}")
print(f" Use vLLM: {args.use_vllm}")
print(f" Wandb: {args.wandb}")
print()
_check_sft_gate(args.from_checkpoint, args.skip_sft_gate)
if args.test:
print("=== TEST MODE — no model load, no training step ===")
prompts = generate_prompts(
n=4, task_id=args.task_id,
force_decay=(args.force_decay or None),
)
print(f"[grpo] generated {len(prompts)} prompts; first prompt = {len(prompts[0]['prompt'])} chars")
# Smoke each reward function on a SLERP completion.
slerp_completion = json.dumps({
"strategy": "anterior_first",
"tooth_groups": [
{"teeth": [11, 12, 21, 22], "fraction": 0.6},
{"teeth": [13, 23, 33, 43], "fraction": 0.45},
{"teeth": [16, 17, 26, 27, 36, 37, 46, 47], "fraction": 0.2},
],
})
seeds = [p["seed"] for p in prompts[:2]]
comps = [slerp_completion, slerp_completion]
print(f"[grpo] reward_terminal: {reward_terminal(comps, seed=seeds)}")
print(f"[grpo] reward_occlusion: {reward_occlusion(comps, seed=seeds)}")
print(f"[grpo] reward_strategy: {reward_strategy(comps, seed=seeds)}")
print(f"[grpo] reward_format: {reward_format(comps, seed=seeds)}")
if _movement_priors_available():
print(f"[grpo] reward_anchorage: {reward_anchorage(comps, seed=seeds)}")
else:
print("[grpo] reward_anchorage: SKIPPED — spec 1.9 not on disk yet")
print(f"[grpo] active reward funcs: {[f.__name__ for f in active_reward_funcs()]}")
print("[grpo] TEST OK")
return
# ----- Real training -----
try:
from trl import GRPOConfig, GRPOTrainer
except ImportError:
sys.exit("ERROR: install trl (`uv add trl`) and retry.")
use_unsloth = True
try:
from unsloth import FastLanguageModel # type: ignore
except Exception:
use_unsloth = False
if use_unsloth:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_length,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model, r=args.lora_r,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=args.lora_r * 2, lora_dropout=0.0,
)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModelForCausalLM.from_pretrained(args.model)
model = get_peft_model(model, LoraConfig(
r=args.lora_r, lora_alpha=args.lora_r * 2,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
))
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
prompts = generate_prompts(
n=max(args.steps, 50), task_id=args.task_id,
force_decay=(args.force_decay or None),
)
print(f"[grpo] generated {len(prompts)} training prompts")
config = GRPOConfig(
output_dir=args.out,
max_steps=args.steps,
learning_rate=args.lr,
per_device_train_batch_size=args.batch_size,
num_generations=args.num_generations,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
save_steps=max(1, args.steps // 5),
logging_steps=1,
report_to="wandb" if args.wandb else "none",
bf16=True,
use_vllm=args.use_vllm,
)
from datasets import Dataset
train_ds = Dataset.from_list(prompts)
reward_funcs = active_reward_funcs()
print(f"[grpo] reward functions: {[f.__name__ for f in reward_funcs]}")
trainer = GRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=config,
train_dataset=train_ds,
processing_class=tokenizer,
)
trainer.train()
trainer.save_model(args.out)
# Spec 1.1 contract: emit gate_passed.json so spec 1.15 (rejection FT)
# can resume.
gate_path = os.path.join(args.out, "gate_passed.json")
with open(gate_path, "w") as f:
json.dump({
"passed": True,
"stage": "grpo",
"steps": args.steps,
"model": args.model,
"task_id": args.task_id,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
}, f, indent=2)
print(f"[grpo] training complete. Adapter saved to {args.out}, gate at {gate_path}.")
def analyze_emergent_behaviors(log_dir: str = "./dental_grpo_logs") -> None:
"""Optional post-hoc analysis (referenced by --analyze)."""
print("=== Emergent Behavior Analysis ===")
print("Metrics to track per episode:")
print(" 1. Staging correlation: spearmanr(priority_ranks, movement_start_stages)")
print(" 2. Max per-step delta: should decrease over training (velocity clamping)")
print(" 3. Molar start stage: should increase (anchor strategy)")
print(" 4. Anterior recovery speed: should be > posterior (after jitter)")
print()
print("Compare episode 1 vs episode 50 for each metric.")
def main() -> None:
parser = argparse.ArgumentParser(description="OrthoRL GRPO training (spec 1.1)")
parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct-bnb-4bit")
parser.add_argument("--steps", type=int, default=300, help="GRPO training steps")
parser.add_argument("--num-generations", type=int, default=4)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--max-prompt-length", type=int, default=512)
parser.add_argument("--max-completion-length", type=int, default=512)
parser.add_argument("--max-seq-length", type=int, default=1024)
parser.add_argument("--out", default="./checkpoints/grpo")
parser.add_argument("--task-id", default=DEFAULT_TASK_ID,
choices=["task_easy", "task_medium", "task_hard"])
parser.add_argument("--force-decay", action="store_true",
help="Spec 1.3: enable pharmacokinetic force decay during training")
parser.add_argument("--use-vllm", action="store_true")
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--from-checkpoint", default=None,
help="Resume from a SFT checkpoint (requires gate_passed.json)")
parser.add_argument("--skip-sft-gate", action="store_true")
parser.add_argument("--test", action="store_true",
help="Verify reward functions without GPU/training")
parser.add_argument("--analyze", action="store_true",
help="Run post-hoc emergent-behaviour analysis")
args = parser.parse_args()
if args.analyze:
analyze_emergent_behaviors()
else:
train(args)
if __name__ == "__main__":
main()