"""Generate MI workshop Figure 1: grokking vs standard probe heatmap. Reads M1 probe outputs from experiments/runs/*/mechinterp/m1_probe_data.json, produces a 2x2 grid (rows = hospital/tumor probe; cols = grokking/standard) with epoch-x-layer heatmaps. Hospital = Reds (want fading); Tumor = Greens (want rising). Picks the strongest grokking run by best_ood and the standard control with periodic checkpoints (or final.pt if that's all that's available). Usage: python -m experiments.figure_mi_comparison \ [--grok-run experiments/runs/] [--std-run experiments/runs/] \ [--out paper_figures/figure1_MI_probe_comparison] """ from __future__ import annotations import argparse import glob import json import os import sys from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np ROOT = Path(__file__).resolve().parent.parent def _load_summary(run_dir: Path) -> dict: p = run_dir / "results" / "summary.json" if not p.exists(): return {} try: return json.loads(p.read_text()) except Exception: return {} def _pick_best_grok_run() -> Path | None: candidates = [] for f in glob.glob(str(ROOT / "experiments/runs/*/mechinterp/m1_probe_data.json")): run_dir = Path(f).parent.parent s = _load_summary(run_dir) if s.get("condition") != "grokking": continue best = s.get("best_ood", 0) or 0 candidates.append((best, run_dir)) if not candidates: return None candidates.sort(reverse=True) return candidates[0][1] def _pick_std_run() -> Path | None: candidates = [] for f in glob.glob(str(ROOT / "experiments/runs/*/mechinterp/m1_probe_data.json")): run_dir = Path(f).parent.parent s = _load_summary(run_dir) if s.get("condition") != "standard": continue best = s.get("best_ood", s.get("ood_test_acc", 0)) or 0 candidates.append((best, run_dir)) if not candidates: return None candidates.sort(reverse=True) return candidates[0][1] def _load_probe(run_dir: Path) -> dict | None: p = run_dir / "mechinterp" / "m1_probe_data.json" if not p.exists(): return None return json.loads(p.read_text()) def _heatmap(ax, data: dict, key: str, title: str, cmap: str): if data is None: ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes) ax.set_xticks([]) ax.set_yticks([]) ax.set_title(title, fontweight="bold", fontsize=10) return None epochs = data["epochs"] layers = data["layers"] mat = np.array(data[key]) # shape (n_epochs, n_layers) if mat.ndim != 2: mat = mat.reshape(len(epochs), len(layers)) im = ax.imshow( mat.T, aspect="auto", cmap=cmap, vmin=0.0, vmax=1.0, interpolation="nearest", origin="lower", ) ax.set_xticks(range(len(epochs))) ax.set_xticklabels(epochs, rotation=45, ha="right", fontsize=7) ax.set_yticks(range(len(layers))) ax.set_yticklabels(layers, fontsize=8) ax.set_xlabel("Epoch", fontsize=9) ax.set_title(title, fontweight="bold", fontsize=10) return im def make_figure(grok_dir: Path, std_dir: Path | None, out_base: Path): grok = _load_probe(grok_dir) std = _load_probe(std_dir) if std_dir else None if grok is None: print(f"[ERROR] no probe data at {grok_dir}/mechinterp/m1_probe_data.json") sys.exit(1) # Hospital probe on H3 (held-in held-out hospital) — has signal. # H4 version is degenerate (probe class set excludes H4 by construction → ≡ 0). hosp_key = "hospital_probe_id" if "hospital_probe_id" in grok else "hospital_probe" # Tumor probe on H4 (truly OOD hospital) — measures causal-feature transfer. tumor_key = "tumor_probe_ood" if "tumor_probe_ood" in grok else "tumor_probe" fig, axes = plt.subplots(2, 2, figsize=(14, 10)) grok_title = f"Grokking ({grok_dir.name[-30:]})" std_title = f"Standard ({std_dir.name[-30:]})" if std_dir else "Standard (no data)" im00 = _heatmap(axes[0][0], grok, hosp_key, f"{grok_title}\nHospital probe on H3 (shortcut recoverability, ↓ good)", "Reds") im01 = _heatmap(axes[0][1], std, hosp_key, f"{std_title}\nHospital probe on H3 (shortcut recoverability, ↓ good)", "Reds") im10 = _heatmap(axes[1][0], grok, tumor_key, f"{grok_title}\nTumor probe on H4 (causal transfer, ↑ good)", "Greens") im11 = _heatmap(axes[1][1], std, tumor_key, f"{std_title}\nTumor probe on H4 (causal transfer, ↑ good)", "Greens") for im, ax in [(im00, axes[0][0]), (im01, axes[0][1]), (im10, axes[1][0]), (im11, axes[1][1])]: if im is not None: plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02) fig.suptitle( "Figure 1 — Layer-wise circuit analysis: grokking-favorable vs standard training\n" "Grokking: deep-layer hospital recoverability (Reds) drops over training while tumor recoverability (Greens) is preserved.\n" "Standard: no localized scrubbing of hospital information.", fontsize=11, y=1.005, fontweight="bold", ) plt.tight_layout() out_base.parent.mkdir(parents=True, exist_ok=True) png = out_base.with_suffix(".png") pdf = out_base.with_suffix(".pdf") fig.savefig(png, bbox_inches="tight", dpi=200) fig.savefig(pdf, bbox_inches="tight") plt.close(fig) print(f"Saved {png}") print(f"Saved {pdf}") def main(): ap = argparse.ArgumentParser() ap.add_argument("--grok-run", default=None, help="Path to grokking run dir; auto-pick best if omitted") ap.add_argument("--std-run", default=None, help="Path to standard run dir; auto-pick best if omitted") ap.add_argument("--out", default="paper_figures/figure1_MI_probe_comparison") args = ap.parse_args() grok_dir = Path(args.grok_run) if args.grok_run else _pick_best_grok_run() std_dir = Path(args.std_run) if args.std_run else _pick_std_run() if grok_dir is None: print("[ERROR] No grokking run with M1 probe data found.") print(" Run experiments/mechinterp_m1.py on a grokking run first.") sys.exit(2) print(f"Grokking run : {grok_dir}") print(f"Standard run : {std_dir if std_dir else '(none — figure will show only grokking)'}") out_base = ROOT / args.out make_figure(grok_dir, std_dir, out_base) if __name__ == "__main__": main()