Spaces:
Sleeping
Sleeping
| """ | |
| clinical_exam.py — Spec 2.2: Clinical Knowledge Exam (Pre/Post Training) | |
| 10 multiple-choice questions that test whether a trained model has learned | |
| the environment's strategy-reward mapping — NOT generic orthodontic knowledge. | |
| Usage: | |
| # Without a model (mock mode — scores questions randomly for testing) | |
| python clinical_exam.py --mock | |
| # With a model (GPU recommended for large models) | |
| python clinical_exam.py --model unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit | |
| # Score a trained checkpoint | |
| python clinical_exam.py --model ./checkpoints/final | |
| # Generate exam score curve from multiple checkpoints | |
| python clinical_exam.py --checkpoints ./checkpoints/step-50 ./checkpoints/step-100 ./checkpoints/step-150 | |
| Pitch line: "Before training: 3/10. After 300 GRPO steps: 8/10. The agent | |
| learned orthodontic reasoning without being taught a single fact." | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import re | |
| import sys | |
| from typing import Any, Dict, List, Optional | |
| # --------------------------------------------------------------------------- | |
| # Exam questions (environment-specific — tests reward structure, not textbook) | |
| # --------------------------------------------------------------------------- | |
| EXAM_QUESTIONS: List[Dict[str, Any]] = [ | |
| # --- Strategy selection (4 questions) --- | |
| { | |
| "question": ( | |
| "A patient presents with Class II Division 1 malocclusion (molar AP distance 3.5mm, " | |
| "proclined upper incisors). Which treatment strategy maximises reward in this environment?" | |
| ), | |
| "choices": { | |
| "A": "anterior_first", | |
| "B": "retraction_first", | |
| "C": "expansion_first", | |
| "D": "intrusion_first", | |
| }, | |
| "answer": "B", | |
| "category": "strategy_selection", | |
| }, | |
| { | |
| "question": ( | |
| "Moderate crowding (6mm arch deficit) with Class I molar relationship and no overbite problem. " | |
| "Which strategy receives a 1.2× reward multiplier in this environment?" | |
| ), | |
| "choices": { | |
| "A": "retraction_first", | |
| "B": "distal_first", | |
| "C": "expansion_first", | |
| "D": "intrusion_first", | |
| }, | |
| "answer": "C", | |
| "category": "strategy_selection", | |
| }, | |
| { | |
| "question": ( | |
| "Class III case with mesial molar relationship and -2.5mm AP offset (lower jaw forward). " | |
| "Optimal strategy for this environment?" | |
| ), | |
| "choices": { | |
| "A": "distal_first", | |
| "B": "anterior_first", | |
| "C": "retraction_first", | |
| "D": "expansion_first", | |
| }, | |
| "answer": "A", | |
| "category": "strategy_selection", | |
| }, | |
| { | |
| "question": ( | |
| "Deep bite (5.2mm overbite) with Class II Division 2 malocclusion. " | |
| "Which strategy unlocks the deep-bite override bonus in this environment?" | |
| ), | |
| "choices": { | |
| "A": "expansion_first", | |
| "B": "anterior_first", | |
| "C": "retraction_first", | |
| "D": "intrusion_first", | |
| }, | |
| "answer": "D", | |
| "category": "strategy_selection", | |
| }, | |
| # --- Diagnosis interpretation (3 questions) --- | |
| { | |
| "question": ( | |
| "The tool diagnose_angle_class measures molar AP distance as 0.8mm. " | |
| "What Angle classification does it return?" | |
| ), | |
| "choices": { | |
| "A": "Class I", | |
| "B": "Class II Division 1", | |
| "C": "Class II Division 2", | |
| "D": "Class III", | |
| }, | |
| "answer": "A", | |
| "category": "diagnosis", | |
| }, | |
| { | |
| "question": ( | |
| "measure_crowding returns crowding_mm = 11.2. " | |
| "What severity category does this correspond to in the environment?" | |
| ), | |
| "choices": { | |
| "A": "mild (< 4mm)", | |
| "B": "moderate (4–8mm)", | |
| "C": "severe (> 8mm)", | |
| "D": "none (< 1mm)", | |
| }, | |
| "answer": "C", | |
| "category": "diagnosis", | |
| }, | |
| { | |
| "question": ( | |
| "measure_overbite returns overbite_mm = -1.5 (negative = lower arch anterior). " | |
| "What is the correct classification?" | |
| ), | |
| "choices": { | |
| "A": "deep_bite", | |
| "B": "normal", | |
| "C": "open_bite", | |
| "D": "edge_to_edge", | |
| }, | |
| "answer": "C", | |
| "category": "diagnosis", | |
| }, | |
| # --- Staging order (2 questions) --- | |
| { | |
| "question": ( | |
| "In a retraction_first treatment strategy, which teeth are prioritised for movement " | |
| "in the first 8 aligner stages?" | |
| ), | |
| "choices": { | |
| "A": "Lower incisors (31, 32, 41, 42)", | |
| "B": "Upper molars (16, 17, 26, 27)", | |
| "C": "Upper canines and premolars — retraction arc", | |
| "D": "All 28 teeth equally", | |
| }, | |
| "answer": "C", | |
| "category": "staging", | |
| }, | |
| { | |
| "question": ( | |
| "The environment penalises a tooth movement that exceeds which per-stage translation limit?" | |
| ), | |
| "choices": { | |
| "A": "0.10 mm", | |
| "B": "0.25 mm", | |
| "C": "0.50 mm", | |
| "D": "1.00 mm", | |
| }, | |
| "answer": "B", | |
| "category": "staging", | |
| }, | |
| # --- Failure recovery (1 question) --- | |
| { | |
| "question": ( | |
| "An agent selects anterior_first on a Class II Div 1 case and receives a 0.6× reward multiplier. " | |
| "What is the correct corrective action?" | |
| ), | |
| "choices": { | |
| "A": "Increase movement fraction for all teeth", | |
| "B": "Call diagnose_angle_class first, then select retraction_first", | |
| "C": "Skip diagnostic tools and rely on the default strategy", | |
| "D": "Switch to expansion_first instead", | |
| }, | |
| "answer": "B", | |
| "category": "recovery", | |
| }, | |
| ] | |
| assert len(EXAM_QUESTIONS) == 10, "Exam must have exactly 10 questions" | |
| assert all(len(q["choices"]) == 4 for q in EXAM_QUESTIONS), "Each question must have 4 choices" | |
| # --------------------------------------------------------------------------- | |
| # Formatting + answer extraction | |
| # --------------------------------------------------------------------------- | |
| def _format_mcq(q: Dict[str, Any]) -> str: | |
| """Format a question dict as a multiple-choice prompt string.""" | |
| lines = [q["question"], ""] | |
| for key in ("A", "B", "C", "D"): | |
| lines.append(f" {key}) {q['choices'][key]}") | |
| lines.append("\nAnswer with just the letter (A, B, C, or D):") | |
| return "\n".join(lines) | |
| def _extract_answer(text: str) -> str: | |
| """ | |
| Extract the first A/B/C/D letter from model output. | |
| Handles: "B", "The answer is B", "B: Distal-first", "(B)", etc. | |
| Returns "X" if no valid letter found. | |
| """ | |
| clean = text.strip().upper() | |
| # Try patterns in order of specificity | |
| for pattern in [ | |
| r"\bANSWER[:\s]+([ABCD])\b", # "Answer: B" | |
| r"\bTHE ANSWER IS ([ABCD])\b", # "The answer is B" | |
| r"^\s*([ABCD])[)\.:\s]", # "B)" or "B." or "B:" | |
| r"\b([ABCD])\b", # any standalone letter | |
| ]: | |
| m = re.search(pattern, clean) | |
| if m: | |
| return m.group(1) | |
| return "X" | |
| # --------------------------------------------------------------------------- | |
| # Exam runner | |
| # --------------------------------------------------------------------------- | |
| def run_exam( | |
| model: Any, | |
| tokenizer: Any, | |
| verbose: bool = False, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run the 10-question clinical exam on the given model. | |
| Parameters | |
| ---------- | |
| model : transformers/unsloth model (must support generate()) | |
| tokenizer : corresponding tokenizer | |
| verbose : bool | |
| If True, print each question, predicted answer, and correct answer. | |
| Returns | |
| ------- | |
| dict with keys: | |
| score : int — number correct | |
| total : int — 10 | |
| pct : float — score / total | |
| details : list[dict] — per-question results | |
| """ | |
| import torch | |
| correct = 0 | |
| details: List[Dict[str, Any]] = [] | |
| for i, q in enumerate(EXAM_QUESTIONS): | |
| prompt = _format_mcq(q) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are an AI assistant taking a clinical orthodontics exam. " | |
| "Answer each question with a single letter: A, B, C, or D." | |
| ), | |
| }, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| # Tokenize using chat template if available | |
| try: | |
| input_text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| except Exception: | |
| input_text = prompt | |
| inputs = tokenizer(input_text, return_tensors="pt") | |
| if hasattr(model, "device"): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=16, | |
| do_sample=False, | |
| temperature=1.0, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated = tokenizer.decode( | |
| out[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True | |
| ) | |
| predicted = _extract_answer(generated) | |
| is_correct = predicted == q["answer"] | |
| correct += int(is_correct) | |
| details.append( | |
| { | |
| "question_idx": i + 1, | |
| "category": q["category"], | |
| "predicted": predicted, | |
| "correct": q["answer"], | |
| "is_correct": is_correct, | |
| "raw_output": generated.strip()[:80], | |
| } | |
| ) | |
| if verbose: | |
| mark = "✓" if is_correct else "✗" | |
| print( | |
| f" Q{i + 1} ({q['category']:20s}) {mark} predicted={predicted} correct={q['answer']}" | |
| ) | |
| # Category breakdown | |
| by_cat: Dict[str, Dict[str, int]] = {} | |
| for d in details: | |
| cat = d["category"] | |
| by_cat.setdefault(cat, {"correct": 0, "total": 0}) | |
| by_cat[cat]["total"] += 1 | |
| by_cat[cat]["correct"] += int(d["is_correct"]) | |
| return { | |
| "score": correct, | |
| "total": len(EXAM_QUESTIONS), | |
| "pct": correct / len(EXAM_QUESTIONS), | |
| "details": details, | |
| "by_category": by_cat, | |
| } | |
| def run_exam_mock(rng_seed: int = 0) -> Dict[str, Any]: | |
| """ | |
| Mock exam runner for testing without a GPU model. | |
| Returns a result dict with random answers (useful for structural tests). | |
| """ | |
| import random | |
| rng = random.Random(rng_seed) | |
| correct = 0 | |
| details = [] | |
| for i, q in enumerate(EXAM_QUESTIONS): | |
| predicted = rng.choice(["A", "B", "C", "D"]) | |
| is_correct = predicted == q["answer"] | |
| correct += int(is_correct) | |
| details.append( | |
| { | |
| "question_idx": i + 1, | |
| "category": q["category"], | |
| "predicted": predicted, | |
| "correct": q["answer"], | |
| "is_correct": is_correct, | |
| "raw_output": predicted, | |
| } | |
| ) | |
| by_cat: Dict[str, Dict[str, int]] = {} | |
| for d in details: | |
| cat = d["category"] | |
| by_cat.setdefault(cat, {"correct": 0, "total": 0}) | |
| by_cat[cat]["total"] += 1 | |
| by_cat[cat]["correct"] += int(d["is_correct"]) | |
| return { | |
| "score": correct, | |
| "total": 10, | |
| "pct": correct / 10, | |
| "details": details, | |
| "by_category": by_cat, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Plotting | |
| # --------------------------------------------------------------------------- | |
| def plot_exam_curve( | |
| scores_by_step: Dict[int, int], | |
| output_path: str = "results/exam_curve.png", | |
| ) -> str: | |
| """ | |
| Plot exam score vs training step. | |
| Parameters | |
| ---------- | |
| scores_by_step : dict mapping step → score (0-10) | |
| output_path : where to save | |
| Returns | |
| ------- | |
| str : saved path | |
| """ | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| steps = sorted(scores_by_step.keys()) | |
| scores = [scores_by_step[s] for s in steps] | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| ax.plot(steps, scores, "bo-", linewidth=2, markersize=8, label="Exam score") | |
| ax.axhline(y=2.5, color="gray", linestyle=":", alpha=0.5, label="Random baseline (2.5/10)") | |
| ax.set_xlabel("Training Step", fontsize=12) | |
| ax.set_ylabel("Exam Score (/ 10)", fontsize=12) | |
| ax.set_title("Clinical Knowledge: Score vs Training Steps", fontsize=13, fontweight="bold") | |
| ax.set_ylim(0, 10.5) | |
| ax.legend(fontsize=10) | |
| ax.grid(True, alpha=0.3) | |
| # Annotate first and last points | |
| if steps: | |
| ax.annotate( | |
| f"{scores[0]}/10", | |
| (steps[0], scores[0]), | |
| textcoords="offset points", | |
| xytext=(5, 5), | |
| fontsize=9, | |
| ) | |
| ax.annotate( | |
| f"{scores[-1]}/10", | |
| (steps[-1], scores[-1]), | |
| textcoords="offset points", | |
| xytext=(5, 5), | |
| fontsize=9, | |
| ) | |
| os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) | |
| plt.tight_layout() | |
| plt.savefig(output_path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return os.path.abspath(output_path) | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def _print_result(result: Dict[str, Any], label: str = "") -> None: | |
| prefix = f"[{label}] " if label else "" | |
| print(f"{prefix}Score: {result['score']}/{result['total']} ({result['pct']:.0%})") | |
| print(f"{prefix}By category:") | |
| for cat, counts in result["by_category"].items(): | |
| pct = counts["correct"] / max(counts["total"], 1) | |
| print(f" {cat:25s}: {counts['correct']}/{counts['total']} ({pct:.0%})") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="OrthoRL Clinical Knowledge Exam (spec 2.2)") | |
| parser.add_argument("--model", default=None, help="Model ID or checkpoint path") | |
| parser.add_argument( | |
| "--mock", action="store_true", help="Run with random answers (no GPU required, for testing)" | |
| ) | |
| parser.add_argument("--verbose", action="store_true", help="Print per-question results") | |
| parser.add_argument( | |
| "--checkpoints", | |
| nargs="+", | |
| default=None, | |
| help="List of checkpoint paths for score-vs-step curve", | |
| ) | |
| parser.add_argument( | |
| "--output", default="results/exam_curve.png", help="Output path for exam curve plot" | |
| ) | |
| args = parser.parse_args() | |
| if args.mock: | |
| result = run_exam_mock() | |
| _print_result(result, label="mock") | |
| sys.exit(0) | |
| if args.checkpoints: | |
| # Multi-checkpoint mode: plot score curve | |
| try: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| except ImportError: | |
| print("ERROR: pip install transformers") | |
| sys.exit(1) | |
| scores_by_step: Dict[int, int] = {} | |
| for i, ckpt in enumerate(args.checkpoints): | |
| step = (i + 1) * 50 # assume 50-step intervals | |
| print(f"\nLoading checkpoint {ckpt} (step {step}) ...") | |
| try: | |
| tok = AutoTokenizer.from_pretrained(ckpt) | |
| mdl = AutoModelForCausalLM.from_pretrained(ckpt) | |
| res = run_exam(mdl, tok, verbose=args.verbose) | |
| scores_by_step[step] = res["score"] | |
| _print_result(res, label=f"step-{step}") | |
| except Exception as e: | |
| print(f" WARNING: failed to load {ckpt}: {e}") | |
| if scores_by_step: | |
| saved = plot_exam_curve(scores_by_step, args.output) | |
| print(f"\nExam curve saved: {saved}") | |
| sys.exit(0) | |
| if not args.model: | |
| print("Provide --model, --mock, or --checkpoints. Use --mock for a quick test.") | |
| sys.exit(1) | |
| # Single model mode | |
| try: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| except ImportError: | |
| print("ERROR: pip install transformers") | |
| sys.exit(1) | |
| print(f"Loading model: {args.model}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.model) | |
| model = AutoModelForCausalLM.from_pretrained(args.model) | |
| print(f"\nRunning {len(EXAM_QUESTIONS)}-question clinical exam ...") | |
| if args.verbose: | |
| print() | |
| result = run_exam(model, tokenizer, verbose=args.verbose) | |
| print() | |
| _print_result(result, label=args.model.split("/")[-1][:20]) | |