"""Multi-seed mean ± std figure for M4 and M5 across the 3 n=1000 grokking seeds (s42, s123, s456) plus single-seed standard baselines as reference lines. Reads: experiments/runs//mechinterp/m4_ablation_avgpool_trajectory.json experiments/runs//mechinterp/m5_steering_*.json Outputs: paper_figures/figure_multiseed_intervention.{png,pdf} """ from __future__ import annotations import glob import json from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np ROOT = Path(__file__).resolve().parent.parent def _load_summary(run_dir: Path) -> dict: p = run_dir / "results" / "summary.json" return json.loads(p.read_text()) if p.exists() else {} def _gather(): """Return per-(cond, n) → list of run dicts with M4 traj + M5 sweep.""" out = {} # Use both 2026-05-05 (initial 7) and 2026-05-08 (4 new) runs run_dirs = sorted(list(ROOT.glob("experiments/runs/20260505-*/")) + list(ROOT.glob("experiments/runs/20260508-*/"))) for run_dir in run_dirs: s = _load_summary(run_dir) cond = s.get("condition") n = s.get("n_train") if cond is None or n is None: continue m4_p = run_dir / "mechinterp" / "m4_ablation_avgpool_trajectory.json" m5_p = list((run_dir / "mechinterp").glob("m5_steering_*.json")) m4 = json.loads(m4_p.read_text()) if m4_p.exists() else None m5 = json.loads(m5_p[0].read_text()) if m5_p else None if m4 is None and m5 is None: continue out.setdefault((cond, n), []).append({ "run_id": run_dir.name, "seed": s.get("seed"), "best_ood": s.get("best_ood"), "m4": m4, "m5": m5, }) return out def _stack_m4(runs): """Align M4 trajectories on shared epochs and return: (eps, raw_mat, abl_mat, delta_mat) — each (n_runs, n_eps).""" if not runs: return None, None, None, None eps_set = set.intersection(*[ {r["epoch"] for r in run["m4"]} for run in runs if run["m4"] ]) eps = sorted(eps_set) raw = np.array([ [next(r for r in run["m4"] if r["epoch"] == e)["head_ood_acc_raw"] for e in eps] for run in runs ]) abl = np.array([ [next(r for r in run["m4"] if r["epoch"] == e)["head_ood_acc_ablated"] for e in eps] for run in runs ]) return np.array(eps), raw, abl, abl - raw def _stack_m5(runs): """Align M5 sweeps on shared α and return (alphas, head_mat, hosp_mat, tum_mat).""" runs = [run for run in runs if run["m5"]] if not runs: return None, None, None, None alphas_set = set.intersection(*[ {r["alpha"] for r in run["m5"]["sweep"]} for run in runs ]) alphas = sorted(alphas_set) head = np.array([ [next(r for r in run["m5"]["sweep"] if r["alpha"] == a)["head_ood_acc"] for a in alphas] for run in runs ]) return np.array(alphas), head def main(): data = _gather() n1000_grok = data.get(("grokking", 1000), []) n1000_std = data.get(("standard", 1000), []) print(f" grokking n=1000 seeds: {[r['seed'] for r in n1000_grok]}") print(f" standard n=1000 seeds: {[r['seed'] for r in n1000_std]}") if len(n1000_grok) < 2: print(" Need ≥2 grokking n=1000 seeds; aborting.") return fig, axes = plt.subplots(1, 2, figsize=(14, 5.5)) # ─────────────── Panel A: M4 trajectory mean ± std ─────────────── ax = axes[0] eps, raw_g, abl_g, _ = _stack_m4(n1000_grok) if eps is not None: n_seeds_g = len(n1000_grok) raw_mu, raw_sd = raw_g.mean(0), raw_g.std(0) abl_mu, abl_sd = abl_g.mean(0), abl_g.std(0) ax.plot(eps, raw_mu, "-", color="navy", lw=2.2, label=f"grok raw (n={n_seeds_g} seeds)") ax.fill_between(eps, raw_mu - raw_sd, raw_mu + raw_sd, color="navy", alpha=0.18) ax.plot(eps, abl_mu, "--", color="darkorange", lw=2.2, label=f"grok ablated (n={n_seeds_g} seeds)") ax.fill_between(eps, abl_mu - abl_sd, abl_mu + abl_sd, color="darkorange", alpha=0.18) if n1000_std: eps_s, raw_s, abl_s, _ = _stack_m4(n1000_std) if eps_s is not None: ax.plot(eps_s, raw_s.mean(0), "-", color="darkred", lw=1.6, alpha=0.9, label=f"standard raw (n={len(n1000_std)} seed)") ax.plot(eps_s, abl_s.mean(0), "--", color="darkred", lw=1.6, alpha=0.9, label=f"standard ablated (n={len(n1000_std)} seed)") ax.set_xlabel("Training epoch") ax.set_ylabel("Head OOD (H4) accuracy") ax.set_title("M4 — Shortcut subspace ablation across training\n" "(mean ± std over seeds for n=1000)", fontweight="bold", fontsize=10) ax.legend(fontsize=8, loc="lower right"); ax.grid(alpha=0.3) ax.set_ylim(0.42, 0.80) # ─────────────── Panel B: M5 steering mean ± std ─────────────── ax = axes[1] a_g, head_g = _stack_m5(n1000_grok) if a_g is not None and len(n1000_grok) >= 2: mu = head_g.mean(0); sd = head_g.std(0) ax.plot(a_g, mu, "-o", color="navy", lw=2.2, ms=7, label=f"grokking n=1000 (n={len(n1000_grok)} seeds)") ax.fill_between(a_g, mu - sd, mu + sd, color="navy", alpha=0.20) if n1000_std: a_s, head_s = _stack_m5(n1000_std) if a_s is not None: ax.plot(a_s, head_s.mean(0), "-s", color="darkred", lw=1.8, ms=6, alpha=0.9, label=f"standard n=1000 (n={len(n1000_std)} seed)") ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5) ax.set_xlabel("Steering coefficient α (σ-units of v_s)") ax.set_ylabel("Head OOD (H4) accuracy") ax.set_title("M5 — Activation steering response\n" "(mean ± std over seeds for n=1000)", fontweight="bold", fontsize=10) ax.legend(fontsize=8); ax.grid(alpha=0.3) fig.suptitle("Multi-seed Causal-Intervention Robustness (n=1000, 3 seeds for grokking)", fontsize=12, fontweight="bold", y=1.01) plt.tight_layout() out = ROOT / "paper_figures" / "figure_multiseed_intervention" fig.savefig(out.with_suffix(".png"), dpi=180, bbox_inches="tight") fig.savefig(out.with_suffix(".pdf"), bbox_inches="tight") plt.close(fig) print(f" Saved {out}.png + .pdf") if __name__ == "__main__": main()