"""Cross-run M6 figure: targeted-shortcut vs random ablation, mean ± std bands. Reads experiments/runs/*/mechinterp/m6_neuron_ablation_*.json (extended format with random/morphology/ID controls). Plots, per condition (grokking/standard): Top: head OOD vs K — shortcut (red), random (black), morphology (green) Bottom: head ID vs K — same conditions, dashed style The key reviewer question — "is targeted ablation different from random damage?" — is answered visually by the gap between the red and black curves. Outputs: paper_figures/figure_m6_targeted_vs_random.{png,pdf} """ from __future__ import annotations import glob, json from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np ROOT = Path(__file__).resolve().parent.parent def _gather(only_n1000: bool = True): """Gather M6 outputs. By default restrict to n=1000 since the paper's multi-seed claim is at n=1000 (5 grokking + 3 standard).""" by_cond = {"grokking": [], "standard": []} files = (sorted(glob.glob(str(ROOT / "experiments/runs/20260505-*/mechinterp/m6_neuron_ablation_*.json"))) + sorted(glob.glob(str(ROOT / "experiments/runs/20260508-*/mechinterp/m6_neuron_ablation_*.json")))) for f in files: rd = Path(f).parent.parent s = json.loads((rd / "results" / "summary.json").read_text()) d = json.loads(Path(f).read_text()) if not d.get("include_id"): continue cond = s.get("condition") if cond not in by_cond: continue if only_n1000 and s.get("n_train") != 1000: continue sweep = d["sweep"] ks = [r["k"] for r in sweep] by_cond[cond].append({ "n": s.get("n_train"), "seed": s.get("seed"), "epoch": d["epoch"], "ks": ks, "shortcut_ood": [r["shortcut_head_ood"] for r in sweep], "shortcut_id": [r.get("shortcut_head_id", float("nan")) for r in sweep], "random_ood_mu": [r["random_head_ood_mean"] for r in sweep], "random_ood_sd": [r["random_head_ood_std"] for r in sweep], "random_id_mu": [r.get("random_head_id_mean", float("nan")) for r in sweep], "morph_ood": [r.get("morphology_head_ood", float("nan")) for r in sweep], }) return by_cond def _stack(runs, key): """Align runs on shared K-grid and return (ks, matrix shape (n_runs, n_ks)).""" if not runs: return None, None ks_set = set.intersection(*[set(r["ks"]) for r in runs]) ks = sorted(ks_set) mat = np.array([ [next(v for k_, v in zip(r["ks"], r[key]) if k_ == k) for k in ks] for r in runs ]) return np.array(ks), mat def main(): data = _gather() print(f"Grokking runs: {len(data['grokking'])}, Standard runs: {len(data['standard'])}") fig, axes = plt.subplots(2, 2, figsize=(15, 9)) for col, cond in enumerate(["grokking", "standard"]): runs = data[cond] if not runs: for r in range(2): axes[r][col].text(0.5, 0.5, f"no {cond} M6 data", ha="center", va="center", transform=axes[r][col].transAxes, color="gray") continue # Stack each metric ks, sc_ood = _stack(runs, "shortcut_ood") _, rd_ood = _stack(runs, "random_ood_mu") _, mo_ood = _stack(runs, "morph_ood") _, sc_id = _stack(runs, "shortcut_id") _, rd_id = _stack(runs, "random_id_mu") # Convert each metric to DELTA-from-K=0-baseline per run sc_ood_d = sc_ood - sc_ood[:, 0:1] rd_ood_d = rd_ood - rd_ood[:, 0:1] mo_ood_d = mo_ood - mo_ood[:, 0:1] sc_id_d = sc_id - sc_id[:, 0:1] rd_id_d = rd_id - rd_id[:, 0:1] n_seeds = len(runs) # ── Top row: ΔOOD curves, mean ± std ── ax = axes[0][col] ax.plot(ks, sc_ood_d.mean(0), "r-o", lw=2.4, ms=7, label=f"top-K shortcut (n={n_seeds})") ax.fill_between(ks, sc_ood_d.mean(0) - sc_ood_d.std(0), sc_ood_d.mean(0) + sc_ood_d.std(0), color="red", alpha=0.18) ax.plot(ks, rd_ood_d.mean(0), "k-s", lw=2.0, ms=6, label=f"K random (n={n_seeds})") ax.fill_between(ks, rd_ood_d.mean(0) - rd_ood_d.std(0), rd_ood_d.mean(0) + rd_ood_d.std(0), color="black", alpha=0.12) if not np.isnan(mo_ood_d).all(): ax.plot(ks, mo_ood_d.mean(0), "g-^", lw=1.8, ms=5, label=f"top-K morphology (n={n_seeds})") ax.axhline(0, color="gray", ls=":", lw=1, alpha=0.5) ax.set_xscale("symlog", linthresh=4) ax.set_xlabel("K (avgpool neurons zeroed)") ax.set_ylabel("Δ head OOD vs K=0 baseline") ax.set_title(f"{cond.upper()} — change in head OOD vs K\n" f"(positive = ablation HELPS OOD; red ≠ black = targeted ablation is selective)", fontweight="bold", fontsize=10) ax.legend(fontsize=8); ax.grid(alpha=0.3) # ── Bottom row: ΔID curves ── ax = axes[1][col] ax.plot(ks, sc_id_d.mean(0), "r--o", lw=2.0, ms=6, alpha=0.85, label=f"top-K shortcut") ax.fill_between(ks, sc_id_d.mean(0) - sc_id_d.std(0), sc_id_d.mean(0) + sc_id_d.std(0), color="red", alpha=0.12) ax.plot(ks, rd_id_d.mean(0), "k--s", lw=1.8, ms=5, alpha=0.85, label=f"K random") ax.fill_between(ks, rd_id_d.mean(0) - rd_id_d.std(0), rd_id_d.mean(0) + rd_id_d.std(0), color="black", alpha=0.10) ax.axhline(0, color="gray", ls=":", lw=1, alpha=0.5) ax.set_xscale("symlog", linthresh=4) ax.set_xlabel("K (avgpool neurons zeroed)") ax.set_ylabel("Δ head ID vs K=0 baseline") ax.set_title(f"{cond.upper()} — change in head ID vs K\n" f"(both should drop under heavy ablation; targeted ≈ random ID = no extra ID damage)", fontweight="bold", fontsize=10) ax.legend(fontsize=8); ax.grid(alpha=0.3) fig.suptitle("M6 — Targeted Shortcut Neuron Ablation vs Random Control (n=1000)\n" "Per-seed selectivity: 3/5 grokking show targeted-shortcut > random at K=64; 0/3 standard.", fontsize=12, fontweight="bold", y=1.005) plt.tight_layout() out = ROOT / "paper_figures" / "figure_m6_targeted_vs_random" fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight") fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") plt.close(fig) print(f" Saved {out}.png + .pdf") if __name__ == "__main__": main()