""" clinical_exam.py — Spec 2.2: Clinical Knowledge Exam (Pre/Post Training) 10 multiple-choice questions that test whether a trained model has learned the environment's strategy-reward mapping — NOT generic orthodontic knowledge. Usage: # Without a model (mock mode — scores questions randomly for testing) python clinical_exam.py --mock # With a model (GPU recommended for large models) python clinical_exam.py --model unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit # Score a trained checkpoint python clinical_exam.py --model ./checkpoints/final # Generate exam score curve from multiple checkpoints python clinical_exam.py --checkpoints ./checkpoints/step-50 ./checkpoints/step-100 ./checkpoints/step-150 Pitch line: "Before training: 3/10. After 300 GRPO steps: 8/10. The agent learned orthodontic reasoning without being taught a single fact." """ from __future__ import annotations import argparse import os import re import sys from typing import Any, Dict, List, Optional # --------------------------------------------------------------------------- # Exam questions (environment-specific — tests reward structure, not textbook) # --------------------------------------------------------------------------- EXAM_QUESTIONS: List[Dict[str, Any]] = [ # --- Strategy selection (4 questions) --- { "question": ( "A patient presents with Class II Division 1 malocclusion (molar AP distance 3.5mm, " "proclined upper incisors). Which treatment strategy maximises reward in this environment?" ), "choices": { "A": "anterior_first", "B": "retraction_first", "C": "expansion_first", "D": "intrusion_first", }, "answer": "B", "category": "strategy_selection", }, { "question": ( "Moderate crowding (6mm arch deficit) with Class I molar relationship and no overbite problem. " "Which strategy receives a 1.2× reward multiplier in this environment?" ), "choices": { "A": "retraction_first", "B": "distal_first", "C": "expansion_first", "D": "intrusion_first", }, "answer": "C", "category": "strategy_selection", }, { "question": ( "Class III case with mesial molar relationship and -2.5mm AP offset (lower jaw forward). " "Optimal strategy for this environment?" ), "choices": { "A": "distal_first", "B": "anterior_first", "C": "retraction_first", "D": "expansion_first", }, "answer": "A", "category": "strategy_selection", }, { "question": ( "Deep bite (5.2mm overbite) with Class II Division 2 malocclusion. " "Which strategy unlocks the deep-bite override bonus in this environment?" ), "choices": { "A": "expansion_first", "B": "anterior_first", "C": "retraction_first", "D": "intrusion_first", }, "answer": "D", "category": "strategy_selection", }, # --- Diagnosis interpretation (3 questions) --- { "question": ( "The tool diagnose_angle_class measures molar AP distance as 0.8mm. " "What Angle classification does it return?" ), "choices": { "A": "Class I", "B": "Class II Division 1", "C": "Class II Division 2", "D": "Class III", }, "answer": "A", "category": "diagnosis", }, { "question": ( "measure_crowding returns crowding_mm = 11.2. " "What severity category does this correspond to in the environment?" ), "choices": { "A": "mild (< 4mm)", "B": "moderate (4–8mm)", "C": "severe (> 8mm)", "D": "none (< 1mm)", }, "answer": "C", "category": "diagnosis", }, { "question": ( "measure_overbite returns overbite_mm = -1.5 (negative = lower arch anterior). " "What is the correct classification?" ), "choices": { "A": "deep_bite", "B": "normal", "C": "open_bite", "D": "edge_to_edge", }, "answer": "C", "category": "diagnosis", }, # --- Staging order (2 questions) --- { "question": ( "In a retraction_first treatment strategy, which teeth are prioritised for movement " "in the first 8 aligner stages?" ), "choices": { "A": "Lower incisors (31, 32, 41, 42)", "B": "Upper molars (16, 17, 26, 27)", "C": "Upper canines and premolars — retraction arc", "D": "All 28 teeth equally", }, "answer": "C", "category": "staging", }, { "question": ( "The environment penalises a tooth movement that exceeds which per-stage translation limit?" ), "choices": { "A": "0.10 mm", "B": "0.25 mm", "C": "0.50 mm", "D": "1.00 mm", }, "answer": "B", "category": "staging", }, # --- Failure recovery (1 question) --- { "question": ( "An agent selects anterior_first on a Class II Div 1 case and receives a 0.6× reward multiplier. " "What is the correct corrective action?" ), "choices": { "A": "Increase movement fraction for all teeth", "B": "Call diagnose_angle_class first, then select retraction_first", "C": "Skip diagnostic tools and rely on the default strategy", "D": "Switch to expansion_first instead", }, "answer": "B", "category": "recovery", }, ] assert len(EXAM_QUESTIONS) == 10, "Exam must have exactly 10 questions" assert all(len(q["choices"]) == 4 for q in EXAM_QUESTIONS), "Each question must have 4 choices" # --------------------------------------------------------------------------- # Formatting + answer extraction # --------------------------------------------------------------------------- def _format_mcq(q: Dict[str, Any]) -> str: """Format a question dict as a multiple-choice prompt string.""" lines = [q["question"], ""] for key in ("A", "B", "C", "D"): lines.append(f" {key}) {q['choices'][key]}") lines.append("\nAnswer with just the letter (A, B, C, or D):") return "\n".join(lines) def _extract_answer(text: str) -> str: """ Extract the first A/B/C/D letter from model output. Handles: "B", "The answer is B", "B: Distal-first", "(B)", etc. Returns "X" if no valid letter found. """ clean = text.strip().upper() # Try patterns in order of specificity for pattern in [ r"\bANSWER[:\s]+([ABCD])\b", # "Answer: B" r"\bTHE ANSWER IS ([ABCD])\b", # "The answer is B" r"^\s*([ABCD])[)\.:\s]", # "B)" or "B." or "B:" r"\b([ABCD])\b", # any standalone letter ]: m = re.search(pattern, clean) if m: return m.group(1) return "X" # --------------------------------------------------------------------------- # Exam runner # --------------------------------------------------------------------------- def run_exam( model: Any, tokenizer: Any, verbose: bool = False, ) -> Dict[str, Any]: """ Run the 10-question clinical exam on the given model. Parameters ---------- model : transformers/unsloth model (must support generate()) tokenizer : corresponding tokenizer verbose : bool If True, print each question, predicted answer, and correct answer. Returns ------- dict with keys: score : int — number correct total : int — 10 pct : float — score / total details : list[dict] — per-question results """ import torch correct = 0 details: List[Dict[str, Any]] = [] for i, q in enumerate(EXAM_QUESTIONS): prompt = _format_mcq(q) messages = [ { "role": "system", "content": ( "You are an AI assistant taking a clinical orthodontics exam. " "Answer each question with a single letter: A, B, C, or D." ), }, {"role": "user", "content": prompt}, ] # Tokenize using chat template if available try: input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) except Exception: input_text = prompt inputs = tokenizer(input_text, return_tensors="pt") if hasattr(model, "device"): inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=16, do_sample=False, temperature=1.0, pad_token_id=tokenizer.eos_token_id, ) generated = tokenizer.decode( out[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True ) predicted = _extract_answer(generated) is_correct = predicted == q["answer"] correct += int(is_correct) details.append( { "question_idx": i + 1, "category": q["category"], "predicted": predicted, "correct": q["answer"], "is_correct": is_correct, "raw_output": generated.strip()[:80], } ) if verbose: mark = "✓" if is_correct else "✗" print( f" Q{i + 1} ({q['category']:20s}) {mark} predicted={predicted} correct={q['answer']}" ) # Category breakdown by_cat: Dict[str, Dict[str, int]] = {} for d in details: cat = d["category"] by_cat.setdefault(cat, {"correct": 0, "total": 0}) by_cat[cat]["total"] += 1 by_cat[cat]["correct"] += int(d["is_correct"]) return { "score": correct, "total": len(EXAM_QUESTIONS), "pct": correct / len(EXAM_QUESTIONS), "details": details, "by_category": by_cat, } def run_exam_mock(rng_seed: int = 0) -> Dict[str, Any]: """ Mock exam runner for testing without a GPU model. Returns a result dict with random answers (useful for structural tests). """ import random rng = random.Random(rng_seed) correct = 0 details = [] for i, q in enumerate(EXAM_QUESTIONS): predicted = rng.choice(["A", "B", "C", "D"]) is_correct = predicted == q["answer"] correct += int(is_correct) details.append( { "question_idx": i + 1, "category": q["category"], "predicted": predicted, "correct": q["answer"], "is_correct": is_correct, "raw_output": predicted, } ) by_cat: Dict[str, Dict[str, int]] = {} for d in details: cat = d["category"] by_cat.setdefault(cat, {"correct": 0, "total": 0}) by_cat[cat]["total"] += 1 by_cat[cat]["correct"] += int(d["is_correct"]) return { "score": correct, "total": 10, "pct": correct / 10, "details": details, "by_category": by_cat, } # --------------------------------------------------------------------------- # Plotting # --------------------------------------------------------------------------- def plot_exam_curve( scores_by_step: Dict[int, int], output_path: str = "results/exam_curve.png", ) -> str: """ Plot exam score vs training step. Parameters ---------- scores_by_step : dict mapping step → score (0-10) output_path : where to save Returns ------- str : saved path """ import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt steps = sorted(scores_by_step.keys()) scores = [scores_by_step[s] for s in steps] fig, ax = plt.subplots(figsize=(8, 5)) ax.plot(steps, scores, "bo-", linewidth=2, markersize=8, label="Exam score") ax.axhline(y=2.5, color="gray", linestyle=":", alpha=0.5, label="Random baseline (2.5/10)") ax.set_xlabel("Training Step", fontsize=12) ax.set_ylabel("Exam Score (/ 10)", fontsize=12) ax.set_title("Clinical Knowledge: Score vs Training Steps", fontsize=13, fontweight="bold") ax.set_ylim(0, 10.5) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) # Annotate first and last points if steps: ax.annotate( f"{scores[0]}/10", (steps[0], scores[0]), textcoords="offset points", xytext=(5, 5), fontsize=9, ) ax.annotate( f"{scores[-1]}/10", (steps[-1], scores[-1]), textcoords="offset points", xytext=(5, 5), fontsize=9, ) os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) return os.path.abspath(output_path) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _print_result(result: Dict[str, Any], label: str = "") -> None: prefix = f"[{label}] " if label else "" print(f"{prefix}Score: {result['score']}/{result['total']} ({result['pct']:.0%})") print(f"{prefix}By category:") for cat, counts in result["by_category"].items(): pct = counts["correct"] / max(counts["total"], 1) print(f" {cat:25s}: {counts['correct']}/{counts['total']} ({pct:.0%})") if __name__ == "__main__": parser = argparse.ArgumentParser(description="OrthoRL Clinical Knowledge Exam (spec 2.2)") parser.add_argument("--model", default=None, help="Model ID or checkpoint path") parser.add_argument( "--mock", action="store_true", help="Run with random answers (no GPU required, for testing)" ) parser.add_argument("--verbose", action="store_true", help="Print per-question results") parser.add_argument( "--checkpoints", nargs="+", default=None, help="List of checkpoint paths for score-vs-step curve", ) parser.add_argument( "--output", default="results/exam_curve.png", help="Output path for exam curve plot" ) args = parser.parse_args() if args.mock: result = run_exam_mock() _print_result(result, label="mock") sys.exit(0) if args.checkpoints: # Multi-checkpoint mode: plot score curve try: from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: print("ERROR: pip install transformers") sys.exit(1) scores_by_step: Dict[int, int] = {} for i, ckpt in enumerate(args.checkpoints): step = (i + 1) * 50 # assume 50-step intervals print(f"\nLoading checkpoint {ckpt} (step {step}) ...") try: tok = AutoTokenizer.from_pretrained(ckpt) mdl = AutoModelForCausalLM.from_pretrained(ckpt) res = run_exam(mdl, tok, verbose=args.verbose) scores_by_step[step] = res["score"] _print_result(res, label=f"step-{step}") except Exception as e: print(f" WARNING: failed to load {ckpt}: {e}") if scores_by_step: saved = plot_exam_curve(scores_by_step, args.output) print(f"\nExam curve saved: {saved}") sys.exit(0) if not args.model: print("Provide --model, --mock, or --checkpoints. Use --mock for a quick test.") sys.exit(1) # Single model mode try: from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: print("ERROR: pip install transformers") sys.exit(1) print(f"Loading model: {args.model}") tokenizer = AutoTokenizer.from_pretrained(args.model) model = AutoModelForCausalLM.from_pretrained(args.model) print(f"\nRunning {len(EXAM_QUESTIONS)}-question clinical exam ...") if args.verbose: print() result = run_exam(model, tokenizer, verbose=args.verbose) print() _print_result(result, label=args.model.split("/")[-1][:20])