"""Train the horner_rnn transition cell (bit-level Horner step) + chain fine-tuning. Stage 1: train cell f(t, bit, b, p) = (2t + bit*b) mod p (quotients {0,1,2}, easier than base-4's {0..6}) with grad clipping, EMA, hard-boundary mining. Stage 2 (optional, default off): fine-tune end-to-end through the 16-step chain with a straight-through estimator on the quantized state, loss on every step's ground-truth intermediate. In practice this was destructive at lr2=5e-5 (chain val collapsed); the shipped weights come from stage 1 alone, which reaches chain val ~0.998 on held-out primes. Kept for further experimentation at lower learning rates. """ from __future__ import annotations import argparse import time import sys from pathlib import Path import torch import torch.nn as nn # Import the shared architecture from the sibling model.py. HERE = Path(__file__).resolve().parent sys.path.insert(0, str(HERE)) from model import HornerCell, BITS, _to_bits as to_bits # noqa: E402 def sieve_primes(limit: int) -> list[int]: is_p = bytearray([1]) * limit is_p[0] = is_p[1] = 0 for i in range(2, int(limit ** 0.5) + 1): if is_p[i]: is_p[i * i :: i] = bytearray(len(is_p[i * i :: i])) return [i for i in range(2, limit) if is_p[i]] def sample_batch(primes_t, n, device, hard_frac=0.5): p = primes_t[torch.randint(len(primes_t), (n,), device=device)] b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1) bit = torch.randint(0, 2, (n,), device=device) n_hard = int(n * hard_frac) t = torch.empty(n, dtype=torch.long, device=device) t[n_hard:] = (torch.rand(n - n_hard, device=device) * p[n_hard:]).long() if n_hard: ph, bh, bith = p[:n_hard], b[:n_hard], bit[:n_hard] q = torch.randint(0, 3, (n_hard,), device=device) delta = torch.randint(-2, 3, (n_hard,), device=device) th = (q * ph + delta - bith * bh) >> 1 t[:n_hard] = th.clamp(min=0) % ph z = (2 * t + bit * b) % p return t, bit, b, p, z @torch.no_grad() def exact_rate(model, primes_t, device, n=200_000, bs=65536) -> float: ok = 0 for i in range(0, n, bs): m = min(bs, n - i) t, bit, b, p, z = sample_batch(primes_t, m, device, hard_frac=0.0) logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p)) ok += ((logits > 0).long() == to_bits(z).long()).all(dim=1).sum().item() return ok / n @torch.no_grad() def chain_exact_rate(model, primes_t, device, n=20_000) -> float: p = primes_t[torch.randint(len(primes_t), (n,), device=device)] a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1) b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1) truth = (a * b) % p bb, pb = to_bits(b), to_bits(p) tb = torch.zeros(n, BITS, device=device) for i in range(BITS - 1, -1, -1): bit = ((a >> i) & 1).float().unsqueeze(1) tb = (model(tb, bit, bb, pb) > 0).float() pred = (tb.long() * (1 << torch.arange(BITS, device=device))).sum(dim=1) return (pred == truth).float().mean().item() def chain_finetune_batch(model, primes_t, n, device, loss_fn): """One end-to-end pass: STE state, per-step CE against true intermediates.""" p = primes_t[torch.randint(len(primes_t), (n,), device=device)] a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1) b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1) bb, pb = to_bits(b), to_bits(p) tb = torch.zeros(n, BITS, device=device) t_true = torch.zeros_like(a) loss = torch.zeros((), device=device) for i in range(BITS - 1, -1, -1): bit_i = (a >> i) & 1 t_true = (2 * t_true + bit_i * b) % p logits = model(tb, bit_i.float().unsqueeze(1), bb, pb) loss = loss + loss_fn(logits, to_bits(t_true)) hard = (logits > 0).float() soft = torch.sigmoid(logits) tb = hard + (soft - soft.detach()) # straight-through return loss / BITS def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--stage1-minutes", type=float, default=50.0) ap.add_argument("--stage2-minutes", type=float, default=0.0) ap.add_argument("--batch", type=int, default=32768) ap.add_argument("--chain-batch", type=int, default=4096) ap.add_argument("--lr", type=float, default=3e-4) ap.add_argument("--lr2", type=float, default=5e-5) ap.add_argument("--width", type=int, default=4096) ap.add_argument("--depth", type=int, default=4) ap.add_argument("--init", type=str, default="") ap.add_argument("--out", type=str, default=str(HERE / "weights16.pt")) args = ap.parse_args() device = torch.device("cuda") torch.manual_seed(0) small = sieve_primes(256) primes = [p for p in sieve_primes(1 << 16) if p >= 256] g = torch.Generator().manual_seed(1) perm = torch.randperm(len(primes), generator=g).tolist() val_primes = torch.tensor([primes[i] for i in perm[: len(primes) // 10]], device=device) train_primes = torch.tensor( small + [primes[i] for i in perm[len(primes) // 10 :]], device=device ) print(f"train primes {len(train_primes)}, val primes {len(val_primes)}") model = HornerCell(args.width, args.depth).to(device) if args.init: ckpt = torch.load(args.init, map_location=device, weights_only=True) model.load_state_dict(ckpt["state_dict"]) print(f"initialised from {args.init}") ema = HornerCell(args.width, args.depth).to(device) ema.load_state_dict(model.state_dict()) for q in ema.parameters(): q.requires_grad_(False) print(f"params: {sum(t.numel() for t in model.parameters()):,}") loss_fn = nn.BCEWithLogitsLoss() EMA_DECAY = 0.999 def update_ema(): with torch.no_grad(): for q, w in zip(ema.parameters(), model.parameters()): q.lerp_(w, 1 - EMA_DECAY) best_chain = -1.0 def save_if_best(tag): nonlocal best_chain ch = chain_exact_rate(ema, val_primes, device) if ch > best_chain: best_chain = ch torch.save({"state_dict": ema.state_dict(), "config": ema.config}, args.out) return ch # ----- Stage 1: cell training ----- if args.stage1_minutes > 0: opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) total_steps = int(args.stage1_minutes * 60 * 16) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr * 0.02) deadline = time.monotonic() + args.stage1_minutes * 60 start = time.monotonic() step = 0 while time.monotonic() < deadline: t, bit, b, p, z = sample_batch(train_primes, args.batch, device) logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p)) loss = loss_fn(logits, to_bits(z)) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() if step < total_steps: sched.step() update_ema() step += 1 if step % 1000 == 0: va = exact_rate(ema, val_primes, device, n=100_000) ch = save_if_best("s1") print( f"S1 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} " f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s", flush=True, ) # ----- Stage 2: end-to-end chain fine-tuning (STE) ----- if args.stage2_minutes > 0: opt = torch.optim.AdamW(model.parameters(), lr=args.lr2, weight_decay=1e-5) total_steps = int(args.stage2_minutes * 60 * 3) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr2 * 0.1) deadline = time.monotonic() + args.stage2_minutes * 60 start = time.monotonic() step = 0 while time.monotonic() < deadline: loss = chain_finetune_batch(model, train_primes, args.chain_batch, device, loss_fn) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() if step < total_steps: sched.step() update_ema() step += 1 if step % 200 == 0: va = exact_rate(ema, val_primes, device, n=100_000) ch = save_if_best("s2") print( f"S2 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} " f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s", flush=True, ) va = exact_rate(ema, val_primes, device, n=500_000) ch = chain_exact_rate(ema, val_primes, device, n=50_000) print(f"FINAL ema cell val {va:.6f} | chain val {ch:.4f} | best chain {best_chain:.4f}") return 0 if __name__ == "__main__": raise SystemExit(main())