"""M4 — Representation Ablation: causal intervention on the shortcut subspace. Pipeline: 1. Pick a checkpoint (peak-OOD epoch by default). 2. Extract features at avgpool (or `--layer`) for train (H0-H2) + OOD (H4) splits. 3. Fit a hospital-classification logistic-regression probe on train features. The probe's weight rows define the *shortcut subspace* in feature space. 4. Build the projector P = W^T (W W^T)^-1 W onto that subspace and define `ablate(h) = h - P h`. 5. Re-classify OOD images with the *same* trained classifier head, fed: (a) raw features h — baseline OOD accuracy (b) ablated features h - Ph — post-intervention OOD accuracy 6. Also report: (c) shortcut accuracy (probe.score on h vs h-Ph) (d) tumor probe accuracy on h vs h-Ph (sanity: the causal feature should survive the intervention) (e) head's tumor classification accuracy on H4 with raw vs ablated features If the intervention is causal: - shortcut probe accuracy: collapses - OOD accuracy: improves (or at least doesn't decay as much) - tumor probe accuracy: largely preserved Usage ----- python -m experiments.mechinterp_m4_ablation \\ --run_dir experiments/runs/ \\ --data_root data/wilds \\ --layer avgpool \\ [--epoch 50] # default: peak_ood_epoch from summary.json [--max_samples 1000] Output: /mechinterp/m4_ablation__ep.json /mechinterp/m4_ablation__ep.png """ from __future__ import annotations import argparse import json import os import sys from pathlib import Path from typing import Dict, Tuple import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import StandardScaler from torch.utils.data import DataLoader, Subset from torchvision import transforms ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT)) # Re-use M1 helpers — hooks, model loader, feature extraction, ckpt discovery. from experiments.mechinterp_m1 import ( register_hooks, extract_features, load_model_from_checkpoint, find_checkpoints, ) from utils.camelyon_data import get_camelyon_subsets class _TransformWrapper: def __init__(self, dataset, transform): self.dataset = dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): img, label, metadata = self.dataset[idx] return self.transform(img), label, metadata def _build_loaders(data_root: str, max_samples: int, seed: int = 42): transform = transforms.Compose([ transforms.Resize((96, 96)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( root_dir=data_root, download=False ) train_t = _TransformWrapper(train_ds, transform) ood_t = _TransformWrapper(ood_test_ds, transform) torch.manual_seed(seed) train_idx = torch.randperm(len(train_t))[:max_samples] ood_idx = torch.randperm(len(ood_t))[:max_samples // 2] train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128, shuffle=False, num_workers=0) ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128, shuffle=False, num_workers=0) return train_loader, ood_loader def _select_epoch(run_dir: Path, requested: int | None) -> Tuple[int, Path]: ckpts = find_checkpoints(str(run_dir)) if not ckpts: raise FileNotFoundError(f"No checkpoints in {run_dir}/checkpoints/") if requested is not None: for ep, p in ckpts: if ep == requested: return ep, Path(p) raise ValueError(f"Requested epoch {requested} not in checkpoints " f"({[ep for ep, _ in ckpts]})") # default: peak OOD epoch from summary.json summary_path = run_dir / "results" / "summary.json" peak = None if summary_path.exists(): s = json.loads(summary_path.read_text()) peak = s.get("peak_ood_epoch", None) if peak is not None and peak > 0: # nearest periodic checkpoint nearest = min(ckpts, key=lambda x: abs(x[0] - peak)) return nearest[0], Path(nearest[1]) # fall back to last checkpoint return ckpts[-1][0], Path(ckpts[-1][1]) def _build_projector(W: np.ndarray) -> np.ndarray: """W has shape (k, d). Returns P (d, d) projecting onto rowspace(W).""" # Use SVD for a stable orthonormal basis of rowspace U, s, Vt = np.linalg.svd(W, full_matrices=False) # rowspace basis = Vt rows where singular values > tol tol = max(W.shape) * np.finfo(s.dtype).eps * (s.max() if s.size else 0.0) keep = s > tol basis = Vt[keep] # (k', d) return basis.T @ basis # (d, d) projector onto rowspace def _build_shortcut_subspace( X: np.ndarray, hospital_ids: np.ndarray, method: str = "lda", subspace_dim: int = 32 ) -> np.ndarray: """Return a (k, d) basis whose row-span is the 'shortcut subspace'. method='probe' — k = (n_classes - 1) probe weight rows (small subspace). method='lda' — k = subspace_dim top between-class directions: take per-hospital means in feature space, center them, and run SVD. This gives a rank-bounded but data-driven subspace that captures hospital-discriminating variance. method='pca-class' — top-PCs of features colored by hospital (mean-removed per class), giving us the variance directions that mostly reflect within-hospital structure × class. """ if method == "probe": clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1) clf.fit(X, hospital_ids) return clf.coef_ if method == "lda": classes = np.unique(hospital_ids) global_mean = X.mean(axis=0, keepdims=True) between = [] for c in classes: mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True) between.append(mu_c - global_mean) between = np.vstack(between) # (n_classes, d) # Augment with random hospital-correlated directions to grow rank up # to subspace_dim — use top PCs of *centered-by-hospital-mean* features. if subspace_dim > between.shape[0]: # within-hospital residuals residuals = [] for c in classes: mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True) residuals.append(X[hospital_ids == c] - mu_c) R = np.vstack(residuals) # PCA on residuals — these are within-hospital directions; remove # them from the shortcut subspace by KEEPING only the between-class # directions. So we just return between as-is, plus the top PCs of # the *original* features projected onto the orthogonal complement # of `between` IF the user wants more dims. U, s, Vt = np.linalg.svd(X - global_mean, full_matrices=False) top = Vt[:subspace_dim] # Score each PC by how much it correlates with hospital-id variance # (one-hot expansion); keep top by that correlation. one_hot = np.eye(len(classes))[ np.searchsorted(classes, hospital_ids) ] # (N, n_classes) proj = (X - global_mean) @ top.T # (N, subspace_dim) corrs = np.array([ np.max(np.abs([np.corrcoef(proj[:, k], one_hot[:, c])[0, 1] for c in range(len(classes))])) for k in range(subspace_dim) ]) # take the top-k most-hospital-correlated PCs order = np.argsort(-np.nan_to_num(corrs)) top_hosp = top[order[:subspace_dim]] # combine: between-class means + top-hospital-correlated PCs return np.vstack([between, top_hosp]) return between raise ValueError(f"Unknown method: {method}") def _classifier_logits_from_features( model: nn.Module, features: np.ndarray, layer: str, device: str ) -> np.ndarray: """Apply the *post-`layer`* part of the network to the (modified) features and return the model's binary-classification logits. For ResNet, `avgpool` features have shape (N, C). The classifier head `model.fc` (timm: `model.get_classifier()`) maps C → 2. For non-avgpool layers we do not currently support full propagation — caller should use layer='avgpool' for OOD-accuracy interventions.""" if layer != "avgpool": raise NotImplementedError( "Re-applying the classifier head from intermediate spatial layers " "is not yet supported. Use --layer avgpool for the head-level " "ablation." ) # Find the classifier head (timm convention: model.fc or model.get_classifier()) if hasattr(model, "get_classifier"): head = model.get_classifier() elif hasattr(model, "fc"): head = model.fc elif hasattr(model, "classifier"): head = model.classifier else: raise RuntimeError("Could not locate classifier head on the model.") head = head.to(device).eval() with torch.no_grad(): x = torch.tensor(features, dtype=torch.float32, device=device) logits = head(x).cpu().numpy() return logits def _accuracy(logits: np.ndarray, labels: np.ndarray) -> float: if logits.ndim == 1 or logits.shape[1] == 1: pred = (logits.flatten() > 0).astype(int) else: pred = logits.argmax(axis=1) return float((pred == labels).mean()) def run_ablation( run_dir: Path, data_root: str, layer: str = "avgpool", epoch: int | None = None, max_samples: int = 1000, device: str = "cuda", subspace_method: str = "lda", subspace_dim: int = 32, ) -> Dict: epoch, ckpt_path = _select_epoch(run_dir, epoch) print(f"\n M4 — Representation Ablation") print(f" run_dir : {run_dir.name}") print(f" epoch : {epoch} ({ckpt_path.name})") print(f" layer : {layer}") # Load model and dataloaders model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device) model.eval() register_hooks(model) cfg_path = run_dir / "config.json" seed = 42 if cfg_path.exists(): seed = json.loads(cfg_path.read_text()).get("seed", 42) train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed) # Extract features print(f" Extracting features ({max_samples} samples per split)...") feats_train, hosp_train, tumor_train = extract_features( model, train_loader, device, max_samples=max_samples ) feats_ood, hosp_ood, tumor_ood = extract_features( model, ood_loader, device, max_samples=max_samples // 2 ) if layer not in feats_train: raise KeyError(f"Layer '{layer}' not in extracted features " f"({list(feats_train.keys())})") X_tr = np.asarray(feats_train[layer]) # (N_tr, D) X_ood = np.asarray(feats_ood[layer]) # (N_ood, D) if X_tr.ndim > 2: # spatial map; flatten X_tr = X_tr.reshape(X_tr.shape[0], -1) X_ood = X_ood.reshape(X_ood.shape[0], -1) # Normalize features (probe is sensitive to scale; classifier head was # trained on un-normalized features so we keep two parallel pipelines). scaler = StandardScaler().fit(X_tr) X_tr_n = scaler.transform(X_tr) X_ood_n = scaler.transform(X_ood) # ──────────── 1. Fit hospital probe + build shortcut subspace print(f" Fitting hospital probe on H0/H1/H2 train features...") hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1) hosp_clf.fit(X_tr_n, hosp_train) hosp_acc_train = hosp_clf.score(X_tr_n, hosp_train) # Build a richer shortcut subspace via LDA-style between-class + # hospital-correlated top PCs. This catches more shortcut variance than # the (n_classes - 1)-D probe-rowspace alone. W = _build_shortcut_subspace(X_tr_n, np.asarray(hosp_train), method=subspace_method, subspace_dim=subspace_dim) P = _build_projector(W) # (D, D) rank_subspace = int(np.linalg.matrix_rank(P, tol=1e-8)) print(f" Shortcut subspace: dim={rank_subspace} method={subspace_method} " f"(probe train acc {hosp_acc_train:.3f})") # ──────────── 2. Build ablated versions of features # Apply the projection in the *normalized* feature space, then un-scale # for re-feeding to the classifier head (which was trained on raw features). def ablate_norm(X_n): return X_n - X_n @ P.T X_ood_ablated_n = ablate_norm(X_ood_n) # un-scale X_ood_ablated = scaler.inverse_transform(X_ood_ablated_n) # Sanity probe metrics print(f" Re-fitting tumor probe on train features...") tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1) tumor_clf.fit(X_tr_n, tumor_train) tumor_acc_train = tumor_clf.score(X_tr_n, tumor_train) # Probe accuracies on raw vs ablated OOD features hosp_acc_ood_raw = hosp_clf.score(X_ood_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan") hosp_acc_ood_ablated = hosp_clf.score(X_ood_ablated_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan") tumor_acc_ood_raw = tumor_clf.score(X_ood_n, tumor_ood) tumor_acc_ood_ablated = tumor_clf.score(X_ood_ablated_n, tumor_ood) # ──────────── 3. Head-level OOD classification accuracy print(f" Re-classifying OOD with model head (raw vs ablated features)...") logits_raw = _classifier_logits_from_features(model, X_ood, layer, device) logits_ablated = _classifier_logits_from_features(model, X_ood_ablated, layer, device) head_acc_raw = _accuracy(logits_raw, tumor_ood) head_acc_ablated = _accuracy(logits_ablated, tumor_ood) # ──────────── 4. Pack + report result = { "run_id": run_dir.name, "epoch": epoch, "layer": layer, "max_samples": max_samples, "shortcut_subspace_dim": rank_subspace, "hospital_probe_train_acc": hosp_acc_train, "tumor_probe_train_acc": tumor_acc_train, "hospital_probe_ood_raw": hosp_acc_ood_raw, "hospital_probe_ood_ablated": hosp_acc_ood_ablated, "tumor_probe_ood_raw": tumor_acc_ood_raw, "tumor_probe_ood_ablated": tumor_acc_ood_ablated, "head_ood_acc_raw": head_acc_raw, "head_ood_acc_ablated": head_acc_ablated, "intervention_effect": { "shortcut_collapse": hosp_acc_ood_raw - hosp_acc_ood_ablated, "ood_improvement": head_acc_ablated - head_acc_raw, "tumor_preservation": tumor_acc_ood_ablated - tumor_acc_ood_raw, }, } print(f"\n RESULTS") print(f" hospital probe (OOD): {hosp_acc_ood_raw:.3f} → {hosp_acc_ood_ablated:.3f} " f"(Δ {result['intervention_effect']['shortcut_collapse']:+.3f})") print(f" tumor probe (OOD) : {tumor_acc_ood_raw:.3f} → {tumor_acc_ood_ablated:.3f} " f"(Δ {result['intervention_effect']['tumor_preservation']:+.3f})") print(f" head OOD acc : {head_acc_raw:.3f} → {head_acc_ablated:.3f} " f"(Δ {result['intervention_effect']['ood_improvement']:+.3f})") return result def plot_ablation(result: Dict, out_path: Path): metrics = ["hospital_probe_ood", "tumor_probe_ood", "head_ood_acc"] raw_keys = ["hospital_probe_ood_raw", "tumor_probe_ood_raw", "head_ood_acc_raw"] ablated_keys = ["hospital_probe_ood_ablated", "tumor_probe_ood_ablated", "head_ood_acc_ablated"] labels = ["Hospital probe\n(↓ = causal effect)", "Tumor probe\n(stable = good)", "Head OOD acc\n(↑ = causal effect)"] raws = [result[k] for k in raw_keys] ablateds = [result[k] for k in ablated_keys] fig, ax = plt.subplots(figsize=(9, 5)) x = np.arange(len(metrics)) w = 0.35 b1 = ax.bar(x - w / 2, raws, w, label="raw features", color="#444") b2 = ax.bar(x + w / 2, ablateds, w, label="shortcut-ablated", color="#c33") for bars in (b1, b2): for b in bars: ax.text(b.get_x() + b.get_width() / 2, b.get_height() + 0.005, f"{b.get_height():.3f}", ha="center", va="bottom", fontsize=9) ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=9) ax.set_ylim(0, 1.05); ax.set_ylabel("Accuracy") ax.set_title(f"M4 — Causal Ablation of Shortcut Subspace\n" f"{result['run_id']} • ep{result['epoch']} • layer={result['layer']} " f"• subspace dim={result['shortcut_subspace_dim']}", fontsize=10, fontweight="bold") ax.legend(loc="upper right") ax.grid(alpha=0.3, axis="y") plt.tight_layout() fig.savefig(out_path, dpi=180, bbox_inches="tight") plt.close(fig) def main(): p = argparse.ArgumentParser() p.add_argument("--run_dir", required=True) p.add_argument("--data_root", default="data/wilds") p.add_argument("--layer", default="avgpool", choices=["avgpool"]) # head-level intervention only at avgpool p.add_argument("--epoch", type=int, default=None, help="Specific checkpoint epoch; default = peak_ood_epoch from summary.json") p.add_argument("--max_samples", type=int, default=1000) p.add_argument("--device", default="cuda") p.add_argument("--subspace_method", default="lda", choices=["lda", "probe"], help="lda = LDA-style between-class + hospital-correlated PCs; " "probe = LR probe row-space (small, often only 2-D)") p.add_argument("--subspace_dim", type=int, default=32, help="Target subspace dim for lda method") p.add_argument("--all_epochs", action="store_true", help="Sweep across all periodic checkpoints") args = p.parse_args() run_dir = Path(args.run_dir) out_dir = run_dir / "mechinterp" out_dir.mkdir(parents=True, exist_ok=True) if args.all_epochs: # Sweep across every periodic checkpoint, build a trajectory. ckpts = find_checkpoints(str(run_dir)) # de-duplicate (final.pt may share epoch with last ep*.pt) seen = set(); uniq = [] for ep, p in ckpts: if ep in seen: continue seen.add(ep); uniq.append((ep, p)) traj = [] for ep, _ in uniq: try: r = run_ablation( run_dir=run_dir, data_root=args.data_root, layer=args.layer, epoch=ep, max_samples=args.max_samples, device=args.device, subspace_method=args.subspace_method, subspace_dim=args.subspace_dim, ) traj.append(r) except Exception as e: print(f" [skip ep{ep}] {e}") out = out_dir / f"m4_ablation_{args.layer}_trajectory.json" out.write_text(json.dumps(traj, indent=2)) plot_trajectory(traj, out.with_suffix(".png")) print(f"\n → {out}") print(f" → {out.with_suffix('.png')}") return result = run_ablation( run_dir=run_dir, data_root=args.data_root, layer=args.layer, epoch=args.epoch, max_samples=args.max_samples, device=args.device, subspace_method=args.subspace_method, subspace_dim=args.subspace_dim, ) base = out_dir / f"m4_ablation_{args.layer}_ep{result['epoch']:05d}" (base.with_suffix(".json")).write_text(json.dumps(result, indent=2)) plot_ablation(result, base.with_suffix(".png")) print(f"\n → {base.with_suffix('.json')}") print(f" → {base.with_suffix('.png')}") def plot_trajectory(traj, out_path: Path): """Plot the intervention effect across training epochs.""" eps = [r["epoch"] for r in traj] head_raw = [r["head_ood_acc_raw"] for r in traj] head_abl = [r["head_ood_acc_ablated"] for r in traj] tum_raw = [r["tumor_probe_ood_raw"] for r in traj] tum_abl = [r["tumor_probe_ood_ablated"] for r in traj] fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Panel A: head OOD acc raw vs ablated ax = axes[0] ax.plot(eps, head_raw, "k-o", lw=2, label="raw features") ax.plot(eps, head_abl, "r-s", lw=2, label="shortcut-ablated features") ax.fill_between(eps, head_raw, head_abl, where=[a > b for a, b in zip(head_abl, head_raw)], color="seagreen", alpha=0.3, label="ablation helps") ax.fill_between(eps, head_raw, head_abl, where=[a < b for a, b in zip(head_abl, head_raw)], color="salmon", alpha=0.3, label="ablation hurts") ax.set_xlabel("Training epoch"); ax.set_ylabel("OOD (H4) head accuracy") ax.set_title("Head OOD accuracy: raw vs shortcut-ablated", fontweight="bold") ax.legend(fontsize=9); ax.grid(alpha=0.3) # Panel B: tumor probe survival ax = axes[1] ax.plot(eps, tum_raw, "k-o", lw=2, label="raw features") ax.plot(eps, tum_abl, "g-s", lw=2, label="shortcut-ablated features") ax.set_xlabel("Training epoch"); ax.set_ylabel("Tumor probe OOD accuracy") ax.set_title("Tumor probe survival under ablation\n(stable line = causal feature preserved)", fontweight="bold") ax.legend(fontsize=9); ax.grid(alpha=0.3); ax.set_ylim(0.4, 1.0) rid = traj[0]["run_id"] if traj else "?" layer = traj[0]["layer"] if traj else "?" fig.suptitle(f"M4 — Causal Ablation Trajectory: {rid} • layer={layer}", fontsize=11, fontweight="bold") plt.tight_layout() fig.savefig(out_path, dpi=180, bbox_inches="tight") plt.close(fig) if __name__ == "__main__": main()