""" CausalGrok — Paper Figure Generator Reads every experiments/runs//results/history.json on disk and produces: paper_figures/figure1_smoking_gun.png|pdf ← IRM penalty + val acc paper_figures/figure2_mechanisms.png ← weight norm + feature rank paper_figures/figure3_shortcut.png ← shortcut ratio over training paper_figures/table1_ablations.csv ← summary across runs Per-run figures are also saved into experiments/runs//figures/. Run after experiments complete: bash scripts/plot_all.sh # or: python -m experiments.plot_results """ from __future__ import annotations import argparse import glob import json import os from typing import Dict, List import matplotlib import matplotlib.pyplot as plt import pandas as pd from utils.run_dir import DEFAULT_BASE matplotlib.rcParams.update({"font.size": 12, "figure.dpi": 150}) # ────────────────────────────────────────────── # LOADING # ────────────────────────────────────────────── def discover_runs(runs_dir: str = DEFAULT_BASE) -> List[Dict]: """One record per run that has a history.json.""" runs = [] for run_dir in sorted(glob.glob(os.path.join(runs_dir, "*"))): hist_path = os.path.join(run_dir, "results", "history.json") cfg_path = os.path.join(run_dir, "config.json") if not os.path.isfile(hist_path): continue try: df = pd.DataFrame(json.load(open(hist_path))) except Exception: continue # Normalize column names for v1 vs v2 compatibility # v1 uses: val_acc, train_acc # v2 uses: id_val_acc, ood_acc, train_acc if "id_val_acc" in df.columns and "val_acc" not in df.columns: df = df.rename(columns={"id_val_acc": "val_acc"}) cfg = json.load(open(cfg_path)) if os.path.isfile(cfg_path) else {} runs.append(dict(run_dir=run_dir, df=df, cfg=cfg, run_id=os.path.basename(run_dir))) return runs def average_by_condition(runs: List[Dict]) -> Dict[str, pd.DataFrame]: """ Group runs by (condition, n_train) so we never average across incompatible dataset sizes. Returned key is "_n". """ grouped: Dict[tuple, List[pd.DataFrame]] = {} for r in runs: cond = r["cfg"].get("condition") if cond is None: cond = "grokking" if "grokking" in r["run_id"] else "standard" n_train = r["cfg"].get("n_train", 0) grouped.setdefault((cond, n_train), []).append(r["df"]) out: Dict[str, pd.DataFrame] = {} for (cond, n), dfs in grouped.items(): merged = pd.concat(dfs, ignore_index=True) numeric_cols = [c for c in merged.columns if c != "epoch" and pd.api.types.is_numeric_dtype(merged[c])] out[f"{cond}_n{n}"] = merged.groupby("epoch")[numeric_cols].mean().reset_index() return out def pick_headline_curves(data: Dict[str, pd.DataFrame]): """ Pick one grokking curve and one standard curve for the headline figure. Heuristic: prefer n=500 (the canonical small-data regime for this paper); otherwise fall back to the smallest n_train available. Large-dataset runs grok fast and the plateau disappears, washing out the visual story. """ def best(cond_prefix): keys = [k for k in data if k.startswith(f"{cond_prefix}_n")] if not keys: return None target = f"{cond_prefix}_n500" if target in keys: return target keys.sort(key=lambda k: int(k.split("_n")[-1])) return keys[0] return best("grokking"), best("standard") # ────────────────────────────────────────────── # FIGURE 1 — THE SMOKING GUN # ────────────────────────────────────────────── def figure1_smoking_gun(data: Dict[str, pd.DataFrame], save_dir: str): fig, axes = plt.subplots(1, 2, figsize=(14, 5)) grok_key, std_key = pick_headline_curves(data) panels = [ (axes[0], grok_key, "#2563EB", f"Grokking-Favorable Training\n({grok_key or 'no data'})"), (axes[1], std_key, "#DC2626", f"Standard Training\n({std_key or 'no data'})"), ] for ax, cond, color, title in panels: if cond is None or cond not in data: ax.text(0.5, 0.5, f"No {cond} data yet", ha="center", va="center", transform=ax.transAxes) ax.set_title(title, fontweight="bold") continue df = data[cond] ax2 = ax.twinx() ax.plot(df["epoch"], df["val_acc"], color=color, lw=2.5, label="ID Val Accuracy (H3)", zorder=3) # For v2 runs: also show OOD accuracy (the actual grokking signal) if "ood_acc" in df.columns: ax.plot(df["epoch"], df["ood_acc"], color=color, lw=2.5, ls="--", alpha=0.7, label="OOD Accuracy (H4)", zorder=3) ax2.plot(df["epoch"], df["irm_mean"], color="#F59E0B", lw=2, ls="--", label="IRM Penalty ↓", zorder=2) if "grokking_detected" in df.columns: grok = df[df["grokking_detected"].astype(bool)] if len(grok): ep = int(grok["epoch"].min()) ax.axvline(ep, color="gray", ls=":", lw=1.5) ax.annotate(f"Grokking\nep.{ep}", xy=(ep, 0.5), xytext=(ep + ep * 0.05, 0.3), fontsize=9, color="gray", arrowprops=dict(arrowstyle="->", color="gray")) ax.set_xlabel("Epoch") ax.set_ylabel("Val Accuracy", color=color) ax2.set_ylabel("IRM Penalty (↓ = causal)", color="#F59E0B") ax.set_title(title, fontweight="bold") ax.tick_params(axis="y", labelcolor=color) ax2.tick_params(axis="y", labelcolor="#F59E0B") ax.set_ylim([0, 1.05]) ax.grid(alpha=0.3) h1, l1 = ax.get_legend_handles_labels() h2, l2 = ax2.get_legend_handles_labels() ax.legend(h1 + h2, l1 + l2, loc="center left", fontsize=9) fig.suptitle( "Figure 1 — IRM Invariance Penalty Drops at the Grokking Transition\n" "Causal feature discovery and delayed generalization are the same event", fontsize=12, y=1.02 ) plt.tight_layout() plt.savefig(os.path.join(save_dir, "figure1_smoking_gun.png"), bbox_inches="tight") plt.savefig(os.path.join(save_dir, "figure1_smoking_gun.pdf"), bbox_inches="tight") print(" Figure 1 saved") plt.close() def figure2_mechanisms(data: Dict[str, pd.DataFrame], save_dir: str): grok_key, _ = pick_headline_curves(data) if grok_key is None: print(" Skipping Figure 2 (no grokking data)") return df = data[grok_key] fig, ax1 = plt.subplots(figsize=(10, 5)) ax2 = ax1.twinx() ax3 = ax1.twinx() ax3.spines["right"].set_position(("outward", 60)) ax1.plot(df["epoch"], df["val_acc"], "#2563EB", lw=2.5, label="Val Acc") ax2.plot(df["epoch"], df["weight_norm"], "#10B981", lw=2, ls="--", label="Weight Norm ‖W‖") ax3.plot(df["epoch"], df["feature_rank"], "#F59E0B", lw=2, ls="-.", label="Feature Rank") ax1.set_xlabel("Epoch"); ax1.set_ylabel("Val Accuracy", color="#2563EB") ax2.set_ylabel("Weight Norm", color="#10B981") ax3.set_ylabel("Feature Rank", color="#F59E0B") ax1.tick_params(axis="y", labelcolor="#2563EB") ax2.tick_params(axis="y", labelcolor="#10B981") ax3.tick_params(axis="y", labelcolor="#F59E0B") handles = (ax1.get_legend_handles_labels()[0] + ax2.get_legend_handles_labels()[0] + ax3.get_legend_handles_labels()[0]) labels = (ax1.get_legend_handles_labels()[1] + ax2.get_legend_handles_labels()[1] + ax3.get_legend_handles_labels()[1]) ax1.legend(handles, labels, loc="center left", fontsize=9) ax1.set_title( "Figure 2 — Training Dynamics: Weight Norm + Feature Rank as Progress Measures", fontweight="bold") ax1.grid(alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(save_dir, "figure2_mechanisms.png"), bbox_inches="tight") print(" Figure 2 saved") plt.close() def figure3_shortcut(data: Dict[str, pd.DataFrame], save_dir: str): grok_key, _ = pick_headline_curves(data) if grok_key is None: print(" Skipping Figure 3 (no grokking data)") return df = data[grok_key] fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(df["epoch"], df["center_conf"], "#2563EB", lw=2, label="Center (anatomy) confidence") ax.plot(df["epoch"], df["border_conf"], "#DC2626", lw=2, ls="--", label="Border (artifact) confidence") ax.plot(df["epoch"], df["shortcut_ratio"], "#F59E0B", lw=2, ls="-.", label="Shortcut ratio (border/center)") ax.axhline(1.0, color="gray", ls=":", lw=1, alpha=0.7, label="Ratio = 1 (equal reliance)") ax.set_xlabel("Epoch"); ax.set_ylabel("Confidence / Ratio") ax.set_title( "Figure 3 — Shortcut Reliance: Model shifts from artifacts to anatomy at grokking", fontweight="bold") ax.legend(fontsize=10); ax.grid(alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(save_dir, "figure3_shortcut.png"), bbox_inches="tight") print(" Figure 3 saved") plt.close() def table1_ablations(runs: List[Dict], save_dir: str): rows = [] for r in runs: df = r["df"] if df.empty: continue if "grokking_detected" in df: grok_rows = df[df["grokking_detected"].astype(bool)] else: grok_rows = df.iloc[:0] irm0 = df["irm_mean"].iloc[0] if "irm_mean" in df else float("nan") irm_min = df["irm_mean"].min() if "irm_mean" in df else float("nan") # Co-movement: compare the epoch where val_acc jumped (grokking # transition) vs. the epoch where IRM dropped fastest. Small gap # ⇒ same event ⇒ paper's central claim. Large gap ⇒ separate # events ⇒ weaker, lagged claim. irm_drop_ep = -1 if "irm_mean" in df and len(df) > 1: irm_delta = df["irm_mean"].diff().abs() if irm_delta.notna().any(): irm_drop_ep = int(df.loc[irm_delta.idxmax(), "epoch"]) grok_ep = int(grok_rows["epoch"].min()) if len(grok_rows) else -1 epoch_gap = abs(grok_ep - irm_drop_ep) if grok_ep > 0 and irm_drop_ep > 0 else -1 rows.append({ "run_id": r["run_id"], "condition": r["cfg"].get("condition", ""), "n_train": r["cfg"].get("n_train"), "seed": r["cfg"].get("seed"), "best_val_acc": df["val_acc"].max() if "val_acc" in df else float("nan"), "grokking_epoch": grok_ep, "irm_drop_epoch": irm_drop_ep, "epoch_gap": epoch_gap, "irm_drop_pct": (irm0 - irm_min) / (irm0 + 1e-8) * 100, "final_shortcut_ratio": df["shortcut_ratio"].iloc[-1] if "shortcut_ratio" in df else float("nan"), "run_dir": r["run_dir"], }) if not rows: print(" No runs to summarize.") return table = pd.DataFrame(rows).sort_values("best_val_acc", ascending=False) out_path = os.path.join(save_dir, "table1_ablations.csv") table.to_csv(out_path, index=False) print(f"\nTable 1 ({len(table)} runs):") print(table.to_string(index=False)) print(f"\n Saved → {out_path}") def per_run_figure(r: Dict): df = r["df"] if df.empty: return out = os.path.join(r["run_dir"], "figures", "training_curves.png") fig, ax = plt.subplots(figsize=(9, 4.5)) ax2 = ax.twinx() ax.plot(df["epoch"], df["val_acc"], "#2563EB", lw=2, label="Val Acc") ax.plot(df["epoch"], df["train_acc"], "#9CA3AF", lw=1, ls=":", label="Train Acc") ax2.plot(df["epoch"], df["irm_mean"], "#F59E0B", lw=2, ls="--", label="IRM") ax.set_xlabel("Epoch"); ax.set_ylabel("Accuracy") ax2.set_ylabel("IRM penalty") ax.set_title(r["run_id"], fontsize=10) ax.grid(alpha=0.3) h1, l1 = ax.get_legend_handles_labels() h2, l2 = ax2.get_legend_handles_labels() ax.legend(h1 + h2, l1 + l2, loc="center left", fontsize=8) plt.tight_layout() plt.savefig(out, bbox_inches="tight") plt.close() def main(): p = argparse.ArgumentParser() p.add_argument("--runs_dir", default=DEFAULT_BASE) p.add_argument("--save_dir", default="paper_figures") args = p.parse_args() os.makedirs(args.save_dir, exist_ok=True) runs = discover_runs(args.runs_dir) print(f"Found {len(runs)} runs in {args.runs_dir}/") if not runs: return for r in runs: per_run_figure(r) data = average_by_condition(runs) print(f"Conditions averaged: {sorted(data.keys())}") figure1_smoking_gun(data, args.save_dir) figure2_mechanisms(data, args.save_dir) figure3_shortcut(data, args.save_dir) table1_ablations(runs, args.save_dir) print(f"\nAll cross-run artifacts in {args.save_dir}/") if __name__ == "__main__": main()