orthorl / clinical_exam.py
sri-manikanta's picture
Initial deploy: spec 1.5
cc2303a verified
Raw
History Blame Contribute Delete
16.8 kB
"""
clinical_exam.py — Spec 2.2: Clinical Knowledge Exam (Pre/Post Training)
10 multiple-choice questions that test whether a trained model has learned
the environment's strategy-reward mapping — NOT generic orthodontic knowledge.
Usage:
# Without a model (mock mode — scores questions randomly for testing)
python clinical_exam.py --mock
# With a model (GPU recommended for large models)
python clinical_exam.py --model unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit
# Score a trained checkpoint
python clinical_exam.py --model ./checkpoints/final
# Generate exam score curve from multiple checkpoints
python clinical_exam.py --checkpoints ./checkpoints/step-50 ./checkpoints/step-100 ./checkpoints/step-150
Pitch line: "Before training: 3/10. After 300 GRPO steps: 8/10. The agent
learned orthodontic reasoning without being taught a single fact."
"""
from __future__ import annotations
import argparse
import os
import re
import sys
from typing import Any, Dict, List, Optional
# ---------------------------------------------------------------------------
# Exam questions (environment-specific — tests reward structure, not textbook)
# ---------------------------------------------------------------------------
EXAM_QUESTIONS: List[Dict[str, Any]] = [
# --- Strategy selection (4 questions) ---
{
"question": (
"A patient presents with Class II Division 1 malocclusion (molar AP distance 3.5mm, "
"proclined upper incisors). Which treatment strategy maximises reward in this environment?"
),
"choices": {
"A": "anterior_first",
"B": "retraction_first",
"C": "expansion_first",
"D": "intrusion_first",
},
"answer": "B",
"category": "strategy_selection",
},
{
"question": (
"Moderate crowding (6mm arch deficit) with Class I molar relationship and no overbite problem. "
"Which strategy receives a 1.2× reward multiplier in this environment?"
),
"choices": {
"A": "retraction_first",
"B": "distal_first",
"C": "expansion_first",
"D": "intrusion_first",
},
"answer": "C",
"category": "strategy_selection",
},
{
"question": (
"Class III case with mesial molar relationship and -2.5mm AP offset (lower jaw forward). "
"Optimal strategy for this environment?"
),
"choices": {
"A": "distal_first",
"B": "anterior_first",
"C": "retraction_first",
"D": "expansion_first",
},
"answer": "A",
"category": "strategy_selection",
},
{
"question": (
"Deep bite (5.2mm overbite) with Class II Division 2 malocclusion. "
"Which strategy unlocks the deep-bite override bonus in this environment?"
),
"choices": {
"A": "expansion_first",
"B": "anterior_first",
"C": "retraction_first",
"D": "intrusion_first",
},
"answer": "D",
"category": "strategy_selection",
},
# --- Diagnosis interpretation (3 questions) ---
{
"question": (
"The tool diagnose_angle_class measures molar AP distance as 0.8mm. "
"What Angle classification does it return?"
),
"choices": {
"A": "Class I",
"B": "Class II Division 1",
"C": "Class II Division 2",
"D": "Class III",
},
"answer": "A",
"category": "diagnosis",
},
{
"question": (
"measure_crowding returns crowding_mm = 11.2. "
"What severity category does this correspond to in the environment?"
),
"choices": {
"A": "mild (< 4mm)",
"B": "moderate (4–8mm)",
"C": "severe (> 8mm)",
"D": "none (< 1mm)",
},
"answer": "C",
"category": "diagnosis",
},
{
"question": (
"measure_overbite returns overbite_mm = -1.5 (negative = lower arch anterior). "
"What is the correct classification?"
),
"choices": {
"A": "deep_bite",
"B": "normal",
"C": "open_bite",
"D": "edge_to_edge",
},
"answer": "C",
"category": "diagnosis",
},
# --- Staging order (2 questions) ---
{
"question": (
"In a retraction_first treatment strategy, which teeth are prioritised for movement "
"in the first 8 aligner stages?"
),
"choices": {
"A": "Lower incisors (31, 32, 41, 42)",
"B": "Upper molars (16, 17, 26, 27)",
"C": "Upper canines and premolars — retraction arc",
"D": "All 28 teeth equally",
},
"answer": "C",
"category": "staging",
},
{
"question": (
"The environment penalises a tooth movement that exceeds which per-stage translation limit?"
),
"choices": {
"A": "0.10 mm",
"B": "0.25 mm",
"C": "0.50 mm",
"D": "1.00 mm",
},
"answer": "B",
"category": "staging",
},
# --- Failure recovery (1 question) ---
{
"question": (
"An agent selects anterior_first on a Class II Div 1 case and receives a 0.6× reward multiplier. "
"What is the correct corrective action?"
),
"choices": {
"A": "Increase movement fraction for all teeth",
"B": "Call diagnose_angle_class first, then select retraction_first",
"C": "Skip diagnostic tools and rely on the default strategy",
"D": "Switch to expansion_first instead",
},
"answer": "B",
"category": "recovery",
},
]
assert len(EXAM_QUESTIONS) == 10, "Exam must have exactly 10 questions"
assert all(len(q["choices"]) == 4 for q in EXAM_QUESTIONS), "Each question must have 4 choices"
# ---------------------------------------------------------------------------
# Formatting + answer extraction
# ---------------------------------------------------------------------------
def _format_mcq(q: Dict[str, Any]) -> str:
"""Format a question dict as a multiple-choice prompt string."""
lines = [q["question"], ""]
for key in ("A", "B", "C", "D"):
lines.append(f" {key}) {q['choices'][key]}")
lines.append("\nAnswer with just the letter (A, B, C, or D):")
return "\n".join(lines)
def _extract_answer(text: str) -> str:
"""
Extract the first A/B/C/D letter from model output.
Handles: "B", "The answer is B", "B: Distal-first", "(B)", etc.
Returns "X" if no valid letter found.
"""
clean = text.strip().upper()
# Try patterns in order of specificity
for pattern in [
r"\bANSWER[:\s]+([ABCD])\b", # "Answer: B"
r"\bTHE ANSWER IS ([ABCD])\b", # "The answer is B"
r"^\s*([ABCD])[)\.:\s]", # "B)" or "B." or "B:"
r"\b([ABCD])\b", # any standalone letter
]:
m = re.search(pattern, clean)
if m:
return m.group(1)
return "X"
# ---------------------------------------------------------------------------
# Exam runner
# ---------------------------------------------------------------------------
def run_exam(
model: Any,
tokenizer: Any,
verbose: bool = False,
) -> Dict[str, Any]:
"""
Run the 10-question clinical exam on the given model.
Parameters
----------
model : transformers/unsloth model (must support generate())
tokenizer : corresponding tokenizer
verbose : bool
If True, print each question, predicted answer, and correct answer.
Returns
-------
dict with keys:
score : int — number correct
total : int — 10
pct : float — score / total
details : list[dict] — per-question results
"""
import torch
correct = 0
details: List[Dict[str, Any]] = []
for i, q in enumerate(EXAM_QUESTIONS):
prompt = _format_mcq(q)
messages = [
{
"role": "system",
"content": (
"You are an AI assistant taking a clinical orthodontics exam. "
"Answer each question with a single letter: A, B, C, or D."
),
},
{"role": "user", "content": prompt},
]
# Tokenize using chat template if available
try:
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
input_text = prompt
inputs = tokenizer(input_text, return_tensors="pt")
if hasattr(model, "device"):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=16,
do_sample=False,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(
out[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
predicted = _extract_answer(generated)
is_correct = predicted == q["answer"]
correct += int(is_correct)
details.append(
{
"question_idx": i + 1,
"category": q["category"],
"predicted": predicted,
"correct": q["answer"],
"is_correct": is_correct,
"raw_output": generated.strip()[:80],
}
)
if verbose:
mark = "✓" if is_correct else "✗"
print(
f" Q{i + 1} ({q['category']:20s}) {mark} predicted={predicted} correct={q['answer']}"
)
# Category breakdown
by_cat: Dict[str, Dict[str, int]] = {}
for d in details:
cat = d["category"]
by_cat.setdefault(cat, {"correct": 0, "total": 0})
by_cat[cat]["total"] += 1
by_cat[cat]["correct"] += int(d["is_correct"])
return {
"score": correct,
"total": len(EXAM_QUESTIONS),
"pct": correct / len(EXAM_QUESTIONS),
"details": details,
"by_category": by_cat,
}
def run_exam_mock(rng_seed: int = 0) -> Dict[str, Any]:
"""
Mock exam runner for testing without a GPU model.
Returns a result dict with random answers (useful for structural tests).
"""
import random
rng = random.Random(rng_seed)
correct = 0
details = []
for i, q in enumerate(EXAM_QUESTIONS):
predicted = rng.choice(["A", "B", "C", "D"])
is_correct = predicted == q["answer"]
correct += int(is_correct)
details.append(
{
"question_idx": i + 1,
"category": q["category"],
"predicted": predicted,
"correct": q["answer"],
"is_correct": is_correct,
"raw_output": predicted,
}
)
by_cat: Dict[str, Dict[str, int]] = {}
for d in details:
cat = d["category"]
by_cat.setdefault(cat, {"correct": 0, "total": 0})
by_cat[cat]["total"] += 1
by_cat[cat]["correct"] += int(d["is_correct"])
return {
"score": correct,
"total": 10,
"pct": correct / 10,
"details": details,
"by_category": by_cat,
}
# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------
def plot_exam_curve(
scores_by_step: Dict[int, int],
output_path: str = "results/exam_curve.png",
) -> str:
"""
Plot exam score vs training step.
Parameters
----------
scores_by_step : dict mapping step → score (0-10)
output_path : where to save
Returns
-------
str : saved path
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
steps = sorted(scores_by_step.keys())
scores = [scores_by_step[s] for s in steps]
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(steps, scores, "bo-", linewidth=2, markersize=8, label="Exam score")
ax.axhline(y=2.5, color="gray", linestyle=":", alpha=0.5, label="Random baseline (2.5/10)")
ax.set_xlabel("Training Step", fontsize=12)
ax.set_ylabel("Exam Score (/ 10)", fontsize=12)
ax.set_title("Clinical Knowledge: Score vs Training Steps", fontsize=13, fontweight="bold")
ax.set_ylim(0, 10.5)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
# Annotate first and last points
if steps:
ax.annotate(
f"{scores[0]}/10",
(steps[0], scores[0]),
textcoords="offset points",
xytext=(5, 5),
fontsize=9,
)
ax.annotate(
f"{scores[-1]}/10",
(steps[-1], scores[-1]),
textcoords="offset points",
xytext=(5, 5),
fontsize=9,
)
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
return os.path.abspath(output_path)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _print_result(result: Dict[str, Any], label: str = "") -> None:
prefix = f"[{label}] " if label else ""
print(f"{prefix}Score: {result['score']}/{result['total']} ({result['pct']:.0%})")
print(f"{prefix}By category:")
for cat, counts in result["by_category"].items():
pct = counts["correct"] / max(counts["total"], 1)
print(f" {cat:25s}: {counts['correct']}/{counts['total']} ({pct:.0%})")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="OrthoRL Clinical Knowledge Exam (spec 2.2)")
parser.add_argument("--model", default=None, help="Model ID or checkpoint path")
parser.add_argument(
"--mock", action="store_true", help="Run with random answers (no GPU required, for testing)"
)
parser.add_argument("--verbose", action="store_true", help="Print per-question results")
parser.add_argument(
"--checkpoints",
nargs="+",
default=None,
help="List of checkpoint paths for score-vs-step curve",
)
parser.add_argument(
"--output", default="results/exam_curve.png", help="Output path for exam curve plot"
)
args = parser.parse_args()
if args.mock:
result = run_exam_mock()
_print_result(result, label="mock")
sys.exit(0)
if args.checkpoints:
# Multi-checkpoint mode: plot score curve
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
print("ERROR: pip install transformers")
sys.exit(1)
scores_by_step: Dict[int, int] = {}
for i, ckpt in enumerate(args.checkpoints):
step = (i + 1) * 50 # assume 50-step intervals
print(f"\nLoading checkpoint {ckpt} (step {step}) ...")
try:
tok = AutoTokenizer.from_pretrained(ckpt)
mdl = AutoModelForCausalLM.from_pretrained(ckpt)
res = run_exam(mdl, tok, verbose=args.verbose)
scores_by_step[step] = res["score"]
_print_result(res, label=f"step-{step}")
except Exception as e:
print(f" WARNING: failed to load {ckpt}: {e}")
if scores_by_step:
saved = plot_exam_curve(scores_by_step, args.output)
print(f"\nExam curve saved: {saved}")
sys.exit(0)
if not args.model:
print("Provide --model, --mock, or --checkpoints. Use --mock for a quick test.")
sys.exit(1)
# Single model mode
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
print("ERROR: pip install transformers")
sys.exit(1)
print(f"Loading model: {args.model}")
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModelForCausalLM.from_pretrained(args.model)
print(f"\nRunning {len(EXAM_QUESTIONS)}-question clinical exam ...")
if args.verbose:
print()
result = run_exam(model, tokenizer, verbose=args.verbose)
print()
_print_result(result, label=args.model.split("/")[-1][:20])