CausalGrok / code /experiments /causalgrok_camelyon_v2.py
nileshsarkar-ai's picture
Upload code/experiments
50fa85c verified
"""
CausalGrok β€” Camelyon17 Training Loop v2
Nilesh
KEY CHANGE FROM v1:
OOD test accuracy (H4 β€” unseen hospital) is now tracked at EVERY
checkpoint, not just at the end. Grokking detection watches OOD acc,
not ID val acc. This is the correct signal.
The paper claim: after ID accuracy converges (fast, expected), the model
undergoes a delayed phase transition in OOD generalization β€” grokking
the cross-hospital invariant causal features. This co-occurs with a drop
in IRM penalty. That is the grokking we care about for clinical deployment.
Two curves to watch:
val_acc (H3 ID val) β€” converges fast, expected ~0.86 by ep 50
ood_acc (H4 OOD test) β€” should plateau then JUMP (the grokking)
Run via:
python -m experiments.causalgrok_camelyon_v2 --condition grokking --n_train 300
"""
from __future__ import annotations
import argparse
import json
import os
import time
from datetime import datetime, timezone
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import timm
try:
import wandb
except ImportError:
wandb = None
from utils.grokfast import gradfilter_ema
from utils.camelyon_data import get_camelyon_subsets
from utils.run_dir import make_run_dir, ensure_run_dir, save_config
# ──────────────────────────────────────────────
# CONFIG
# ──────────────────────────────────────────────
def get_config(condition):
base = dict(
seed=42, n_train=300, batch_size=32, img_size=96,
n_classes=2, log_every=50,
device="cuda" if torch.cuda.is_available() else "cpu",
)
if condition == "standard":
base.update(dict(
condition="standard",
lr=1e-3, weight_decay=1e-4,
# Default 3000 epochs to match grokking config and the
# paper's reported runs; previously defaulted to 300 which
# made the standard baseline trivially under-trained
# relative to grokking. See paper Limitations Β§M3.
n_epochs=3000, init_scale=1.0, use_grokfast=False,
))
elif condition == "grokking":
base.update(dict(
condition="grokking",
lr=1e-3, weight_decay=5e-3,
n_epochs=3000, init_scale=4.0, use_grokfast=True,
grokfast_alpha=0.98, grokfast_lamb=2.0,
))
return base
# ──────────────────────────────────────────────
# WILDS-SAFE METRICS
# All handle the (imgs, labels, metadata) 3-tuple WILDS batch format.
# ──────────────────────────────────────────────
@torch.no_grad()
def accuracy_wilds(model, loader, device, max_samples=None):
model.eval()
correct = total = 0
for batch in loader:
imgs = batch[0].to(device)
labels = batch[1].squeeze().long().to(device)
preds = model(imgs).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
if max_samples and total >= max_samples:
break
return correct / max(total, 1)
@torch.no_grad()
def weight_norm_fn(model):
return sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
@torch.no_grad()
def feature_rank_wilds(model, loader, device, n=300):
model.eval()
feats = []
def hook_fn(module, input, output):
avg_pool = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1))
feats.append(avg_pool.view(avg_pool.size(0), -1).cpu())
hook = model.layer4[-1].register_forward_hook(hook_fn)
count = 0
for batch in loader:
model(batch[0].to(device))
count += batch[0].size(0)
if count >= n:
break
hook.remove()
if not feats:
return float("nan")
F_mat = torch.cat(feats)[:n]
try:
_, s, _ = torch.svd(F_mat)
s = s / (s.sum() + 1e-10)
return torch.exp(-(s * torch.log(s + 1e-10)).sum()).item()
except Exception:
return float("nan")
@torch.no_grad()
def shortcut_ratio_wilds(model, loader, device, n_samples=200):
"""
Stain shortcut proxy: compare model confidence on center crop
(tissue β€” causal features) vs. border region (stain β€” spurious).
sc > 1.0 = relying on border stain more than tissue (shortcut)
sc < 1.0 = relying on tissue center more than stain (causal)
The transition from > 1.0 to < 1.0 during training is the
attribution-level signature of the grokking transition.
"""
model.eval()
cc, bc = [], []
count = 0
for batch in loader:
if count >= n_samples:
break
imgs = batch[0].to(device)
B, C, H, W = imgs.shape
hs, he = H // 4, 3 * H // 4
ws, we = W // 4, 3 * W // 4
center = F.interpolate(
imgs[:, :, hs:he, ws:we], size=(H, W),
mode="bilinear", align_corners=False
)
border = imgs.clone()
border[:, :, hs:he, ws:we] = 0.0
cc.append(F.softmax(model(center), 1).max(1).values.mean().item())
bc.append(F.softmax(model(border), 1).max(1).values.mean().item())
count += imgs.size(0)
cconf = float(np.mean(cc)) if cc else 0.5
bconf = float(np.mean(bc)) if bc else 0.5
return cconf, bconf
def irm_penalty_wilds(model, envs, device):
"""
IRMv1 penalty across TRAINING hospital environments (H0-H2).
Diagnostic version: uses create_graph=False, returns floats. Used as a
monitoring metric only (logged per epoch).
"""
model.eval()
penalties = []
for env in envs:
w = torch.tensor(1.0, requires_grad=True, device=device)
logits = model(env["x"]) * w
loss = F.cross_entropy(logits, env["y"])
grad = torch.autograd.grad(loss, w, create_graph=False)[0]
penalties.append(grad.item() ** 2)
t = torch.tensor(penalties)
return t.mean().item(), t.var().item()
def irm_penalty_train_time(logits_list, y_list):
"""
IRMv1 penalty for use INSIDE the training loss (differentiable).
Splits a batch by environment, computes per-env loss with a virtual
scale variable, takes the squared gradient of each per-env loss w.r.t.
that scale, returns the mean across envs.
Args:
logits_list: list of (per-env) logits tensors
y_list: list of (per-env) label tensors
Returns:
scalar tensor (differentiable), the IRM penalty contribution.
"""
penalty = 0.0
n = 0
for logits, y in zip(logits_list, y_list):
if logits.shape[0] == 0:
continue
scale = torch.tensor(1.0, requires_grad=True, device=logits.device)
loss = F.cross_entropy(logits * scale, y)
grad = torch.autograd.grad(loss, scale, create_graph=True)[0]
penalty = penalty + grad ** 2
n += 1
if n == 0:
return torch.tensor(0.0, device=logits_list[0].device)
return penalty / n
def eval_irm_penalty_wilds(model, id_val_loader, ood_test_loader, device):
"""
IRM penalty evaluated on HELD-OUT environments (H3 and H4).
This avoids the measurement artifact of training on H0-H2 where loss→0.
HIGH penalty = model relies on hospital-discriminating features = shortcuts.
LOW penalty = model ignores hospital labels = causal features.
"""
model.eval()
penalties = []
# Create environment views from eval data
for loader, hospital_label in [
(id_val_loader, "H3"),
(ood_test_loader, "H4"),
]:
xs, ys = [], []
count = 0
with torch.no_grad():
for batch in loader:
imgs = batch[0].to(device)
labels = batch[1].squeeze().long().to(device)
xs.append(model(imgs))
ys.append(labels)
count += imgs.size(0)
if count >= 500:
break
if xs:
x = torch.cat(xs)
y = torch.cat(ys)
w = torch.tensor(1.0, requires_grad=True, device=device)
logits = x * w
loss = F.cross_entropy(logits, y)
try:
grad = torch.autograd.grad(loss, w, create_graph=False)[0]
penalties.append(grad.item() ** 2)
except:
penalties.append(float("nan"))
if penalties and not any(np.isnan(p) for p in penalties):
return float(np.mean(penalties)), float(np.var(penalties))
else:
return float("nan"), float("nan")
# ──────────────────────────────────────────────
# DATA
# ──────────────────────────────────────────────
class TransformWrapper:
def __init__(self, dataset, transform):
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img, label, metadata = self.dataset[idx]
return self.transform(img), label, metadata
def get_dataloaders(cfg, data_root):
transform = transforms.Compose([
transforms.Resize((cfg["img_size"], cfg["img_size"])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets(
root_dir=data_root, download=True)
# Subsample training set
torch.manual_seed(cfg["seed"])
indices = torch.randperm(len(train_ds))[:cfg["n_train"]]
train_subset = Subset(train_ds, indices)
# Wrap with TransformWrapper to apply transforms
train_subset = TransformWrapper(train_subset, transform)
id_val_ds = TransformWrapper(id_val_ds, transform)
ood_test_ds = TransformWrapper(ood_test_ds, transform)
train_loader = DataLoader(train_subset, batch_size=cfg["batch_size"],
shuffle=True, num_workers=0, pin_memory=True)
id_val_loader = DataLoader(id_val_ds, batch_size=256,
shuffle=False, num_workers=0, pin_memory=True)
ood_test_loader = DataLoader(ood_test_ds, batch_size=256,
shuffle=False, num_workers=0, pin_memory=True)
return train_loader, id_val_loader, ood_test_loader, train_subset
def get_hospital_environments(train_subset, device):
"""
Build IRM environments from ground-truth hospital labels.
Returns list of {x, y} dicts β€” one per unique hospital in the subset.
Hospitals in Camelyon17 train split: 0, 1, 2.
"""
loader = DataLoader(train_subset, batch_size=512,
shuffle=False, num_workers=4)
all_imgs, all_labels, all_meta = [], [], []
for imgs, labels, meta in loader:
all_imgs.append(imgs)
all_labels.append(labels.squeeze().long())
all_meta.append(meta)
all_imgs = torch.cat(all_imgs)
all_labels = torch.cat(all_labels)
hospitals = torch.cat(all_meta)[:, 0].long() # field 0 = hospital ID
envs = []
for h in torch.unique(hospitals):
mask = hospitals == h
n = mask.sum().item()
envs.append({
"x": all_imgs[mask].to(device),
"y": all_labels[mask].to(device),
"hospital": int(h),
})
pos_rate = all_labels[mask].float().mean().item()
print(f" Env hospital={int(h)}: {n} samples, "
f"positive rate={pos_rate:.2f}")
return envs
# ──────────────────────────────────────────────
# MODEL
# ──────────────────────────────────────────────
def build_model(cfg):
model = timm.create_model("resnet18", pretrained=False,
num_classes=cfg["n_classes"])
if cfg["init_scale"] != 1.0:
with torch.no_grad():
for name, p in model.named_parameters():
if "weight" in name and p.dim() > 1:
p.data *= cfg["init_scale"]
return model.to(cfg["device"])
# ──────────────────────────────────────────────
# TRAIN
# ──────────────────────────────────────────────
def train(cfg, model, train_loader, id_val_loader, ood_test_loader,
envs, optimizer, run_dir):
criterion = nn.CrossEntropyLoss()
grads_ema = None
history = []
best_id_val = 0.0
best_ood = 0.0
peak_ood_epoch = None # Epoch where best_ood was achieved
grok_epoch = None
irm_base = None
history_path = os.path.join(run_dir, "results", "history.json")
grad_clip = cfg.get("grad_clip", 1.0)
# Grokking detection parameters.
# We watch OOD accuracy (H4), not ID val accuracy (H3).
# ID val converges fast (expected). OOD is what should grok.
plateau_window = 10
plateau_eps = 0.01
# Ungrokking early stopping parameters.
# If OOD peaks then declines, stop at the peak rather than training to convergence.
ood_patience = cfg.get("ood_patience", 20) # checkpoints to wait before stopping
ood_min_delta = cfg.get("ood_min_delta", 0.01) # minimum improvement threshold
use_ood_early_stop = cfg.get("use_ood_early_stop", False)
print(f"\n{'='*60}")
print(f" {cfg['condition'].upper()} | Camelyon17 v2 | {cfg['n_epochs']} epochs")
print(f" WD={cfg['weight_decay']} | Ξ±={cfg['init_scale']} | n={cfg['n_train']}")
print(f" Tracking: ID val (H3) + OOD test (H4) at every checkpoint")
print(f" Grokking detection: watching OOD acc, not ID val acc")
print(f" IRM envs: {len(envs)} hospitals")
print(f"{'='*60}", flush=True)
irm_weight = float(cfg.get("irm_weight", 0.0))
use_irm_in_loss = irm_weight > 0.0
if use_irm_in_loss:
print(f" IRM-in-loss: ENABLED, alpha={irm_weight}", flush=True)
else:
print(f" IRM-in-loss: disabled (CE-only training; IRM penalty is diagnostic)", flush=True)
for epoch in range(1, cfg["n_epochs"] + 1):
# ── Train step ────────────────────────────────────────────────
model.train()
loss_sum = n_b = 0
for imgs, labels, metadata in train_loader:
imgs = imgs.to(cfg["device"])
labels = labels.squeeze().long().to(cfg["device"])
optimizer.zero_grad()
logits = model(imgs)
ce_loss = criterion(logits, labels)
if use_irm_in_loss:
# Split this batch by training hospital (H0/H1/H2) and
# compute IRMv1 penalty as a differentiable scalar.
hosp_ids = metadata[:, 0].long().to(cfg["device"])
logits_per_env, y_per_env = [], []
for h in [0, 1, 2]:
mask = (hosp_ids == h)
if mask.sum() < 2:
continue
logits_per_env.append(logits[mask])
y_per_env.append(labels[mask])
if len(logits_per_env) >= 2:
irm_term = irm_penalty_train_time(logits_per_env, y_per_env)
loss = ce_loss + irm_weight * irm_term
else:
loss = ce_loss
else:
loss = ce_loss
loss.backward()
if cfg.get("use_grokfast"):
grads_ema = gradfilter_ema(
model, grads_ema,
alpha=cfg.get("grokfast_alpha", 0.98),
lamb=cfg.get("grokfast_lamb", 2.0))
if grad_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=grad_clip)
optimizer.step()
loss_sum += loss.item()
n_b += 1
# ── Checkpoint metrics ────────────────────────────────────────
if epoch % cfg["log_every"] == 0 or epoch == 1:
tr_acc = accuracy_wilds(model, train_loader, cfg["device"])
id_acc = accuracy_wilds(model, id_val_loader, cfg["device"])
ood_acc = accuracy_wilds(model, ood_test_loader, cfg["device"]) # KEY
wn = weight_norm_fn(model)
fr = feature_rank_wilds(model, id_val_loader, cfg["device"])
irm_m, irm_v = irm_penalty_wilds(model, envs, cfg["device"])
cconf, bconf = shortcut_ratio_wilds(
model, id_val_loader, cfg["device"])
if irm_base is None:
irm_base = irm_m
# ── OOD grokking detection ────────────────────────────────
# Require sustained plateau in OOD acc before the jump.
# The ID val acc plateau is expected and not grokking.
if grok_epoch is None and len(history) >= plateau_window:
last = history[-plateau_window:]
ref = last[-1]["ood_acc"]
flat = sum(1 for r in last
if abs(r["ood_acc"] - ref) < plateau_eps)
if flat >= plateau_window - 2 and ood_acc > best_ood + 0.05:
grok_epoch = epoch
irm_drop = (irm_base - irm_m) / (irm_base + 1e-8) * 100
print(f"\n *** OOD GROKKING at epoch {epoch} ***")
print(f" OOD: {best_ood:.3f} β†’ {ood_acc:.3f} | "
f"IRM drop: {irm_drop:.1f}%", flush=True)
if id_acc > best_id_val: best_id_val = id_acc
if ood_acc > best_ood:
best_ood = ood_acc
peak_ood_epoch = epoch # Track when peak was achieved
sc_ratio = min(bconf / (cconf + 1e-8), 10.0)
# OOD gap: how much worse is OOD vs ID?
# This should shrink at the grokking transition.
ood_gap = id_acc - ood_acc
row = dict(
epoch = epoch,
train_loss = loss_sum / n_b,
train_acc = tr_acc,
id_val_acc = id_acc,
ood_acc = ood_acc, # ← primary grokking signal
ood_gap = ood_gap, # ← should narrow at transition
weight_norm = wn,
feature_rank = fr,
irm_mean = irm_m,
irm_var = irm_v,
center_conf = cconf,
border_conf = bconf,
shortcut_ratio = sc_ratio,
grokking_detected = grok_epoch is not None,
)
history.append(row)
if wandb:
wandb.log(row)
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
# Save periodic checkpoint for M1 analysis (every 200 epochs)
if epoch % 200 == 0:
ckpt_dir = os.path.join(run_dir, "checkpoints")
os.makedirs(ckpt_dir, exist_ok=True)
ckpt_path = os.path.join(ckpt_dir, f"ep{epoch:05d}.pt")
torch.save(model.state_dict(), ckpt_path)
print(f" βœ“ Checkpoint β†’ ep{epoch:05d}.pt", flush=True)
# ── OOD-aware early stopping (if ungrokking detected) ───────
# If OOD peaks then declines, stop at the peak rather than full epochs.
if use_ood_early_stop and peak_ood_epoch is not None and len(history) >= ood_patience:
recent_ood = [r["ood_acc"] for r in history[-ood_patience:]]
ood_trend = max(recent_ood) - min(recent_ood)
if ood_acc < best_ood - ood_min_delta:
print(f"\n *** EARLY STOP (OOD declining) at epoch {epoch} ***", flush=True)
print(f" Peak OOD: {best_ood:.4f} at epoch {peak_ood_epoch}", flush=True)
print(f" Current: {ood_acc:.4f} ({ood_acc-best_ood:+.4f})", flush=True)
# Save peak checkpoint separately for clinical deployment
if peak_ood_epoch and peak_ood_epoch % 200 == 0:
peak_src = os.path.join(run_dir, "checkpoints", f"ep{peak_ood_epoch:05d}.pt")
peak_dst = os.path.join(run_dir, "checkpoints", "peak_ood.pt")
if os.path.exists(peak_src):
import shutil
shutil.copy(peak_src, peak_dst)
print(f" Saved peak β†’ checkpoints/peak_ood.pt", flush=True)
break # Exit training loop
print(f" ep {epoch:5d} | "
f"tr {tr_acc:.3f} | "
f"id {id_acc:.3f} | "
f"ood {ood_acc:.3f} | "
f"gap {ood_gap:+.3f} | " # + means OOD worse than ID
f"β€–Wβ€– {wn:.1f} | "
f"rank {fr:.1f} | "
f"IRM {irm_m:.4f} | "
f"sc {sc_ratio:.2f}x",
flush=True)
# ── Final summary ─────────────────────────────────────────────────
# One final OOD eval at the very end
final_ood = accuracy_wilds(model, ood_test_loader, cfg["device"])
if wandb:
wandb.log({"final_ood_acc": final_ood,
"grokking_epoch": grok_epoch or -1})
# Decision numbers
irm_drop_pct = float("nan")
irm_drop_ep = epoch_gap = -1
if history:
irm0 = history[0]["irm_mean"]
irm_min = min(r["irm_mean"] for r in history)
if irm0:
irm_drop_pct = (irm0 - irm_min) / (irm0 + 1e-8) * 100
if len(history) > 1:
biggest = 0.0
for prev, cur in zip(history[:-1], history[1:]):
d = abs(cur["irm_mean"] - prev["irm_mean"])
if d > biggest:
biggest = d
irm_drop_ep = cur["epoch"]
if grok_epoch and irm_drop_ep > 0:
epoch_gap = abs(grok_epoch - irm_drop_ep)
# OOD grokking: did OOD acc improve significantly after ID convergence?
# Measure: max OOD acc in last 20% of training vs. OOD acc when ID
# first plateaued (epoch ~200-300 for standard training).
ood_early = np.mean([r["ood_acc"] for r in history[:5]]) if history else 0
ood_late = np.mean([r["ood_acc"] for r in history[-5:]]) if history else 0
ood_improvement = ood_late - ood_early
# Ungrokking detection: did OOD collapse after peaking?
ood_delta = final_ood - best_ood # Negative = ungrokking
summary = dict(
run_id = cfg["run_id"],
condition = cfg["condition"],
n_train = cfg["n_train"],
seed = cfg["seed"],
best_id_val = best_id_val,
best_ood = best_ood,
peak_ood_epoch = peak_ood_epoch or -1, # When peak was achieved
final_ood = final_ood,
ood_delta = ood_delta, # final - best (ungrokking signal)
ood_improvement = ood_improvement, # ← key: did OOD grok?
grokking_epoch = grok_epoch or -1,
irm_drop_pct = irm_drop_pct,
irm_drop_epoch = irm_drop_ep,
epoch_gap = epoch_gap,
final_weight_norm = history[-1]["weight_norm"] if history else None,
final_feature_rank= history[-1]["feature_rank"] if history else None,
final_irm = history[-1]["irm_mean"] if history else None,
final_shortcut_ratio = history[-1]["shortcut_ratio"] if history else None,
final_ood_gap = history[-1]["ood_gap"] if history else None,
)
with open(os.path.join(run_dir, "results", "summary.json"), "w") as f:
json.dump(summary, f, indent=2)
torch.save(model.state_dict(),
os.path.join(run_dir, "checkpoints", "final.pt"))
print(f"\n Best ID val (H3): {best_id_val:.4f}")
print(f" Best OOD (H4): {best_ood:.4f}")
print(f" OOD improvement: {ood_improvement:+.4f} ← did OOD grok?")
print(f" Grokking at: {grok_epoch}")
print(f" IRM drop: {irm_drop_pct:.1f}%",
flush=True)
return history
# ──────────────────────────────────────────────
# MAIN
# ──────────────────────────────────────────────
def main():
p = argparse.ArgumentParser()
p.add_argument("--condition", default="grokking",
choices=["standard", "grokking"])
p.add_argument("--n_train", type=int, default=300)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--log_every", type=int, default=50)
p.add_argument("--wandb_project", default="causalgrok")
p.add_argument("--wandb_mode", default="offline",
choices=["online", "offline", "disabled"])
p.add_argument("--run_dir", default=None)
p.add_argument("--data_root", default="data/wilds")
p.add_argument("--weight_decay", type=float, default=None)
p.add_argument("--init_scale", type=float, default=None)
p.add_argument("--n_epochs", type=int, default=None)
p.add_argument("--lr", type=float, default=None)
p.add_argument("--grokfast", choices=["on", "off"], default=None)
p.add_argument("--grad_clip", type=float, default=1.0)
p.add_argument("--irm_weight", type=float, default=0.0,
help="IRMv1 penalty weight added to training loss "
"(0 = pure cross-entropy / diagnostic-only IRM).")
args = p.parse_args()
cfg = get_config(args.condition)
cfg.update(n_train=args.n_train, seed=args.seed,
log_every=args.log_every, grad_clip=args.grad_clip)
if args.weight_decay is not None: cfg["weight_decay"] = args.weight_decay
if args.init_scale is not None: cfg["init_scale"] = args.init_scale
if args.n_epochs is not None: cfg["n_epochs"] = args.n_epochs
if args.lr is not None: cfg["lr"] = args.lr
if args.grokfast is not None: cfg["use_grokfast"] = (args.grokfast == "on")
cfg["irm_weight"] = args.irm_weight
if cfg["device"] == "cuda":
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(cfg["seed"])
np.random.seed(cfg["seed"])
if args.run_dir is None:
run_dir, run_id = make_run_dir(
["camelyon_v2", cfg["condition"],
f"n{cfg['n_train']}", f"s{cfg['seed']}"])
else:
run_dir = args.run_dir
ensure_run_dir(run_dir)
run_id = os.path.basename(os.path.normpath(run_dir))
cfg["run_id"] = run_id
cfg["run_dir"] = run_dir
save_config(cfg, run_dir)
if wandb:
wandb.init(project=args.wandb_project, config=cfg, name=run_id,
mode=args.wandb_mode, dir=run_dir)
print(f"\nDevice: {cfg['device']}")
print(f"Run ID: {run_id}")
print(f"Started: {datetime.now(timezone.utc).isoformat()}", flush=True)
train_loader, id_val_loader, ood_test_loader, train_subset = \
get_dataloaders(cfg, args.data_root)
envs = get_hospital_environments(train_subset, cfg["device"])
model = build_model(cfg)
print(f"Train: {len(train_subset)} | "
f"ID val (H3): {len(id_val_loader.dataset)} | "
f"OOD test (H4): {len(ood_test_loader.dataset)}")
print(f"Params: {sum(p.numel() for p in model.parameters()):,}",
flush=True)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=cfg["lr"], weight_decay=cfg["weight_decay"])
t0 = time.time()
train(cfg, model, train_loader, id_val_loader, ood_test_loader,
envs, optimizer, run_dir)
print(f"\nWall time: {(time.time()-t0)/60:.1f} min", flush=True)
if wandb:
wandb.finish()
if __name__ == "__main__":
main()