""" CausalGrok — Main Training Loop Nilesh Core experiment: does the IRM invariance penalty drop at the SAME epoch as validation accuracy jumps (the grokking transition)? If yes → the paper's central claim is confirmed. Run via the launchers (always nohup-detached so SSH disconnects don't kill it): bash scripts/launch.sh grokking 500 42 All artifacts (config, logs, history, checkpoints, figures) for every invocation land in experiments/runs// and are kept forever. """ from __future__ import annotations import argparse import json import os import time from datetime import datetime, timezone import numpy as np import torch import torch.nn as nn import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset, Subset from torchvision.models import resnet18 from medmnist import PneumoniaMNIST import wandb from utils.metrics import ( accuracy, weight_norm, feature_rank, irm_penalty, shortcut_ratio, ) from utils.grokfast import gradfilter_ema from utils.pseudo_envs import make_brightness_envs from utils.run_dir import make_run_dir, ensure_run_dir, save_config # ────────────────────────────────────────────── # CONFIG # ────────────────────────────────────────────── def get_config(condition): base = dict( seed=42, n_train=500, batch_size=32, img_size=28, n_classes=2, log_every=50, n_pseudo_envs=3, device="cuda" if torch.cuda.is_available() else "cpu", ) if condition == "standard": base.update(dict(condition="standard", lr=1e-3, weight_decay=1e-4, n_epochs=300, init_scale=1.0, use_grokfast=False)) elif condition == "grokking": base.update(dict(condition="grokking", lr=1e-3, weight_decay=1e-3, n_epochs=3000, init_scale=4.0, use_grokfast=True, grokfast_alpha=0.98, grokfast_lamb=2.0)) return base # ────────────────────────────────────────────── # DATA # ────────────────────────────────────────────── class SpuriousColorPatchDataset(Dataset): """ Wraps a (image-tensor, label) dataset and stamps a colored corner patch correlated with the label at probability `rho`. Encoding (after Normalize mean=.5/std=.5, image is in [-1,1] across 3 identical grayscale channels): encoded label 0 → channel-0 high, channels 1,2 low (red corner) encoded label 1 → channel-2 high, channels 0,1 low (blue corner) With prob rho the encoded label matches the true label — a usable shortcut. With prob (1-rho) it's flipped — pure noise on the patch. The same `seed` produces the same per-sample correlation decisions across val/test so the spurious feature is stable across runs and the ceiling effect (val plateau ≈ rho before grokking) is clean. """ def __init__(self, base, rho=0.8, patch_size=4, seed=0, hi=1.0, lo=-1.0): self.base = base self.rho = float(rho) self.patch_size = int(patch_size) self.hi = hi self.lo = lo rng = torch.Generator().manual_seed(int(seed)) self.is_correlated = (torch.rand(len(base), generator=rng) < self.rho) def __len__(self): return len(self.base) def __getitem__(self, idx): img, label = self.base[idx] # label may be a 1-element tensor or a python scalar try: label_int = int(label.squeeze().item()) except AttributeError: label_int = int(label) encoded = label_int if bool(self.is_correlated[idx]) else (1 - label_int) ps = self.patch_size if encoded == 0: img[0, :ps, :ps] = self.hi img[1, :ps, :ps] = self.lo img[2, :ps, :ps] = self.lo else: img[0, :ps, :ps] = self.lo img[1, :ps, :ps] = self.lo img[2, :ps, :ps] = self.hi return img, label def get_dataloaders(cfg, data_root): # medmnist 3.x raises if root doesn't exist; create it ourselves # rather than relying on its default-root fallback. os.makedirs(data_root, exist_ok=True) transform = transforms.Compose([ transforms.Resize((cfg["img_size"], cfg["img_size"])), transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5]), transforms.Lambda(lambda x: x.repeat(3, 1, 1)), ]) train_ds = PneumoniaMNIST(split="train", transform=transform, download=True, root=data_root) val_ds = PneumoniaMNIST(split="val", transform=transform, download=True, root=data_root) test_ds = PneumoniaMNIST(split="test", transform=transform, download=True, root=data_root) # Spurious-feature injection: colored corner patch at correlation rho. # Same rho on all splits so the shortcut model plateaus at val≈rho; # grokking transition is the model breaking through that ceiling. rho = cfg.get("spurious_rho") if rho: ps = cfg.get("spurious_patch_size", 4) sd = cfg.get("spurious_seed", cfg["seed"]) train_ds = SpuriousColorPatchDataset(train_ds, rho=rho, patch_size=ps, seed=sd + 1) val_ds = SpuriousColorPatchDataset(val_ds, rho=rho, patch_size=ps, seed=sd + 2) test_ds = SpuriousColorPatchDataset(test_ds, rho=rho, patch_size=ps, seed=sd + 3) torch.manual_seed(cfg["seed"]) indices = torch.randperm(len(train_ds))[:cfg["n_train"]] train_subset = Subset(train_ds, indices) train_loader = DataLoader(train_subset, batch_size=cfg["batch_size"], shuffle=True, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=4, pin_memory=True) test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=4, pin_memory=True) return train_loader, val_loader, test_loader, train_subset # ────────────────────────────────────────────── # MODEL # ────────────────────────────────────────────── def build_model(cfg): model = resnet18(weights=None, num_classes=cfg["n_classes"]) if cfg["init_scale"] != 1.0: with torch.no_grad(): for name, p in model.named_parameters(): if "weight" in name and p.dim() > 1: p.data *= cfg["init_scale"] return model.to(cfg["device"]) # ────────────────────────────────────────────── # TRAIN # ────────────────────────────────────────────── def train(cfg, model, train_loader, val_loader, test_loader, pseudo_envs, optimizer, run_dir): criterion = nn.CrossEntropyLoss() grads_ema = None history = [] best_val = 0.0 grok_epoch = None irm_base = None print(f"\n{'='*55}") print(f" {cfg['condition'].upper()} | {cfg['n_epochs']} epochs | " f"WD={cfg['weight_decay']} | α={cfg['init_scale']}") print(f" run_dir: {run_dir}") print(f"{'='*55}", flush=True) history_path = os.path.join(run_dir, "results", "history.json") grad_clip = cfg.get("grad_clip", 1.0) plateau_window = 10 plateau_eps = 0.01 # |Δval_acc| within this counts as flat for epoch in range(1, cfg["n_epochs"] + 1): model.train() loss_sum = 0.0 n_b = 0 for imgs, labels in train_loader: imgs = imgs.to(cfg["device"]) labels = labels.squeeze().long().to(cfg["device"]) optimizer.zero_grad() loss = criterion(model(imgs), labels) loss.backward() # Order matters: Grokfast amplifies, THEN we clip the # amplified result. Clipping before Grokfast would let the # amplification re-blow up the gradient and partially # undo the safety bound. if cfg.get("use_grokfast"): grads_ema = gradfilter_ema( model, grads_ema, alpha=cfg.get("grokfast_alpha", 0.98), lamb=cfg.get("grokfast_lamb", 2.0)) if grad_clip and grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip) optimizer.step() loss_sum += loss.item(); n_b += 1 if epoch % cfg["log_every"] == 0 or epoch == 1: tr_acc = accuracy(model, train_loader, cfg["device"]) vl_acc = accuracy(model, val_loader, cfg["device"]) wn = weight_norm(model) fr = feature_rank(model, val_loader, cfg["device"]) irm_m, irm_v = irm_penalty(model, pseudo_envs, cfg["device"]) cconf, bconf = shortcut_ratio(model, val_loader, cfg["device"]) if irm_base is None: irm_base = irm_m # Robust grokking detection — require a sustained plateau in # val_acc (≥ plateau_window-2 of the last `plateau_window` # checkpoints flat within `plateau_eps`) BEFORE the jump. # Otherwise early-training noise (0.50 → 0.56) can trigger. if grok_epoch is None and len(history) >= plateau_window: last = history[-plateau_window:] ref = last[-1]["val_acc"] flat = sum(1 for r in last if abs(r["val_acc"] - ref) < plateau_eps) if flat >= plateau_window - 2 and vl_acc > best_val + 0.05: grok_epoch = epoch irm_drop = (irm_base - irm_m) / (irm_base + 1e-8) * 100 print(f"\n *** GROKKING at epoch {epoch} ***") print(f" Val: {best_val:.3f}→{vl_acc:.3f} | IRM drop: {irm_drop:.1f}%", flush=True) if vl_acc > best_val: best_val = vl_acc # Cap the shortcut ratio — early training can give cconf≈bconf≈0 # which makes the raw ratio explode. sc_ratio = min(bconf / (cconf + 1e-8), 10.0) row = dict(epoch=epoch, train_loss=loss_sum / n_b, train_acc=tr_acc, val_acc=vl_acc, weight_norm=wn, feature_rank=fr, irm_mean=irm_m, irm_var=irm_v, center_conf=cconf, border_conf=bconf, shortcut_ratio=sc_ratio, grokking_detected=grok_epoch is not None) history.append(row) wandb.log(row) with open(history_path, "w") as f: json.dump(history, f, indent=2) print(f" ep {epoch:5d} | loss {loss_sum/n_b:.4f} | " f"tr {tr_acc:.3f} | vl {vl_acc:.3f} | " f"‖W‖ {wn:.1f} | rank {fr:.1f} | " f"IRM {irm_m:.4f} | sc {sc_ratio:.2f}x", flush=True) test_acc = accuracy(model, test_loader, cfg["device"]) wandb.log({"test_acc": test_acc, "grokking_epoch": grok_epoch or -1}) # Compute the four decision numbers right here so summary.json is # the single source of truth for go/no-go. irm_drop_pct = float("nan") irm_drop_ep = -1 epoch_gap = -1 if history: irm0 = history[0]["irm_mean"] irm_min = min(r["irm_mean"] for r in history) if irm0: irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100.0 # Epoch of biggest IRM step-change (proxy for "the IRM drop") if len(history) > 1: biggest = 0.0 for prev, cur in zip(history[:-1], history[1:]): d = abs(cur["irm_mean"] - prev["irm_mean"]) if d > biggest: biggest = d irm_drop_ep = cur["epoch"] if grok_epoch and irm_drop_ep > 0: epoch_gap = abs(grok_epoch - irm_drop_ep) summary = dict( run_id = cfg["run_id"], condition = cfg["condition"], n_train = cfg["n_train"], seed = cfg["seed"], test_acc = test_acc, best_val = best_val, grokking_epoch = grok_epoch if grok_epoch else -1, irm_drop_pct = irm_drop_pct, irm_drop_epoch = irm_drop_ep, epoch_gap = epoch_gap, final_weight_norm = history[-1]["weight_norm"] if history else None, final_feature_rank = history[-1]["feature_rank"] if history else None, final_irm = history[-1]["irm_mean"] if history else None, final_shortcut_ratio = history[-1]["shortcut_ratio"] if history else None, ) with open(os.path.join(run_dir, "results", "summary.json"), "w") as f: json.dump(summary, f, indent=2) ckpt_path = os.path.join(run_dir, "checkpoints", "final.pt") torch.save(model.state_dict(), ckpt_path) print(f"\n Test acc: {test_acc:.4f} | Grokking at: {grok_epoch}") print(f" History → {history_path}") print(f" Checkpoint → {ckpt_path}", flush=True) return history # ────────────────────────────────────────────── # MAIN # ────────────────────────────────────────────── def main(): p = argparse.ArgumentParser() p.add_argument("--condition", default="grokking", choices=["standard", "grokking"]) p.add_argument("--n_train", type=int, default=500) p.add_argument("--seed", type=int, default=42) p.add_argument("--log_every", type=int, default=50) p.add_argument("--wandb_project", default="causalgrok") p.add_argument("--wandb_mode", default="online", choices=["online", "offline", "disabled"]) p.add_argument("--run_dir", default=None, help="Override the auto-generated experiments/runs// path") p.add_argument("--data_root", default="data", help="Where MedMNIST cache lives") # Per-knob overrides for the ablation grid. When set, they override # the preset chosen by --condition. When omitted, the preset wins. p.add_argument("--weight_decay", type=float, default=None) p.add_argument("--init_scale", type=float, default=None) p.add_argument("--n_epochs", type=int, default=None) p.add_argument("--lr", type=float, default=None) p.add_argument("--grokfast", choices=["on", "off"], default=None, help="Force Grokfast on/off, overriding the preset") p.add_argument("--grad_clip", type=float, default=1.0, help="Max ℓ2 gradient norm; 0 disables clipping") # Spurious-feature injection (Outcome-C variant). p.add_argument("--spurious_rho", type=float, default=None, help="Probability that the colored corner patch is correctly correlated with the label. None/0 disables injection.") p.add_argument("--spurious_patch_size", type=int, default=4) p.add_argument("--spurious_seed", type=int, default=None, help="Defaults to --seed; controls per-sample correlation decisions") args = p.parse_args() cfg = get_config(args.condition) cfg.update(n_train=args.n_train, seed=args.seed, log_every=args.log_every) # CLI overrides take precedence over preset if args.weight_decay is not None: cfg["weight_decay"] = args.weight_decay if args.init_scale is not None: cfg["init_scale"] = args.init_scale if args.n_epochs is not None: cfg["n_epochs"] = args.n_epochs if args.lr is not None: cfg["lr"] = args.lr if args.grokfast is not None: cfg["use_grokfast"] = (args.grokfast == "on") cfg["grad_clip"] = args.grad_clip cfg["spurious_rho"] = args.spurious_rho cfg["spurious_patch_size"] = args.spurious_patch_size cfg["spurious_seed"] = args.spurious_seed if args.spurious_seed is not None else args.seed # ── Use the remaining compute on a shared GPU more aggressively ── # TF32 matmuls are A100-native and ~2× faster than fp32 with no # measurable effect on grokking dynamics for our scale of model. # cudnn.benchmark autotunes conv algorithms for our fixed shape. if cfg["device"] == "cuda": torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.manual_seed(cfg["seed"]) np.random.seed(cfg["seed"]) if args.run_dir is None: # Tag spurious runs in the run_id so the dirs are # distinguishable on disk and globs like # `experiments/runs/*spurious*/` work without ambiguity. parts = [cfg["condition"]] if cfg.get("spurious_rho"): parts.append(f"spurious{cfg['spurious_rho']}") parts += [f"n{cfg['n_train']}", f"s{cfg['seed']}"] run_dir, run_id = make_run_dir(parts) else: run_dir = args.run_dir ensure_run_dir(run_dir) run_id = os.path.basename(os.path.normpath(run_dir)) cfg["run_id"] = run_id cfg["run_dir"] = run_dir save_config(cfg, run_dir) wandb.init(project=args.wandb_project, config=cfg, name=run_id, mode=args.wandb_mode, dir=run_dir) print(f"\nDevice: {cfg['device']}") print(f"Run ID: {run_id}") print(f"Started (UTC): {datetime.now(timezone.utc).isoformat()}", flush=True) train_loader, val_loader, test_loader, train_subset = get_dataloaders(cfg, args.data_root) pseudo_envs = make_brightness_envs(train_subset, cfg["n_pseudo_envs"], cfg["device"]) model = build_model(cfg) optimizer = torch.optim.AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"]) print(f"Train: {len(train_subset)} | Val: {len(val_loader.dataset)} | " f"Test: {len(test_loader.dataset)}") print(f"Params: {sum(p.numel() for p in model.parameters()):,}", flush=True) t0 = time.time() train(cfg, model, train_loader, val_loader, test_loader, pseudo_envs, optimizer, run_dir) print(f"\nWall time: {(time.time() - t0) / 60:.1f} min", flush=True) wandb.finish() if __name__ == "__main__": main()