Kurtis-EON1-Hybrid-0.7B-v0.1.1 / modeling_hybrid.py
mrs83's picture
Upload modeling_hybrid.py
6b3235f verified
"""
echo_hybrid/modeling_hybrid.py
────────────────────────────────────────────────────────────────────────────
HybridEchoModel — Qwen2 backbone with DSRN memory injectors
HybridEchoForCausalLM — wraps the model with an LM head and generation support
HybridEchoCache — DynamicCache subclass that also carries DSRN recurrent states
FINAL STABLE ARCHITECTURE: Stateless Backbone with Position Alignment.
"""
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint as gradient_checkpoint
from transformers import GenerationMixin
from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Model,
Qwen2PreTrainedModel,
)
from .configuration_hybrid import HybridEchoConfig
from .dsrn_memory_block import DSRNMemoryInjector
# ─────────────────────────────────────────────────────────────────────────────
# Cache
# ─────────────────────────────────────────────────────────────────────────────
class HybridEchoCache(DynamicCache):
def __init__(self, config=None):
# transformers 5.x: DynamicCache uses self.layers (not self.key_cache).
# config must be passed so __init__ creates the per-layer cache objects;
# without it self.layers=[] and update() silently discards KV tensors.
super().__init__(config=config)
self.dsrn_states: List[Tuple[torch.Tensor, torch.Tensor]] = []
self.seen_tokens = 0
@classmethod
def from_legacy_cache(
cls,
dsrn_states: List[Tuple[torch.Tensor, torch.Tensor]],
config=None,
) -> "HybridEchoCache":
cache = cls(config=config)
cache.dsrn_states = dsrn_states
return cache
def reorder_cache(self, beam_idx: torch.LongTensor):
super().reorder_cache(beam_idx)
self.dsrn_states = [
(
h.index_select(0, beam_idx.to(h.device)),
c.index_select(0, beam_idx.to(c.device)),
)
for h, c in self.dsrn_states
]
# ─────────────────────────────────────────────────────────────────────────────
# Inner model — Qwen2 backbone + hooked DSRN injectors
# ─────────────────────────────────────────────────────────────────────────────
class HybridEchoModel(Qwen2PreTrainedModel):
config_class = HybridEchoConfig
_tp_plan = {} # No tensor parallel plan; required by vLLM TransformersForCausalLM
def __init__(self, config: HybridEchoConfig):
super().__init__(config)
self.backbone = Qwen2Model(config)
stride = config.dsrn_injection_stride
self.num_injectors = config.num_hidden_layers // stride
self.memory_injectors = nn.ModuleList(
[DSRNMemoryInjector(config) for _ in range(self.num_injectors)]
)
self._dsrn_input_states = []
self._dsrn_output_states = []
self._eos_mask = None
self._hook_handles = []
self._register_dsrn_hooks()
self.gradient_checkpointing = False
def _apply_dsrn_critical_inits(self):
bias_val = getattr(self.config, "gate_bias_init", 1.0)
for injector in self.memory_injectors:
nn.init.zeros_(injector.linear_read.weight)
nn.init.constant_(injector.linear_gate.bias, bias_val)
nn.init.zeros_(injector.surprise_lambda)
w = injector.linear_pred.weight
if w.dtype in (torch.bfloat16, torch.float16):
tmp = torch.empty_like(w, dtype=torch.float32, device=w.device)
nn.init.orthogonal_(tmp, gain=0.1)
with torch.no_grad():
w.copy_(tmp.to(device=w.device, dtype=w.dtype))
else:
nn.init.orthogonal_(w, gain=0.1)
def _make_hook(self, injector_idx: int):
def hook(module: nn.Module, layer_input: tuple, layer_output):
if not self._dsrn_input_states:
return layer_output
if isinstance(layer_output, torch.Tensor):
hidden_states = layer_output
is_bare_tensor = True
else:
hidden_states = layer_output[0]
is_bare_tensor = False
h_prev, c_prev = self._dsrn_input_states[injector_idx]
injector = self.memory_injectors[injector_idx]
eos_mask = self._eos_mask
if self.gradient_checkpointing and self.training:
# Wrap the injector in a gradient checkpoint so its intermediate
# activations are discarded during the forward pass and recomputed
# per-injector during backward. use_reentrant=False is required
# for compatibility with torch.compile and nested checkpointing.
def injector_fn(hs, h, c):
return injector(hs, h, c, eos_mask=eos_mask)
x_out, h_new, c_new = gradient_checkpoint(
injector_fn,
hidden_states,
h_prev,
c_prev,
use_reentrant=False,
)
else:
x_out, h_new, c_new = injector(hidden_states, h_prev, c_prev, eos_mask=eos_mask)
self._dsrn_output_states.append((h_new, c_new))
if is_bare_tensor:
return x_out
return (x_out,) + layer_output[1:]
return hook
def _register_dsrn_hooks(self):
for h in self._hook_handles:
h.remove()
self._hook_handles = []
stride = self.config.dsrn_injection_stride
layers = self.backbone.layers
for injector_idx in range(self.num_injectors):
target_layer_idx = (injector_idx + 1) * stride - 1
if target_layer_idx < len(layers):
handle = layers[target_layer_idx].register_forward_hook(
self._make_hook(injector_idx)
)
self._hook_handles.append(handle)
def get_input_embeddings(self):
return self.backbone.embed_tokens
def set_input_embeddings(self, value):
self.backbone.embed_tokens = value
def _init_dsrn_states(self, batch_size: int, device: torch.device, dtype: torch.dtype):
D = self.config.hidden_size
D_s = self.config.dsrn_state_dim
return [
(
torch.zeros(batch_size, D, device=device, dtype=dtype),
torch.zeros(batch_size, D_s, device=device, dtype=dtype),
)
for _ in range(self.num_injectors)
]
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridEchoCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
eos_mask: Optional[torch.Tensor] = None,
**kwargs,
):
# Two independent cache controls:
# use_dsrn_cache — whether to return a HybridEchoCache carrying DSRN states.
# use_kv_cache — config flag controlling backbone KV cache (→ backbone_use_cache below).
# Ablation mode disables backbone KV but DSRN states must still survive across
# steps. Default is True so generation always carries recurrent state forward.
use_dsrn_cache = use_cache if use_cache is not None else True
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None:
B = input_ids.shape[0]
device = input_ids.device
dtype = self.backbone.embed_tokens.weight.dtype
seq_len = input_ids.shape[1]
elif inputs_embeds is not None:
B = inputs_embeds.shape[0]
device = inputs_embeds.device
dtype = inputs_embeds.dtype
seq_len = inputs_embeds.shape[1]
else:
raise ValueError("Provide input_ids or inputs_embeds.")
if past_key_values is not None and isinstance(past_key_values, HybridEchoCache):
dsrn_states = past_key_values.dsrn_states
if not dsrn_states:
dsrn_states = self._init_dsrn_states(B, device, dtype)
else:
dsrn_states = self._init_dsrn_states(B, device, dtype)
if past_key_values is None:
past_key_values = HybridEchoCache(config=self.config) if use_dsrn_cache else None
# Detach DSRN states before carrying forward to the next step.
# Without .detach(), _dsrn_input_states holds tensors with grad_fn
# from the previous forward pass. Step N+1's graph then includes
# step N's graph as a parent, preventing step N's activation tensors
# from being freed after backward — accumulating ~11.78 GiB per step.
self._dsrn_input_states = [(h.detach(), c.detach()) for h, c in dsrn_states]
# use_dsrn_cache only gates whether output states are *returned* in the
# HybridEchoCache. Setting this to [] when use_cache=False (training mode)
# caused the hook to exit early → injectors bypassed → grad_norm=0.
self._dsrn_output_states = []
self._eos_mask = eos_mask
# ── Qwen2 requires 2D position_ids ────────────────────────────────
final_position_ids = position_ids if position_ids is not None else cache_position
if final_position_ids is not None:
if final_position_ids.dim() == 1:
final_position_ids = final_position_ids.unsqueeze(0)
if final_position_ids.shape[1] > seq_len:
final_position_ids = final_position_ids[:, -seq_len:]
# ── STATELESS BACKBONE WITH POSITION ALIGNMENT ───────────────────
# backbone_use_cache is driven solely by config.use_kv_cache — it is
# independent of use_dsrn_cache. Disabling DSRN state return must
# never silently disable the backbone KV cache.
backbone_use_cache = getattr(self.config, "use_kv_cache", True)
if not backbone_use_cache:
# Force attention_mask=None to avoid SDPA length mismatches in stateless mode
attention_mask = None
# ── CRITICAL: Stateless RoPE Grounding ──────────────────────
# When the KV-cache is purged every step, the backbone sees an
# empty KV-cache and defaults position to 0 for every token,
# causing the "systemsystemsystem" identity collapse.
# We derive the true absolute position from seen_tokens so RoPE
# stays coherent even without a physical KV history.
if final_position_ids is None:
# Logical position of the first token in this call
offset = 0
if past_key_values is not None and hasattr(past_key_values, "seen_tokens"):
offset = past_key_values.seen_tokens
final_position_ids = torch.arange(
offset, offset + seq_len, device=device, dtype=torch.long
).unsqueeze(
0
) # (1, seq_len) — broadcast over batch
backbone_out = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=final_position_ids,
past_key_values=past_key_values if backbone_use_cache else None,
inputs_embeds=inputs_embeds,
use_cache=backbone_use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
new_dsrn_states = self._dsrn_output_states
new_cache = backbone_out.past_key_values
if use_dsrn_cache and (new_cache is not None or not backbone_use_cache):
if new_cache is None or not isinstance(new_cache, HybridEchoCache):
hybrid_cache = HybridEchoCache(config=self.config)
if new_cache is not None:
if hasattr(new_cache, 'layers'):
hybrid_cache.layers = new_cache.layers
if hasattr(new_cache, 'get_seq_length'):
hybrid_cache.seen_tokens = new_cache.get_seq_length()
elif past_key_values is not None:
if hasattr(past_key_values, 'seen_tokens'):
hybrid_cache.seen_tokens = past_key_values.seen_tokens
elif hasattr(past_key_values, 'get_seq_length'):
hybrid_cache.seen_tokens = past_key_values.get_seq_length()
new_cache = hybrid_cache
new_cache.dsrn_states = new_dsrn_states
if not getattr(self.config, "use_kv_cache", True):
new_cache.seen_tokens += seq_len
elif past_key_values is not None:
if isinstance(past_key_values, HybridEchoCache):
past_key_values.dsrn_states = new_dsrn_states
if not getattr(self.config, "use_kv_cache", True):
past_key_values.seen_tokens += seq_len
new_cache = past_key_values
self._dsrn_input_states = []
self._dsrn_output_states = []
self._eos_mask = None
if not return_dict:
return (backbone_out.last_hidden_state, new_cache)
backbone_out.past_key_values = new_cache
return backbone_out
class HybridEchoForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
config_class = HybridEchoConfig
_tp_plan = {} # No tensor parallel plan; required by vLLM TransformersForCausalLM
_is_causal = True
supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_static_cache = False
main_input_name = "input_ids"
def __init__(self, config: HybridEchoConfig):
super().__init__(config)
self.model = HybridEchoModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
# Do NOT call _apply_dsrn_critical_inits here - it overwrites checkpoint weights
def get_input_embeddings(self):
return self.model.backbone.embed_tokens
def set_input_embeddings(self, value):
self.model.backbone.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None):
self.model.gradient_checkpointing = enable
self.model.backbone.gradient_checkpointing = enable
if enable:
# Gradient checkpointing and the KV cache are mutually exclusive:
# GC discards activations between layers while the cache would
# re-materialise them, negating all memory savings.
self.config.use_cache = False
self.config.use_kv_cache = False
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridEchoCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
eos_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
eos_mask=eos_mask,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# logits stays in the model dtype (bfloat16). The float32 cast is
# deferred to the chunked cross-entropy loop below where it is applied
# to at most _CHUNK_SIZE positions at a time, keeping peak allocation
# to O(chunk_size × vocab_size) rather than O(batch × seq × vocab_size).
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Chunked cross-entropy — avoids materialising the full
# [batch*seq, vocab] float32 tensor all at once.
# At seq=2048, vocab=151936: full fp32 = 1.24 GiB which exceeds
# available VRAM headroom. Processing _CHUNK_SIZE positions per
# iteration keeps the peak allocation to ~311 MiB per chunk.
_CHUNK_SIZE = 512
_flat_logits = shift_logits.view(-1, self.config.vocab_size)
_flat_labels = shift_labels.view(-1)
_loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")
_total_loss = torch.zeros((), dtype=torch.float32, device=_flat_logits.device)
_total_tokens = _flat_logits.new_zeros((), dtype=torch.long)
for _i in range(0, _flat_logits.size(0), _CHUNK_SIZE):
_cl = _flat_logits[_i : _i + _CHUNK_SIZE].float()
_ll = _flat_labels[_i : _i + _CHUNK_SIZE]
_total_loss = _total_loss + _loss_fct(_cl, _ll)
_total_tokens = _total_tokens + (_ll != -100).sum()
del _cl
loss = _total_loss / _total_tokens.clamp(min=1)
if not return_dict:
out = (logits,) + (outputs.past_key_values,)
return ((loss,) + out) if loss is not None else out
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
**kwargs,
):
past_len = 0
if past_key_values is not None:
# Prefer actual KV tensor count; fall back to seen_tokens in ablation
# mode where KV tensors are empty (get_seq_length() returns 0).
kv_len = (
past_key_values.get_seq_length()
if hasattr(past_key_values, "get_seq_length")
else 0
)
if kv_len > 0:
past_len = kv_len
elif hasattr(past_key_values, "seen_tokens"):
past_len = past_key_values.seen_tokens
if cache_position is None:
cache_position = torch.arange(
past_len, past_len + input_ids.shape[1], device=input_ids.device
)
if past_len > 0:
if input_ids.shape[1] > 1:
input_ids = input_ids[:, -1:]
use_cache = kwargs.get("use_cache")
if use_cache is None:
use_cache = getattr(self.config, "use_kv_cache", True)
# In ablation mode the backbone has no physical KV cache to infer
# position from, so we must pass position_ids explicitly.
# cache_position already holds the correct absolute positions, so we
# promote it to 2D position_ids for the backbone's RoPE computation.
if not getattr(self.config, "use_kv_cache", True) and position_ids is None:
position_ids = cache_position.unsqueeze(0) # (1, seq_len)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"cache_position": cache_position,
"position_ids": position_ids,
}
def _reorder_cache(self, past_key_values, beam_idx):
if isinstance(past_key_values, HybridEchoCache):
past_key_values.reorder_cache(beam_idx)
return past_key_values
def generate(self, *args, **kwargs):
"""Override generate() to ensure HybridEchoCache is used from the start.
Without this, generate() begins with past_key_values=None which causes
HybridEchoModel.forward() to fall into the else-branch and create a plain
DynamicCache. On subsequent steps isinstance(pkv, HybridEchoCache) fails,
dsrn_states resets to zeros every token, and the model collapses
('systemsystemsystem').
We also enforce Standard Mode (CACHED+DSRN) by saving use_kv_cache=True on
self.config so that HybridEchoModel.forward() passes the backbone KV cache
through correctly — without it, backbone_use_cache evaluates to False and
the backbone receives past_key_values=None every step.
"""
self.config.use_kv_cache = True
if not isinstance(kwargs.get("past_key_values"), HybridEchoCache):
kwargs["past_key_values"] = HybridEchoCache(config=self.config)
return super().generate(*args, **kwargs)