from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import GenerationMixin, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_echo import EchoConfig try: from vllm.model_executor.models.transformers import ALL_ATTENTION_FUNCTIONS except ImportError: ALL_ATTENTION_FUNCTIONS = {} try: from transformers.cache_utils import Cache except ImportError: class Cache: pass class EchoCache(Cache): """ Custom Cache to prevent Hugging Face's DynamicCache from dropping the (k_attn, v_attn) elements from the DSRN 4-tuple state. """ def __init__(self, states=None): self.states = states if states is not None else [] self.layers = self.states # HF expectation @property def is_compileable(self): return False def get_seq_length(self, layer_idx=0): if not self.states or len(self.states) <= layer_idx: return 0 state = self.states[layer_idx] if len(state) == 4: return state[2].shape[2] return 0 def get_max_length(self): return None def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[dict] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # EchoModel handles its own cache updates internally within the blocks. # This update method is just a shim to satisfy the Cache protocol. # k, v are already updated in the state tuple returned by the block. if len(self.states) > layer_idx: state = self.states[layer_idx] if len(state) == 4: return state[2], state[3] return key_states, value_states def get_usable_length(self, new_seq_length, layer_idx=0): return self.get_seq_length(layer_idx) def __getitem__(self, idx): return self.states[idx] def __len__(self): return len(self.states) def __iter__(self): return iter(self.states) def reorder_cache(self, beam_idx: torch.LongTensor): reordered_states = [] for layer_state in self.states: reordered_layer_state = tuple( tensor.index_select(0, beam_idx.to(tensor.device)) for tensor in layer_state ) reordered_states.append(reordered_layer_state) self.states = reordered_states # --- STANDALONE KERNELS (AUTOMAGICALLY INLINED) --- def _sequential_scan(a, b, h): """ Core sequential scan for a batch of sequences. Vectorized across all dimensions except time. """ a.shape[:-1] a.shape[-1] # a, b: (..., T, D) # h: (..., D) T = a.shape[-2] res = torch.empty_like(b) curr_h = h for t in range(T): curr_h = a[..., t, :] * curr_h + b[..., t, :] res[..., t, :] = curr_h return res, curr_h def dsrn_parallel_scan(g_t, m_t, c_0=None, chunk_size=32, use_triton=False): """ Parallel implementation of the DSRN slow-state update: c_t = (1 - g_t) * c_{t-1} + g_t * m_t Uses a Hierarchical Chunked Scan for O(T/K + K) speed and stability, or a custom Triton kernel for dramatically reduced memory bandwidth. """ # Global Override: Disabling Triton scan while debugging LoRA NaN gradients if use_triton and g_t.is_cuda: try: from .triton_scan import triton_dsrn_parallel_scan return triton_dsrn_parallel_scan(g_t, m_t, c_0) except ImportError: import warnings warnings.warn("Triton scan unavailable. Falling back to PyTorch scan.", UserWarning) orig_dtype = g_t.dtype a = (1.0 - g_t).float() b = (g_t * m_t).float() B, T, D = a.shape device = a.device # Pad T to be multiple of chunk_size pad_len = (chunk_size - (T % chunk_size)) % chunk_size if pad_len > 0: a = F.pad(a, (0, 0, 0, pad_len), value=1.0) b = F.pad(b, (0, 0, 0, pad_len), value=0.0) new_T = T + pad_len num_chunks = new_T // chunk_size # 1. Reshape to (B, num_chunks, chunk_size, D) a_chunks = a.view(B, num_chunks, chunk_size, D) b_chunks = b.view(B, num_chunks, chunk_size, D) # 2. Local scan within each chunk (vectorized across B and num_chunks) h_init_local = torch.zeros(B, num_chunks, D, device=device, dtype=torch.float32) c_res, c_final = _sequential_scan(a_chunks, b_chunks, h_init_local) # Summary of a for each chunk (product of a) a_final = torch.prod(a_chunks, dim=2) # (B, num_chunks, D) # 3. Global scan across chunk summaries h_0 = c_0.float() if c_0 is not None else torch.zeros(B, D, device=device, dtype=torch.float32) # h_chunk_outputs[:, j] is the state AFTER chunk j. h_chunk_outputs, _ = _sequential_scan(a_final, c_final, h_0) # The state BEFORE chunk j is h_chunk_outputs[:, j-1]. h_starts = torch.cat([h_0.unsqueeze(1), h_chunk_outputs[:, :-1]], dim=1) # 4. Final combine: h_{j, i} = a_prefix_{j, i} * h_starts[j] + c_res[j, i] a_prefix = torch.cumprod(a_chunks, dim=2) final_h = a_prefix * h_starts.unsqueeze(2) + c_res # Reshape back and crop, then cast back to original dtype return final_h.view(B, -1, D)[:, :T].to(orig_dtype) def rms_norm_fn(hidden_states, weight, eps=1e-6): input_dtype = hidden_states.dtype hidden_states = hidden_states.contiguous().to(torch.float32) variance = (hidden_states * hidden_states).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + eps) return weight * hidden_states.to(input_dtype) def dsrn_parallel_kernel_legacy( model_block: nn.Module, x: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor, eos_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Legacy DSRN kernel (Fixed LayerNorm, No Surprise Read). Identical to the version that passed verification. """ B, T, D = x.shape # 1. Norm and Projections x_norm = F.layer_norm( x, (D,), weight=model_block.norm_fast.weight, bias=model_block.norm_fast.bias, ) # Fast State Path (Scan) gru_proj = F.linear(x_norm, model_block.gru_cell.weight_ih, model_block.gru_cell.bias_ih) z_all = torch.sigmoid(gru_proj[:, :, :D]) r_all = torch.tanh(gru_proj[:, :, 2 * D :]) # Optimization: slice instead of chunk # --- EOS RESET LOGIC (Fast State) --- if eos_mask is not None: reset_mask = torch.roll(eos_mask, shifts=1, dims=1) reset_mask[:, 0] = ( 0 # First token reset depends on previous chunk eos, handled by h_prev/c_prev passing 0 ) # Apply strict reset to z_all z_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.ones_like(z_all), z_all) # h_t = (1 - z_t) * h_{t-1} + z_t * r_t h_all = dsrn_parallel_scan( z_all, r_all, h_prev, use_triton=getattr(model_block, "use_triton", False) ) h_new = h_all[:, -1] # 2. Slow State Path # CAUSAL SHIFT: Predict x[t] using h[t-1] # h_all is [h_1, ..., h_T]. We need [h_0, ..., h_{T-1}] # Prepend h_prev to shift h_shifted = torch.cat([h_prev.unsqueeze(1), h_all[:, :-1, :]], dim=1) x_pred = model_block.linear_pred(h_shifted) diff = x - x_pred error = torch.clamp(diff * diff, max=10.0).mean(dim=-1, keepdim=True) # Constrain surprise_lambda strictly positive to guarantee error opens the memory gate surprise_signal = error * torch.nn.functional.softplus(model_block.surprise_lambda) # Gates gate_logits = model_block.linear_gate(h_all) + surprise_signal g_all = torch.sigmoid(gate_logits) m_all = torch.tanh(model_block.linear_memory(h_all)) # --- EOS RESET LOGIC (Slow State) --- if eos_mask is not None: reset_mask = torch.roll(eos_mask, shifts=1, dims=1) reset_mask[:, 0] = 0 g_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.zeros_like(g_all), g_all) # c_t c_all = dsrn_parallel_scan( g_all, m_all, c_prev, use_triton=getattr(model_block, "use_triton", False) ) c_new = c_all[:, -1] # --- Inter-Chunk Reset --- # If the LAST token is EOS, then h_new/c_new (which are states FOR NEXT CHUNK) must be 0. if eos_mask is not None: last_is_eos = eos_mask[:, -1].float() # (B,) keep_prob = (1.0 - last_is_eos).unsqueeze(-1) # (B, 1) h_new = h_new * keep_prob c_new = c_new * keep_prob gate_stats = g_all.mean(dim=-1) # 3. Final MLP Path h_norm = F.layer_norm( h_all, (D,), weight=model_block.norm_ff.weight, bias=model_block.norm_ff.bias ) mlp_out = model_block.mlp_down(model_block.mlp_act(model_block.mlp_up(h_norm))) x_out = x + mlp_out # Continuous Read (Surprise Gate Fix) # Enabled on Legacy to fix Disconnected Slow State bug while keeping LayerNorm x_out = x_out + model_block.linear_read(c_all) return x_out, h_new, c_new, gate_stats def dsrn_parallel_kernel_hybrid( model_block: nn.Module, x: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor, eos_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Hybrid DSRN kernel (RMSNorm + Surprise Read). """ B, T, D = x.shape # 1. Norm (RMSNorm hardcoded for Hybrid path) x_norm = rms_norm_fn(x, model_block.norm_fast.weight) # Fast State gru_proj = F.linear(x_norm, model_block.gru_cell.weight_ih, model_block.gru_cell.bias_ih) z_all = torch.sigmoid(gru_proj[:, :, :D]) r_all = torch.tanh(gru_proj[:, :, 2 * D :]) # --- EOS RESET LOGIC (Fast State) --- if eos_mask is not None: reset_mask = torch.roll(eos_mask, shifts=1, dims=1) reset_mask[:, 0] = 0 z_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.ones_like(z_all), z_all) h_all = dsrn_parallel_scan( z_all, r_all, h_prev, use_triton=getattr(model_block, "use_triton", False) ) h_new = h_all[:, -1] # 2. Slow State # CAUSAL SHIFT: Predict x[t] using h[t-1] h_shifted = torch.cat([h_prev.unsqueeze(1), h_all[:, :-1, :]], dim=1) x_pred = model_block.linear_pred(h_shifted) diff = x - x_pred error = torch.clamp(diff * diff, max=10.0).mean(dim=-1, keepdim=True) # Constrain surprise_lambda strictly positive to guarantee error opens the memory gate surprise_signal = error * torch.nn.functional.softplus(model_block.surprise_lambda) gate_logits = model_block.linear_gate(h_all) + surprise_signal g_all = torch.sigmoid(gate_logits) m_all = torch.tanh(model_block.linear_memory(h_all)) # --- EOS RESET LOGIC (Slow State) --- if eos_mask is not None: reset_mask = torch.roll(eos_mask, shifts=1, dims=1) reset_mask[:, 0] = 0 g_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.zeros_like(g_all), g_all) c_all = dsrn_parallel_scan( g_all, m_all, c_prev, use_triton=getattr(model_block, "use_triton", False) ) c_new = c_all[:, -1] # --- Inter-Chunk Reset --- if eos_mask is not None: last_is_eos = eos_mask[:, -1].float() keep_prob = (1.0 - last_is_eos).unsqueeze(-1) h_new = h_new * keep_prob c_new = c_new * keep_prob gate_stats = g_all.mean(dim=-1) # 3. Final MLP h_norm = rms_norm_fn(h_all, model_block.norm_ff.weight) mlp_out = model_block.mlp_down(model_block.mlp_act(model_block.mlp_up(h_norm))) x_out = x + mlp_out # Continuous Read (Hybrid Feature) if model_block.use_hybrid_attention: x_out = x_out + model_block.linear_read(c_all) return x_out, h_new, c_new, gate_stats def dsrn_parallel_kernel( model_block: nn.Module, x: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor, eos_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Wrapper for backward compatibility. Dispatches based on config. """ if getattr(model_block, "use_rmsnorm", False): return dsrn_parallel_kernel_hybrid(model_block, x, h_prev, c_prev, eos_mask=eos_mask) return dsrn_parallel_kernel_legacy(model_block, x, h_prev, c_prev, eos_mask=eos_mask) class HymbaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ HymbaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class EchoRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=4096, base=10000.0, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base self.device = device # We NO LONGER use buffers here because they are being corrupted by # Hugging Face's weight loading mechanism for this specific model. # We will compute and move them on the first forward pass. self._cos_cached = None self._sin_cached = None def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len # Compute inv_freq locally inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) ) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self._cos_cached = emb.cos().to(dtype) self._sin_cached = emb.sin().to(dtype) def forward(self, x, seq_len=None): if ( self._cos_cached is None or seq_len > self.max_seq_len_cached or self._cos_cached.device != x.device ): self._set_cos_sin_cache( seq_len=max(seq_len, self.max_position_embeddings), device=x.device, dtype=x.dtype ) return ( self._cos_cached[:seq_len].to(dtype=x.dtype), self._sin_cached[:seq_len].to(dtype=x.dtype), ) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): cos = cos[position_ids].unsqueeze(unsqueeze_dim) # (B, 1, T, D) sin = sin[position_ids].unsqueeze(unsqueeze_dim) # (B, 1, T, D) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class SlidingWindowAttention(nn.Module): def __init__(self, config: EchoConfig): super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_heads self.head_dim = self.hidden_size // self.num_heads self.window_size = getattr(config, "window_size", 128) self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.rotary_emb = EchoRotaryEmbedding( self.head_dim, base=getattr(config, "rope_theta", 10000.0), ) def forward( self, x, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, **kwargs, ): B, T, C = x.shape qkv = self.qkv_proj(x) q, k, v = qkv.chunk(3, dim=-1) # Reshape for multi-head attention q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # --- RoPE Injection --- if position_ids is None: # Fallback if position_ids was not passed seq_length_with_past = T if past_key_values is not None: seq_length_with_past += past_key_values[0].shape[2] position_ids = ( torch.arange( seq_length_with_past - T, seq_length_with_past, dtype=torch.long, device=x.device, ) .unsqueeze(0) .view(-1, T) ) kv_seq_len = k.shape[2] if past_key_values is not None: kv_seq_len += past_key_values[0].shape[2] cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) # ---------------------- if past_key_values is not None: k_past, v_past = past_key_values k = torch.cat([k_past, k], dim=2) v = torch.cat([v_past, v], dim=2) # The cache MUST store the full history, do not overwrite it with truncated slices current_key_value = (k, v) # Create slices for attention computation k_attn = k v_attn = v # Enforce Sliding Window (Truncate oldest tokens for attention ONLY) if self.window_size is not None and k_attn.shape[2] > self.window_size: k_attn = k_attn[:, :, -self.window_size :, :] v_attn = v_attn[:, :, -self.window_size :, :] attn_fn = ALL_ATTENTION_FUNCTIONS.get( kwargs.get("attn_implementation", "sdpa"), F.scaled_dot_product_attention ) # Determining causality and windowing: # 1. Training (T > 1): Use sliding window causal mask. # 2. Decoding (T = 1): Use sliding window and NO CAUSAL MASK if T > 1: # Training/Prefill: Attend to full k, v but apply band-limited causal mask # Build sliding window causal mask (T, kv_seq_len) kv_all_seq_len = k.shape[2] past_seq_len = kv_all_seq_len - T mask = torch.zeros((T, kv_all_seq_len), device=x.device, dtype=x.dtype) row_idx = torch.arange(T, device=x.device).view(-1, 1) col_idx = torch.arange(kv_all_seq_len, device=x.device).view(1, -1) abs_pos = row_idx + past_seq_len # Causal upper triangle = -inf mask = torch.where(col_idx > abs_pos, float("-inf"), mask) # Keep tokens in range [abs_pos - self.window_size, abs_pos] if self.window_size is not None: mask = torch.where((abs_pos - col_idx) >= self.window_size, float("-inf"), mask) # Replace -inf with 0 for the permitted window (float mask expected by sdpa) mask = torch.where(mask == float("-inf"), mask, torch.zeros_like(mask)) y = attn_fn(q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0)) else: # Decoding: Recurrent step, attend only to the last window_size tokens y = attn_fn(q, k_attn, v_attn, is_causal=False) y = y.transpose(1, 2).contiguous().view(B, T, C) return self.out_proj(y), current_key_value class DSRNBlock(nn.Module): def __init__(self, config: EchoConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.state_size = config.hidden_size * config.num_heads self.use_triton = getattr(config, "use_triton", True) self.use_hybrid_attention = getattr(config, "use_hybrid_attention", True) self.use_rmsnorm = getattr(config, "use_rmsnorm", True) # Fast State (GRU) if self.use_rmsnorm: self.norm_fast = HymbaRMSNorm(config.hidden_size) else: self.norm_fast = nn.LayerNorm(config.hidden_size) self.gru_cell = nn.GRUCell(config.hidden_size, config.hidden_size) # Hybrid Attention if self.use_hybrid_attention: self.attn = SlidingWindowAttention(config) # Slow State (DSRN) self.linear_read = nn.Linear(self.state_size, config.hidden_size, bias=False) self.linear_gate = nn.Linear(config.hidden_size, self.state_size) self.linear_memory = nn.Linear(config.hidden_size, self.state_size) # -- Surprise Mechanism -- self.linear_pred = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.surprise_lambda = nn.Parameter(torch.zeros(self.state_size)) # Feed-Forward if self.use_rmsnorm: self.norm_ff = HymbaRMSNorm(config.hidden_size) else: self.norm_ff = nn.LayerNorm(config.hidden_size) # Simple MLP: Linear -> GELU -> Linear # mlp_up / mlp_act / mlp_down are the ONLY registered submodules. # No self.mlp alias — that caused double-registration and spurious "missing keys". intermediate_size = getattr( config, "intermediate_size", int(config.hidden_size * getattr(config, "mlp_ratio", 4.0)) ) self.mlp_up = nn.Linear(config.hidden_size, intermediate_size) self.mlp_act = nn.GELU() self.mlp_down = nn.Linear(intermediate_size, config.hidden_size) def forward( self, x: torch.Tensor, state_prev: Tuple[torch.Tensor, ...], **kwargs ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: # Unpack state # Supports (h, c) or (h, c, k_attn, v_attn) h_prev = state_prev[0] c_prev = state_prev[1] if self.use_triton and x.is_cuda: # Placeholder for Triton pass # Use Parallel Kernel x_out, h_new, c_new, gate_stats = dsrn_parallel_kernel(self, x, h_prev, c_prev) if self.use_hybrid_attention: # Re-apply norm for attention branch (cleanest for surgical transplant) x_norm = self.norm_fast(x) # Extract attention state from tuple if present (h, c, k_attn, v_attn) # HF state structure is now: (h, c, k_attn, v_attn) # But wait, past_key_values in forward loop is just (h,c) from legacy code. # We need to expand the state tuple to include attention KV. attn_kv = None if len(state_prev) == 4: attn_kv = (state_prev[2], state_prev[3]) attn_out, new_attn_kv = self.attn(x_norm, past_key_values=attn_kv, **kwargs) x_out = x_out + attn_out # Update state with new KV if new_attn_kv is not None: h_new_full = (h_new, c_new, new_attn_kv[0], new_attn_kv[1]) else: h_new_full = (h_new, c_new) else: h_new_full = (h_new, c_new) return x_out, h_new_full, gate_stats class EchoPreTrainedModel(PreTrainedModel): config_class = EchoConfig base_model_prefix = "model" _no_split_modules = ["DSRNBlock"] # Silently drop legacy mlp.0.*/mlp.1.*/mlp.2.* alias keys if they exist in old # local training checkpoints from before the self.mlp aliasing was removed. # The canonical names are mlp_up.* / mlp_act.* / mlp_down.* which load fine. _keys_to_ignore_on_load_unexpected = [ r".*\.mlp\.0\..*", r".*\.mlp\.1\..*", r".*\.mlp\.2\..*", ] def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) class EchoModel(EchoPreTrainedModel): supports_gradient_checkpointing = True _supports_attention_backend = True def __init__(self, config: EchoConfig): super().__init__(config) self.embed_dim = config.embed_dim self.num_layers = config.num_layers self.num_heads = config.num_heads self.state_dim = config.embed_dim * config.num_heads self.embedding = nn.Embedding(config.vocab_size, config.embed_dim) self.blocks = nn.ModuleList([DSRNBlock(config) for _ in range(config.num_layers)]) if getattr(config, "use_rmsnorm", False): self.final_norm = HymbaRMSNorm(config.hidden_size) else: self.final_norm = nn.LayerNorm(config.hidden_size) self.gradient_checkpointing = False self.post_init() # --- ZOMBIE GRADIENT PATCH (FIXED) --- # Fixed: Now using controlled bias defaults to 1.0 to encourage open gates initially bias_val = getattr(config, "gate_bias_init", 1.0) for block in self.blocks: nn.init.constant_(block.linear_gate.bias, bias_val) # Init Surprise if ( block.linear_pred.weight.dtype in (torch.bfloat16, torch.float16) and block.linear_pred.weight.is_cuda ): _device = block.linear_pred.weight.device _dtype = block.linear_pred.weight.dtype temp_w = torch.empty_like( block.linear_pred.weight, dtype=torch.float32, device="cpu" ) nn.init.orthogonal_(temp_w, gain=0.1) with torch.no_grad(): block.linear_pred.weight.copy_(temp_w.to(device=_device, dtype=_dtype)) else: nn.init.orthogonal_(block.linear_pred.weight, gain=0.1) nn.init.zeros_(block.surprise_lambda) # CRITICAL: Zero-Init Residual Output (Identity Start) nn.init.zeros_(block.mlp_down.weight) nn.init.zeros_(block.mlp_down.bias) def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None): """Enable/disable gradient checkpointing.""" self.gradient_checkpointing = enable def get_input_embeddings(self): return self.embedding def set_input_embeddings(self, value): self.embedding = value def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, output_dsrn_telemetry: Optional[bool] = False, **kwargs, ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_len = input_ids.shape x = self.embedding(input_ids) elif inputs_embeds is not None: batch_size, seq_len, _ = inputs_embeds.shape x = inputs_embeds else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = x.device # Initialize states if not provided or if it's an empty Cache object is_empty_cache = ( hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0 ) if past_key_values is None or is_empty_cache: past_key_values = [] for _ in range(self.num_layers): h = torch.zeros(batch_size, self.embed_dim, device=device, dtype=x.dtype) c = torch.zeros(batch_size, self.state_dim, device=device, dtype=x.dtype) past_key_values.append((h, c)) current_states = past_key_values next_states = [] all_gate_stats = [] if output_dsrn_telemetry else None all_c_states = [] if output_dsrn_telemetry else None # Layer-Major Execution for i, block in enumerate(self.blocks): # Handle potential DynamicCache structure or list of tuples if hasattr(current_states, "__getitem__"): state_i = current_states[i] else: state_i = current_states[i] if len(state_i) == 2: # DSRN Only pass elif len(state_i) == 4: # DSRN + Attention State pass else: # Fallback for empty/malformed states h_prev = torch.zeros(batch_size, self.embed_dim, device=device) c_prev = torch.zeros(batch_size, self.state_dim, device=device) state_i = (h_prev, c_prev) # Use gradient checkpointing if enabled if self.gradient_checkpointing and self.training: # Checkpointing complex states is tricky, usually just pass h/c out = torch.utils.checkpoint.checkpoint(block, x, state_i, use_reentrant=False) else: out = block(x, state_i, **kwargs) x = out[0] next_states.append(out[1]) if output_dsrn_telemetry: all_gate_stats.append(out[2]) all_c_states.append(out[1][1]) x = self.final_norm(x) if isinstance(current_states, EchoCache): current_states.states = next_states next_states = current_states elif EchoCache is not None: next_states = EchoCache(next_states) if output_dsrn_telemetry: return x, next_states, all_c_states, all_gate_stats return x, next_states class EchoForCausalLM(EchoPreTrainedModel, GenerationMixin): _is_causal = True supports_gradient_checkpointing = True _supports_cache_class = False _supports_static_cache = False main_input_name = "input_ids" def __init__(self, config: EchoConfig): super().__init__(config) self.model = EchoModel(config) self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None): """Enable/disable gradient checkpointing.""" self.model._set_gradient_checkpointing(enable, gradient_checkpointing_func) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, output_dsrn_telemetry: Optional[bool] = False, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False) ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False) ) use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", True) return_dict = ( return_dict if return_dict is not None else getattr(self.config, "use_return_dict", True) ) ''' If kwargs is getting overloaded with extra args HF generate passes, we safely extract kwargs here. ''' # Pass position_ids explicitly alongside **kwargs kwargs["position_ids"] = position_ids model_out = self.model( input_ids=input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, output_dsrn_telemetry=output_dsrn_telemetry, **kwargs, ) hidden_states = model_out[0] new_states = model_out[1] if len(model_out) > 2: self._latest_c_states = model_out[2] self._latest_gate_stats = model_out[3] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits, new_states) return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=new_states if use_cache else None, hidden_states=None, # EchoModel doesn't expose internal states yet attentions=None, # EchoModel doesn't expose attention weights yet ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, **kwargs ): # If past_key_values is a DynamicCache, we need to extract the underlying list of tuples # if the custom cache hasn't taken over yet. But actually, HF doesn't know about our 4-tuples. # So we should just let EchoModel handle it. If HF gave us a DynamicCache, it might be empty # or mangled. if ( past_key_values is not None and not isinstance(past_key_values, (list, tuple)) and not isinstance(past_key_values, EchoCache) ): # It's a DynamicCache. It's likely from the first generation step. # We can't use it directly because it stripped our (h,c). # But wait, on the VERY first generation step, past_key_values is None, then EchoModel returns EchoCache. # On subsequent steps we get EchoCache. # So if we get a DynamicCache, it means someone passed past_key_values explicitly to generate(), # or HF auto-created it on step 0 and passed it to step 1 incorrectly. pass # In newer transformers, past_key_values could be a DynamicCache. # Check if it's effectively empty. is_empty = False if past_key_values is None: is_empty = True elif hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0: is_empty = True elif isinstance(past_key_values, list) and len(past_key_values) == 0: is_empty = True # If past_key_values is used, we only need the last token if not is_empty: input_ids = input_ids[:, -1:] model_inputs = { "input_ids": input_ids, "past_key_values": past_key_values, "attention_mask": attention_mask, "use_cache": kwargs.get("use_cache"), } # Pass through extra kwargs like output_dsrn_telemetry model_inputs.update({k: v for k, v in kwargs.items() if k not in model_inputs}) return model_inputs def _reorder_cache(self, past_key_values, beam_idx): """ Reorders cache for beam search or contrastive search. past_key_values: List[Tuple(h, c, ...)] """ if past_key_values is None: return None reordered_past = [] for layer_past in past_key_values: # Each layer_past is a tuple of tensors (h, c) or (h, c, k, v) reordered_layer_past = tuple( p.index_select(0, beam_idx.to(p.device)) for p in layer_past ) reordered_past.append(reordered_layer_past) return reordered_past