""" CausalGrok — Camelyon17 Training Loop v2 Nilesh KEY CHANGE FROM v1: OOD test accuracy (H4 — unseen hospital) is now tracked at EVERY checkpoint, not just at the end. Grokking detection watches OOD acc, not ID val acc. This is the correct signal. The paper claim: after ID accuracy converges (fast, expected), the model undergoes a delayed phase transition in OOD generalization — grokking the cross-hospital invariant causal features. This co-occurs with a drop in IRM penalty. That is the grokking we care about for clinical deployment. Two curves to watch: val_acc (H3 ID val) — converges fast, expected ~0.86 by ep 50 ood_acc (H4 OOD test) — should plateau then JUMP (the grokking) Run via: python -m experiments.causalgrok_camelyon_v2 --condition grokking --n_train 300 """ from __future__ import annotations import argparse import json import os import time from datetime import datetime, timezone import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from torch.utils.data import DataLoader, Subset import timm try: import wandb except ImportError: wandb = None from utils.grokfast import gradfilter_ema from utils.camelyon_data import get_camelyon_subsets from utils.run_dir import make_run_dir, ensure_run_dir, save_config # ────────────────────────────────────────────── # CONFIG # ────────────────────────────────────────────── def get_config(condition): base = dict( seed=42, n_train=300, batch_size=32, img_size=96, n_classes=2, log_every=50, device="cuda" if torch.cuda.is_available() else "cpu", ) if condition == "standard": base.update(dict( condition="standard", lr=1e-3, weight_decay=1e-4, # Default 3000 epochs to match grokking config and the # paper's reported runs; previously defaulted to 300 which # made the standard baseline trivially under-trained # relative to grokking. See paper Limitations §M3. n_epochs=3000, init_scale=1.0, use_grokfast=False, )) elif condition == "grokking": base.update(dict( condition="grokking", lr=1e-3, weight_decay=5e-3, n_epochs=3000, init_scale=4.0, use_grokfast=True, grokfast_alpha=0.98, grokfast_lamb=2.0, )) return base # ────────────────────────────────────────────── # WILDS-SAFE METRICS # All handle the (imgs, labels, metadata) 3-tuple WILDS batch format. # ────────────────────────────────────────────── @torch.no_grad() def accuracy_wilds(model, loader, device, max_samples=None): model.eval() correct = total = 0 for batch in loader: imgs = batch[0].to(device) labels = batch[1].squeeze().long().to(device) preds = model(imgs).argmax(1) correct += (preds == labels).sum().item() total += labels.size(0) if max_samples and total >= max_samples: break return correct / max(total, 1) @torch.no_grad() def weight_norm_fn(model): return sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 @torch.no_grad() def feature_rank_wilds(model, loader, device, n=300): model.eval() feats = [] def hook_fn(module, input, output): avg_pool = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1)) feats.append(avg_pool.view(avg_pool.size(0), -1).cpu()) hook = model.layer4[-1].register_forward_hook(hook_fn) count = 0 for batch in loader: model(batch[0].to(device)) count += batch[0].size(0) if count >= n: break hook.remove() if not feats: return float("nan") F_mat = torch.cat(feats)[:n] try: _, s, _ = torch.svd(F_mat) s = s / (s.sum() + 1e-10) return torch.exp(-(s * torch.log(s + 1e-10)).sum()).item() except Exception: return float("nan") @torch.no_grad() def shortcut_ratio_wilds(model, loader, device, n_samples=200): """ Stain shortcut proxy: compare model confidence on center crop (tissue — causal features) vs. border region (stain — spurious). sc > 1.0 = relying on border stain more than tissue (shortcut) sc < 1.0 = relying on tissue center more than stain (causal) The transition from > 1.0 to < 1.0 during training is the attribution-level signature of the grokking transition. """ model.eval() cc, bc = [], [] count = 0 for batch in loader: if count >= n_samples: break imgs = batch[0].to(device) B, C, H, W = imgs.shape hs, he = H // 4, 3 * H // 4 ws, we = W // 4, 3 * W // 4 center = F.interpolate( imgs[:, :, hs:he, ws:we], size=(H, W), mode="bilinear", align_corners=False ) border = imgs.clone() border[:, :, hs:he, ws:we] = 0.0 cc.append(F.softmax(model(center), 1).max(1).values.mean().item()) bc.append(F.softmax(model(border), 1).max(1).values.mean().item()) count += imgs.size(0) cconf = float(np.mean(cc)) if cc else 0.5 bconf = float(np.mean(bc)) if bc else 0.5 return cconf, bconf def irm_penalty_wilds(model, envs, device): """ IRMv1 penalty across TRAINING hospital environments (H0-H2). Diagnostic version: uses create_graph=False, returns floats. Used as a monitoring metric only (logged per epoch). """ model.eval() penalties = [] for env in envs: w = torch.tensor(1.0, requires_grad=True, device=device) logits = model(env["x"]) * w loss = F.cross_entropy(logits, env["y"]) grad = torch.autograd.grad(loss, w, create_graph=False)[0] penalties.append(grad.item() ** 2) t = torch.tensor(penalties) return t.mean().item(), t.var().item() def irm_penalty_train_time(logits_list, y_list): """ IRMv1 penalty for use INSIDE the training loss (differentiable). Splits a batch by environment, computes per-env loss with a virtual scale variable, takes the squared gradient of each per-env loss w.r.t. that scale, returns the mean across envs. Args: logits_list: list of (per-env) logits tensors y_list: list of (per-env) label tensors Returns: scalar tensor (differentiable), the IRM penalty contribution. """ penalty = 0.0 n = 0 for logits, y in zip(logits_list, y_list): if logits.shape[0] == 0: continue scale = torch.tensor(1.0, requires_grad=True, device=logits.device) loss = F.cross_entropy(logits * scale, y) grad = torch.autograd.grad(loss, scale, create_graph=True)[0] penalty = penalty + grad ** 2 n += 1 if n == 0: return torch.tensor(0.0, device=logits_list[0].device) return penalty / n def eval_irm_penalty_wilds(model, id_val_loader, ood_test_loader, device): """ IRM penalty evaluated on HELD-OUT environments (H3 and H4). This avoids the measurement artifact of training on H0-H2 where loss→0. HIGH penalty = model relies on hospital-discriminating features = shortcuts. LOW penalty = model ignores hospital labels = causal features. """ model.eval() penalties = [] # Create environment views from eval data for loader, hospital_label in [ (id_val_loader, "H3"), (ood_test_loader, "H4"), ]: xs, ys = [], [] count = 0 with torch.no_grad(): for batch in loader: imgs = batch[0].to(device) labels = batch[1].squeeze().long().to(device) xs.append(model(imgs)) ys.append(labels) count += imgs.size(0) if count >= 500: break if xs: x = torch.cat(xs) y = torch.cat(ys) w = torch.tensor(1.0, requires_grad=True, device=device) logits = x * w loss = F.cross_entropy(logits, y) try: grad = torch.autograd.grad(loss, w, create_graph=False)[0] penalties.append(grad.item() ** 2) except: penalties.append(float("nan")) if penalties and not any(np.isnan(p) for p in penalties): return float(np.mean(penalties)), float(np.var(penalties)) else: return float("nan"), float("nan") # ────────────────────────────────────────────── # DATA # ────────────────────────────────────────────── class TransformWrapper: def __init__(self, dataset, transform): self.dataset = dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): img, label, metadata = self.dataset[idx] return self.transform(img), label, metadata def get_dataloaders(cfg, data_root): transform = transforms.Compose([ transforms.Resize((cfg["img_size"], cfg["img_size"])), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( root_dir=data_root, download=True) # Subsample training set torch.manual_seed(cfg["seed"]) indices = torch.randperm(len(train_ds))[:cfg["n_train"]] train_subset = Subset(train_ds, indices) # Wrap with TransformWrapper to apply transforms train_subset = TransformWrapper(train_subset, transform) id_val_ds = TransformWrapper(id_val_ds, transform) ood_test_ds = TransformWrapper(ood_test_ds, transform) train_loader = DataLoader(train_subset, batch_size=cfg["batch_size"], shuffle=True, num_workers=0, pin_memory=True) id_val_loader = DataLoader(id_val_ds, batch_size=256, shuffle=False, num_workers=0, pin_memory=True) ood_test_loader = DataLoader(ood_test_ds, batch_size=256, shuffle=False, num_workers=0, pin_memory=True) return train_loader, id_val_loader, ood_test_loader, train_subset def get_hospital_environments(train_subset, device): """ Build IRM environments from ground-truth hospital labels. Returns list of {x, y} dicts — one per unique hospital in the subset. Hospitals in Camelyon17 train split: 0, 1, 2. """ loader = DataLoader(train_subset, batch_size=512, shuffle=False, num_workers=4) all_imgs, all_labels, all_meta = [], [], [] for imgs, labels, meta in loader: all_imgs.append(imgs) all_labels.append(labels.squeeze().long()) all_meta.append(meta) all_imgs = torch.cat(all_imgs) all_labels = torch.cat(all_labels) hospitals = torch.cat(all_meta)[:, 0].long() # field 0 = hospital ID envs = [] for h in torch.unique(hospitals): mask = hospitals == h n = mask.sum().item() envs.append({ "x": all_imgs[mask].to(device), "y": all_labels[mask].to(device), "hospital": int(h), }) pos_rate = all_labels[mask].float().mean().item() print(f" Env hospital={int(h)}: {n} samples, " f"positive rate={pos_rate:.2f}") return envs # ────────────────────────────────────────────── # MODEL # ────────────────────────────────────────────── def build_model(cfg): model = timm.create_model("resnet18", pretrained=False, num_classes=cfg["n_classes"]) if cfg["init_scale"] != 1.0: with torch.no_grad(): for name, p in model.named_parameters(): if "weight" in name and p.dim() > 1: p.data *= cfg["init_scale"] return model.to(cfg["device"]) # ────────────────────────────────────────────── # TRAIN # ────────────────────────────────────────────── def train(cfg, model, train_loader, id_val_loader, ood_test_loader, envs, optimizer, run_dir): criterion = nn.CrossEntropyLoss() grads_ema = None history = [] best_id_val = 0.0 best_ood = 0.0 peak_ood_epoch = None # Epoch where best_ood was achieved grok_epoch = None irm_base = None history_path = os.path.join(run_dir, "results", "history.json") grad_clip = cfg.get("grad_clip", 1.0) # Grokking detection parameters. # We watch OOD accuracy (H4), not ID val accuracy (H3). # ID val converges fast (expected). OOD is what should grok. plateau_window = 10 plateau_eps = 0.01 # Ungrokking early stopping parameters. # If OOD peaks then declines, stop at the peak rather than training to convergence. ood_patience = cfg.get("ood_patience", 20) # checkpoints to wait before stopping ood_min_delta = cfg.get("ood_min_delta", 0.01) # minimum improvement threshold use_ood_early_stop = cfg.get("use_ood_early_stop", False) print(f"\n{'='*60}") print(f" {cfg['condition'].upper()} | Camelyon17 v2 | {cfg['n_epochs']} epochs") print(f" WD={cfg['weight_decay']} | α={cfg['init_scale']} | n={cfg['n_train']}") print(f" Tracking: ID val (H3) + OOD test (H4) at every checkpoint") print(f" Grokking detection: watching OOD acc, not ID val acc") print(f" IRM envs: {len(envs)} hospitals") print(f"{'='*60}", flush=True) irm_weight = float(cfg.get("irm_weight", 0.0)) use_irm_in_loss = irm_weight > 0.0 if use_irm_in_loss: print(f" IRM-in-loss: ENABLED, alpha={irm_weight}", flush=True) else: print(f" IRM-in-loss: disabled (CE-only training; IRM penalty is diagnostic)", flush=True) for epoch in range(1, cfg["n_epochs"] + 1): # ── Train step ──────────────────────────────────────────────── model.train() loss_sum = n_b = 0 for imgs, labels, metadata in train_loader: imgs = imgs.to(cfg["device"]) labels = labels.squeeze().long().to(cfg["device"]) optimizer.zero_grad() logits = model(imgs) ce_loss = criterion(logits, labels) if use_irm_in_loss: # Split this batch by training hospital (H0/H1/H2) and # compute IRMv1 penalty as a differentiable scalar. hosp_ids = metadata[:, 0].long().to(cfg["device"]) logits_per_env, y_per_env = [], [] for h in [0, 1, 2]: mask = (hosp_ids == h) if mask.sum() < 2: continue logits_per_env.append(logits[mask]) y_per_env.append(labels[mask]) if len(logits_per_env) >= 2: irm_term = irm_penalty_train_time(logits_per_env, y_per_env) loss = ce_loss + irm_weight * irm_term else: loss = ce_loss else: loss = ce_loss loss.backward() if cfg.get("use_grokfast"): grads_ema = gradfilter_ema( model, grads_ema, alpha=cfg.get("grokfast_alpha", 0.98), lamb=cfg.get("grokfast_lamb", 2.0)) if grad_clip > 0: torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=grad_clip) optimizer.step() loss_sum += loss.item() n_b += 1 # ── Checkpoint metrics ──────────────────────────────────────── if epoch % cfg["log_every"] == 0 or epoch == 1: tr_acc = accuracy_wilds(model, train_loader, cfg["device"]) id_acc = accuracy_wilds(model, id_val_loader, cfg["device"]) ood_acc = accuracy_wilds(model, ood_test_loader, cfg["device"]) # KEY wn = weight_norm_fn(model) fr = feature_rank_wilds(model, id_val_loader, cfg["device"]) irm_m, irm_v = irm_penalty_wilds(model, envs, cfg["device"]) cconf, bconf = shortcut_ratio_wilds( model, id_val_loader, cfg["device"]) if irm_base is None: irm_base = irm_m # ── OOD grokking detection ──────────────────────────────── # Require sustained plateau in OOD acc before the jump. # The ID val acc plateau is expected and not grokking. if grok_epoch is None and len(history) >= plateau_window: last = history[-plateau_window:] ref = last[-1]["ood_acc"] flat = sum(1 for r in last if abs(r["ood_acc"] - ref) < plateau_eps) if flat >= plateau_window - 2 and ood_acc > best_ood + 0.05: grok_epoch = epoch irm_drop = (irm_base - irm_m) / (irm_base + 1e-8) * 100 print(f"\n *** OOD GROKKING at epoch {epoch} ***") print(f" OOD: {best_ood:.3f} → {ood_acc:.3f} | " f"IRM drop: {irm_drop:.1f}%", flush=True) if id_acc > best_id_val: best_id_val = id_acc if ood_acc > best_ood: best_ood = ood_acc peak_ood_epoch = epoch # Track when peak was achieved sc_ratio = min(bconf / (cconf + 1e-8), 10.0) # OOD gap: how much worse is OOD vs ID? # This should shrink at the grokking transition. ood_gap = id_acc - ood_acc row = dict( epoch = epoch, train_loss = loss_sum / n_b, train_acc = tr_acc, id_val_acc = id_acc, ood_acc = ood_acc, # ← primary grokking signal ood_gap = ood_gap, # ← should narrow at transition weight_norm = wn, feature_rank = fr, irm_mean = irm_m, irm_var = irm_v, center_conf = cconf, border_conf = bconf, shortcut_ratio = sc_ratio, grokking_detected = grok_epoch is not None, ) history.append(row) if wandb: wandb.log(row) with open(history_path, "w") as f: json.dump(history, f, indent=2) # Save periodic checkpoint for M1 analysis (every 200 epochs) if epoch % 200 == 0: ckpt_dir = os.path.join(run_dir, "checkpoints") os.makedirs(ckpt_dir, exist_ok=True) ckpt_path = os.path.join(ckpt_dir, f"ep{epoch:05d}.pt") torch.save(model.state_dict(), ckpt_path) print(f" ✓ Checkpoint → ep{epoch:05d}.pt", flush=True) # ── OOD-aware early stopping (if ungrokking detected) ─────── # If OOD peaks then declines, stop at the peak rather than full epochs. if use_ood_early_stop and peak_ood_epoch is not None and len(history) >= ood_patience: recent_ood = [r["ood_acc"] for r in history[-ood_patience:]] ood_trend = max(recent_ood) - min(recent_ood) if ood_acc < best_ood - ood_min_delta: print(f"\n *** EARLY STOP (OOD declining) at epoch {epoch} ***", flush=True) print(f" Peak OOD: {best_ood:.4f} at epoch {peak_ood_epoch}", flush=True) print(f" Current: {ood_acc:.4f} ({ood_acc-best_ood:+.4f})", flush=True) # Save peak checkpoint separately for clinical deployment if peak_ood_epoch and peak_ood_epoch % 200 == 0: peak_src = os.path.join(run_dir, "checkpoints", f"ep{peak_ood_epoch:05d}.pt") peak_dst = os.path.join(run_dir, "checkpoints", "peak_ood.pt") if os.path.exists(peak_src): import shutil shutil.copy(peak_src, peak_dst) print(f" Saved peak → checkpoints/peak_ood.pt", flush=True) break # Exit training loop print(f" ep {epoch:5d} | " f"tr {tr_acc:.3f} | " f"id {id_acc:.3f} | " f"ood {ood_acc:.3f} | " f"gap {ood_gap:+.3f} | " # + means OOD worse than ID f"‖W‖ {wn:.1f} | " f"rank {fr:.1f} | " f"IRM {irm_m:.4f} | " f"sc {sc_ratio:.2f}x", flush=True) # ── Final summary ───────────────────────────────────────────────── # One final OOD eval at the very end final_ood = accuracy_wilds(model, ood_test_loader, cfg["device"]) if wandb: wandb.log({"final_ood_acc": final_ood, "grokking_epoch": grok_epoch or -1}) # Decision numbers irm_drop_pct = float("nan") irm_drop_ep = epoch_gap = -1 if history: irm0 = history[0]["irm_mean"] irm_min = min(r["irm_mean"] for r in history) if irm0: irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100 if len(history) > 1: biggest = 0.0 for prev, cur in zip(history[:-1], history[1:]): d = abs(cur["irm_mean"] - prev["irm_mean"]) if d > biggest: biggest = d irm_drop_ep = cur["epoch"] if grok_epoch and irm_drop_ep > 0: epoch_gap = abs(grok_epoch - irm_drop_ep) # OOD grokking: did OOD acc improve significantly after ID convergence? # Measure: max OOD acc in last 20% of training vs. OOD acc when ID # first plateaued (epoch ~200-300 for standard training). ood_early = np.mean([r["ood_acc"] for r in history[:5]]) if history else 0 ood_late = np.mean([r["ood_acc"] for r in history[-5:]]) if history else 0 ood_improvement = ood_late - ood_early # Ungrokking detection: did OOD collapse after peaking? ood_delta = final_ood - best_ood # Negative = ungrokking summary = dict( run_id = cfg["run_id"], condition = cfg["condition"], n_train = cfg["n_train"], seed = cfg["seed"], best_id_val = best_id_val, best_ood = best_ood, peak_ood_epoch = peak_ood_epoch or -1, # When peak was achieved final_ood = final_ood, ood_delta = ood_delta, # final - best (ungrokking signal) ood_improvement = ood_improvement, # ← key: did OOD grok? grokking_epoch = grok_epoch or -1, irm_drop_pct = irm_drop_pct, irm_drop_epoch = irm_drop_ep, epoch_gap = epoch_gap, final_weight_norm = history[-1]["weight_norm"] if history else None, final_feature_rank= history[-1]["feature_rank"] if history else None, final_irm = history[-1]["irm_mean"] if history else None, final_shortcut_ratio = history[-1]["shortcut_ratio"] if history else None, final_ood_gap = history[-1]["ood_gap"] if history else None, ) with open(os.path.join(run_dir, "results", "summary.json"), "w") as f: json.dump(summary, f, indent=2) torch.save(model.state_dict(), os.path.join(run_dir, "checkpoints", "final.pt")) print(f"\n Best ID val (H3): {best_id_val:.4f}") print(f" Best OOD (H4): {best_ood:.4f}") print(f" OOD improvement: {ood_improvement:+.4f} ← did OOD grok?") print(f" Grokking at: {grok_epoch}") print(f" IRM drop: {irm_drop_pct:.1f}%", flush=True) return history # ────────────────────────────────────────────── # MAIN # ────────────────────────────────────────────── def main(): p = argparse.ArgumentParser() p.add_argument("--condition", default="grokking", choices=["standard", "grokking"]) p.add_argument("--n_train", type=int, default=300) p.add_argument("--seed", type=int, default=42) p.add_argument("--log_every", type=int, default=50) p.add_argument("--wandb_project", default="causalgrok") p.add_argument("--wandb_mode", default="offline", choices=["online", "offline", "disabled"]) p.add_argument("--run_dir", default=None) p.add_argument("--data_root", default="data/wilds") p.add_argument("--weight_decay", type=float, default=None) p.add_argument("--init_scale", type=float, default=None) p.add_argument("--n_epochs", type=int, default=None) p.add_argument("--lr", type=float, default=None) p.add_argument("--grokfast", choices=["on", "off"], default=None) p.add_argument("--grad_clip", type=float, default=1.0) p.add_argument("--irm_weight", type=float, default=0.0, help="IRMv1 penalty weight added to training loss " "(0 = pure cross-entropy / diagnostic-only IRM).") args = p.parse_args() cfg = get_config(args.condition) cfg.update(n_train=args.n_train, seed=args.seed, log_every=args.log_every, grad_clip=args.grad_clip) if args.weight_decay is not None: cfg["weight_decay"] = args.weight_decay if args.init_scale is not None: cfg["init_scale"] = args.init_scale if args.n_epochs is not None: cfg["n_epochs"] = args.n_epochs if args.lr is not None: cfg["lr"] = args.lr if args.grokfast is not None: cfg["use_grokfast"] = (args.grokfast == "on") cfg["irm_weight"] = args.irm_weight if cfg["device"] == "cuda": torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.manual_seed(cfg["seed"]) np.random.seed(cfg["seed"]) if args.run_dir is None: run_dir, run_id = make_run_dir( ["camelyon_v2", cfg["condition"], f"n{cfg['n_train']}", f"s{cfg['seed']}"]) else: run_dir = args.run_dir ensure_run_dir(run_dir) run_id = os.path.basename(os.path.normpath(run_dir)) cfg["run_id"] = run_id cfg["run_dir"] = run_dir save_config(cfg, run_dir) if wandb: wandb.init(project=args.wandb_project, config=cfg, name=run_id, mode=args.wandb_mode, dir=run_dir) print(f"\nDevice: {cfg['device']}") print(f"Run ID: {run_id}") print(f"Started: {datetime.now(timezone.utc).isoformat()}", flush=True) train_loader, id_val_loader, ood_test_loader, train_subset = \ get_dataloaders(cfg, args.data_root) envs = get_hospital_environments(train_subset, cfg["device"]) model = build_model(cfg) print(f"Train: {len(train_subset)} | " f"ID val (H3): {len(id_val_loader.dataset)} | " f"OOD test (H4): {len(ood_test_loader.dataset)}") print(f"Params: {sum(p.numel() for p in model.parameters()):,}", flush=True) optimizer = torch.optim.AdamW( model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"]) t0 = time.time() train(cfg, model, train_loader, id_val_loader, ood_test_loader, envs, optimizer, run_dir) print(f"\nWall time: {(time.time()-t0)/60:.1f} min", flush=True) if wandb: wandb.finish() if __name__ == "__main__": main()