from gc import enable from functorch import dim import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.utils.data import Dataset, DataLoader class SwiGLU(nn.Module): def __init__(self): super().__init__() self.silu = nn.SiLU() # SiLU is the same as Swish, and is available in PyTorch def forward(self, x): x, gate = x.chunk(2, dim=-1) return self.silu(gate) * x class GatedResidual(nn.Module): def __init__(self, dim): super().__init__() # This layer looks at both the current state and the new info # to decide what to keep. #layernorm self.gate_layer = nn.Linear(dim * 2, dim) self.output_norm = nn.RMSNorm(dim) self.gate_x = nn.Linear(dim, dim, bias=False) self.gate_r = nn.Linear(dim, dim, bias=False) def forward(self, x, residual): """ x: The new information (e.g., from Attention) residual: The current memory/state (the 'highway') """ # 1. Concatenate them and calculate the 'Valve' (0 to 1) # 2. The 'Convex Combination' - pure stability # If gate is 0, we keep only the old memory. # If gate is 1, we take only the new info. gate = torch.sigmoid(self.gate_x(x) + self.gate_r(residual)) # no cat needed mixed = (1 - gate) * residual + gate * x # 3. Final cleanup return self.output_norm(mixed) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) class RoPE(nn.Module): def __init__(self, head_dim, max_seq_len=2048, base=10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len = max_seq_len self._build_cache(max_seq_len) def _build_cache(self, seq_len): t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) def forward(self, x, seq_len): if seq_len > self.max_seq_len: self._build_cache(seq_len) cos = self.cos_cached[:, :, :seq_len, :x.shape[-1]] sin = self.sin_cached[:, :, :seq_len, :x.shape[-1]] return (x * cos) + (rotate_half(x) * sin) class RoPEAttention(nn.Module): def __init__(self, dim, heads, kv_heads=None, bottleneck=256): super().__init__() self.dim = dim self.bottleneck = bottleneck self.heads = heads self.kv_heads = kv_heads or heads assert heads % self.kv_heads == 0 self.head_dim = bottleneck // heads # single latent projection self.latent = nn.Linear(dim, bottleneck, bias=False) self.q_proj = nn.Linear(bottleneck, bottleneck, bias=False) self.k_proj = nn.Linear(bottleneck, self.kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(bottleneck, self.kv_heads * self.head_dim, bias=False) self.out_proj = nn.Linear(bottleneck, dim, bias=False) def forward(self, x, rope): x = self.latent(x) b, t, _ = x.shape q = self.q_proj(x).view(b, t, self.heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(b, t, self.kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(b, t, self.kv_heads, self.head_dim).transpose(1, 2) split_size = self.head_dim // 2 q_c, q_p = q.split([split_size, split_size], dim=-1) k_c, k_p = k.split([split_size, split_size], dim=-1) # Apply RoPE ONLY to the Position half q_p = rope(q_p, t) k_p = rope(k_p, t) q = torch.cat([q_c, q_p], dim=-1) k = torch.cat([k_c, k_p], dim=-1) out = F.scaled_dot_product_attention( q, k, v, enable_gqa=True, is_causal=True ) out = out.transpose(1, 2).contiguous().view(b, t, self.bottleneck) return self.out_proj(out) class StreamDataset(Dataset): def __init__(self, bin_file, seq_len): # dtype MUST match what you used in tofile() self.data = np.memmap(bin_file, dtype=np.uint16, mode='r') self.seq_len = seq_len # We need seq_len + 1 to get a (input, target) pair self.n_samples = len(self.data) // (seq_len + 1) def __len__(self): return self.n_samples def __getitem__(self, idx): start = idx * (self.seq_len + 1) end = start + self.seq_len + 1 chunk = self.data[start:end] x = chunk[:-1] y = chunk[1:] return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long) class FastMemoryCell(nn.Module): """ Drop-in replacement for memory1 + GatedResidual. Speed over your original: Original: 4 matmuls (gate_x, gate_r x2 GR calls) + lin1 = 5 total This: 1 matmul (fused_proj) = 5x fewer weight multiplications Quality over a vanilla GRU: - Bidirectional: returns BOTH a new hidden state AND a context vector - Reset gate (like GRU) for selective forgetting - Shared candidate prevents the two gate paths from fighting each other - RMSNorm instead of LayerNorm (no mean subtraction = ~20% faster norm) """ def __init__(self, dim: int): super().__init__() # ONE big Linear replaces: # self.GR.gate_x (dim -> dim) # self.GR.gate_r (dim -> dim) # self.GR1.gate_x (dim -> dim) # self.GR1.gate_r (dim -> dim) # self.lin1 (dim -> dim) # # In C++ terms: instead of 4 small GEMM calls, # we do 1 large GEMM — GPU loves wide matmuls. self.fused_proj = nn.Linear(dim * 2, dim * 3, bias=True) # RMSNorm: skips mean subtraction vs LayerNorm, ~20% faster # Requires PyTorch >= 2.1. Fall back to nn.LayerNorm if needed. self.norm = nn.RMSNorm(dim) def forward(self, x: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # Single cat + single matmul for ALL gate logic # Shape: [batch, dim*3] proj = self.fused_proj(torch.cat([x, state], dim=-1)) # Slice into 3 equal parts along last dim — no memory copy, just views g_update, g_reset, g_context = proj.chunk(3, dim=-1) g_update = g_update.sigmoid() # How much new info enters the state g_reset = g_reset.sigmoid() # GRU-style: what old state to use for candidate g_context = g_context.sigmoid() # How much updated state leaks into context output # GRU-style candidate: reset gate filters what old state matters candidate = (g_reset * state).tanh() # --- Two outputs, shared computation --- # 1. New hidden state (equivalent to your: x = GR(input, state) → lin1) new_state = (1.0 - g_update) * state + g_update * candidate # 2. Context vector (equivalent to your: w = GR(state, input)) # Blends raw input with the freshly updated state context = (1.0 - g_context) * x + g_context * new_state return self.norm(new_state), context #My rms norm implementation. This upcasts to fp32, then casts to whatever the input dtype was. class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) # stays float32 always def forward(self, x): # Upcast input to float32 for norm (more numerically stable anyway) # then cast result back to whatever dtype x was return torch.rms_norm(x.to(self.weight.dtype), x.shape[-1:], self.weight, self.eps) class FlashCrossAttention(nn.Module): """ Drop-in replacement for nn.MultiheadAttention that uses Flash Attention (O(seq) memory instead of O(seq²)). Usage identical to your existing MA1/MA2: self.layerMA1 = AI_ex.FlashCrossAttention(dim, heads) y, _ = self.layerMA1(query, key, value, attn_mask=mask) The attn_mask parameter is accepted but ignored — Flash Attention handles causality internally and is always more memory efficient than passing an explicit mask. """ def __init__(self, dim, heads, bottleneck=256): super().__init__() self.heads = heads self.head_dim = dim // heads self.dim = dim self.q_proj = nn.Linear(dim, dim, bias=False) self.k_proj = nn.Linear(dim, dim, bias=False) self.v_proj = nn.Linear(dim, dim, bias=False) self.out_proj = nn.Linear(dim, dim, bias=False) def forward(self, query, key, value, attn_mask=None): B, Sq, D = query.shape Skv = key.size(1) H, Hd = self.heads, self.head_dim q = self.q_proj(query).view(B, Sq, H, Hd).transpose(1, 2) k = self.k_proj(key).view(B, Skv, H, Hd).transpose(1, 2) v = self.v_proj(value).view(B, Skv, H, Hd).transpose(1, 2) # Flash Attention — O(seq) memory, same result as standard attention out = F.scaled_dot_product_attention( q, k, v, attn_mask=None, # cross attention doesn't need causal mask is_causal=False, enable_gqa=True # use GQA optimization for even more speed and memory savings ) out = out.transpose(1, 2).contiguous().view(B, Sq, D) return self.out_proj(out), None # None matches nn.MultiheadAttention signature class ThinkingRouter(nn.Module): """ Routes to different experts based on the QUALITY of current thinking, not just the content. Uses three signals: 1. Delta: how much y changed this iteration (uncertainty signal) 2. Drift: how far y is from the linguistic anchor (grounding signal) 3. Iter: which iteration we're on (stage signal) These directly describe WHERE we are in the thinking process, making routing decisions interpretable and meaningful. """ def __init__(self, dim: int, n_experts: int = 2, max_iter: int = 3): super().__init__() self.n_experts = n_experts self.max_iter = max_iter # Iteration embedding — gives each iteration a learned "personality" # iter 1 = "first pass", iter 2 = "refinement", iter 3 = "verification" self.iter_embed = nn.Embedding(max_iter, 16) # Project the three signals into routing logits # Input: delta_scalar + drift_scalar + iter_embed(16) = 18 dims self.router = nn.Sequential( nn.Linear(18, 64), SwiGLU(), nn.Linear(32, n_experts, bias=False) ) # Init router to be nearly uniform at start # → experts start with equal load, specialization emerges nn.init.normal_(self.router[0].weight, std=0.01) nn.init.normal_(self.router[2].weight, std=0.01) self.last_weights = None # store for balancing loss def forward(self, y: torch.Tensor, # current hidden state y_prev: torch.Tensor, # hidden state from last iter linguistic_anchor: torch.Tensor, # what the input said iter_idx: int # which iteration (0-indexed) ) -> torch.Tensor: """ Returns routing weights [batch, n_experts]. """ # Signal 1: Delta — how much thinking changed this step # High delta = uncertain, still changing a lot # Low delta = converging, changes are subtle delta = (y - y_prev).norm(dim=-1).mean(dim=-1, keepdim=True) # delta shape: [batch, 1] # Signal 2: Drift — how far current thinking is from the input # High drift = model is thinking abstractly, far from literal input # Low drift = model is still closely following the input drift = (y - linguistic_anchor).norm(dim=-1).mean(dim=-1, keepdim=True) # drift shape: [batch, 1] # Normalize both signals so they're comparable # Detach to avoid routing gradients affecting main computation delta = delta.detach() / (delta.detach().mean() + 1e-8) drift = drift.detach() / (drift.detach().mean() + 1e-8) # Signal 3: Iteration stage embedding iter_clamped = min(iter_idx if isinstance(iter_idx, int) else iter_idx.item(), self.max_iter - 1) # clamp to valid range iter_tensor = torch.as_tensor(iter_clamped, device=y.device, dtype=torch.long) iter_emb = self.iter_embed(iter_tensor) iter_emb = iter_emb.unsqueeze(0).expand(y.size(0), -1) # [batch, 16] # Combine signals routing_input = torch.cat([delta, drift, iter_emb], dim=-1) # [batch, 18] logits = self.router(routing_input) # [batch, n_experts] prob = torch.softmax(logits, dim=-1) # is this the right dimension/axis to view? I need to make sure that this is the axis of probabilities. return prob class MoLLayer(nn.Module): def __init__(self, dim: int, ffndim: int, n_experts: int = 2, max_iter: int = 3, bias=True): super().__init__() self.n_experts = n_experts self.router = ThinkingRouter(dim, n_experts, max_iter) self.experts = nn.ModuleList([ nn.Sequential( nn.Linear(dim, ffndim * 2, bias=bias), SwiGLU(), nn.Linear(ffndim, dim, bias=bias) ) for _ in range(n_experts) ]) def forward(self, x:torch.Tensor, x_prev:torch.Tensor, linguistic_anchor: torch.Tensor, iter_idx:int):# -> torch.Tensor #self.router.last_weights = None weights = self.router(x, x_prev, linguistic_anchor, iter_idx) out = torch.zeros_like(x) for i in range(self.n_experts): out += weights[:, i].unsqueeze(1).unsqueeze(2) * self.experts[i](x) return out class custom_mem(nn.Module): def __init__(self, dim, num_heads, head_dim=None, dtype=torch.bfloat16): super().__init__() self.dim = dim self.heads = num_heads # If head_dim isn't set, we'll default to a light bottleneck self.head_dim = head_dim if head_dim else (dim // num_heads) // 2 self.total_mid_dim = self.heads * self.head_dim self.norm_curr = RMSNorm(dim) self.norm_prev = RMSNorm(dim) self.norm_anch = RMSNorm(dim) self.input_proj = nn.Linear(dim * 3, self.total_mid_dim, bias=False, dtype=dtype) self.w12 = nn.Linear(self.total_mid_dim, self.total_mid_dim * 2, bias=False, dtype=dtype) #self.scale1 = nn.Linear(self.total_mid_dim // 2, self.total_mid_dim, bias=False, dtype=dtype) self.w3 = nn.Linear(self.total_mid_dim, dim, bias=False, dtype=dtype) self.swiglu = SwiGLU() def forward(self, current, previous, anchor): # Apply RMSNorm to inputs to stabilize the "Trinity" c = self.norm_curr(current) p = self.norm_prev(previous) a = self.norm_anch(anchor) # Concatenate and project to the multi-head space # x shape: [Batch, Seq, (num_heads * head_dim)] combined = torch.cat([c, p, a], dim=-1) x = self.input_proj(combined) # SwiGLU Logic: Split the mid_dim in half for gate vs value gate_val = self.w12(x) gate, val = gate_val.chunk(2, dim=-1) # Apply swiglu to the gate and multiply by the value # This happens in parallel across all heads x = F.silu(gate) * val # Final Mixing: w3 sees all heads at once and collapses them back to 'dim' return self.w3(x) class engram(nn.Module): def __init__( self, dim: int, num_heads: int = 4, ngram: int = 3, memory_size: int = 16384, bottleneck: int = 128, dtype=torch.bfloat16, ): super().__init__() assert dim % num_heads == 0, "dim must be divisible by num_heads" self.dim = dim self.num_heads = num_heads self.ngram = ngram self.memory_size = memory_size self.key_gate = nn.Linear(dim, num_heads, bias=False, dtype=dtype) self.memory = nn.Embedding(memory_size, dim, dtype=dtype) self.post = nn.Sequential( nn.Linear(dim, bottleneck * 2, bias=False, dtype=dtype), SwiGLU(), nn.Linear(bottleneck, dim, bias=False, dtype=dtype), ) self.norm = RMSNorm(dim) self.percent = nn.Linear(dim, 1, bias=False, dtype=dtype) self.register_buffer( "coeffs", torch.tensor([1, 1315423911, 2654435761, 2246822519], dtype=torch.long), persistent=False, ) self.register_buffer( "salts", torch.tensor([0, 97, 193, 389], dtype=torch.long), persistent=False, ) def _hash_ngrams(self, input_ids: torch.Tensor) -> torch.Tensor: B, S = input_ids.shape if self.ngram > 1: pad = torch.zeros((B, self.ngram - 1), device=input_ids.device, dtype=input_ids.dtype) ids = torch.cat([pad, input_ids], dim=1) windows = ids.unfold(1, self.ngram, 1).long() # [B, S, ngram] else: windows = input_ids.long().unsqueeze(-1) # [B, S, 1] coeffs = self.coeffs[:windows.size(-1)].view(1, 1, -1) base = (windows * coeffs).sum(dim=-1) # [B, S] hashes = [] for salt in self.salts[:self.num_heads]: hashes.append(torch.remainder(base + salt, self.memory_size)) return torch.stack(hashes, dim=-1) # [B, S, H] def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: idx = self._hash_ngrams(input_ids) # [B, S, H] mem = self.memory(idx) # [B, S, H, D] x = x.to(self.key_gate.weight.dtype) gates = torch.softmax(self.key_gate(x), dim=-1).unsqueeze(-1) # [B, S, H, 1] out = (mem * gates).sum(dim=2) # [B, S, D] out = self.post(self.norm(out)) return torch.sigmoid(self.percent(out)) * out