""" CausalGrok — Camelyon17 Training Loop Nilesh Camelyon17: 5-hospital histopathology dataset with natural stain-color shortcuts. Real hospital environment labels enable true IRM without pseudo-environments. Run via: python -m experiments.causalgrok_camelyon --condition grokking --n_train 300 """ from __future__ import annotations import argparse import json import os import time from datetime import datetime, timezone import numpy as np import torch import torch.nn as nn import torchvision.transforms as transforms from torch.utils.data import DataLoader, Subset import timm import wandb from utils.metrics import ( accuracy, weight_norm, feature_rank, irm_penalty, shortcut_ratio, ) from utils.grokfast import gradfilter_ema from utils.camelyon_data import get_camelyon_subsets from utils.run_dir import make_run_dir, ensure_run_dir, save_config # ────────────────────────────────────────────── # CONFIG # ────────────────────────────────────────────── def get_config(condition): base = dict( seed=42, n_train=300, batch_size=32, img_size=96, n_classes=2, log_every=50, device="cuda" if torch.cuda.is_available() else "cpu", ) if condition == "standard": base.update(dict( condition="standard", lr=1e-3, weight_decay=1e-4, n_epochs=300, init_scale=1.0, use_grokfast=False )) elif condition == "grokking": base.update(dict( condition="grokking", lr=1e-3, weight_decay=5e-3, n_epochs=3000, init_scale=4.0, use_grokfast=True, grokfast_alpha=0.98, grokfast_lamb=2.0 )) return base # ────────────────────────────────────────────── # DATA # ────────────────────────────────────────────── def get_dataloaders(cfg, data_root): """Load Camelyon17 via WILDS. Hospitals become IRM environments.""" train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets(root_dir=data_root, download=True) transform = transforms.Compose([ transforms.Resize((cfg["img_size"], cfg["img_size"])), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Wrapper to apply transforms to WILDS PIL images class TransformWrapper: def __init__(self, dataset, transform): self.dataset = dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): img, label, metadata = self.dataset[idx] return self.transform(img), label, metadata # Wrap datasets with transforms train_ds = TransformWrapper(train_ds, transform) id_val_ds = TransformWrapper(id_val_ds, transform) ood_test_ds = TransformWrapper(ood_test_ds, transform) # Subsample training set to n_train (deterministic by seed) torch.manual_seed(cfg["seed"]) indices = torch.randperm(len(train_ds))[:cfg["n_train"]] train_subset = Subset(train_ds, indices) # DataLoaders train_loader = DataLoader( train_subset, batch_size=cfg["batch_size"], shuffle=True, num_workers=0, pin_memory=True ) id_val_loader = DataLoader( id_val_ds, batch_size=256, shuffle=False, num_workers=0, pin_memory=True ) ood_test_loader = DataLoader( ood_test_ds, batch_size=256, shuffle=False, num_workers=0, pin_memory=True ) return train_loader, id_val_loader, ood_test_loader, train_subset def get_hospital_environments(train_subset, device): """ Build IRM environments from hospital labels. Returns list of {x, y} dicts — format irm_penalty expects. """ # Get full images + labels + metadata for the subset loader = DataLoader(train_subset, batch_size=512, shuffle=False, num_workers=0) all_imgs, all_labels, all_meta = [], [], [] for batch in loader: imgs, labels, meta = batch[0], batch[1], batch[2] all_imgs.append(imgs) all_labels.append(labels.squeeze().long()) all_meta.append(meta) all_imgs = torch.cat(all_imgs) # [N, 3, 96, 96] all_labels = torch.cat(all_labels) # [N] all_meta = torch.cat(all_meta) # [N, n_meta_fields] hospitals = all_meta[:, 0].long() # hospital ID is field 0 envs = [] for h in torch.unique(hospitals): mask = hospitals == h envs.append({ "x": all_imgs[mask].to(device), "y": all_labels[mask].to(device), "hospital": int(h), }) if h <= 2: # Only train hospitals print(f" Env hospital={int(h)}: {mask.sum().item()} samples") return envs # ────────────────────────────────────────────── # MODEL # ────────────────────────────────────────────── def build_model(cfg): model = timm.create_model("resnet18", pretrained=False, num_classes=cfg["n_classes"]) if cfg["init_scale"] != 1.0: with torch.no_grad(): for name, p in model.named_parameters(): if "weight" in name and p.dim() > 1: p.data *= cfg["init_scale"] return model.to(cfg["device"]) # ────────────────────────────────────────────── # WILDS-COMPATIBLE METRICS (handle 3-tuple batches) # ────────────────────────────────────────────── @torch.no_grad() def accuracy_wilds(model, loader, device): model.eval() correct = total = 0 for batch in loader: imgs = batch[0].to(device) labels = batch[1].squeeze().long().to(device) preds = model(imgs).argmax(1) correct += (preds == labels).sum().item() total += labels.size(0) return correct / total @torch.no_grad() def feature_rank_wilds(model, loader, device, n=200): model.eval() feats = [] def hook_fn(module, input, output): avg_pool = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1)) feats.append(avg_pool.view(avg_pool.size(0), -1).cpu()) hook = model.layer4[-1].register_forward_hook(hook_fn) count = 0 for batch in loader: model(batch[0].to(device)) count += batch[0].size(0) if count >= n: break hook.remove() if not feats: return float("nan") F_mat = torch.cat(feats)[:n] try: _, s, _ = torch.svd(F_mat) s = s/(s.sum()+1e-10) return torch.exp(-(s*torch.log(s+1e-10)).sum()).item() except: return float("nan") @torch.no_grad() def shortcut_ratio_wilds(model, loader, device, n_samples=200): import torch.nn.functional as F model.eval() cc, bc = [], [] count = 0 for batch in loader: if count >= n_samples: break imgs = batch[0].to(device) B, C, H, W = imgs.shape hs, he, ws, we = H//4, 3*H//4, W//4, 3*W//4 center = F.interpolate(imgs[:,:,hs:he,ws:we], (H,W), mode='bilinear', align_corners=False) border = imgs.clone() border[:,:,hs:he,ws:we] = 0. cc.append(F.softmax(model(center), 1).max(1).values.mean().item()) bc.append(F.softmax(model(border), 1).max(1).values.mean().item()) count += imgs.size(0) if not cc or not bc: return 1.0, 1.0 return float(np.mean(cc)), float(np.mean(bc)) # ────────────────────────────────────────────── # TRAIN # ────────────────────────────────────────────── def train(cfg, model, train_loader, id_val_loader, ood_test_loader, hospital_envs, optimizer, run_dir): criterion = nn.CrossEntropyLoss() grads_ema = None history = [] best_val = 0.0 grok_epoch = None irm_base = None print(f"\n{'='*55}") print(f" {cfg['condition'].upper()} | {cfg['n_epochs']} epochs | " f"WD={cfg['weight_decay']} | α={cfg['init_scale']}") print(f" run_dir: {run_dir}") print(f"{'='*55}", flush=True) history_path = os.path.join(run_dir, "results", "history.json") grad_clip = cfg.get("grad_clip", 1.0) plateau_window = 10 plateau_eps = 0.01 for epoch in range(1, cfg["n_epochs"] + 1): model.train() loss_sum = 0.0 n_b = 0 for imgs, labels, metadata in train_loader: imgs = imgs.to(cfg["device"]) labels = labels.squeeze().long().to(cfg["device"]) optimizer.zero_grad() loss = criterion(model(imgs), labels) loss.backward() if cfg.get("use_grokfast"): grads_ema = gradfilter_ema( model, grads_ema, alpha=cfg.get("grokfast_alpha", 0.98), lamb=cfg.get("grokfast_lamb", 2.0) ) if grad_clip and grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip) optimizer.step() loss_sum += loss.item() n_b += 1 if epoch % cfg["log_every"] == 0 or epoch == 1: tr_acc = accuracy_wilds(model, train_loader, cfg["device"]) val_acc = accuracy_wilds(model, id_val_loader, cfg["device"]) wn = weight_norm(model) fr = feature_rank_wilds(model, id_val_loader, cfg["device"]) irm_m, irm_v = irm_penalty(model, hospital_envs, cfg["device"]) cconf, bconf = shortcut_ratio_wilds(model, id_val_loader, cfg["device"]) if irm_base is None: irm_base = irm_m # Robust grokking detection if grok_epoch is None and len(history) >= plateau_window: last = history[-plateau_window:] ref = last[-1]["val_acc"] flat = sum(1 for r in last if abs(r["val_acc"] - ref) < plateau_eps) if flat >= plateau_window - 2 and val_acc > best_val + 0.05: grok_epoch = epoch irm_drop = (irm_base - irm_m) / (irm_base + 1e-8) * 100 print(f"\n *** GROKKING at epoch {epoch} ***") print(f" Val: {best_val:.3f}→{val_acc:.3f} | IRM drop: {irm_drop:.1f}%", flush=True) if val_acc > best_val: best_val = val_acc sc_ratio = min(bconf / (cconf + 1e-8), 10.0) row = dict( epoch=epoch, train_loss=loss_sum / n_b, train_acc=tr_acc, val_acc=val_acc, weight_norm=wn, feature_rank=fr, irm_mean=irm_m, irm_var=irm_v, center_conf=cconf, border_conf=bconf, shortcut_ratio=sc_ratio, grokking_detected=grok_epoch is not None ) history.append(row) wandb.log(row) with open(history_path, "w") as f: json.dump(history, f, indent=2) print(f" ep {epoch:5d} | tr {tr_acc:.3f} | vl {val_acc:.3f} | " f"‖W‖ {wn:.1f} | rank {fr:.1f} | " f"IRM {irm_m:.4f} | sc {sc_ratio:.2f}x", flush=True) ood_test_acc = accuracy_wilds(model, ood_test_loader, cfg["device"]) wandb.log({"ood_test_acc": ood_test_acc, "grokking_epoch": grok_epoch or -1}) # Decision numbers irm_drop_pct = float("nan") irm_drop_ep = -1 epoch_gap = -1 if history: irm0 = history[0]["irm_mean"] irm_min = min(r["irm_mean"] for r in history) if irm0: irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100.0 if len(history) > 1: biggest = 0.0 for prev, cur in zip(history[:-1], history[1:]): d = abs(cur["irm_mean"] - prev["irm_mean"]) if d > biggest: biggest = d irm_drop_ep = cur["epoch"] if grok_epoch and irm_drop_ep > 0: epoch_gap = abs(grok_epoch - irm_drop_ep) summary = dict( run_id=cfg["run_id"], condition=cfg["condition"], n_train=cfg["n_train"], seed=cfg["seed"], id_val_acc=best_val, ood_test_acc=ood_test_acc, grokking_epoch=grok_epoch if grok_epoch else -1, irm_drop_pct=irm_drop_pct, irm_drop_epoch=irm_drop_ep, epoch_gap=epoch_gap, final_weight_norm=history[-1]["weight_norm"] if history else None, final_feature_rank=history[-1]["feature_rank"] if history else None, final_irm=history[-1]["irm_mean"] if history else None, final_shortcut_ratio=history[-1]["shortcut_ratio"] if history else None, ) with open(os.path.join(run_dir, "results", "summary.json"), "w") as f: json.dump(summary, f, indent=2) ckpt_path = os.path.join(run_dir, "checkpoints", "final.pt") torch.save(model.state_dict(), ckpt_path) print(f"\n ID Val acc: {best_val:.4f} | OOD Test acc: {ood_test_acc:.4f} | Grokking at: {grok_epoch}") print(f" History → {history_path}") print(f" Checkpoint → {ckpt_path}", flush=True) return history # ────────────────────────────────────────────── # MAIN # ────────────────────────────────────────────── def main(): p = argparse.ArgumentParser() p.add_argument("--condition", default="grokking", choices=["standard", "grokking"]) p.add_argument("--n_train", type=int, default=300) p.add_argument("--seed", type=int, default=42) p.add_argument("--log_every", type=int, default=50) p.add_argument("--wandb_project", default="causalgrok") p.add_argument("--wandb_mode", default="online", choices=["online", "offline", "disabled"]) p.add_argument("--run_dir", default=None) p.add_argument("--data_root", default="data/wilds") # Per-knob overrides p.add_argument("--weight_decay", type=float, default=None) p.add_argument("--init_scale", type=float, default=None) p.add_argument("--n_epochs", type=int, default=None) p.add_argument("--lr", type=float, default=None) p.add_argument("--grokfast", choices=["on", "off"], default=None) p.add_argument("--grad_clip", type=float, default=1.0) args = p.parse_args() cfg = get_config(args.condition) cfg.update(n_train=args.n_train, seed=args.seed, log_every=args.log_every) # CLI overrides if args.weight_decay is not None: cfg["weight_decay"] = args.weight_decay if args.init_scale is not None: cfg["init_scale"] = args.init_scale if args.n_epochs is not None: cfg["n_epochs"] = args.n_epochs if args.lr is not None: cfg["lr"] = args.lr if args.grokfast is not None: cfg["use_grokfast"] = (args.grokfast == "on") cfg["grad_clip"] = args.grad_clip if cfg["device"] == "cuda": torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.manual_seed(cfg["seed"]) np.random.seed(cfg["seed"]) if args.run_dir is None: parts = [cfg["condition"], f"n{cfg['n_train']}", f"s{cfg['seed']}"] run_dir, run_id = make_run_dir(parts) else: run_dir = args.run_dir ensure_run_dir(run_dir) run_id = os.path.basename(os.path.normpath(run_dir)) cfg["run_id"] = run_id cfg["run_dir"] = run_dir save_config(cfg, run_dir) wandb.init(project=args.wandb_project, config=cfg, name=run_id, mode=args.wandb_mode, dir=run_dir) print(f"\nDevice: {cfg['device']}") print(f"Run ID: {run_id}") print(f"Started (UTC): {datetime.now(timezone.utc).isoformat()}", flush=True) train_loader, id_val_loader, ood_test_loader, train_subset = get_dataloaders(cfg, args.data_root) hospital_envs = get_hospital_environments(train_subset, cfg["device"]) model = build_model(cfg) optimizer = torch.optim.AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"]) print(f"Train: {len(train_subset)} | ID Val: {len(id_val_loader.dataset)} | " f"OOD Test: {len(ood_test_loader.dataset)}") print(f"Params: {sum(p.numel() for p in model.parameters()):,}") print(f"Hospital environments: {len(hospital_envs)} envs", flush=True) t0 = time.time() train(cfg, model, train_loader, id_val_loader, ood_test_loader, hospital_envs, optimizer, run_dir) print(f"\nWall time: {(time.time() - t0) / 60:.1f} min", flush=True) wandb.finish() if __name__ == "__main__": main()