| """Train the horner_rnn transition cell (bit-level Horner step) + chain fine-tuning. |
| |
| Stage 1: train cell f(t, bit, b, p) = (2t + bit*b) mod p (quotients {0,1,2}, |
| easier than base-4's {0..6}) with grad clipping, EMA, hard-boundary mining. |
| |
| Stage 2 (optional, default off): fine-tune end-to-end through the 16-step |
| chain with a straight-through estimator on the quantized state, loss on every |
| step's ground-truth intermediate. In practice this was destructive at lr2=5e-5 |
| (chain val collapsed); the shipped weights come from stage 1 alone, which |
| reaches chain val ~0.998 on held-out primes. Kept for further experimentation |
| at lower learning rates. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import time |
|
|
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
|
|
| |
| HERE = Path(__file__).resolve().parent |
| sys.path.insert(0, str(HERE)) |
| from model import HornerCell, BITS, _to_bits as to_bits |
|
|
|
|
| def sieve_primes(limit: int) -> list[int]: |
| is_p = bytearray([1]) * limit |
| is_p[0] = is_p[1] = 0 |
| for i in range(2, int(limit ** 0.5) + 1): |
| if is_p[i]: |
| is_p[i * i :: i] = bytearray(len(is_p[i * i :: i])) |
| return [i for i in range(2, limit) if is_p[i]] |
|
|
|
|
| def sample_batch(primes_t, n, device, hard_frac=0.5): |
| p = primes_t[torch.randint(len(primes_t), (n,), device=device)] |
| b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1) |
| bit = torch.randint(0, 2, (n,), device=device) |
| n_hard = int(n * hard_frac) |
| t = torch.empty(n, dtype=torch.long, device=device) |
| t[n_hard:] = (torch.rand(n - n_hard, device=device) * p[n_hard:]).long() |
| if n_hard: |
| ph, bh, bith = p[:n_hard], b[:n_hard], bit[:n_hard] |
| q = torch.randint(0, 3, (n_hard,), device=device) |
| delta = torch.randint(-2, 3, (n_hard,), device=device) |
| th = (q * ph + delta - bith * bh) >> 1 |
| t[:n_hard] = th.clamp(min=0) % ph |
| z = (2 * t + bit * b) % p |
| return t, bit, b, p, z |
|
|
|
|
| @torch.no_grad() |
| def exact_rate(model, primes_t, device, n=200_000, bs=65536) -> float: |
| ok = 0 |
| for i in range(0, n, bs): |
| m = min(bs, n - i) |
| t, bit, b, p, z = sample_batch(primes_t, m, device, hard_frac=0.0) |
| logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p)) |
| ok += ((logits > 0).long() == to_bits(z).long()).all(dim=1).sum().item() |
| return ok / n |
|
|
|
|
| @torch.no_grad() |
| def chain_exact_rate(model, primes_t, device, n=20_000) -> float: |
| p = primes_t[torch.randint(len(primes_t), (n,), device=device)] |
| a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1) |
| b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1) |
| truth = (a * b) % p |
| bb, pb = to_bits(b), to_bits(p) |
| tb = torch.zeros(n, BITS, device=device) |
| for i in range(BITS - 1, -1, -1): |
| bit = ((a >> i) & 1).float().unsqueeze(1) |
| tb = (model(tb, bit, bb, pb) > 0).float() |
| pred = (tb.long() * (1 << torch.arange(BITS, device=device))).sum(dim=1) |
| return (pred == truth).float().mean().item() |
|
|
|
|
| def chain_finetune_batch(model, primes_t, n, device, loss_fn): |
| """One end-to-end pass: STE state, per-step CE against true intermediates.""" |
| p = primes_t[torch.randint(len(primes_t), (n,), device=device)] |
| a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1) |
| b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1) |
| bb, pb = to_bits(b), to_bits(p) |
| tb = torch.zeros(n, BITS, device=device) |
| t_true = torch.zeros_like(a) |
| loss = torch.zeros((), device=device) |
| for i in range(BITS - 1, -1, -1): |
| bit_i = (a >> i) & 1 |
| t_true = (2 * t_true + bit_i * b) % p |
| logits = model(tb, bit_i.float().unsqueeze(1), bb, pb) |
| loss = loss + loss_fn(logits, to_bits(t_true)) |
| hard = (logits > 0).float() |
| soft = torch.sigmoid(logits) |
| tb = hard + (soft - soft.detach()) |
| return loss / BITS |
|
|
|
|
| def main() -> int: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--stage1-minutes", type=float, default=50.0) |
| ap.add_argument("--stage2-minutes", type=float, default=0.0) |
| ap.add_argument("--batch", type=int, default=32768) |
| ap.add_argument("--chain-batch", type=int, default=4096) |
| ap.add_argument("--lr", type=float, default=3e-4) |
| ap.add_argument("--lr2", type=float, default=5e-5) |
| ap.add_argument("--width", type=int, default=4096) |
| ap.add_argument("--depth", type=int, default=4) |
| ap.add_argument("--init", type=str, default="") |
| ap.add_argument("--out", type=str, default=str(HERE / "weights16.pt")) |
| args = ap.parse_args() |
|
|
| device = torch.device("cuda") |
| torch.manual_seed(0) |
|
|
| small = sieve_primes(256) |
| primes = [p for p in sieve_primes(1 << 16) if p >= 256] |
| g = torch.Generator().manual_seed(1) |
| perm = torch.randperm(len(primes), generator=g).tolist() |
| val_primes = torch.tensor([primes[i] for i in perm[: len(primes) // 10]], device=device) |
| train_primes = torch.tensor( |
| small + [primes[i] for i in perm[len(primes) // 10 :]], device=device |
| ) |
| print(f"train primes {len(train_primes)}, val primes {len(val_primes)}") |
|
|
| model = HornerCell(args.width, args.depth).to(device) |
| if args.init: |
| ckpt = torch.load(args.init, map_location=device, weights_only=True) |
| model.load_state_dict(ckpt["state_dict"]) |
| print(f"initialised from {args.init}") |
| ema = HornerCell(args.width, args.depth).to(device) |
| ema.load_state_dict(model.state_dict()) |
| for q in ema.parameters(): |
| q.requires_grad_(False) |
| print(f"params: {sum(t.numel() for t in model.parameters()):,}") |
| loss_fn = nn.BCEWithLogitsLoss() |
| EMA_DECAY = 0.999 |
|
|
| def update_ema(): |
| with torch.no_grad(): |
| for q, w in zip(ema.parameters(), model.parameters()): |
| q.lerp_(w, 1 - EMA_DECAY) |
|
|
| best_chain = -1.0 |
|
|
| def save_if_best(tag): |
| nonlocal best_chain |
| ch = chain_exact_rate(ema, val_primes, device) |
| if ch > best_chain: |
| best_chain = ch |
| torch.save({"state_dict": ema.state_dict(), "config": ema.config}, args.out) |
| return ch |
|
|
| |
| if args.stage1_minutes > 0: |
| opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) |
| total_steps = int(args.stage1_minutes * 60 * 16) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr * 0.02) |
| deadline = time.monotonic() + args.stage1_minutes * 60 |
| start = time.monotonic() |
| step = 0 |
| while time.monotonic() < deadline: |
| t, bit, b, p, z = sample_batch(train_primes, args.batch, device) |
| logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p)) |
| loss = loss_fn(logits, to_bits(z)) |
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| if step < total_steps: |
| sched.step() |
| update_ema() |
| step += 1 |
| if step % 1000 == 0: |
| va = exact_rate(ema, val_primes, device, n=100_000) |
| ch = save_if_best("s1") |
| print( |
| f"S1 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} " |
| f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s", |
| flush=True, |
| ) |
|
|
| |
| if args.stage2_minutes > 0: |
| opt = torch.optim.AdamW(model.parameters(), lr=args.lr2, weight_decay=1e-5) |
| total_steps = int(args.stage2_minutes * 60 * 3) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr2 * 0.1) |
| deadline = time.monotonic() + args.stage2_minutes * 60 |
| start = time.monotonic() |
| step = 0 |
| while time.monotonic() < deadline: |
| loss = chain_finetune_batch(model, train_primes, args.chain_batch, device, loss_fn) |
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| if step < total_steps: |
| sched.step() |
| update_ema() |
| step += 1 |
| if step % 200 == 0: |
| va = exact_rate(ema, val_primes, device, n=100_000) |
| ch = save_if_best("s2") |
| print( |
| f"S2 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} " |
| f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s", |
| flush=True, |
| ) |
|
|
| va = exact_rate(ema, val_primes, device, n=500_000) |
| ch = chain_exact_rate(ema, val_primes, device, n=50_000) |
| print(f"FINAL ema cell val {va:.6f} | chain val {ch:.4f} | best chain {best_chain:.4f}") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|