""" CausalGrok — Ablation Grid Runner Launches every condition × size × seed as its own subprocess. Each run gets its own experiments/runs// directory with isolated logs, results, checkpoints, and figures. Use the launchers (they fork this under nohup): bash scripts/run_quick_ablations.sh bash scripts/run_full_grid.sh You can also call directly: python -m experiments.run_ablations --quick python -m experiments.run_ablations --parallel """ from __future__ import annotations import argparse import itertools import os import subprocess import sys from datetime import datetime, timezone from utils.run_dir import DEFAULT_BASE, ensure_run_dir # (weight_decay, init_scale, use_grokfast, label) — labels feed into wandb tags # For Camelyon17: baseline WD is 5e-3 (empirically optimal) GRID = [ (0.0, 1.0, False, "wd0_a1"), # control — no regularization (1e-4, 1.0, False, "wd1e4_a1"), # mild WD only (5e-3, 1.0, False, "wd5e3_a1"), # baseline WD, standard init (5e-3, 4.0, False, "wd5e3_a4"), # baseline WD + large init (5e-3, 4.0, True, "wd5e3_a4_gf"), # full recipe (main condition) (1e-2, 4.0, True, "wd1e2_a4_gf"), # higher WD variant ] SIZES = [100, 250, 500, 1000, 2000] SEEDS = [42, 123, 456] def cmd_for(wd, alpha, gf, n_train, seed, run_dir, n_epochs=None): """ Build the per-cell training command. The trainer exposes --weight_decay, --init_scale, --grokfast {on,off}, --n_epochs as overrides, so each grid cell actually runs with its own knobs. Compute saver: cells that cannot grok (low/no weight decay AND Grokfast disabled) get only 300 epochs instead of 3000. They flatline either way; we just want the baseline non-grokking accuracy to compare against. This trims roughly 30% off the full grid wall time without losing any signal. """ if n_epochs is None: can_grok = (wd >= 1e-3) or gf n_epochs = 3000 if can_grok else 300 return [ sys.executable, "-m", "experiments.causalgrok_camelyon_v2", "--condition", "grokking", "--n_train", str(n_train), "--seed", str(seed), "--weight_decay", str(wd), "--init_scale", str(alpha), "--n_epochs", str(n_epochs), "--grokfast", "on" if gf else "off", "--wandb_project", "causalgrok", "--run_dir", run_dir, ] def build_run_dir(label, n_train, seed): stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") run_id = f"{stamp}_{label}_n{n_train}_s{seed}" run_dir = os.path.join(DEFAULT_BASE, run_id) ensure_run_dir(run_dir) return run_dir def run_quick(): """Minimal sanity probe: control vs. full recipe at n=500, seed 42.""" for wd, alpha, gf, label in [GRID[0], GRID[4]]: run_dir = build_run_dir(label, 500, 42) log = os.path.join(run_dir, "logs", "train.log") err = os.path.join(run_dir, "logs", "train.err") print(f"\n>> {label} n=500 seed=42 → {run_dir}") with open(log, "w") as out, open(err, "w") as ferr: subprocess.run(cmd_for(wd, alpha, gf, 500, 42, run_dir), stdout=out, stderr=ferr, check=True) def run_parallel(grid, sizes, seeds, n_gpus=None): if n_gpus is None: try: import torch n_gpus = max(1, torch.cuda.device_count()) except Exception: n_gpus = 1 procs = [] for idx, ((wd, alpha, gf, label), n, seed) in enumerate( itertools.product(grid, sizes, seeds)): env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = str(idx % n_gpus) run_dir = build_run_dir(label, n, seed) log = os.path.join(run_dir, "logs", "train.log") err = os.path.join(run_dir, "logs", "train.err") out_f = open(log, "w") err_f = open(err, "w") p = subprocess.Popen(cmd_for(wd, alpha, gf, n, seed, run_dir), env=env, stdout=out_f, stderr=err_f) print(f" GPU {idx % n_gpus}: {label} n={n} seed={seed} PID={p.pid} → {run_dir}") procs.append((p, out_f, err_f, run_dir)) print(f"\nLaunched {len(procs)} jobs. Waiting...", flush=True) for p, out_f, err_f, run_dir in procs: rc = p.wait() out_f.close(); err_f.close() status = "OK" if rc == 0 else f"FAILED rc={rc}" print(f" {status} {run_dir}", flush=True) print("All done.") if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument("--quick", action="store_true") p.add_argument("--parallel", action="store_true") p.add_argument("--n_gpus", type=int, default=None, help="Override torch.cuda.device_count()") args = p.parse_args() if args.quick: run_quick() elif args.parallel: run_parallel(GRID, SIZES, SEEDS, n_gpus=args.n_gpus) else: run_parallel(GRID[:1], [500], [42], n_gpus=args.n_gpus)