"""Compose the MI workshop's causal-intervention figure (M4 + M5 across runs). Reads M4 trajectory + M5 sweep JSONs from every run that has them, groups by condition (grokking vs standard), and produces a 2×2 layout: Top row — M4 ablation: head OOD acc (raw) vs (shortcut-ablated), trajectories across all available runs in each condition. Bottom row — M5 steering: head OOD acc as a function of α, one curve per run. Saves: paper_figures/figure_intervention_comparison.{png,pdf} """ from __future__ import annotations import glob import json from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np ROOT = Path(__file__).resolve().parent.parent def _load_summary(run_dir: Path) -> dict: p = run_dir / "results" / "summary.json" if not p.exists(): return {} return json.loads(p.read_text()) def _gather(): """Return dict: cond → list of (run_dir, m4_traj?, m5?).""" out = {"grokking": [], "standard": []} run_dirs = sorted(list(ROOT.glob("experiments/runs/20260505-*/")) + list(ROOT.glob("experiments/runs/20260508-*/"))) for run_dir in run_dirs: s = _load_summary(run_dir) cond = s.get("condition") if cond not in out: continue m4_path = run_dir / "mechinterp" / "m4_ablation_avgpool_trajectory.json" m5_glob = list((run_dir / "mechinterp").glob("m5_steering_*.json")) m4 = json.loads(m4_path.read_text()) if m4_path.exists() else None m5 = json.loads(m5_glob[0].read_text()) if m5_glob else None if m4 is None and m5 is None: continue out[cond].append({ "run_dir": run_dir.name, "n_train": s.get("n_train"), "seed": s.get("seed"), "best_ood": s.get("best_ood"), "m4": m4, "m5": m5, }) return out def main(): data = _gather() n_grok = len(data["grokking"]) n_std = len(data["standard"]) print(f"Found {n_grok} grokking runs and {n_std} standard runs with intervention data") fig, axes = plt.subplots(2, 2, figsize=(15, 9)) cmap_grok = plt.cm.Blues cmap_std = plt.cm.Reds # ─────────────── M4 ablation trajectories ─────────────── for col, cond in enumerate(["grokking", "standard"]): ax = axes[0][col] runs = data[cond] cmap = cmap_grok if cond == "grokking" else cmap_std if not runs: ax.text(0.5, 0.5, f"no {cond} M4 data yet", ha="center", va="center", transform=ax.transAxes, color="gray", fontsize=11) ax.set_title(f"{cond.upper()}: head OOD raw vs ablated", fontweight="bold") continue # collect trajectories all_eps = sorted(set( r["epoch"] for run in runs if run["m4"] for r in run["m4"] )) for i, run in enumerate([r for r in runs if r["m4"]]): color = cmap(0.4 + 0.5 * i / max(1, len(runs) - 1)) traj = run["m4"] eps = [r["epoch"] for r in traj] raw = [r["head_ood_acc_raw"] for r in traj] abl = [r["head_ood_acc_ablated"] for r in traj] label_n = f"n={run['n_train']} s{run['seed']}" ax.plot(eps, raw, "-", color=color, alpha=0.8, lw=1.6, label=f"{label_n} raw") ax.plot(eps, abl, "--", color=color, alpha=0.8, lw=1.6, label=f"{label_n} ablated") ax.set_xlabel("Training epoch"); ax.set_ylabel("Head OOD (H4) accuracy") ax.set_title(f"M4 ablation — {cond.upper()}\n" f"raw (—) vs shortcut-ablated (- -)", fontweight="bold", fontsize=10) ax.legend(fontsize=7, ncol=2); ax.grid(alpha=0.3) ax.set_ylim(0.30, 0.85) # ─────────────── M5 steering sweeps ─────────────── for col, cond in enumerate(["grokking", "standard"]): ax = axes[1][col] runs = data[cond] cmap = cmap_grok if cond == "grokking" else cmap_std m5_runs = [r for r in runs if r["m5"]] if not m5_runs: ax.text(0.5, 0.5, f"no {cond} M5 data yet", ha="center", va="center", transform=ax.transAxes, color="gray", fontsize=11) ax.set_title(f"{cond.upper()}: head OOD vs steering α", fontweight="bold") continue for i, run in enumerate(m5_runs): color = cmap(0.4 + 0.5 * i / max(1, len(m5_runs) - 1)) m5 = run["m5"] alphas = [r["alpha"] for r in m5["sweep"]] heads = [r["head_ood_acc"] for r in m5["sweep"]] ax.plot(alphas, heads, "-o", color=color, lw=2, ms=6, label=f"n={run['n_train']} s{run['seed']} ep{m5['epoch']}") ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5) ax.set_xlabel("Steering coefficient α (σ-units of v_s)") ax.set_ylabel("Head OOD (H4) accuracy") ax.set_title(f"M5 steering — {cond.upper()}\n" f"α=0 baseline; |α|↑ = stronger shortcut activation", fontweight="bold", fontsize=10) ax.legend(fontsize=8); ax.grid(alpha=0.3) ax.set_ylim(0.45, 0.80) fig.suptitle( "Causal Interventions on the Shortcut Subspace (avgpool, ResNet-18, Camelyon17)\n" "M4 = ablate-and-evaluate, M5 = steer-and-evaluate", fontsize=12, fontweight="bold", y=1.005, ) plt.tight_layout() out = ROOT / "paper_figures" / "figure_intervention_comparison" fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight") fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") plt.close(fig) print(f"Saved {out}.png + .pdf") if __name__ == "__main__": main()