modular_arithmetic / train.py
etwk
Horner-RNN modular-multiplication model (tiers 1-5, up to 2^64)
3d2c226
Raw
History Blame
9.06 kB
"""Train the horner_rnn transition cell (bit-level Horner step) + chain fine-tuning.
Stage 1: train cell f(t, bit, b, p) = (2t + bit*b) mod p (quotients {0,1,2},
easier than base-4's {0..6}) with grad clipping, EMA, hard-boundary mining.
Stage 2 (optional, default off): fine-tune end-to-end through the 16-step
chain with a straight-through estimator on the quantized state, loss on every
step's ground-truth intermediate. In practice this was destructive at lr2=5e-5
(chain val collapsed); the shipped weights come from stage 1 alone, which
reaches chain val ~0.998 on held-out primes. Kept for further experimentation
at lower learning rates.
"""
from __future__ import annotations
import argparse
import time
import sys
from pathlib import Path
import torch
import torch.nn as nn
# Import the shared architecture from the sibling model.py.
HERE = Path(__file__).resolve().parent
sys.path.insert(0, str(HERE))
from model import HornerCell, BITS, _to_bits as to_bits # noqa: E402
def sieve_primes(limit: int) -> list[int]:
is_p = bytearray([1]) * limit
is_p[0] = is_p[1] = 0
for i in range(2, int(limit ** 0.5) + 1):
if is_p[i]:
is_p[i * i :: i] = bytearray(len(is_p[i * i :: i]))
return [i for i in range(2, limit) if is_p[i]]
def sample_batch(primes_t, n, device, hard_frac=0.5):
p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
bit = torch.randint(0, 2, (n,), device=device)
n_hard = int(n * hard_frac)
t = torch.empty(n, dtype=torch.long, device=device)
t[n_hard:] = (torch.rand(n - n_hard, device=device) * p[n_hard:]).long()
if n_hard:
ph, bh, bith = p[:n_hard], b[:n_hard], bit[:n_hard]
q = torch.randint(0, 3, (n_hard,), device=device)
delta = torch.randint(-2, 3, (n_hard,), device=device)
th = (q * ph + delta - bith * bh) >> 1
t[:n_hard] = th.clamp(min=0) % ph
z = (2 * t + bit * b) % p
return t, bit, b, p, z
@torch.no_grad()
def exact_rate(model, primes_t, device, n=200_000, bs=65536) -> float:
ok = 0
for i in range(0, n, bs):
m = min(bs, n - i)
t, bit, b, p, z = sample_batch(primes_t, m, device, hard_frac=0.0)
logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p))
ok += ((logits > 0).long() == to_bits(z).long()).all(dim=1).sum().item()
return ok / n
@torch.no_grad()
def chain_exact_rate(model, primes_t, device, n=20_000) -> float:
p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
truth = (a * b) % p
bb, pb = to_bits(b), to_bits(p)
tb = torch.zeros(n, BITS, device=device)
for i in range(BITS - 1, -1, -1):
bit = ((a >> i) & 1).float().unsqueeze(1)
tb = (model(tb, bit, bb, pb) > 0).float()
pred = (tb.long() * (1 << torch.arange(BITS, device=device))).sum(dim=1)
return (pred == truth).float().mean().item()
def chain_finetune_batch(model, primes_t, n, device, loss_fn):
"""One end-to-end pass: STE state, per-step CE against true intermediates."""
p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
bb, pb = to_bits(b), to_bits(p)
tb = torch.zeros(n, BITS, device=device)
t_true = torch.zeros_like(a)
loss = torch.zeros((), device=device)
for i in range(BITS - 1, -1, -1):
bit_i = (a >> i) & 1
t_true = (2 * t_true + bit_i * b) % p
logits = model(tb, bit_i.float().unsqueeze(1), bb, pb)
loss = loss + loss_fn(logits, to_bits(t_true))
hard = (logits > 0).float()
soft = torch.sigmoid(logits)
tb = hard + (soft - soft.detach()) # straight-through
return loss / BITS
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--stage1-minutes", type=float, default=50.0)
ap.add_argument("--stage2-minutes", type=float, default=0.0)
ap.add_argument("--batch", type=int, default=32768)
ap.add_argument("--chain-batch", type=int, default=4096)
ap.add_argument("--lr", type=float, default=3e-4)
ap.add_argument("--lr2", type=float, default=5e-5)
ap.add_argument("--width", type=int, default=4096)
ap.add_argument("--depth", type=int, default=4)
ap.add_argument("--init", type=str, default="")
ap.add_argument("--out", type=str, default=str(HERE / "weights16.pt"))
args = ap.parse_args()
device = torch.device("cuda")
torch.manual_seed(0)
small = sieve_primes(256)
primes = [p for p in sieve_primes(1 << 16) if p >= 256]
g = torch.Generator().manual_seed(1)
perm = torch.randperm(len(primes), generator=g).tolist()
val_primes = torch.tensor([primes[i] for i in perm[: len(primes) // 10]], device=device)
train_primes = torch.tensor(
small + [primes[i] for i in perm[len(primes) // 10 :]], device=device
)
print(f"train primes {len(train_primes)}, val primes {len(val_primes)}")
model = HornerCell(args.width, args.depth).to(device)
if args.init:
ckpt = torch.load(args.init, map_location=device, weights_only=True)
model.load_state_dict(ckpt["state_dict"])
print(f"initialised from {args.init}")
ema = HornerCell(args.width, args.depth).to(device)
ema.load_state_dict(model.state_dict())
for q in ema.parameters():
q.requires_grad_(False)
print(f"params: {sum(t.numel() for t in model.parameters()):,}")
loss_fn = nn.BCEWithLogitsLoss()
EMA_DECAY = 0.999
def update_ema():
with torch.no_grad():
for q, w in zip(ema.parameters(), model.parameters()):
q.lerp_(w, 1 - EMA_DECAY)
best_chain = -1.0
def save_if_best(tag):
nonlocal best_chain
ch = chain_exact_rate(ema, val_primes, device)
if ch > best_chain:
best_chain = ch
torch.save({"state_dict": ema.state_dict(), "config": ema.config}, args.out)
return ch
# ----- Stage 1: cell training -----
if args.stage1_minutes > 0:
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
total_steps = int(args.stage1_minutes * 60 * 16)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr * 0.02)
deadline = time.monotonic() + args.stage1_minutes * 60
start = time.monotonic()
step = 0
while time.monotonic() < deadline:
t, bit, b, p, z = sample_batch(train_primes, args.batch, device)
logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p))
loss = loss_fn(logits, to_bits(z))
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if step < total_steps:
sched.step()
update_ema()
step += 1
if step % 1000 == 0:
va = exact_rate(ema, val_primes, device, n=100_000)
ch = save_if_best("s1")
print(
f"S1 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} "
f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s",
flush=True,
)
# ----- Stage 2: end-to-end chain fine-tuning (STE) -----
if args.stage2_minutes > 0:
opt = torch.optim.AdamW(model.parameters(), lr=args.lr2, weight_decay=1e-5)
total_steps = int(args.stage2_minutes * 60 * 3)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr2 * 0.1)
deadline = time.monotonic() + args.stage2_minutes * 60
start = time.monotonic()
step = 0
while time.monotonic() < deadline:
loss = chain_finetune_batch(model, train_primes, args.chain_batch, device, loss_fn)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if step < total_steps:
sched.step()
update_ema()
step += 1
if step % 200 == 0:
va = exact_rate(ema, val_primes, device, n=100_000)
ch = save_if_best("s2")
print(
f"S2 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} "
f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s",
flush=True,
)
va = exact_rate(ema, val_primes, device, n=500_000)
ch = chain_exact_rate(ema, val_primes, device, n=50_000)
print(f"FINAL ema cell val {va:.6f} | chain val {ch:.4f} | best chain {best_chain:.4f}")
return 0
if __name__ == "__main__":
raise SystemExit(main())