|
|
| from gc import enable
|
|
|
| from functorch import dim
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import numpy as np
|
| from torch.utils.data import Dataset, DataLoader
|
|
|
| class SwiGLU(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.silu = nn.SiLU()
|
|
|
| def forward(self, x):
|
| x, gate = x.chunk(2, dim=-1)
|
| return self.silu(gate) * x
|
|
|
| class GatedResidual(nn.Module):
|
| def __init__(self, dim):
|
| super().__init__()
|
|
|
|
|
|
|
| self.gate_layer = nn.Linear(dim * 2, dim)
|
| self.output_norm = nn.RMSNorm(dim)
|
| self.gate_x = nn.Linear(dim, dim, bias=False)
|
| self.gate_r = nn.Linear(dim, dim, bias=False)
|
|
|
| def forward(self, x, residual):
|
| """
|
| x: The new information (e.g., from Attention)
|
| residual: The current memory/state (the 'highway')
|
| """
|
|
|
|
|
|
|
|
|
|
|
|
|
| gate = torch.sigmoid(self.gate_x(x) + self.gate_r(residual))
|
| mixed = (1 - gate) * residual + gate * x
|
|
|
| return self.output_norm(mixed)
|
|
|
| def rotate_half(x):
|
| """Rotates half the hidden dims of the input."""
|
| x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
| return torch.cat((-x2, x1), dim=-1)
|
|
|
| class RoPE(nn.Module):
|
| def __init__(self, head_dim, max_seq_len=2048, base=10000):
|
| super().__init__()
|
| inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| self.max_seq_len = max_seq_len
|
| self._build_cache(max_seq_len)
|
|
|
| def _build_cache(self, seq_len):
|
| t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| emb = torch.cat((freqs, freqs), dim=-1)
|
| self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
| self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
|
|
| def forward(self, x, seq_len):
|
| if seq_len > self.max_seq_len:
|
| self._build_cache(seq_len)
|
| cos = self.cos_cached[:, :, :seq_len, :x.shape[-1]]
|
| sin = self.sin_cached[:, :, :seq_len, :x.shape[-1]]
|
| return (x * cos) + (rotate_half(x) * sin)
|
|
|
| class RoPEAttention(nn.Module):
|
| def __init__(self, dim, heads, kv_heads=None, bottleneck=256):
|
| super().__init__()
|
|
|
| self.dim = dim
|
| self.bottleneck = bottleneck
|
| self.heads = heads
|
| self.kv_heads = kv_heads or heads
|
|
|
| assert heads % self.kv_heads == 0
|
|
|
| self.head_dim = bottleneck // heads
|
|
|
|
|
| self.latent = nn.Linear(dim, bottleneck, bias=False)
|
|
|
| self.q_proj = nn.Linear(bottleneck, bottleneck, bias=False)
|
| self.k_proj = nn.Linear(bottleneck, self.kv_heads * self.head_dim, bias=False)
|
| self.v_proj = nn.Linear(bottleneck, self.kv_heads * self.head_dim, bias=False)
|
|
|
| self.out_proj = nn.Linear(bottleneck, dim, bias=False)
|
|
|
| def forward(self, x, rope):
|
| x = self.latent(x)
|
|
|
| b, t, _ = x.shape
|
|
|
| q = self.q_proj(x).view(b, t, self.heads, self.head_dim).transpose(1, 2)
|
| k = self.k_proj(x).view(b, t, self.kv_heads, self.head_dim).transpose(1, 2)
|
| v = self.v_proj(x).view(b, t, self.kv_heads, self.head_dim).transpose(1, 2)
|
|
|
| split_size = self.head_dim // 2
|
|
|
|
|
| q_c, q_p = q.split([split_size, split_size], dim=-1)
|
| k_c, k_p = k.split([split_size, split_size], dim=-1)
|
|
|
|
|
| q_p = rope(q_p, t)
|
| k_p = rope(k_p, t)
|
|
|
| q = torch.cat([q_c, q_p], dim=-1)
|
| k = torch.cat([k_c, k_p], dim=-1)
|
|
|
| out = F.scaled_dot_product_attention(
|
| q, k, v,
|
| enable_gqa=True,
|
| is_causal=True
|
| )
|
|
|
| out = out.transpose(1, 2).contiguous().view(b, t, self.bottleneck)
|
| return self.out_proj(out)
|
|
|
| class StreamDataset(Dataset):
|
| def __init__(self, bin_file, seq_len):
|
|
|
| self.data = np.memmap(bin_file, dtype=np.uint16, mode='r')
|
| self.seq_len = seq_len
|
|
|
| self.n_samples = len(self.data) // (seq_len + 1)
|
|
|
| def __len__(self):
|
| return self.n_samples
|
|
|
| def __getitem__(self, idx):
|
| start = idx * (self.seq_len + 1)
|
| end = start + self.seq_len + 1
|
| chunk = self.data[start:end]
|
| x = chunk[:-1]
|
| y = chunk[1:]
|
| return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
|
|
|
| class FastMemoryCell(nn.Module):
|
| """
|
| Drop-in replacement for memory1 + GatedResidual.
|
|
|
| Speed over your original:
|
| Original: 4 matmuls (gate_x, gate_r x2 GR calls) + lin1 = 5 total
|
| This: 1 matmul (fused_proj) = 5x fewer weight multiplications
|
|
|
| Quality over a vanilla GRU:
|
| - Bidirectional: returns BOTH a new hidden state AND a context vector
|
| - Reset gate (like GRU) for selective forgetting
|
| - Shared candidate prevents the two gate paths from fighting each other
|
| - RMSNorm instead of LayerNorm (no mean subtraction = ~20% faster norm)
|
| """
|
|
|
| def __init__(self, dim: int):
|
| super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| self.fused_proj = nn.Linear(dim * 2, dim * 3, bias=True)
|
|
|
|
|
|
|
| self.norm = nn.RMSNorm(dim)
|
|
|
| def forward(self, x: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
| proj = self.fused_proj(torch.cat([x, state], dim=-1))
|
|
|
|
|
| g_update, g_reset, g_context = proj.chunk(3, dim=-1)
|
|
|
| g_update = g_update.sigmoid()
|
| g_reset = g_reset.sigmoid()
|
| g_context = g_context.sigmoid()
|
|
|
|
|
| candidate = (g_reset * state).tanh()
|
|
|
|
|
|
|
|
|
| new_state = (1.0 - g_update) * state + g_update * candidate
|
|
|
|
|
|
|
| context = (1.0 - g_context) * x + g_context * new_state
|
|
|
| return self.norm(new_state), context
|
|
|
|
|
| class RMSNorm(nn.Module):
|
| def __init__(self, dim, eps=1e-5):
|
| super().__init__()
|
| self.eps = eps
|
| self.weight = nn.Parameter(torch.ones(dim))
|
|
|
| def forward(self, x):
|
|
|
|
|
| return torch.rms_norm(x.to(self.weight.dtype), x.shape[-1:], self.weight, self.eps)
|
|
|
| class FlashCrossAttention(nn.Module):
|
| """
|
| Drop-in replacement for nn.MultiheadAttention that uses
|
| Flash Attention (O(seq) memory instead of O(seq²)).
|
|
|
| Usage identical to your existing MA1/MA2:
|
| self.layerMA1 = AI_ex.FlashCrossAttention(dim, heads)
|
| y, _ = self.layerMA1(query, key, value, attn_mask=mask)
|
|
|
| The attn_mask parameter is accepted but ignored —
|
| Flash Attention handles causality internally and is
|
| always more memory efficient than passing an explicit mask.
|
| """
|
| def __init__(self, dim, heads, bottleneck=256):
|
| super().__init__()
|
| self.heads = heads
|
| self.head_dim = dim // heads
|
| self.dim = dim
|
|
|
| self.q_proj = nn.Linear(dim, dim, bias=False)
|
| self.k_proj = nn.Linear(dim, dim, bias=False)
|
| self.v_proj = nn.Linear(dim, dim, bias=False)
|
| self.out_proj = nn.Linear(dim, dim, bias=False)
|
|
|
| def forward(self, query, key, value, attn_mask=None):
|
| B, Sq, D = query.shape
|
| Skv = key.size(1)
|
| H, Hd = self.heads, self.head_dim
|
|
|
| q = self.q_proj(query).view(B, Sq, H, Hd).transpose(1, 2)
|
| k = self.k_proj(key).view(B, Skv, H, Hd).transpose(1, 2)
|
| v = self.v_proj(value).view(B, Skv, H, Hd).transpose(1, 2)
|
|
|
|
|
| out = F.scaled_dot_product_attention(
|
| q, k, v,
|
| attn_mask=None,
|
| is_causal=False,
|
| enable_gqa=True
|
| )
|
|
|
| out = out.transpose(1, 2).contiguous().view(B, Sq, D)
|
| return self.out_proj(out), None
|
|
|
| class ThinkingRouter(nn.Module):
|
| """
|
| Routes to different experts based on the QUALITY of current thinking,
|
| not just the content. Uses three signals:
|
|
|
| 1. Delta: how much y changed this iteration (uncertainty signal)
|
| 2. Drift: how far y is from the linguistic anchor (grounding signal)
|
| 3. Iter: which iteration we're on (stage signal)
|
|
|
| These directly describe WHERE we are in the thinking process,
|
| making routing decisions interpretable and meaningful.
|
| """
|
| def __init__(self, dim: int, n_experts: int = 2, max_iter: int = 3):
|
| super().__init__()
|
| self.n_experts = n_experts
|
| self.max_iter = max_iter
|
|
|
|
|
| self.iter_embed = nn.Embedding(max_iter, 16)
|
|
|
|
|
|
|
| self.router = nn.Sequential(
|
| nn.Linear(18, 64),
|
| SwiGLU(),
|
| nn.Linear(32, n_experts, bias=False)
|
| )
|
|
|
|
|
|
|
| nn.init.normal_(self.router[0].weight, std=0.01)
|
| nn.init.normal_(self.router[2].weight, std=0.01)
|
|
|
| self.last_weights = None
|
|
|
| def forward(self,
|
| y: torch.Tensor,
|
| y_prev: torch.Tensor,
|
| linguistic_anchor: torch.Tensor,
|
| iter_idx: int
|
| ) -> torch.Tensor:
|
| """
|
| Returns routing weights [batch, n_experts].
|
| """
|
|
|
|
|
|
|
| delta = (y - y_prev).norm(dim=-1).mean(dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
|
|
|
| drift = (y - linguistic_anchor).norm(dim=-1).mean(dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
|
| delta = delta.detach() / (delta.detach().mean() + 1e-8)
|
| drift = drift.detach() / (drift.detach().mean() + 1e-8)
|
|
|
|
|
| iter_clamped = min(iter_idx if isinstance(iter_idx, int)
|
| else iter_idx.item(),
|
| self.max_iter - 1)
|
|
|
| iter_tensor = torch.as_tensor(iter_clamped, device=y.device, dtype=torch.long)
|
| iter_emb = self.iter_embed(iter_tensor)
|
| iter_emb = iter_emb.unsqueeze(0).expand(y.size(0), -1)
|
|
|
|
|
| routing_input = torch.cat([delta, drift, iter_emb], dim=-1)
|
| logits = self.router(routing_input)
|
|
|
| prob = torch.softmax(logits, dim=-1)
|
|
|
|
|
|
|
| return prob
|
|
|
|
|
| class MoLLayer(nn.Module):
|
| def __init__(self, dim: int, ffndim: int,
|
| n_experts: int = 2, max_iter: int = 3, bias=True):
|
| super().__init__()
|
| self.n_experts = n_experts
|
|
|
| self.router = ThinkingRouter(dim, n_experts, max_iter)
|
| self.experts = nn.ModuleList([
|
| nn.Sequential(
|
| nn.Linear(dim, ffndim * 2, bias=bias),
|
| SwiGLU(),
|
| nn.Linear(ffndim, dim, bias=bias)
|
| )
|
| for _ in range(n_experts)
|
| ])
|
|
|
| def forward(self, x:torch.Tensor, x_prev:torch.Tensor, linguistic_anchor: torch.Tensor, iter_idx:int):
|
|
|
|
|
| weights = self.router(x, x_prev, linguistic_anchor, iter_idx)
|
|
|
| out = torch.zeros_like(x)
|
| for i in range(self.n_experts):
|
| out += weights[:, i].unsqueeze(1).unsqueeze(2) * self.experts[i](x)
|
|
|
| return out
|
|
|
| class custom_mem(nn.Module):
|
| def __init__(self, dim, num_heads, head_dim=None, dtype=torch.bfloat16):
|
| super().__init__()
|
| self.dim = dim
|
| self.heads = num_heads
|
|
|
| self.head_dim = head_dim if head_dim else (dim // num_heads) // 2
|
|
|
| self.total_mid_dim = self.heads * self.head_dim
|
|
|
| self.norm_curr = RMSNorm(dim)
|
| self.norm_prev = RMSNorm(dim)
|
| self.norm_anch = RMSNorm(dim)
|
|
|
| self.input_proj = nn.Linear(dim * 3, self.total_mid_dim, bias=False, dtype=dtype)
|
|
|
| self.w12 = nn.Linear(self.total_mid_dim, self.total_mid_dim * 2, bias=False, dtype=dtype)
|
|
|
|
|
| self.w3 = nn.Linear(self.total_mid_dim, dim, bias=False, dtype=dtype)
|
| self.swiglu = SwiGLU()
|
|
|
| def forward(self, current, previous, anchor):
|
|
|
| c = self.norm_curr(current)
|
| p = self.norm_prev(previous)
|
| a = self.norm_anch(anchor)
|
|
|
|
|
|
|
| combined = torch.cat([c, p, a], dim=-1)
|
| x = self.input_proj(combined)
|
|
|
|
|
| gate_val = self.w12(x)
|
| gate, val = gate_val.chunk(2, dim=-1)
|
|
|
|
|
|
|
|
|
| x = F.silu(gate) * val
|
|
|
|
|
| return self.w3(x)
|
|
|
| class engram(nn.Module):
|
| def __init__(
|
| self,
|
| dim: int,
|
| num_heads: int = 4,
|
| ngram: int = 3,
|
| memory_size: int = 16384,
|
| bottleneck: int = 128,
|
| dtype=torch.bfloat16,
|
| ):
|
| super().__init__()
|
| assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
|
|
| self.dim = dim
|
| self.num_heads = num_heads
|
| self.ngram = ngram
|
| self.memory_size = memory_size
|
|
|
| self.key_gate = nn.Linear(dim, num_heads, bias=False, dtype=dtype)
|
| self.memory = nn.Embedding(memory_size, dim, dtype=dtype)
|
|
|
| self.post = nn.Sequential(
|
| nn.Linear(dim, bottleneck * 2, bias=False, dtype=dtype),
|
| SwiGLU(),
|
| nn.Linear(bottleneck, dim, bias=False, dtype=dtype),
|
| )
|
|
|
| self.norm = RMSNorm(dim)
|
| self.percent = nn.Linear(dim, 1, bias=False, dtype=dtype)
|
|
|
| self.register_buffer(
|
| "coeffs",
|
| torch.tensor([1, 1315423911, 2654435761, 2246822519], dtype=torch.long),
|
| persistent=False,
|
| )
|
| self.register_buffer(
|
| "salts",
|
| torch.tensor([0, 97, 193, 389], dtype=torch.long),
|
| persistent=False,
|
| )
|
|
|
| def _hash_ngrams(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| B, S = input_ids.shape
|
|
|
| if self.ngram > 1:
|
| pad = torch.zeros((B, self.ngram - 1), device=input_ids.device, dtype=input_ids.dtype)
|
| ids = torch.cat([pad, input_ids], dim=1)
|
| windows = ids.unfold(1, self.ngram, 1).long()
|
| else:
|
| windows = input_ids.long().unsqueeze(-1)
|
|
|
| coeffs = self.coeffs[:windows.size(-1)].view(1, 1, -1)
|
| base = (windows * coeffs).sum(dim=-1)
|
|
|
| hashes = []
|
| for salt in self.salts[:self.num_heads]:
|
| hashes.append(torch.remainder(base + salt, self.memory_size))
|
|
|
| return torch.stack(hashes, dim=-1)
|
|
|
| def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
| idx = self._hash_ngrams(input_ids)
|
| mem = self.memory(idx)
|
|
|
| x = x.to(self.key_gate.weight.dtype)
|
|
|
| gates = torch.softmax(self.key_gate(x), dim=-1).unsqueeze(-1)
|
| out = (mem * gates).sum(dim=2)
|
|
|
| out = self.post(self.norm(out))
|
| return torch.sigmoid(self.percent(out)) * out
|
|
|
|
|