Biggerbrain2_136m / ai_extras.py
Skull18500's picture
Upload the code and the weights
dc84c90 verified
from gc import enable
from functorch import dim
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
class SwiGLU(nn.Module):
def __init__(self):
super().__init__()
self.silu = nn.SiLU() # SiLU is the same as Swish, and is available in PyTorch
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return self.silu(gate) * x
class GatedResidual(nn.Module):
def __init__(self, dim):
super().__init__()
# This layer looks at both the current state and the new info
# to decide what to keep.
#layernorm
self.gate_layer = nn.Linear(dim * 2, dim)
self.output_norm = nn.RMSNorm(dim)
self.gate_x = nn.Linear(dim, dim, bias=False)
self.gate_r = nn.Linear(dim, dim, bias=False)
def forward(self, x, residual):
"""
x: The new information (e.g., from Attention)
residual: The current memory/state (the 'highway')
"""
# 1. Concatenate them and calculate the 'Valve' (0 to 1)
# 2. The 'Convex Combination' - pure stability
# If gate is 0, we keep only the old memory.
# If gate is 1, we take only the new info.
gate = torch.sigmoid(self.gate_x(x) + self.gate_r(residual)) # no cat needed
mixed = (1 - gate) * residual + gate * x
# 3. Final cleanup
return self.output_norm(mixed)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
class RoPE(nn.Module):
def __init__(self, head_dim, max_seq_len=2048, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len = max_seq_len
self._build_cache(max_seq_len)
def _build_cache(self, seq_len):
t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len):
if seq_len > self.max_seq_len:
self._build_cache(seq_len)
cos = self.cos_cached[:, :, :seq_len, :x.shape[-1]]
sin = self.sin_cached[:, :, :seq_len, :x.shape[-1]]
return (x * cos) + (rotate_half(x) * sin)
class RoPEAttention(nn.Module):
def __init__(self, dim, heads, kv_heads=None, bottleneck=256):
super().__init__()
self.dim = dim
self.bottleneck = bottleneck
self.heads = heads
self.kv_heads = kv_heads or heads
assert heads % self.kv_heads == 0
self.head_dim = bottleneck // heads
# single latent projection
self.latent = nn.Linear(dim, bottleneck, bias=False)
self.q_proj = nn.Linear(bottleneck, bottleneck, bias=False)
self.k_proj = nn.Linear(bottleneck, self.kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(bottleneck, self.kv_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(bottleneck, dim, bias=False)
def forward(self, x, rope):
x = self.latent(x)
b, t, _ = x.shape
q = self.q_proj(x).view(b, t, self.heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(b, t, self.kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(b, t, self.kv_heads, self.head_dim).transpose(1, 2)
split_size = self.head_dim // 2
q_c, q_p = q.split([split_size, split_size], dim=-1)
k_c, k_p = k.split([split_size, split_size], dim=-1)
# Apply RoPE ONLY to the Position half
q_p = rope(q_p, t)
k_p = rope(k_p, t)
q = torch.cat([q_c, q_p], dim=-1)
k = torch.cat([k_c, k_p], dim=-1)
out = F.scaled_dot_product_attention(
q, k, v,
enable_gqa=True,
is_causal=True
)
out = out.transpose(1, 2).contiguous().view(b, t, self.bottleneck)
return self.out_proj(out)
class StreamDataset(Dataset):
def __init__(self, bin_file, seq_len):
# dtype MUST match what you used in tofile()
self.data = np.memmap(bin_file, dtype=np.uint16, mode='r')
self.seq_len = seq_len
# We need seq_len + 1 to get a (input, target) pair
self.n_samples = len(self.data) // (seq_len + 1)
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
start = idx * (self.seq_len + 1)
end = start + self.seq_len + 1
chunk = self.data[start:end]
x = chunk[:-1]
y = chunk[1:]
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
class FastMemoryCell(nn.Module):
"""
Drop-in replacement for memory1 + GatedResidual.
Speed over your original:
Original: 4 matmuls (gate_x, gate_r x2 GR calls) + lin1 = 5 total
This: 1 matmul (fused_proj) = 5x fewer weight multiplications
Quality over a vanilla GRU:
- Bidirectional: returns BOTH a new hidden state AND a context vector
- Reset gate (like GRU) for selective forgetting
- Shared candidate prevents the two gate paths from fighting each other
- RMSNorm instead of LayerNorm (no mean subtraction = ~20% faster norm)
"""
def __init__(self, dim: int):
super().__init__()
# ONE big Linear replaces:
# self.GR.gate_x (dim -> dim)
# self.GR.gate_r (dim -> dim)
# self.GR1.gate_x (dim -> dim)
# self.GR1.gate_r (dim -> dim)
# self.lin1 (dim -> dim)
#
# In C++ terms: instead of 4 small GEMM calls,
# we do 1 large GEMM — GPU loves wide matmuls.
self.fused_proj = nn.Linear(dim * 2, dim * 3, bias=True)
# RMSNorm: skips mean subtraction vs LayerNorm, ~20% faster
# Requires PyTorch >= 2.1. Fall back to nn.LayerNorm if needed.
self.norm = nn.RMSNorm(dim)
def forward(self, x: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Single cat + single matmul for ALL gate logic
# Shape: [batch, dim*3]
proj = self.fused_proj(torch.cat([x, state], dim=-1))
# Slice into 3 equal parts along last dim — no memory copy, just views
g_update, g_reset, g_context = proj.chunk(3, dim=-1)
g_update = g_update.sigmoid() # How much new info enters the state
g_reset = g_reset.sigmoid() # GRU-style: what old state to use for candidate
g_context = g_context.sigmoid() # How much updated state leaks into context output
# GRU-style candidate: reset gate filters what old state matters
candidate = (g_reset * state).tanh()
# --- Two outputs, shared computation ---
# 1. New hidden state (equivalent to your: x = GR(input, state) → lin1)
new_state = (1.0 - g_update) * state + g_update * candidate
# 2. Context vector (equivalent to your: w = GR(state, input))
# Blends raw input with the freshly updated state
context = (1.0 - g_context) * x + g_context * new_state
return self.norm(new_state), context
#My rms norm implementation. This upcasts to fp32, then casts to whatever the input dtype was.
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # stays float32 always
def forward(self, x):
# Upcast input to float32 for norm (more numerically stable anyway)
# then cast result back to whatever dtype x was
return torch.rms_norm(x.to(self.weight.dtype), x.shape[-1:], self.weight, self.eps)
class FlashCrossAttention(nn.Module):
"""
Drop-in replacement for nn.MultiheadAttention that uses
Flash Attention (O(seq) memory instead of O(seq²)).
Usage identical to your existing MA1/MA2:
self.layerMA1 = AI_ex.FlashCrossAttention(dim, heads)
y, _ = self.layerMA1(query, key, value, attn_mask=mask)
The attn_mask parameter is accepted but ignored —
Flash Attention handles causality internally and is
always more memory efficient than passing an explicit mask.
"""
def __init__(self, dim, heads, bottleneck=256):
super().__init__()
self.heads = heads
self.head_dim = dim // heads
self.dim = dim
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
def forward(self, query, key, value, attn_mask=None):
B, Sq, D = query.shape
Skv = key.size(1)
H, Hd = self.heads, self.head_dim
q = self.q_proj(query).view(B, Sq, H, Hd).transpose(1, 2)
k = self.k_proj(key).view(B, Skv, H, Hd).transpose(1, 2)
v = self.v_proj(value).view(B, Skv, H, Hd).transpose(1, 2)
# Flash Attention — O(seq) memory, same result as standard attention
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None, # cross attention doesn't need causal mask
is_causal=False,
enable_gqa=True # use GQA optimization for even more speed and memory savings
)
out = out.transpose(1, 2).contiguous().view(B, Sq, D)
return self.out_proj(out), None # None matches nn.MultiheadAttention signature
class ThinkingRouter(nn.Module):
"""
Routes to different experts based on the QUALITY of current thinking,
not just the content. Uses three signals:
1. Delta: how much y changed this iteration (uncertainty signal)
2. Drift: how far y is from the linguistic anchor (grounding signal)
3. Iter: which iteration we're on (stage signal)
These directly describe WHERE we are in the thinking process,
making routing decisions interpretable and meaningful.
"""
def __init__(self, dim: int, n_experts: int = 2, max_iter: int = 3):
super().__init__()
self.n_experts = n_experts
self.max_iter = max_iter
# Iteration embedding — gives each iteration a learned "personality"
# iter 1 = "first pass", iter 2 = "refinement", iter 3 = "verification"
self.iter_embed = nn.Embedding(max_iter, 16)
# Project the three signals into routing logits
# Input: delta_scalar + drift_scalar + iter_embed(16) = 18 dims
self.router = nn.Sequential(
nn.Linear(18, 64),
SwiGLU(),
nn.Linear(32, n_experts, bias=False)
)
# Init router to be nearly uniform at start
# → experts start with equal load, specialization emerges
nn.init.normal_(self.router[0].weight, std=0.01)
nn.init.normal_(self.router[2].weight, std=0.01)
self.last_weights = None # store for balancing loss
def forward(self,
y: torch.Tensor, # current hidden state
y_prev: torch.Tensor, # hidden state from last iter
linguistic_anchor: torch.Tensor, # what the input said
iter_idx: int # which iteration (0-indexed)
) -> torch.Tensor:
"""
Returns routing weights [batch, n_experts].
"""
# Signal 1: Delta — how much thinking changed this step
# High delta = uncertain, still changing a lot
# Low delta = converging, changes are subtle
delta = (y - y_prev).norm(dim=-1).mean(dim=-1, keepdim=True)
# delta shape: [batch, 1]
# Signal 2: Drift — how far current thinking is from the input
# High drift = model is thinking abstractly, far from literal input
# Low drift = model is still closely following the input
drift = (y - linguistic_anchor).norm(dim=-1).mean(dim=-1, keepdim=True)
# drift shape: [batch, 1]
# Normalize both signals so they're comparable
# Detach to avoid routing gradients affecting main computation
delta = delta.detach() / (delta.detach().mean() + 1e-8)
drift = drift.detach() / (drift.detach().mean() + 1e-8)
# Signal 3: Iteration stage embedding
iter_clamped = min(iter_idx if isinstance(iter_idx, int)
else iter_idx.item(),
self.max_iter - 1) # clamp to valid range
iter_tensor = torch.as_tensor(iter_clamped, device=y.device, dtype=torch.long)
iter_emb = self.iter_embed(iter_tensor)
iter_emb = iter_emb.unsqueeze(0).expand(y.size(0), -1) # [batch, 16]
# Combine signals
routing_input = torch.cat([delta, drift, iter_emb], dim=-1) # [batch, 18]
logits = self.router(routing_input) # [batch, n_experts]
prob = torch.softmax(logits, dim=-1) # is this the right dimension/axis to view? I need to make sure that this is the axis of probabilities.
return prob
class MoLLayer(nn.Module):
def __init__(self, dim: int, ffndim: int,
n_experts: int = 2, max_iter: int = 3, bias=True):
super().__init__()
self.n_experts = n_experts
self.router = ThinkingRouter(dim, n_experts, max_iter)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(dim, ffndim * 2, bias=bias),
SwiGLU(),
nn.Linear(ffndim, dim, bias=bias)
)
for _ in range(n_experts)
])
def forward(self, x:torch.Tensor, x_prev:torch.Tensor, linguistic_anchor: torch.Tensor, iter_idx:int):# -> torch.Tensor
#self.router.last_weights = None
weights = self.router(x, x_prev, linguistic_anchor, iter_idx)
out = torch.zeros_like(x)
for i in range(self.n_experts):
out += weights[:, i].unsqueeze(1).unsqueeze(2) * self.experts[i](x)
return out
class custom_mem(nn.Module):
def __init__(self, dim, num_heads, head_dim=None, dtype=torch.bfloat16):
super().__init__()
self.dim = dim
self.heads = num_heads
# If head_dim isn't set, we'll default to a light bottleneck
self.head_dim = head_dim if head_dim else (dim // num_heads) // 2
self.total_mid_dim = self.heads * self.head_dim
self.norm_curr = RMSNorm(dim)
self.norm_prev = RMSNorm(dim)
self.norm_anch = RMSNorm(dim)
self.input_proj = nn.Linear(dim * 3, self.total_mid_dim, bias=False, dtype=dtype)
self.w12 = nn.Linear(self.total_mid_dim, self.total_mid_dim * 2, bias=False, dtype=dtype)
#self.scale1 = nn.Linear(self.total_mid_dim // 2, self.total_mid_dim, bias=False, dtype=dtype)
self.w3 = nn.Linear(self.total_mid_dim, dim, bias=False, dtype=dtype)
self.swiglu = SwiGLU()
def forward(self, current, previous, anchor):
# Apply RMSNorm to inputs to stabilize the "Trinity"
c = self.norm_curr(current)
p = self.norm_prev(previous)
a = self.norm_anch(anchor)
# Concatenate and project to the multi-head space
# x shape: [Batch, Seq, (num_heads * head_dim)]
combined = torch.cat([c, p, a], dim=-1)
x = self.input_proj(combined)
# SwiGLU Logic: Split the mid_dim in half for gate vs value
gate_val = self.w12(x)
gate, val = gate_val.chunk(2, dim=-1)
# Apply swiglu to the gate and multiply by the value
# This happens in parallel across all heads
x = F.silu(gate) * val
# Final Mixing: w3 sees all heads at once and collapses them back to 'dim'
return self.w3(x)
class engram(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 4,
ngram: int = 3,
memory_size: int = 16384,
bottleneck: int = 128,
dtype=torch.bfloat16,
):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.ngram = ngram
self.memory_size = memory_size
self.key_gate = nn.Linear(dim, num_heads, bias=False, dtype=dtype)
self.memory = nn.Embedding(memory_size, dim, dtype=dtype)
self.post = nn.Sequential(
nn.Linear(dim, bottleneck * 2, bias=False, dtype=dtype),
SwiGLU(),
nn.Linear(bottleneck, dim, bias=False, dtype=dtype),
)
self.norm = RMSNorm(dim)
self.percent = nn.Linear(dim, 1, bias=False, dtype=dtype)
self.register_buffer(
"coeffs",
torch.tensor([1, 1315423911, 2654435761, 2246822519], dtype=torch.long),
persistent=False,
)
self.register_buffer(
"salts",
torch.tensor([0, 97, 193, 389], dtype=torch.long),
persistent=False,
)
def _hash_ngrams(self, input_ids: torch.Tensor) -> torch.Tensor:
B, S = input_ids.shape
if self.ngram > 1:
pad = torch.zeros((B, self.ngram - 1), device=input_ids.device, dtype=input_ids.dtype)
ids = torch.cat([pad, input_ids], dim=1)
windows = ids.unfold(1, self.ngram, 1).long() # [B, S, ngram]
else:
windows = input_ids.long().unsqueeze(-1) # [B, S, 1]
coeffs = self.coeffs[:windows.size(-1)].view(1, 1, -1)
base = (windows * coeffs).sum(dim=-1) # [B, S]
hashes = []
for salt in self.salts[:self.num_heads]:
hashes.append(torch.remainder(base + salt, self.memory_size))
return torch.stack(hashes, dim=-1) # [B, S, H]
def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
idx = self._hash_ngrams(input_ids) # [B, S, H]
mem = self.memory(idx) # [B, S, H, D]
x = x.to(self.key_gate.weight.dtype)
gates = torch.softmax(self.key_gate(x), dim=-1).unsqueeze(-1) # [B, S, H, 1]
out = (mem * gates).sum(dim=2) # [B, S, D]
out = self.post(self.norm(out))
return torch.sigmoid(self.percent(out)) * out