| """ |
| CausalGrok β Camelyon17 Training Loop v2 |
| Nilesh |
| |
| KEY CHANGE FROM v1: |
| OOD test accuracy (H4 β unseen hospital) is now tracked at EVERY |
| checkpoint, not just at the end. Grokking detection watches OOD acc, |
| not ID val acc. This is the correct signal. |
| |
| The paper claim: after ID accuracy converges (fast, expected), the model |
| undergoes a delayed phase transition in OOD generalization β grokking |
| the cross-hospital invariant causal features. This co-occurs with a drop |
| in IRM penalty. That is the grokking we care about for clinical deployment. |
| |
| Two curves to watch: |
| val_acc (H3 ID val) β converges fast, expected ~0.86 by ep 50 |
| ood_acc (H4 OOD test) β should plateau then JUMP (the grokking) |
| |
| Run via: |
| python -m experiments.causalgrok_camelyon_v2 --condition grokking --n_train 300 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import time |
| from datetime import datetime, timezone |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
| from torch.utils.data import DataLoader, Subset |
| import timm |
| try: |
| import wandb |
| except ImportError: |
| wandb = None |
|
|
| from utils.grokfast import gradfilter_ema |
| from utils.camelyon_data import get_camelyon_subsets |
| from utils.run_dir import make_run_dir, ensure_run_dir, save_config |
|
|
|
|
| |
| |
| |
|
|
| def get_config(condition): |
| base = dict( |
| seed=42, n_train=300, batch_size=32, img_size=96, |
| n_classes=2, log_every=50, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| ) |
| if condition == "standard": |
| base.update(dict( |
| condition="standard", |
| lr=1e-3, weight_decay=1e-4, |
| |
| |
| |
| |
| n_epochs=3000, init_scale=1.0, use_grokfast=False, |
| )) |
| elif condition == "grokking": |
| base.update(dict( |
| condition="grokking", |
| lr=1e-3, weight_decay=5e-3, |
| n_epochs=3000, init_scale=4.0, use_grokfast=True, |
| grokfast_alpha=0.98, grokfast_lamb=2.0, |
| )) |
| return base |
|
|
|
|
| |
| |
| |
| |
|
|
| @torch.no_grad() |
| def accuracy_wilds(model, loader, device, max_samples=None): |
| model.eval() |
| correct = total = 0 |
| for batch in loader: |
| imgs = batch[0].to(device) |
| labels = batch[1].squeeze().long().to(device) |
| preds = model(imgs).argmax(1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
| if max_samples and total >= max_samples: |
| break |
| return correct / max(total, 1) |
|
|
|
|
| @torch.no_grad() |
| def weight_norm_fn(model): |
| return sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 |
|
|
|
|
| @torch.no_grad() |
| def feature_rank_wilds(model, loader, device, n=300): |
| model.eval() |
| feats = [] |
|
|
| def hook_fn(module, input, output): |
| avg_pool = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1)) |
| feats.append(avg_pool.view(avg_pool.size(0), -1).cpu()) |
|
|
| hook = model.layer4[-1].register_forward_hook(hook_fn) |
| count = 0 |
| for batch in loader: |
| model(batch[0].to(device)) |
| count += batch[0].size(0) |
| if count >= n: |
| break |
| hook.remove() |
| if not feats: |
| return float("nan") |
| F_mat = torch.cat(feats)[:n] |
| try: |
| _, s, _ = torch.svd(F_mat) |
| s = s / (s.sum() + 1e-10) |
| return torch.exp(-(s * torch.log(s + 1e-10)).sum()).item() |
| except Exception: |
| return float("nan") |
|
|
|
|
| @torch.no_grad() |
| def shortcut_ratio_wilds(model, loader, device, n_samples=200): |
| """ |
| Stain shortcut proxy: compare model confidence on center crop |
| (tissue β causal features) vs. border region (stain β spurious). |
| |
| sc > 1.0 = relying on border stain more than tissue (shortcut) |
| sc < 1.0 = relying on tissue center more than stain (causal) |
| |
| The transition from > 1.0 to < 1.0 during training is the |
| attribution-level signature of the grokking transition. |
| """ |
| model.eval() |
| cc, bc = [], [] |
| count = 0 |
| for batch in loader: |
| if count >= n_samples: |
| break |
| imgs = batch[0].to(device) |
| B, C, H, W = imgs.shape |
| hs, he = H // 4, 3 * H // 4 |
| ws, we = W // 4, 3 * W // 4 |
| center = F.interpolate( |
| imgs[:, :, hs:he, ws:we], size=(H, W), |
| mode="bilinear", align_corners=False |
| ) |
| border = imgs.clone() |
| border[:, :, hs:he, ws:we] = 0.0 |
| cc.append(F.softmax(model(center), 1).max(1).values.mean().item()) |
| bc.append(F.softmax(model(border), 1).max(1).values.mean().item()) |
| count += imgs.size(0) |
| cconf = float(np.mean(cc)) if cc else 0.5 |
| bconf = float(np.mean(bc)) if bc else 0.5 |
| return cconf, bconf |
|
|
|
|
| def irm_penalty_wilds(model, envs, device): |
| """ |
| IRMv1 penalty across TRAINING hospital environments (H0-H2). |
| Diagnostic version: uses create_graph=False, returns floats. Used as a |
| monitoring metric only (logged per epoch). |
| """ |
| model.eval() |
| penalties = [] |
| for env in envs: |
| w = torch.tensor(1.0, requires_grad=True, device=device) |
| logits = model(env["x"]) * w |
| loss = F.cross_entropy(logits, env["y"]) |
| grad = torch.autograd.grad(loss, w, create_graph=False)[0] |
| penalties.append(grad.item() ** 2) |
| t = torch.tensor(penalties) |
| return t.mean().item(), t.var().item() |
|
|
|
|
| def irm_penalty_train_time(logits_list, y_list): |
| """ |
| IRMv1 penalty for use INSIDE the training loss (differentiable). |
| Splits a batch by environment, computes per-env loss with a virtual |
| scale variable, takes the squared gradient of each per-env loss w.r.t. |
| that scale, returns the mean across envs. |
| |
| Args: |
| logits_list: list of (per-env) logits tensors |
| y_list: list of (per-env) label tensors |
| |
| Returns: |
| scalar tensor (differentiable), the IRM penalty contribution. |
| """ |
| penalty = 0.0 |
| n = 0 |
| for logits, y in zip(logits_list, y_list): |
| if logits.shape[0] == 0: |
| continue |
| scale = torch.tensor(1.0, requires_grad=True, device=logits.device) |
| loss = F.cross_entropy(logits * scale, y) |
| grad = torch.autograd.grad(loss, scale, create_graph=True)[0] |
| penalty = penalty + grad ** 2 |
| n += 1 |
| if n == 0: |
| return torch.tensor(0.0, device=logits_list[0].device) |
| return penalty / n |
|
|
|
|
| def eval_irm_penalty_wilds(model, id_val_loader, ood_test_loader, device): |
| """ |
| IRM penalty evaluated on HELD-OUT environments (H3 and H4). |
| This avoids the measurement artifact of training on H0-H2 where lossβ0. |
| HIGH penalty = model relies on hospital-discriminating features = shortcuts. |
| LOW penalty = model ignores hospital labels = causal features. |
| """ |
| model.eval() |
| penalties = [] |
|
|
| |
| for loader, hospital_label in [ |
| (id_val_loader, "H3"), |
| (ood_test_loader, "H4"), |
| ]: |
| xs, ys = [], [] |
| count = 0 |
| with torch.no_grad(): |
| for batch in loader: |
| imgs = batch[0].to(device) |
| labels = batch[1].squeeze().long().to(device) |
| xs.append(model(imgs)) |
| ys.append(labels) |
| count += imgs.size(0) |
| if count >= 500: |
| break |
| if xs: |
| x = torch.cat(xs) |
| y = torch.cat(ys) |
| w = torch.tensor(1.0, requires_grad=True, device=device) |
| logits = x * w |
| loss = F.cross_entropy(logits, y) |
| try: |
| grad = torch.autograd.grad(loss, w, create_graph=False)[0] |
| penalties.append(grad.item() ** 2) |
| except: |
| penalties.append(float("nan")) |
|
|
| if penalties and not any(np.isnan(p) for p in penalties): |
| return float(np.mean(penalties)), float(np.var(penalties)) |
| else: |
| return float("nan"), float("nan") |
|
|
|
|
| |
| |
| |
|
|
| class TransformWrapper: |
| def __init__(self, dataset, transform): |
| self.dataset = dataset |
| self.transform = transform |
| def __len__(self): |
| return len(self.dataset) |
| def __getitem__(self, idx): |
| img, label, metadata = self.dataset[idx] |
| return self.transform(img), label, metadata |
|
|
|
|
| def get_dataloaders(cfg, data_root): |
| transform = transforms.Compose([ |
| transforms.Resize((cfg["img_size"], cfg["img_size"])), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( |
| root_dir=data_root, download=True) |
|
|
| |
| torch.manual_seed(cfg["seed"]) |
| indices = torch.randperm(len(train_ds))[:cfg["n_train"]] |
| train_subset = Subset(train_ds, indices) |
|
|
| |
| train_subset = TransformWrapper(train_subset, transform) |
| id_val_ds = TransformWrapper(id_val_ds, transform) |
| ood_test_ds = TransformWrapper(ood_test_ds, transform) |
|
|
| train_loader = DataLoader(train_subset, batch_size=cfg["batch_size"], |
| shuffle=True, num_workers=0, pin_memory=True) |
| id_val_loader = DataLoader(id_val_ds, batch_size=256, |
| shuffle=False, num_workers=0, pin_memory=True) |
| ood_test_loader = DataLoader(ood_test_ds, batch_size=256, |
| shuffle=False, num_workers=0, pin_memory=True) |
|
|
| return train_loader, id_val_loader, ood_test_loader, train_subset |
|
|
|
|
| def get_hospital_environments(train_subset, device): |
| """ |
| Build IRM environments from ground-truth hospital labels. |
| Returns list of {x, y} dicts β one per unique hospital in the subset. |
| Hospitals in Camelyon17 train split: 0, 1, 2. |
| """ |
| loader = DataLoader(train_subset, batch_size=512, |
| shuffle=False, num_workers=4) |
| all_imgs, all_labels, all_meta = [], [], [] |
| for imgs, labels, meta in loader: |
| all_imgs.append(imgs) |
| all_labels.append(labels.squeeze().long()) |
| all_meta.append(meta) |
|
|
| all_imgs = torch.cat(all_imgs) |
| all_labels = torch.cat(all_labels) |
| hospitals = torch.cat(all_meta)[:, 0].long() |
|
|
| envs = [] |
| for h in torch.unique(hospitals): |
| mask = hospitals == h |
| n = mask.sum().item() |
| envs.append({ |
| "x": all_imgs[mask].to(device), |
| "y": all_labels[mask].to(device), |
| "hospital": int(h), |
| }) |
| pos_rate = all_labels[mask].float().mean().item() |
| print(f" Env hospital={int(h)}: {n} samples, " |
| f"positive rate={pos_rate:.2f}") |
| return envs |
|
|
|
|
| |
| |
| |
|
|
| def build_model(cfg): |
| model = timm.create_model("resnet18", pretrained=False, |
| num_classes=cfg["n_classes"]) |
| if cfg["init_scale"] != 1.0: |
| with torch.no_grad(): |
| for name, p in model.named_parameters(): |
| if "weight" in name and p.dim() > 1: |
| p.data *= cfg["init_scale"] |
| return model.to(cfg["device"]) |
|
|
|
|
| |
| |
| |
|
|
| def train(cfg, model, train_loader, id_val_loader, ood_test_loader, |
| envs, optimizer, run_dir): |
|
|
| criterion = nn.CrossEntropyLoss() |
| grads_ema = None |
| history = [] |
| best_id_val = 0.0 |
| best_ood = 0.0 |
| peak_ood_epoch = None |
| grok_epoch = None |
| irm_base = None |
| history_path = os.path.join(run_dir, "results", "history.json") |
| grad_clip = cfg.get("grad_clip", 1.0) |
|
|
| |
| |
| |
| plateau_window = 10 |
| plateau_eps = 0.01 |
|
|
| |
| |
| ood_patience = cfg.get("ood_patience", 20) |
| ood_min_delta = cfg.get("ood_min_delta", 0.01) |
| use_ood_early_stop = cfg.get("use_ood_early_stop", False) |
|
|
| print(f"\n{'='*60}") |
| print(f" {cfg['condition'].upper()} | Camelyon17 v2 | {cfg['n_epochs']} epochs") |
| print(f" WD={cfg['weight_decay']} | Ξ±={cfg['init_scale']} | n={cfg['n_train']}") |
| print(f" Tracking: ID val (H3) + OOD test (H4) at every checkpoint") |
| print(f" Grokking detection: watching OOD acc, not ID val acc") |
| print(f" IRM envs: {len(envs)} hospitals") |
| print(f"{'='*60}", flush=True) |
|
|
| irm_weight = float(cfg.get("irm_weight", 0.0)) |
| use_irm_in_loss = irm_weight > 0.0 |
| if use_irm_in_loss: |
| print(f" IRM-in-loss: ENABLED, alpha={irm_weight}", flush=True) |
| else: |
| print(f" IRM-in-loss: disabled (CE-only training; IRM penalty is diagnostic)", flush=True) |
|
|
| for epoch in range(1, cfg["n_epochs"] + 1): |
| |
| model.train() |
| loss_sum = n_b = 0 |
| for imgs, labels, metadata in train_loader: |
| imgs = imgs.to(cfg["device"]) |
| labels = labels.squeeze().long().to(cfg["device"]) |
| optimizer.zero_grad() |
| logits = model(imgs) |
| ce_loss = criterion(logits, labels) |
|
|
| if use_irm_in_loss: |
| |
| |
| hosp_ids = metadata[:, 0].long().to(cfg["device"]) |
| logits_per_env, y_per_env = [], [] |
| for h in [0, 1, 2]: |
| mask = (hosp_ids == h) |
| if mask.sum() < 2: |
| continue |
| logits_per_env.append(logits[mask]) |
| y_per_env.append(labels[mask]) |
| if len(logits_per_env) >= 2: |
| irm_term = irm_penalty_train_time(logits_per_env, y_per_env) |
| loss = ce_loss + irm_weight * irm_term |
| else: |
| loss = ce_loss |
| else: |
| loss = ce_loss |
|
|
| loss.backward() |
| if cfg.get("use_grokfast"): |
| grads_ema = gradfilter_ema( |
| model, grads_ema, |
| alpha=cfg.get("grokfast_alpha", 0.98), |
| lamb=cfg.get("grokfast_lamb", 2.0)) |
| if grad_clip > 0: |
| torch.nn.utils.clip_grad_norm_( |
| model.parameters(), max_norm=grad_clip) |
| optimizer.step() |
| loss_sum += loss.item() |
| n_b += 1 |
|
|
| |
| if epoch % cfg["log_every"] == 0 or epoch == 1: |
| tr_acc = accuracy_wilds(model, train_loader, cfg["device"]) |
| id_acc = accuracy_wilds(model, id_val_loader, cfg["device"]) |
| ood_acc = accuracy_wilds(model, ood_test_loader, cfg["device"]) |
| wn = weight_norm_fn(model) |
| fr = feature_rank_wilds(model, id_val_loader, cfg["device"]) |
| irm_m, irm_v = irm_penalty_wilds(model, envs, cfg["device"]) |
| cconf, bconf = shortcut_ratio_wilds( |
| model, id_val_loader, cfg["device"]) |
|
|
| if irm_base is None: |
| irm_base = irm_m |
|
|
| |
| |
| |
| if grok_epoch is None and len(history) >= plateau_window: |
| last = history[-plateau_window:] |
| ref = last[-1]["ood_acc"] |
| flat = sum(1 for r in last |
| if abs(r["ood_acc"] - ref) < plateau_eps) |
| if flat >= plateau_window - 2 and ood_acc > best_ood + 0.05: |
| grok_epoch = epoch |
| irm_drop = (irm_base - irm_m) / (irm_base + 1e-8) * 100 |
| print(f"\n *** OOD GROKKING at epoch {epoch} ***") |
| print(f" OOD: {best_ood:.3f} β {ood_acc:.3f} | " |
| f"IRM drop: {irm_drop:.1f}%", flush=True) |
|
|
| if id_acc > best_id_val: best_id_val = id_acc |
| if ood_acc > best_ood: |
| best_ood = ood_acc |
| peak_ood_epoch = epoch |
|
|
| sc_ratio = min(bconf / (cconf + 1e-8), 10.0) |
|
|
| |
| |
| ood_gap = id_acc - ood_acc |
|
|
| row = dict( |
| epoch = epoch, |
| train_loss = loss_sum / n_b, |
| train_acc = tr_acc, |
| id_val_acc = id_acc, |
| ood_acc = ood_acc, |
| ood_gap = ood_gap, |
| weight_norm = wn, |
| feature_rank = fr, |
| irm_mean = irm_m, |
| irm_var = irm_v, |
| center_conf = cconf, |
| border_conf = bconf, |
| shortcut_ratio = sc_ratio, |
| grokking_detected = grok_epoch is not None, |
| ) |
| history.append(row) |
| if wandb: |
| wandb.log(row) |
|
|
| with open(history_path, "w") as f: |
| json.dump(history, f, indent=2) |
|
|
| |
| if epoch % 200 == 0: |
| ckpt_dir = os.path.join(run_dir, "checkpoints") |
| os.makedirs(ckpt_dir, exist_ok=True) |
| ckpt_path = os.path.join(ckpt_dir, f"ep{epoch:05d}.pt") |
| torch.save(model.state_dict(), ckpt_path) |
| print(f" β Checkpoint β ep{epoch:05d}.pt", flush=True) |
|
|
| |
| |
| if use_ood_early_stop and peak_ood_epoch is not None and len(history) >= ood_patience: |
| recent_ood = [r["ood_acc"] for r in history[-ood_patience:]] |
| ood_trend = max(recent_ood) - min(recent_ood) |
|
|
| if ood_acc < best_ood - ood_min_delta: |
| print(f"\n *** EARLY STOP (OOD declining) at epoch {epoch} ***", flush=True) |
| print(f" Peak OOD: {best_ood:.4f} at epoch {peak_ood_epoch}", flush=True) |
| print(f" Current: {ood_acc:.4f} ({ood_acc-best_ood:+.4f})", flush=True) |
|
|
| |
| if peak_ood_epoch and peak_ood_epoch % 200 == 0: |
| peak_src = os.path.join(run_dir, "checkpoints", f"ep{peak_ood_epoch:05d}.pt") |
| peak_dst = os.path.join(run_dir, "checkpoints", "peak_ood.pt") |
| if os.path.exists(peak_src): |
| import shutil |
| shutil.copy(peak_src, peak_dst) |
| print(f" Saved peak β checkpoints/peak_ood.pt", flush=True) |
|
|
| break |
|
|
| print(f" ep {epoch:5d} | " |
| f"tr {tr_acc:.3f} | " |
| f"id {id_acc:.3f} | " |
| f"ood {ood_acc:.3f} | " |
| f"gap {ood_gap:+.3f} | " |
| f"βWβ {wn:.1f} | " |
| f"rank {fr:.1f} | " |
| f"IRM {irm_m:.4f} | " |
| f"sc {sc_ratio:.2f}x", |
| flush=True) |
|
|
| |
| |
| final_ood = accuracy_wilds(model, ood_test_loader, cfg["device"]) |
| if wandb: |
| wandb.log({"final_ood_acc": final_ood, |
| "grokking_epoch": grok_epoch or -1}) |
|
|
| |
| irm_drop_pct = float("nan") |
| irm_drop_ep = epoch_gap = -1 |
| if history: |
| irm0 = history[0]["irm_mean"] |
| irm_min = min(r["irm_mean"] for r in history) |
| if irm0: |
| irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100 |
| if len(history) > 1: |
| biggest = 0.0 |
| for prev, cur in zip(history[:-1], history[1:]): |
| d = abs(cur["irm_mean"] - prev["irm_mean"]) |
| if d > biggest: |
| biggest = d |
| irm_drop_ep = cur["epoch"] |
| if grok_epoch and irm_drop_ep > 0: |
| epoch_gap = abs(grok_epoch - irm_drop_ep) |
|
|
| |
| |
| |
| ood_early = np.mean([r["ood_acc"] for r in history[:5]]) if history else 0 |
| ood_late = np.mean([r["ood_acc"] for r in history[-5:]]) if history else 0 |
| ood_improvement = ood_late - ood_early |
|
|
| |
| ood_delta = final_ood - best_ood |
|
|
| summary = dict( |
| run_id = cfg["run_id"], |
| condition = cfg["condition"], |
| n_train = cfg["n_train"], |
| seed = cfg["seed"], |
| best_id_val = best_id_val, |
| best_ood = best_ood, |
| peak_ood_epoch = peak_ood_epoch or -1, |
| final_ood = final_ood, |
| ood_delta = ood_delta, |
| ood_improvement = ood_improvement, |
| grokking_epoch = grok_epoch or -1, |
| irm_drop_pct = irm_drop_pct, |
| irm_drop_epoch = irm_drop_ep, |
| epoch_gap = epoch_gap, |
| final_weight_norm = history[-1]["weight_norm"] if history else None, |
| final_feature_rank= history[-1]["feature_rank"] if history else None, |
| final_irm = history[-1]["irm_mean"] if history else None, |
| final_shortcut_ratio = history[-1]["shortcut_ratio"] if history else None, |
| final_ood_gap = history[-1]["ood_gap"] if history else None, |
| ) |
| with open(os.path.join(run_dir, "results", "summary.json"), "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
| torch.save(model.state_dict(), |
| os.path.join(run_dir, "checkpoints", "final.pt")) |
|
|
| print(f"\n Best ID val (H3): {best_id_val:.4f}") |
| print(f" Best OOD (H4): {best_ood:.4f}") |
| print(f" OOD improvement: {ood_improvement:+.4f} β did OOD grok?") |
| print(f" Grokking at: {grok_epoch}") |
| print(f" IRM drop: {irm_drop_pct:.1f}%", |
| flush=True) |
| return history |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--condition", default="grokking", |
| choices=["standard", "grokking"]) |
| p.add_argument("--n_train", type=int, default=300) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--log_every", type=int, default=50) |
| p.add_argument("--wandb_project", default="causalgrok") |
| p.add_argument("--wandb_mode", default="offline", |
| choices=["online", "offline", "disabled"]) |
| p.add_argument("--run_dir", default=None) |
| p.add_argument("--data_root", default="data/wilds") |
| p.add_argument("--weight_decay", type=float, default=None) |
| p.add_argument("--init_scale", type=float, default=None) |
| p.add_argument("--n_epochs", type=int, default=None) |
| p.add_argument("--lr", type=float, default=None) |
| p.add_argument("--grokfast", choices=["on", "off"], default=None) |
| p.add_argument("--grad_clip", type=float, default=1.0) |
| p.add_argument("--irm_weight", type=float, default=0.0, |
| help="IRMv1 penalty weight added to training loss " |
| "(0 = pure cross-entropy / diagnostic-only IRM).") |
| args = p.parse_args() |
|
|
| cfg = get_config(args.condition) |
| cfg.update(n_train=args.n_train, seed=args.seed, |
| log_every=args.log_every, grad_clip=args.grad_clip) |
|
|
| if args.weight_decay is not None: cfg["weight_decay"] = args.weight_decay |
| if args.init_scale is not None: cfg["init_scale"] = args.init_scale |
| if args.n_epochs is not None: cfg["n_epochs"] = args.n_epochs |
| if args.lr is not None: cfg["lr"] = args.lr |
| if args.grokfast is not None: cfg["use_grokfast"] = (args.grokfast == "on") |
| cfg["irm_weight"] = args.irm_weight |
|
|
| if cfg["device"] == "cuda": |
| torch.set_float32_matmul_precision("high") |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| torch.manual_seed(cfg["seed"]) |
| np.random.seed(cfg["seed"]) |
|
|
| if args.run_dir is None: |
| run_dir, run_id = make_run_dir( |
| ["camelyon_v2", cfg["condition"], |
| f"n{cfg['n_train']}", f"s{cfg['seed']}"]) |
| else: |
| run_dir = args.run_dir |
| ensure_run_dir(run_dir) |
| run_id = os.path.basename(os.path.normpath(run_dir)) |
|
|
| cfg["run_id"] = run_id |
| cfg["run_dir"] = run_dir |
| save_config(cfg, run_dir) |
|
|
| if wandb: |
| wandb.init(project=args.wandb_project, config=cfg, name=run_id, |
| mode=args.wandb_mode, dir=run_dir) |
|
|
| print(f"\nDevice: {cfg['device']}") |
| print(f"Run ID: {run_id}") |
| print(f"Started: {datetime.now(timezone.utc).isoformat()}", flush=True) |
|
|
| train_loader, id_val_loader, ood_test_loader, train_subset = \ |
| get_dataloaders(cfg, args.data_root) |
|
|
| envs = get_hospital_environments(train_subset, cfg["device"]) |
| model = build_model(cfg) |
|
|
| print(f"Train: {len(train_subset)} | " |
| f"ID val (H3): {len(id_val_loader.dataset)} | " |
| f"OOD test (H4): {len(ood_test_loader.dataset)}") |
| print(f"Params: {sum(p.numel() for p in model.parameters()):,}", |
| flush=True) |
|
|
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=cfg["lr"], weight_decay=cfg["weight_decay"]) |
|
|
| t0 = time.time() |
| train(cfg, model, train_loader, id_val_loader, ood_test_loader, |
| envs, optimizer, run_dir) |
| print(f"\nWall time: {(time.time()-t0)/60:.1f} min", flush=True) |
| if wandb: |
| wandb.finish() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|