""" utils.metrics — generic, reusable metric functions. Shared by every experiment in experiments/. Anything specific to a particular dataset or training recipe stays out of this module. """ from __future__ import annotations import numpy as np import torch import torch.nn.functional as F @torch.no_grad() def accuracy(model, loader, device): model.eval() correct = total = 0 for imgs, labels in loader: imgs = imgs.to(device) labels = labels.squeeze().long().to(device) preds = model(imgs).argmax(1) correct += (preds == labels).sum().item() total += labels.size(0) return correct / total @torch.no_grad() def weight_norm(model): return sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 @torch.no_grad() def feature_rank(model, loader, device, n=200, hook_module_attr="avgpool"): """ Effective rank of penultimate-layer features = exp(entropy of normalised singular values). Fan et al. 2024 — the most reliable progress measure for grokking. The rank collapses at the transition. `hook_module_attr` is the attribute name on the model whose forward output we treat as the penultimate representation. Defaults to ResNet's avgpool. """ model.eval() feats = [] target = getattr(model, hook_module_attr) hook = target.register_forward_hook( lambda m, i, o: feats.append(o.view(o.size(0), -1).cpu()) ) count = 0 for imgs, _ in loader: model(imgs.to(device)) count += imgs.size(0) if count >= n: break hook.remove() F_mat = torch.cat(feats)[:n] try: _, s, _ = torch.svd(F_mat) s = s / (s.sum() + 1e-10) return torch.exp(-(s * torch.log(s + 1e-10)).sum()).item() except Exception: return float("nan") def irm_penalty(model, envs, device): """ IRMv1 penalty (Arjovsky et al. 2019). For each environment, the squared gradient of the loss w.r.t. a dummy scalar w=1 in the logits. LOW value = invariant predictor = causal features found. HIGH value = environment-specific (spurious) features still in use. `envs` is a list of dicts {"x": tensor, "y": tensor} already on `device`. Returns (mean, var) over environments. """ model.eval() penalties = [] for env in envs: w = torch.tensor(1.0, requires_grad=True, device=device) logits = model(env["x"]) * w loss = F.cross_entropy(logits, env["y"]) grad = torch.autograd.grad(loss, w, create_graph=False)[0] penalties.append(grad.item() ** 2) t = torch.tensor(penalties) return t.mean().item(), t.var().item() @torch.no_grad() def shortcut_ratio(model, loader, device): """ Border-confidence / center-confidence proxy for artifact reliance. Ratio > 1 means the model trusts the borders (where scanner markers, laterality letters, and other artefacts live) more than the center (where actual anatomy is). On CheXpert, replace with GradCAM pointed at known artifact locations vs. anatomical regions. Returns (center_conf_mean, border_conf_mean). """ model.eval() cc, bc = [], [] for imgs, _ in loader: imgs = imgs.to(device) B, C, H, W = imgs.shape hs, he = H // 4, 3 * H // 4 ws, we = W // 4, 3 * W // 4 center = F.interpolate(imgs[:, :, hs:he, ws:we], size=(H, W), mode='bilinear', align_corners=False) border = imgs.clone(); border[:, :, hs:he, ws:we] = 0. cc.append(F.softmax(model(center), 1).max(1).values.mean().item()) bc.append(F.softmax(model(border), 1).max(1).values.mean().item()) return float(np.mean(cc)), float(np.mean(bc))