--- license: cc-by-4.0 tags: - grokking - grokfast - mechanistic-interpretability - out-of-distribution - medical-imaging - camelyon17 - resnet18 - activation-steering library_name: pytorch pipeline_tag: image-classification --- # CausalGrok — Grokking-Favorable Training Artifact Archive Complete artifact bundle for the project **"Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training"**. This Hugging Face repository is the **full preservation archive**: every model checkpoint, per-run training artifact, mechanistic-interpretability output, log, figure, and the paper source. It mirrors the heavy artifacts that do not fit in the source repository. - **Code (browsable)**: https://github.com/nileshsarkar-ai/CausalGrok - **This archive (heavy artifacts)**: 240 checkpoints (~10 GB), run JSONs, logs, figures --- ## What this work studies A small-sample study on Camelyon17 (WILDS hospital-shift OOD pathology classification). We train ResNet-18 under two regimes and ask whether **grokking-favorable optimization** (high weight decay, expanded init scale, Grokfast gradient-EMA filtering) changes the internal shortcut representation, even when out-of-distribution (OOD) accuracy does not improve. - **Empirical result**: across 14 runs, every model *ungrokks* — OOD accuracy peaks early and decays during training; no delayed-generalization (grokking) transition occurs. - **Interventional result**: activation steering along the dominant between-hospital direction `v_s` gives a monotonic OOD response in 4/5 grokking-favorable seeds vs 1/3 standard seeds (Mann-Whitney p=0.071, non-significant at this sample size). Read as *shortcut concentration, not elimination* under heavy regularization. --- ## The grokking-favorable recipe (code) Two training configurations, both pure cross-entropy + AdamW, 3000 epochs. They differ along three axes (a deliberate, disclosed confound). ```python # code/experiments/causalgrok_camelyon_v2.py — get_config() def get_config(condition): base = dict(seed=42, n_train=300, batch_size=32, img_size=96, n_classes=2, log_every=50, device="cuda") if condition == "standard": base.update(condition="standard", lr=1e-3, weight_decay=1e-4, n_epochs=3000, init_scale=1.0, use_grokfast=False) elif condition == "grokking": base.update(condition="grokking", lr=1e-3, weight_decay=5e-3, n_epochs=3000, init_scale=4.0, use_grokfast=True, grokfast_alpha=0.98, grokfast_lamb=2.0) return base ``` **Init-scale rescaling** (grokking-favorable only — every multi-dim weight tensor scaled 4× at init): ```python if cfg["init_scale"] != 1.0: for name, p in model.named_parameters(): if "weight" in name and p.dim() > 1: p.data *= cfg["init_scale"] ``` **Grokfast EMA** — amplifies the slow-varying gradient component (Lee et al. 2024, arXiv:2405.20233). Applied after `loss.backward()`, before `optimizer.step()`: ```python # code/utils/grokfast.py — gradfilter_ema() for name, p in model.named_parameters(): if p.requires_grad and p.grad is not None: if name not in grads_ema: grads_ema[name] = p.grad.data.detach().clone() else: grads_ema[name] = grads_ema[name] * alpha + p.grad.data * (1 - alpha) # alpha=0.98 p.grad.data = p.grad.data + grads_ema[name] * lamb # lamb=2.0 ``` **Training loss** is cross-entropy only. An IRM-style invariance penalty (Arjovsky et al. 2019) is computed every epoch across training-hospital environments as a **diagnostic only** — it is logged, never added to the loss: ```python criterion = nn.CrossEntropyLoss() logits = model(imgs) loss = criterion(logits, labels) # pure CE; irm_weight = 0.0 for every reported run loss.backward() if cfg["use_grokfast"]: grads_ema = gradfilter_ema(model, grads_ema, alpha=0.98, lamb=2.0) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() ``` Checkpoints are saved every 200 epochs (15 periodic + 1 `final.pt` per run). | Hyperparameter | Standard | Grokking-favorable | | --- | --- | --- | | Optimizer | AdamW | AdamW | | Learning rate | 1e-3 | 1e-3 | | Weight decay | 1e-4 | 5e-3 (50×) | | Init scale | 1.0 | 4.0 | | Grokfast EMA | off | on (alpha 0.98, lamb 2.0) | | Grad clip (max-norm) | 1.0 | 1.0 | | Batch size | 32 | 32 | | Epochs | 3000 | 3000 | | IRM weight in loss | 0.0 (diagnostic) | 0.0 (diagnostic) | Model: ResNet-18 (`timm`, no ImageNet pretraining), 96×96 input, 2-class head, 11,177,538 parameters. --- ## Mechanistic-interpretability suite (code) Four probes on `avgpool` features (`D=512`); run on the saved checkpoints. - **M1 — layer-wise linear probing** (`code/experiments/mechinterp_m1.py`): logistic-regression hospital and tumor probes at six ResNet stages. - **M4 — subspace ablation** (`code/experiments/mechinterp_m4_ablation.py`): project features orthogonal to a ~35-dim LDA-style hospital subspace, re-classify with the original head. - **M5 — activation steering** (`code/experiments/mechinterp_m5_steering.py`): steer `h' = h + alpha · sigma · v_s` along the dominant between-hospital direction, `alpha ∈ [-3, +3]`. - **M6 — targeted neuron ablation** (`code/experiments/mechinterp_m6_neuron_ablation.py`): zero top-K hospital-discriminating channels vs random-K and morphology-K controls, `K ∈ {0,4,8,16,32,64,128,256}`. --- ## Repository layout ``` runs// ├── config.json # full hyperparameter config ├── checkpoints/ep00200.pt … ep03000.pt, final.pt # ~44 MB each ├── results/history.json # per-checkpoint metrics (61 rows) ├── results/summary.json # final-summary fields ├── logs/train.log # launch command + per-checkpoint log lines ├── wandb/ # offline wandb run metadata └── mechinterp/ # m1/m4/m5/m6 JSON + PNG outputs figures/ # 7 paper figures (PNG + PDF) + m6_summary.csv (88-row results table) paper/ # main.tex, example_paper.bib, compiled PDF, style files code/ # training + mechanistic-interpretability source logs/ # top-level training and M1/M4/M5/M6 driver logs docs/ # TRAINING_DETAILS.md — exhaustive code/hyperparameter/metric/results reference ``` `.pt` checkpoints total ~10 GB across 240 files. --- ## Run inventory (n=1000) | Cond. | Seed | Run ID | Peak OOD | Peak ep | Final OOD | | --- | --- | --- | --- | --- | --- | | Grok | 7 | 20260508-183413_grokking_n1000_s7 | 0.6876 | 50 | 0.5882 | | Grok | 42 | 20260505-080445_grokking_n1000_s42 | 0.7336 | 350 | 0.6639 | | Grok | 123 | 20260505-100720_grokking_n1000_s123 | 0.7270 | 350 | 0.6447 | | Grok | 456 | 20260505-100720_grokking_n1000_s456 | 0.6722 | 1100 | 0.5224 | | Grok | 2024 | 20260508-183413_grokking_n1000_s2024 | 0.7056 | 400 | 0.5506 | | Std | 42 | 20260505-100720_standard_n1000_s42 | 0.7615 | 1 | 0.6482 | | Std | 123 | 20260508-183413_standard_n1000_s123 | 0.8880* | 1 | 0.6645 | | Std | 456 | 20260508-183413_standard_n1000_s456 | 0.7450 | 1050 | 0.5783 | \* Std s123 peaks at epoch 1 on the random initialization (artifact). Additional runs at n=300 and n=500; full 14-run table and all metrics in `docs/TRAINING_DETAILS.md`. --- ## Loading a checkpoint ```python import torch, timm model = timm.create_model("resnet18", pretrained=False, num_classes=2) sd = torch.load("runs/20260505-080445_grokking_n1000_s42/checkpoints/ep00400.pt", map_location="cpu") model.load_state_dict(sd) model.eval() ``` Checkpoints are plain `state_dict` files; the init-scale rescaling is baked into the trained weights. ```python from huggingface_hub import hf_hub_download, snapshot_download # one file p = hf_hub_download("nileshsarkar-ai/CausalGrok", "runs/20260505-080445_grokking_n1000_s42/checkpoints/ep00400.pt") # whole archive snapshot_download("nileshsarkar-ai/CausalGrok", local_dir="CausalGrok") ``` --- ## Dataset Experiments use **Camelyon17** from the [WILDS benchmark](https://wilds.stanford.edu/). The raw dataset (~10 GB) is **not** mirrored here (public benchmark); `code/utils/camelyon_data.py::get_camelyon_subsets` auto-downloads it via the `wilds` package. --- ## Citation ```bibtex @misc{causalgrok2026, title = {Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training}, author = {Sarkar, Nilesh}, year = {2026}, url = {https://github.com/nileshsarkar-ai/CausalGrok} } ``` ## License CC BY 4.0. The Camelyon17 dataset retains its own WILDS license; this archive does not redistribute it.