"""M6 — Neuron Ablation cross-run comparison figure. Reads experiments/runs/*/mechinterp/m6_neuron_ablation_*.json Plots head OOD acc vs K (top-K hospital-correlated neurons zeroed) for every run, grouped by condition (grokking vs standard). """ from __future__ import annotations import glob, json from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np ROOT = Path(__file__).resolve().parent.parent def main(): fig, axes = plt.subplots(1, 2, figsize=(14, 5)) by_cond = {"grokking": [], "standard": []} for f in sorted(glob.glob(str(ROOT / "experiments/runs/20260505-*/mechinterp/m6_neuron_ablation_*.json"))): rd = Path(f).parent.parent s = json.loads((rd / "results" / "summary.json").read_text()) d = json.loads(Path(f).read_text()) cond = s.get("condition") if cond not in by_cond: continue by_cond[cond].append({ "n": s.get("n_train"), "seed": s.get("seed"), "epoch": d["epoch"], "ks": [r["k"] for r in d["sweep"]], "head": [r["head_ood_acc"] for r in d["sweep"]], "tum": [r["tumor_probe"] for r in d["sweep"]], }) cmaps = {"grokking": plt.cm.Blues, "standard": plt.cm.Reds} for col, cond in enumerate(["grokking", "standard"]): ax = axes[col] runs = by_cond[cond] if not runs: ax.text(0.5, 0.5, f"no {cond} M6 data", ha="center", va="center", transform=ax.transAxes, color="gray", fontsize=11) continue cmap = cmaps[cond] for i, r in enumerate(runs): color = cmap(0.45 + 0.45 * i / max(1, len(runs) - 1)) ax.plot(r["ks"], r["head"], "-o", color=color, lw=2, ms=6, label=f"n={r['n']} s{r['seed']} ep{r['epoch']}") # Aggregate at the most-common K-grid common_ks = sorted(set.intersection(*[set(r["ks"]) for r in runs])) head_mat = np.array([ [next(h for k_, h in zip(r["ks"], r["head"]) if k_ == k) for k in common_ks] for r in runs ]) if len(runs) >= 2: mu = head_mat.mean(0); sd = head_mat.std(0) ax.plot(common_ks, mu, "k-", lw=2.5, label=f"mean (n={len(runs)} runs)") ax.fill_between(common_ks, mu - sd, mu + sd, color="black", alpha=0.15) ax.set_xlabel("K (top-K hospital-correlated neurons zeroed)") ax.set_ylabel("Head OOD (H4) accuracy") ax.set_xscale("symlog", linthresh=4) ax.set_title(f"{cond.upper()} — head OOD vs K\n" f"(decreasing K=0→K=256 measures shortcut-neuron causal weight)", fontweight="bold", fontsize=10) ax.legend(fontsize=8, ncol=1); ax.grid(alpha=0.3) fig.suptitle("M6 — Neuron-level Causal Ablation across all runs", fontsize=12, fontweight="bold", y=1.01) plt.tight_layout() out = ROOT / "paper_figures" / "figure_m6_neuron_comparison" fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight") fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") plt.close(fig) print(f" Saved {out}.png + .pdf") if __name__ == "__main__": main()