"""M6 — Neuron-level Ablation (the textbook reviewer-asked intervention). Pipeline: 1. At a chosen checkpoint (default: peak_ood_epoch), extract avgpool features for train (H0-H2) and OOD (H4) splits. 2. Score each of the 512 avgpool channels by *how predictive its activation is of hospital ID*: we use a one-vs-rest logistic-regression coefficient per channel × class as the per-neuron shortcut score: score_c = max_h |β_{h,c}| (β = coefficients of LR fit per channel) ↑ score_c → channel c is more strongly stain-shortcut-aligned. 3. Sweep top-K ∈ {0, 8, 16, 32, 64, 128} ablated neurons (zero out their activations) and measure: - head OOD acc (raw vs ablated) - hospital-probe acc on raw vs ablated features - tumor-probe acc on raw vs ablated features 4. Strong mechanistic claim: - hospital-probe acc collapses sharply with K (these neurons are carrying hospital info) - head OOD acc *improves* (or at least preserves) at small K (the model was using shortcut neurons to harm OOD) - tumor-probe acc stays flat (causal info is distributed elsewhere) Usage ----- python -m experiments.mechinterp_m6_neuron_ablation \\ --run_dir experiments/runs/ \\ --data_root data/wilds \\ [--epoch 50] [--max_samples 1000] \\ [--ks "0,4,8,16,32,64,128,256"] """ from __future__ import annotations import argparse import json from pathlib import Path from typing import Dict, List import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import torch from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import StandardScaler from torch.utils.data import DataLoader, Subset from torchvision import transforms from experiments.mechinterp_m1 import ( register_hooks, extract_features, load_model_from_checkpoint, ) from experiments.mechinterp_m4_ablation import ( _select_epoch, _TransformWrapper, _classifier_logits_from_features, _accuracy, ) from utils.camelyon_data import get_camelyon_subsets def _build_loaders_with_id(data_root: str, max_samples: int, seed: int = 42): """Like M4's _build_loaders but also returns an ID validation loader so we can track ID acc and compute the OOD/ID degradation ratio.""" transform = transforms.Compose([ transforms.Resize((96, 96)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( root_dir=data_root, download=False ) train_t = _TransformWrapper(train_ds, transform) id_t = _TransformWrapper(id_val_ds, transform) ood_t = _TransformWrapper(ood_test_ds, transform) torch.manual_seed(seed) train_idx = torch.randperm(len(train_t))[:max_samples] id_idx = torch.randperm(len(id_t))[:max_samples // 2] ood_idx = torch.randperm(len(ood_t))[:max_samples // 2] train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128, shuffle=False, num_workers=0) id_loader = DataLoader(Subset(id_t, id_idx), batch_size=128, shuffle=False, num_workers=0) ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128, shuffle=False, num_workers=0) return train_loader, id_loader, ood_loader def _per_neuron_shortcut_scores(X_n: np.ndarray, hosp: np.ndarray) -> np.ndarray: """Return a (D,) array — score per channel c, larger = more hospital-predictive. Uses a 1-feature-at-a-time log-reg fit's |coef| would be dominated by feature scale; instead we fit a single multiclass LR over all features and use the L2 norm of (β_{:,c}) — the column norm of the LR coefficient matrix — as channel c's hospital-discrimination score. """ clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1).fit(X_n, hosp) W = clf.coef_ # (n_classes, D) # column norms — large means many class-discriminations rely on this neuron return np.linalg.norm(W, axis=0) # (D,) def _ablate_and_eval( X_n, mask, scaler, head_target, model, layer, device, hosp_clf, tumor_clf, hosp_target, tumor_target, ): """Apply mask to normalized features, unscale, evaluate everything.""" X_ablated_n = X_n * mask[None, :] X_ablated = scaler.inverse_transform(X_ablated_n) logits = _classifier_logits_from_features(model, X_ablated, layer, device) head_acc = _accuracy(logits, head_target) hosp_acc = hosp_clf.score(X_ablated_n, hosp_target) if hosp_clf is not None and len(np.unique(hosp_target)) > 1 else float("nan") tumor_acc = tumor_clf.score(X_ablated_n, tumor_target) return head_acc, hosp_acc, tumor_acc def run_neuron_ablation( run_dir: Path, data_root: str, epoch: int | None = None, max_samples: int = 1000, device: str = "cuda", ks: List[int] = None, n_random_samples: int = 5, include_morphology: bool = True, include_id: bool = True, ) -> Dict: if ks is None: # Dose-response curve emphasizing small K (per reviewer guidance) ks = [0, 4, 8, 16, 32, 64, 128, 256] epoch, ckpt_path = _select_epoch(run_dir, epoch) print(f"\n M6 — Neuron Ablation (with random + morphology controls)") print(f" run_dir : {run_dir.name}") print(f" epoch : {epoch} ({ckpt_path.name})") print(f" ks : {ks}") print(f" random ablation: {n_random_samples} samplings per K") model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device) model.eval() register_hooks(model) cfg_path = run_dir / "config.json" seed = 42 if cfg_path.exists(): seed = json.loads(cfg_path.read_text()).get("seed", 42) if include_id: train_loader, id_loader, ood_loader = _build_loaders_with_id(data_root, max_samples, seed=seed) else: from experiments.mechinterp_m4_ablation import _build_loaders as _bl train_loader, ood_loader = _bl(data_root, max_samples, seed=seed) id_loader = None print(f" Extracting features (train + id + ood)...") feats_train, hosp_train, tumor_train = extract_features( model, train_loader, device, max_samples=max_samples ) feats_ood, hosp_ood, tumor_ood = extract_features( model, ood_loader, device, max_samples=max_samples // 2 ) feats_id, hosp_id, tumor_id = (None, None, None) if id_loader is not None: feats_id, hosp_id, tumor_id = extract_features( model, id_loader, device, max_samples=max_samples // 2 ) layer = "avgpool" def _to_2d(arr): a = np.asarray(arr); return a.reshape(a.shape[0], -1) X_tr = _to_2d(feats_train[layer]) X_ood = _to_2d(feats_ood[layer]) X_id = _to_2d(feats_id[layer]) if feats_id is not None else None hosp_train = np.asarray(hosp_train); hosp_ood = np.asarray(hosp_ood) tumor_train = np.asarray(tumor_train); tumor_ood = np.asarray(tumor_ood) if X_id is not None: hosp_id = np.asarray(hosp_id); tumor_id = np.asarray(tumor_id) scaler = StandardScaler().fit(X_tr) X_tr_n = scaler.transform(X_tr) X_ood_n = scaler.transform(X_ood) X_id_n = scaler.transform(X_id) if X_id is not None else None # 1. Per-neuron scores: shortcut (hospital) and morphology (tumor) print(f" Scoring {X_tr.shape[1]} avgpool channels...") shortcut_scores = _per_neuron_shortcut_scores(X_tr_n, hosp_train) morphology_scores = _per_neuron_shortcut_scores(X_tr_n, tumor_train) if include_morphology else None rank_shortcut = np.argsort(-shortcut_scores) rank_morphology = np.argsort(-morphology_scores) if morphology_scores is not None else None hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train) tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train) rng = np.random.default_rng(seed) D = X_tr.shape[1] sweep = [] for k in ks: row = {"k": int(k)} # Mask helpers def make_mask(indices): m = np.ones(D) if k > 0: m[indices[:k]] = 0.0 return m # ── A: top-K SHORTCUT neurons (the targeted ablation) ── mask_s = make_mask(rank_shortcut) h_ood, hp_ood, tp_ood = _ablate_and_eval( X_ood_n, mask_s, scaler, tumor_ood, model, layer, device, hosp_clf, tumor_clf, hosp_ood, tumor_ood, ) row["shortcut_head_ood"] = float(h_ood) row["shortcut_hosp_probe"] = float(hp_ood) row["shortcut_tumor_probe"] = float(tp_ood) if X_id_n is not None: h_id, _, _ = _ablate_and_eval( X_id_n, mask_s, scaler, tumor_id, model, layer, device, None, tumor_clf, hosp_id, tumor_id, ) row["shortcut_head_id"] = float(h_id) # ── B: top-K MORPHOLOGY neurons (control: ablate the causal neurons) ── if include_morphology and rank_morphology is not None: mask_m = make_mask(rank_morphology) h_ood_m, _, _ = _ablate_and_eval( X_ood_n, mask_m, scaler, tumor_ood, model, layer, device, None, tumor_clf, hosp_ood, tumor_ood, ) row["morphology_head_ood"] = float(h_ood_m) if X_id_n is not None: h_id_m, _, _ = _ablate_and_eval( X_id_n, mask_m, scaler, tumor_id, model, layer, device, None, tumor_clf, hosp_id, tumor_id, ) row["morphology_head_id"] = float(h_id_m) # ── C: K RANDOM neurons (control: damage uniformly) ── if k > 0: r_oods, r_ids = [], [] for s_ in range(n_random_samples): idx = rng.permutation(D)[:k] m = np.ones(D); m[idx] = 0.0 h_ood_r, _, _ = _ablate_and_eval( X_ood_n, m, scaler, tumor_ood, model, layer, device, None, tumor_clf, hosp_ood, tumor_ood, ) r_oods.append(h_ood_r) if X_id_n is not None: h_id_r, _, _ = _ablate_and_eval( X_id_n, m, scaler, tumor_id, model, layer, device, None, tumor_clf, hosp_id, tumor_id, ) r_ids.append(h_id_r) row["random_head_ood_mean"] = float(np.mean(r_oods)) row["random_head_ood_std"] = float(np.std(r_oods)) if r_ids: row["random_head_id_mean"] = float(np.mean(r_ids)) row["random_head_id_std"] = float(np.std(r_ids)) else: row["random_head_ood_mean"] = row["shortcut_head_ood"] # K=0 same as baseline row["random_head_ood_std"] = 0.0 if X_id_n is not None: row["random_head_id_mean"] = row.get("shortcut_head_id", float("nan")) row["random_head_id_std"] = 0.0 sweep.append(row) # Concise log line print(f" K={k:>4} shortcut={row['shortcut_head_ood']:.3f} " f"random={row.get('random_head_ood_mean', float('nan')):.3f}±" f"{row.get('random_head_ood_std', 0):.3f} " + (f"morphology={row.get('morphology_head_ood', float('nan')):.3f}" if include_morphology else "")) return { "run_id": run_dir.name, "epoch": epoch, "layer": layer, "max_samples": max_samples, "feature_dim": int(X_tr.shape[1]), "shortcut_scores_top10": [int(i) for i in rank_shortcut[:10]], "morphology_scores_top10": ([int(i) for i in rank_morphology[:10]] if rank_morphology is not None else []), "n_random_samples": n_random_samples, "include_id": include_id, "include_morphology": include_morphology, "sweep": sweep, } def plot_neuron_ablation(result: Dict, out_path: Path): sweep = result["sweep"] ks = [r["k"] for r in sweep] has_id = result.get("include_id", False) has_morph = result.get("include_morphology", False) fig, axes = plt.subplots(1, 2 if has_id else 1, figsize=(13, 5)) if has_id else \ plt.subplots(1, 1, figsize=(8, 5)) if not has_id: axes = [axes] # Panel A — Head OOD: shortcut vs random (vs morphology) ax = axes[0] shortcut_ood = [r.get("shortcut_head_ood") for r in sweep] random_ood_mu = [r.get("random_head_ood_mean") for r in sweep] random_ood_sd = [r.get("random_head_ood_std", 0) for r in sweep] morphology_ood = [r.get("morphology_head_ood") for r in sweep] if has_morph else None ax.plot(ks, shortcut_ood, "r-o", lw=2.2, ms=7, label="top-K shortcut neurons (targeted)") ax.plot(ks, random_ood_mu, "k-s", lw=1.8, ms=6, label="K random neurons (control)") ax.fill_between(ks, [m - s for m, s in zip(random_ood_mu, random_ood_sd)], [m + s for m, s in zip(random_ood_mu, random_ood_sd)], color="black", alpha=0.15) if has_morph and morphology_ood is not None: ax.plot(ks, morphology_ood, "g-^", lw=1.8, ms=6, label="top-K morphology neurons (control)") base = shortcut_ood[0] ax.axhline(base, color="gray", ls=":", lw=1, alpha=0.5, label=f"K=0 baseline ({base:.3f})") ax.set_xlabel("K (neurons zeroed at avgpool)") ax.set_ylabel("Head OOD (H4) accuracy") ax.set_xscale("symlog", linthresh=4) ax.set_title("Targeted vs random ablation — OOD effect\n" "(separation = shortcut neurons selectively hurt OOD)", fontweight="bold", fontsize=10) ax.legend(loc="best", fontsize=8); ax.grid(alpha=0.3) # Panel B — ID/OOD tradeoff if has_id: ax = axes[1] shortcut_id = [r.get("shortcut_head_id") for r in sweep] random_id_mu = [r.get("random_head_id_mean") for r in sweep] random_id_sd = [r.get("random_head_id_std", 0) for r in sweep] ax.plot(ks, shortcut_id, "r--o", lw=2, ms=7, alpha=0.85, label="ID (shortcut ablation)") ax.plot(ks, shortcut_ood, "r-o", lw=2, ms=7, label="OOD (shortcut ablation)") ax.plot(ks, random_id_mu, "k--s", lw=1.6, ms=5, alpha=0.7, label="ID (random ablation)") ax.plot(ks, random_ood_mu, "k-s", lw=1.6, ms=5, alpha=0.7, label="OOD (random ablation)") ax.set_xlabel("K (neurons zeroed at avgpool)") ax.set_ylabel("Head accuracy") ax.set_xscale("symlog", linthresh=4) ax.set_title("ID vs OOD degradation tradeoff\n" "(targeted: OOD steady or ↑ while ID slowly ↓ = good)", fontweight="bold", fontsize=10) ax.legend(fontsize=8, loc="best"); ax.grid(alpha=0.3) fig.suptitle(f"M6 — Targeted Neuron Ablation vs Random Control: {result['run_id']} " f"• ep{result['epoch']}", fontsize=11, fontweight="bold") plt.tight_layout() fig.savefig(out_path, dpi=180, bbox_inches="tight") plt.close(fig) def main(): p = argparse.ArgumentParser() p.add_argument("--run_dir", required=True) p.add_argument("--data_root", default="data/wilds") p.add_argument("--epoch", type=int, default=None) p.add_argument("--max_samples", type=int, default=1000) p.add_argument("--device", default="cuda") p.add_argument("--ks", default=None, help="Comma-separated K values, e.g. '0,4,8,16,32,64,128,256'") p.add_argument("--n_random_samples", type=int, default=5, help="Random ablation: averages over this many random K-subsets") p.add_argument("--no_morphology", action="store_true", help="Skip the morphology-targeted ablation control") p.add_argument("--no_id", action="store_true", help="Skip ID accuracy evaluation (faster but loses ID/OOD ratio)") args = p.parse_args() ks = None if args.ks is not None: ks = [int(x) for x in args.ks.split(",")] run_dir = Path(args.run_dir) out_dir = run_dir / "mechinterp" out_dir.mkdir(parents=True, exist_ok=True) result = run_neuron_ablation( run_dir=run_dir, data_root=args.data_root, epoch=args.epoch, max_samples=args.max_samples, device=args.device, ks=ks, n_random_samples=args.n_random_samples, include_morphology=not args.no_morphology, include_id=not args.no_id, ) base = out_dir / f"m6_neuron_ablation_ep{result['epoch']:05d}" base.with_suffix(".json").write_text(json.dumps(result, indent=2)) plot_neuron_ablation(result, base.with_suffix(".png")) print(f"\n → {base.with_suffix('.json')}") print(f" → {base.with_suffix('.png')}") if __name__ == "__main__": main()