""" model_v2.py -- SpikeWhaleLM v2: optimized base architecture. Changes vs model.py (v1): PERFORMANCE - SparseMoEFFN: sort-based expert dispatch (one contiguous slice per expert, index_add_ scatter-back) replaces per-expert boolean masking. Far fewer kernel launches, torch.compile-friendly (no data-dependent boolean indexing in the hot path). - Shared experts fused into ONE ExpertFFN with n_shared * intermediate width (mathematically equivalent to the averaged sum, 1 matmul set instead of N). QUALITY / STABILITY - QK-Norm: per-head RMSNorm on Q and K before RoPE (Gemma2/OLMo2-style). Stabilizes attention logits, tolerates higher LR. (cfg.use_qk_norm, default ON) - z-loss on lm_head logits: zloss_coef * mean(log^2 Z). Prevents logit drift. (cfg.zloss_coef, default 1e-4; set 0 to disable) - MTP heads REDESIGNED: instead of K independent full H x V matrices (which at 50M params dwarfed the model), each MTP head is now a small zero-init H x H projection feeding the SHARED lm_head. Param cost per head: H^2 instead of H*V. MTP loss is down-weighted by cfg.mtp_loss_weight (default 0.3). - HC output: learned softmax mix over streams (HCOutputMix) instead of mean(). - Value-embedding residual (nanoGPT-speedrun style): per-layer learned gate (zero-init => exact no-op at init) adds a projection of the token embedding into each block's input. (cfg.use_value_embed, default OFF = opt-in) All new config keys are read with getattr(cfg, key, default) so your existing config.py works unmodified. NOTE: QK-Norm and HCOutputMix add parameters, so v1 checkpoints need load_state_dict(strict=False) (new params keep init; QK-Norm at init is NOT identity -- prefer training v2 from scratch, or set use_qk_norm=False to stay v1-loadable). XSA is kept byte-identical to v1 but read the note in MLADerfXSAAttention: with num_kv_heads == 1 it removes the SAME rank-1 value subspace from every head. A/B it at 50M before keeping it in the final base. """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from torch.utils.checkpoint import checkpoint as gradient_checkpoint try: # Dotted import so HuggingFace's trust_remote_code loader fetches config.py # as a relative dependency; falls back to flat import for local script use. from .config import SpikeWhaleConfig except ImportError: from config import SpikeWhaleConfig # --------------------------------------------------------------------------- # Primitives # --------------------------------------------------------------------------- class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight class RotaryEmbedding(nn.Module): """RoPE for the rope partition of Q and K (qk_rope_head_dim dims only).""" def __init__(self, dim: int, max_positions: int = 4096, theta: float = 10000.0): super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(max_positions).float() freqs = torch.outer(t, inv_freq) self.register_buffer("cos_cache", freqs.cos()) self.register_buffer("sin_cache", freqs.sin()) def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: cos = self.cos_cache[position_ids].unsqueeze(1) # [B, 1, S, rope_dim//2] sin = self.sin_cache[position_ids].unsqueeze(1) d = cos.shape[-1] x1, x2 = x[..., :d], x[..., d:] return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) # --------------------------------------------------------------------------- # Engram: N-gram hash lookup + DERF gate (unchanged from v1) # --------------------------------------------------------------------------- class TokenCompressor(nn.Module): def __init__(self, embed_dim: int, compress_dim: int): super().__init__() self.proj = nn.Linear(embed_dim, compress_dim, bias=False) nn.init.normal_(self.proj.weight, std=0.02) # Frozen LSH-style projection: gradient never reaches it through the # .long() hash cast, so a fixed random projection is correct (see v1). self.proj.weight.requires_grad_(False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) class MultiHeadHashLookup(nn.Module): def __init__(self, num_heads: int, table_size: int, compress_dim: int, out_dim: int, max_ngram: int = 3): super().__init__() self.num_heads = num_heads self.table_size = table_size self.max_ngram = max_ngram self.out_dim = out_dim self.tables = nn.ModuleList([ nn.Embedding(table_size, out_dim) for _ in range(num_heads) ]) for t in self.tables: nn.init.normal_(t.weight, std=0.01) for n in range(1, max_ngram + 1): for k in range(n): proj = torch.randn(num_heads, compress_dim) proj = proj / (proj.norm(dim=1, keepdim=True) + 1e-8) self.register_buffer(f"hash_proj_n{n}_p{k}", proj) def forward(self, compressed: torch.Tensor) -> torch.Tensor: B, S, _ = compressed.shape device = compressed.device out = torch.zeros(B, S, self.out_dim, device=device, dtype=compressed.dtype) norm = torch.zeros(S, device=device) for n in range(1, self.max_ngram + 1): if S < n: continue valid_len = S - n + 1 start = n - 1 h = torch.zeros(B, valid_len, self.num_heads, device=device) for k in range(n): proj = getattr(self, f"hash_proj_n{n}_p{k}") h = h + torch.matmul(compressed[:, k:k + valid_len, :].float(), proj.t()) idx = h.abs().long() % self.table_size for head_idx, table in enumerate(self.tables): out[:, start:, :] = out[:, start:, :] + table(idx[:, :, head_idx]) norm[start:] += self.num_heads return (out / norm.view(1, -1, 1).clamp(min=1)).to(compressed.dtype) class DERFContextGate(nn.Module): def __init__(self, obs_size: int, init_bias: float = -4.0): super().__init__() self.proj = nn.Linear(obs_size * 2, obs_size) self.alpha = nn.Parameter(torch.ones(obs_size)) self.bias = nn.Parameter(torch.full((obs_size,), init_bias)) self.gamma = nn.Parameter(torch.ones(obs_size)) def forward(self, retrieved: torch.Tensor, x: torch.Tensor) -> torch.Tensor: logits = self.proj(torch.cat([retrieved, x], dim=-1)) gate = self.gamma * ((torch.erf(self.alpha * logits + self.bias) + 1.0) / 2.0) return retrieved * gate class EngramModule(nn.Module): def __init__(self, cfg: SpikeWhaleConfig): super().__init__() self.compressor = TokenCompressor(cfg.hidden_size, cfg.engram_compress_dim) self.lookup = MultiHeadHashLookup( cfg.engram_num_heads, cfg.engram_table_size, cfg.engram_compress_dim, cfg.hidden_size, cfg.engram_max_ngram, ) self.gate = DERFContextGate(cfg.hidden_size, cfg.engram_gate_init_bias) def forward(self, x: torch.Tensor) -> torch.Tensor: compressed = self.compressor(x.detach()) retrieved = self.lookup(compressed) return self.gate(retrieved, x) # --------------------------------------------------------------------------- # Hyper-Connections # --------------------------------------------------------------------------- class HyperConnectionLayer(nn.Module): """Simplified HC: softmax pre-mix / post-distribute over hc_mult streams. Asymmetric init (v1 bugfix) so streams diverge and gradients flow.""" def __init__(self, hidden_size: int, hc_mult: int, sinkhorn_iters: int = 20, eps: float = 1e-6): super().__init__() self.hc_mult = hc_mult self.pre_weight = nn.Parameter( torch.linspace(0.5, -0.5, hc_mult) / max(hc_mult, 1) ) self.post_weight = nn.Parameter( torch.linspace(-0.5, 0.5, hc_mult) / max(hc_mult, 1) ) def pre_op(self, copies: torch.Tensor) -> torch.Tensor: w = F.softmax(self.pre_weight, dim=0) return (copies * w.view(1, -1, 1, 1)).sum(dim=1) def post_op(self, copies: torch.Tensor, delta: torch.Tensor) -> torch.Tensor: w = F.softmax(self.post_weight, dim=0) return copies + delta.unsqueeze(1) * w.view(1, -1, 1, 1) class HCOutputMix(nn.Module): """ NEW (v2): learned combination of the hc_mult streams at the model output, replacing the v1 mean(dim=1). Mean forces the streams toward redundancy at exactly the point where you want them specialized. Initialized uniform so it starts identical to mean() -- a strict generalization, zero risk. """ def __init__(self, hc_mult: int): super().__init__() self.weight = nn.Parameter(torch.zeros(hc_mult)) # softmax(0)=uniform=mean def forward(self, copies: torch.Tensor) -> torch.Tensor: w = F.softmax(self.weight, dim=0) return (copies * w.view(1, -1, 1, 1)).sum(dim=1) # --------------------------------------------------------------------------- # MLA + (DERF) + XSA Attention, now with QK-Norm # --------------------------------------------------------------------------- class MLADerfXSAAttention(nn.Module): """ v2 additions: - QK-Norm (cfg.use_qk_norm, default True): per-head RMSNorm applied to Q and K BEFORE the rope/nope split. Bounds attention logits, the standard modern stability fix; composes cleanly with SDPA and partial RoPE. XSA NOTE (unchanged mechanics, important caveat): with num_kv_heads == 1 (MQA) every query head shares the same value vector, so the self-projection subtraction removes the SAME rank-1 value subspace from all heads -- much more aggressive than per-head XSA. Ablate use_xsa on/off at 50M before locking the base config. """ def __init__(self, cfg: SpikeWhaleConfig): super().__init__() self.num_heads = cfg.num_attention_heads self.num_kv_heads = cfg.num_key_value_heads self.head_dim = cfg.head_dim self.qk_rope_head_dim = cfg.qk_rope_head_dim self.nope_head_dim = cfg.nope_head_dim self.hidden_size = cfg.hidden_size self.use_derf = cfg.use_derf self.use_xsa = cfg.use_xsa self.dropout_p = cfg.attention_dropout self.kv_groups = self.num_heads // self.num_kv_heads self.use_qk_norm = getattr(cfg, "use_qk_norm", True) self.q_a_proj = nn.Linear(cfg.hidden_size, cfg.q_lora_rank, bias=False) self.q_a_norm = RMSNorm(cfg.q_lora_rank, cfg.rms_norm_eps) self.q_b_proj = nn.Linear(cfg.q_lora_rank, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(cfg.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(cfg.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_a_proj = nn.Linear(self.num_heads * self.head_dim, cfg.o_lora_rank, bias=False) self.o_b_proj = nn.Linear(cfg.o_lora_rank, cfg.hidden_size, bias=False) # QK-Norm: one RMSNorm over head_dim, shared across heads (Gemma-2 style). if self.use_qk_norm: self.q_norm = RMSNorm(self.head_dim, cfg.rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, cfg.rms_norm_eps) self.rope = RotaryEmbedding( self.qk_rope_head_dim, max_positions=cfg.max_position_embeddings, theta=cfg.rope_theta, ) if self.use_derf: self.derf_alpha = nn.Parameter(torch.ones(self.num_heads)) self.derf_bias = nn.Parameter(torch.zeros(self.num_heads)) self.derf_gamma = nn.Parameter(torch.ones(self.num_heads)) for m in (self.q_a_proj, self.q_b_proj, self.k_proj, self.v_proj, self.o_a_proj, self.o_b_proj): nn.init.normal_(m.weight, std=cfg.initializer_range) def forward( self, x: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, S, _ = x.shape q = self.q_a_norm(self.q_a_proj(x)) q = self.q_b_proj(q).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) # QK-Norm before RoPE (v2). Cache stores the NORMALIZED k so prefill and # incremental decode agree. if self.use_qk_norm: q = self.q_norm(q) k = self.k_norm(k) q_nope = q[..., :self.nope_head_dim] q_rope = q[..., self.nope_head_dim:] k_nope = k[..., :self.nope_head_dim] k_rope = k[..., self.nope_head_dim:] q_rope = self.rope(q_rope, position_ids) k_rope = self.rope(k_rope, position_ids) q = torch.cat([q_nope, q_rope], dim=-1) k = torch.cat([k_nope, k_rope], dim=-1) if past_key_value is not None: k = torch.cat([past_key_value[0], k], dim=2) v = torch.cat([past_key_value[1], v], dim=2) present = (k, v) if use_cache else None N = k.shape[2] if self.kv_groups > 1: k = k.unsqueeze(2).expand(-1, -1, self.kv_groups, -1, -1).reshape( B, self.num_heads, N, self.head_dim) v = v.unsqueeze(2).expand(-1, -1, self.kv_groups, -1, -1).reshape( B, self.num_heads, N, self.head_dim) if self.use_derf: scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if attention_mask is None and past_key_value is None: is_masked = torch.triu( torch.ones(S, N, dtype=torch.bool, device=scores.device), diagonal=N - S + 1, ).unsqueeze(0).unsqueeze(0) else: is_masked = (attention_mask < -1.0) if attention_mask is not None \ else torch.zeros_like(scores, dtype=torch.bool) safe_scores = scores.masked_fill(is_masked, -10000.0) a = self.derf_alpha.view(1, -1, 1, 1) b = self.derf_bias.view(1, -1, 1, 1) g = self.derf_gamma.view(1, -1, 1, 1) attn_weights = g * torch.erf(a * safe_scores + b) attn_weights = (attn_weights + g) / 2.0 attn_weights = attn_weights.masked_fill(is_masked, 0.0) attn_weights = attn_weights / (attn_weights.sum(dim=-1, keepdim=True) + 1e-8) if self.dropout_p > 0 and self.training: attn_weights = F.dropout(attn_weights, p=self.dropout_p) y = torch.matmul(attn_weights, v) else: q = q.contiguous() k = k.contiguous() v = v.contiguous() drop = self.dropout_p if self.training else 0.0 if past_key_value is None and attention_mask is None: y = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=drop) else: if attention_mask is not None: is_masked = (attention_mask < -1.0) else: is_masked = torch.triu( torch.ones(S, N, dtype=torch.bool, device=q.device), diagonal=N - S + 1, ).unsqueeze(0).unsqueeze(0) y = F.scaled_dot_product_attention( q, k, v, attn_mask=~is_masked, dropout_p=drop) if self.use_xsa: past_len = N - S v_self = v[:, :, past_len:past_len + S, :] vn = v_self / (v_self.norm(dim=-1, keepdim=True) + 1e-8) projection = (y * vn).sum(dim=-1, keepdim=True) * vn y = y - projection y = y.transpose(1, 2).contiguous().view(B, S, self.num_heads * self.head_dim) y = self.o_b_proj(self.o_a_proj(y)) return y, present # --------------------------------------------------------------------------- # MoE FFN -- v2: sort-based dispatch + fused shared expert # --------------------------------------------------------------------------- class ExpertFFN(nn.Module): """Single SwiGLU expert.""" def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) def sqrtsoftplus(x: torch.Tensor) -> torch.Tensor: return torch.sqrt(F.softplus(x) + 1e-8) class SparseMoEFFN(nn.Module): """ v2 changes: - FUSED shared expert: one ExpertFFN with width n_shared * intermediate, scaled by 1/n_shared on output -- equivalent to v1's averaged Python loop, one fused matmul set. (state-dict key changes: shared_expert.*) - SORT-BASED dispatch for routed experts: flatten (token, slot) pairs, argsort by expert id, run each expert on ONE contiguous slice, weighted index_add_ back. No boolean masks, no nonzero(), no per-expert scatter. Routing logic (hash routing, sqrtsoftplus, aux loss) is unchanged. """ def __init__(self, cfg: SpikeWhaleConfig, layer_idx: int = 0): super().__init__() self.n_routed_experts = cfg.n_routed_experts self.n_shared_experts = cfg.n_shared_experts self.num_experts_per_tok = cfg.num_experts_per_tok self.norm_topk_prob = cfg.norm_topk_prob self.scoring_func = cfg.scoring_func self.routed_scaling_factor = cfg.routed_scaling_factor self.use_hash_routing = layer_idx < cfg.num_hash_layers self.aux_loss_coef = cfg.moe_aux_loss_coef self.router = nn.Linear(cfg.hidden_size, cfg.n_routed_experts, bias=False) self.experts = nn.ModuleList([ ExpertFFN(cfg.hidden_size, cfg.moe_intermediate_size) for _ in range(cfg.n_routed_experts) ]) # Fused shared expert (v2) self.shared_expert = ( ExpertFFN(cfg.hidden_size, cfg.moe_intermediate_size * cfg.n_shared_experts) if cfg.n_shared_experts > 0 else None ) self._last_aux_loss: Optional[torch.Tensor] = None def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: B, S, H = x.shape x_flat = x.view(B * S, H) T = B * S K = self.num_experts_per_tok # Shared expert: always active, single fused pass. if self.shared_expert is not None: shared_out = self.shared_expert(x_flat) if self.n_shared_experts > 1: shared_out = shared_out / self.n_shared_experts else: shared_out = None # ---- Routing (unchanged logic) ---- if self.use_hash_routing: if position_ids is not None: base = (position_ids.reshape(T, 1) % self.n_routed_experts).long() else: base = (torch.arange(T, device=x.device) % self.n_routed_experts).unsqueeze(1) offsets = torch.arange(K, device=x.device) top_k_indices = (base + offsets.unsqueeze(0)) % self.n_routed_experts # [T, K] top_k_weights = torch.full((T, K), 1.0 / K, device=x.device, dtype=x_flat.dtype) self._last_aux_loss = None else: router_logits = self.router(x_flat) if self.scoring_func == "sqrtsoftplus": routing_scores = sqrtsoftplus(router_logits) else: routing_scores = F.softmax(router_logits, dim=-1) top_k_scores, top_k_indices = torch.topk(routing_scores, K, dim=-1) if self.norm_topk_prob: top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-8) else: top_k_weights = top_k_scores top_k_weights = top_k_weights * self.routed_scaling_factor softmax_probs = F.softmax(router_logits, dim=-1) expert_mask = torch.zeros_like(softmax_probs) expert_mask.scatter_(1, top_k_indices, 1.0) f_e = expert_mask.mean(0) p_e = softmax_probs.mean(0) self._last_aux_loss = self.n_routed_experts * (f_e * p_e).sum() * self.aux_loss_coef # ---- Sort-based dispatch (v2) ---- # Flatten the (token, slot) assignment: T*K rows total. flat_expert = top_k_indices.reshape(-1) # [T*K] flat_weight = top_k_weights.reshape(-1, 1) # [T*K, 1] flat_token = torch.arange(T, device=x.device).repeat_interleave(K) # [T*K] order = torch.argsort(flat_expert, stable=True) # group by expert sorted_expert = flat_expert[order] sorted_token = flat_token[order] sorted_weight = flat_weight[order] counts = torch.bincount(sorted_expert, minlength=self.n_routed_experts) # boundaries per expert in the sorted order (CPU sync once per forward; # unavoidable without grouped-GEMM, still vastly cheaper than v1's # per-expert nonzero/masking) counts_list = counts.tolist() gathered = x_flat[sorted_token] # [T*K, H] out_flat = torch.zeros_like(x_flat) start = 0 for expert_idx, cnt in enumerate(counts_list): if cnt == 0: continue end = start + cnt seg = gathered[start:end] seg_out = self.experts[expert_idx](seg) * sorted_weight[start:end] out_flat.index_add_(0, sorted_token[start:end], seg_out.to(out_flat.dtype)) start = end if shared_out is not None: out_flat = out_flat + shared_out return out_flat.view(B, S, H) def get_aux_loss(self) -> Optional[torch.Tensor]: return self._last_aux_loss class DenseFFN(nn.Module): def __init__(self, cfg: SpikeWhaleConfig): super().__init__() self.gate_proj = nn.Linear(cfg.hidden_size, cfg.moe_intermediate_size, bias=False) self.up_proj = nn.Linear(cfg.hidden_size, cfg.moe_intermediate_size, bias=False) self.down_proj = nn.Linear(cfg.moe_intermediate_size, cfg.hidden_size, bias=False) def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) def get_aux_loss(self) -> Optional[torch.Tensor]: return None # --------------------------------------------------------------------------- # Transformer block # --------------------------------------------------------------------------- class TransformerBlock(nn.Module): def __init__(self, cfg: SpikeWhaleConfig, layer_idx: int): super().__init__() self.use_hc = cfg.use_hyper_connections self.hidden_dropout = cfg.hidden_dropout self.use_value_embed = getattr(cfg, "use_value_embed", False) self.attn_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) self.attn = MLADerfXSAAttention(cfg) self.ffn_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) if cfg.use_moe and layer_idx in cfg.moe_layers: self.ffn = SparseMoEFFN(cfg, layer_idx) self.is_moe = True else: self.ffn = DenseFFN(cfg) self.is_moe = False if self.use_hc: self.hc_attn = HyperConnectionLayer(cfg.hidden_size, cfg.hc_mult, cfg.hc_sinkhorn_iters, cfg.hc_eps) self.hc_ffn = HyperConnectionLayer(cfg.hidden_size, cfg.hc_mult, cfg.hc_sinkhorn_iters, cfg.hc_eps) # NEW (v2, opt-in): value-embedding residual. Zero-init gate -> exact # no-op at init; learns to mix raw token-embedding signal into each # block's input (nanoGPT-speedrun "value embedding"/U-net skip family; # consistent wins at the 50-500M scale). if self.use_value_embed: self.ve_gate = nn.Parameter(torch.zeros(1)) def forward( self, x: torch.Tensor, # [B, hc_mult, S, H] if HC else [B, S, H] position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple] = None, use_cache: bool = False, token_embed: Optional[torch.Tensor] = None, # [B, S, H] (value-embed) ) -> Tuple[torch.Tensor, Optional[Tuple], Optional[torch.Tensor]]: # --- Attention sub-layer --- if self.use_hc: h = self.hc_attn.pre_op(x) else: h = x if self.use_value_embed and token_embed is not None: h = h + torch.tanh(self.ve_gate) * token_embed attn_out, present = self.attn( self.attn_norm(h), position_ids, attention_mask, past_key_value, use_cache ) attn_out = F.dropout(attn_out, p=self.hidden_dropout, training=self.training) if self.use_hc: x = self.hc_attn.post_op(x, attn_out) h = self.hc_ffn.pre_op(x) else: h = h + attn_out # --- FFN sub-layer --- ffn_out = self.ffn(self.ffn_norm(h), position_ids) ffn_out = F.dropout(ffn_out, p=self.hidden_dropout, training=self.training) if self.use_hc: x = self.hc_ffn.post_op(x, ffn_out) else: x = h + ffn_out return x, present, self.ffn.get_aux_loss() # --------------------------------------------------------------------------- # HRM refinement (unchanged) # --------------------------------------------------------------------------- class HRMRefinementBlock(nn.Module): def __init__(self, hidden_size: int, refine_dim: int, steps: int, eps: float = 1e-6): super().__init__() self.steps = steps self.norm = RMSNorm(hidden_size, eps) self.down = nn.Linear(hidden_size * 2, refine_dim, bias=False) self.up = nn.Linear(refine_dim, hidden_size, bias=False) self.gate = nn.Parameter(torch.zeros(steps)) nn.init.normal_(self.down.weight, std=0.02) nn.init.zeros_(self.up.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: anchor = x h = x for t in range(self.steps): inp = torch.cat([self.norm(h), anchor], dim=-1) update = self.up(F.silu(self.down(inp))) h = h + torch.tanh(self.gate[t]) * update return h # --------------------------------------------------------------------------- # Full model # --------------------------------------------------------------------------- class SpikeWhaleModel(nn.Module): """Decoder stack without LM head.""" def __init__(self, cfg: SpikeWhaleConfig): super().__init__() self.cfg = cfg self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size) nn.init.normal_(self.embed_tokens.weight, std=cfg.initializer_range) self.engram = EngramModule(cfg) if cfg.use_engram else None self.layers = nn.ModuleList([ TransformerBlock(cfg, layer_idx=i) for i in range(cfg.num_hidden_layers) ]) self.norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) self.hc_out_mix = ( HCOutputMix(cfg.hc_mult) if cfg.use_hyper_connections else None ) self.hrm_refine = ( HRMRefinementBlock(cfg.hidden_size, cfg.hrm_refine_dim, cfg.hrm_refine_steps, cfg.rms_norm_eps) if getattr(cfg, "use_hrm_refine", False) else None ) self.use_value_embed = getattr(cfg, "use_value_embed", False) self.gradient_checkpointing = False def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[List[Tuple]], torch.Tensor]: B, S = input_ids.shape device = input_ids.device if position_ids is None: past_len = past_key_values[0][0].shape[2] if past_key_values else 0 position_ids = torch.arange( past_len, past_len + S, device=device ).unsqueeze(0).expand(B, -1) x = self.embed_tokens(input_ids) token_embed = x if self.use_value_embed else None if self.engram is not None: x = x + self.engram(x) if self.cfg.use_hyper_connections: x = x.unsqueeze(1).expand(-1, self.cfg.hc_mult, -1, -1).clone() present_key_values = [] if use_cache else None total_aux_loss = torch.tensor(0.0, device=device) # Gradient checkpointing is incompatible with use_cache (the cache from # the discarded forward would be silently wrong on recompute). assert not (self.gradient_checkpointing and self.training and use_cache), \ "use_cache=True is not supported with gradient checkpointing" for layer_idx, layer in enumerate(self.layers): pkv = past_key_values[layer_idx] if past_key_values else None if self.gradient_checkpointing and self.training: x, present, aux_loss = gradient_checkpoint( layer, x, position_ids, attention_mask, None, False, token_embed, use_reentrant=False, ) else: x, present, aux_loss = layer( x, position_ids, attention_mask, pkv, use_cache, token_embed) if use_cache: present_key_values.append(present) if aux_loss is not None: total_aux_loss = total_aux_loss + aux_loss if self.cfg.use_hyper_connections: x = self.hc_out_mix(x) # v2: learned mix (init == mean) if self.hrm_refine is not None: x = self.hrm_refine(x) x = self.norm(x) return x, present_key_values, total_aux_loss class MTPHead(nn.Module): """ v2 MTP head: small zero-init H x H projection feeding the SHARED lm_head. Cost per head: H^2 params (e.g. 1M at H=1024) instead of H*V (e.g. 50M+). Zero-init means at step 0 the head predicts exactly what lm_head predicts for the residual path = 0, i.e. uniform-ish gradient pressure; the residual form (x + proj(x)) keeps it anchored to the trunk representation. """ def __init__(self, hidden_size: int): super().__init__() self.proj = nn.Linear(hidden_size, hidden_size, bias=False) nn.init.zeros_(self.proj.weight) def forward(self, hidden: torch.Tensor) -> torch.Tensor: return hidden + self.proj(hidden) class SpikeWhaleLM(PreTrainedModel): """ v2 loss = CE + zloss_coef * z-loss + mtp_loss_weight * mean(MTP CE) + MoE aux loss """ config_class = SpikeWhaleConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["TransformerBlock"] def __init__(self, cfg: SpikeWhaleConfig): super().__init__(cfg) self.model = SpikeWhaleModel(cfg) self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) nn.init.normal_(self.lm_head.weight, std=cfg.initializer_range) self.zloss_coef = getattr(cfg, "zloss_coef", 1e-4) self.mtp_loss_weight = getattr(cfg, "mtp_loss_weight", 0.3) # v2 MTP: H x H residual projections sharing lm_head (see MTPHead). self.mtp_heads = nn.ModuleList([ MTPHead(cfg.hidden_size) for _ in range(cfg.num_nextn_predict_layers) ]) if cfg.num_nextn_predict_layers > 0 else None self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def tie_weights(self, **kwargs): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight def save_pretrained(self, *args, **kwargs): tied = ( self.config.tie_word_embeddings and self.lm_head.weight.data_ptr() == self.model.embed_tokens.weight.data_ptr() ) if tied: self.lm_head.weight = nn.Parameter(self.model.embed_tokens.weight.detach().clone()) try: super().save_pretrained(*args, **kwargs) finally: if tied: self.lm_head.weight = self.model.embed_tokens.weight def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, SpikeWhaleModel): module.gradient_checkpointing = value def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple]] = None, labels: Optional[torch.Tensor] = None, use_cache: bool = False, **kwargs, ) -> CausalLMOutputWithPast: hidden, present_kvs, aux_loss = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, ) logits = self.lm_head(hidden) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() flat_logits = shift_logits.view(-1, shift_logits.size(-1)) flat_labels = shift_labels.view(-1) loss = F.cross_entropy(flat_logits, flat_labels, ignore_index=-100) # z-loss (v2): penalize log^2 of the partition function on valid # positions. Keeps logits from drifting; pairs well with Muon. if self.zloss_coef > 0: valid = flat_labels != -100 if valid.any(): log_z = torch.logsumexp(flat_logits[valid].float(), dim=-1) loss = loss + self.zloss_coef * (log_z ** 2).mean() # MTP (v2): residual H x H head -> shared lm_head, down-weighted. if self.mtp_heads is not None and self.mtp_loss_weight > 0: mtp_total = torch.tensor(0.0, device=loss.device) n_active = 0 for k, head in enumerate(self.mtp_heads, start=1): offset = k + 1 if hidden.size(1) > offset: mtp_hidden = head(hidden[..., :-offset, :]) mtp_logits = self.lm_head(mtp_hidden) mtp_labels = labels[..., offset:].contiguous() mtp_total = mtp_total + F.cross_entropy( mtp_logits.reshape(-1, mtp_logits.size(-1)), mtp_labels.reshape(-1), ignore_index=-100, ) n_active += 1 if n_active > 0: loss = loss + self.mtp_loss_weight * mtp_total / n_active loss = loss + aux_loss return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=present_kvs, ) def count_parameters(self) -> int: return sum(p.numel() for p in self.parameters())