Kurtis-EON1-Hybrid-0.7B-v0.1.1 / configuration_hybrid.py
mrs83's picture
Update configuration_hybrid.py
fe3d45a verified
"""
echo_hybrid/configuration_hybrid.py
─────────────────────────────────────────────────────────────────────────────
HybridEchoConfig: extends Qwen2Config with DSRN memory-injector parameters.
Design rationale
────────────────
Rather than inventing an entirely new config, we subclass Qwen2Config so that
every existing Qwen2 hyper-parameter (hidden_size, num_hidden_layers, etc.) is
available without duplication. The only additions are the four DSRN-specific
fields documented below.
CRITICAL NOTES (from AGENTS.md)
─────────────────────────────────
β€’ model_type MUST be "echo_hybrid" so AutoConfig routing works after
AutoConfig.register("echo_hybrid", HybridEchoConfig).
β€’ Do NOT use this config with EchoForCausalLM β€” that model expects EchoConfig.
"""
from transformers import Qwen2Config
class HybridEchoConfig(Qwen2Config):
"""
Qwen2Config subclass that adds DSRN memory-injector fields.
New fields
──────────
dsrn_state_dim : int
Dimension of the c_t slow-state vector maintained by each
DSRNMemoryInjector. Defaults to 512. Can be set equal to
hidden_size (896 for Qwen2-0.5B) for a richer slow-state, at the
cost of extra parameters per injector.
dsrn_injection_stride : int
Insert one DSRNMemoryInjector after every N transformer layers.
For Qwen2-0.5B (24 layers) the default of 4 yields 6 injectors.
dsrn_use_triton : bool
Route the parallel scan to the custom Triton kernel defined in
echo_hf/triton_scan.py. Disabled by default because the Triton
kernel targets CUDA/ROCm and is not available everywhere.
gate_bias_init : float
Initial value of linear_gate.bias in every injector. A positive
value (~1.0) keeps memory gates open at init, allowing gradients to
flow into c_t immediately. Increase to 2.0 if c_t norms do not
grow beyond ~0.0 after Phase-1 warm-up.
use_kv_cache : bool
Controls the Qwen2 backbone KV-cache. Independent of use_cache
(DSRN state return).
- True (default / recommended): Standard Hybrid mode β€” mode 2.
Backbone KV-cache active; attention handles fast-state, DSRN handles
slow-state. Best quality and lowest peak VRAM.
- False: Ablation / stateless mode β€” mode 1.
Backbone KV-cache disabled; every forward re-feeds the full growing
context so attention stays coherent. DSRN slow-state is the sole
cross-step memory. Useful for ablation studies and "Attention Tax"
vs "Recurrent Gain" benchmarks.
"""
model_type = "echo_hybrid"
def __init__(
self,
dsrn_state_dim: int = 512,
dsrn_injection_stride: int = 4,
dsrn_use_triton: bool = False,
gate_bias_init: float = 1.0,
use_kv_cache: bool = True, # Kill-switch: False = DSRN-only ablation
**kwargs,
):
super().__init__(**kwargs)
self.dsrn_state_dim = dsrn_state_dim
self.dsrn_injection_stride = dsrn_injection_stride
self.dsrn_use_triton = dsrn_use_triton
self.gate_bias_init = gate_bias_init
self.use_kv_cache = use_kv_cache
self.auto_map = {
"AutoConfig": "configuration_hybrid.HybridEchoConfig",
"AutoModel": "modeling_hybrid.HybridEchoModel",
"AutoModelForCausalLM": "modeling_hybrid.HybridEchoForCausalLM",
}