mrs83 commited on
Commit
cb9201c
Β·
verified Β·
1 Parent(s): f9e3cdb

Upload 3 files

Browse files
Files changed (3) hide show
  1. configuration_echo.py +64 -0
  2. modeling_echo.py +980 -0
  3. triton_scan.py +521 -0
configuration_echo.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class EchoConfig(PretrainedConfig):
5
+ model_type = "echo"
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size=49152,
10
+ embed_dim=768,
11
+ num_layers=4,
12
+ num_heads=4,
13
+ mlp_ratio=4,
14
+ gate_bias_init=0.0,
15
+ use_hybrid_attention=True,
16
+ use_rmsnorm=True,
17
+ **kwargs,
18
+ ):
19
+ # Synchronize hidden_size and embed_dim
20
+ hidden_size = kwargs.pop("hidden_size", embed_dim)
21
+ if embed_dim != hidden_size:
22
+ # Prefer larger if both are non-standard
23
+ major_dim = max(embed_dim, hidden_size)
24
+ embed_dim = hidden_size = major_dim
25
+
26
+ self.vocab_size = vocab_size
27
+ self.embed_dim = embed_dim
28
+ self.hidden_size = hidden_size
29
+ self.num_layers = num_layers
30
+ self.num_heads = num_heads
31
+ self.mlp_ratio = mlp_ratio
32
+ self.gate_bias_init = gate_bias_init
33
+ self.use_hybrid_attention = use_hybrid_attention
34
+ self.use_rmsnorm = use_rmsnorm
35
+
36
+ # Standard HF aliases
37
+ self.num_hidden_layers = num_layers
38
+ self.num_attention_heads = num_heads
39
+
40
+ # TGI/HF AutoMap support
41
+ self.auto_map = {
42
+ "AutoConfig": "configuration_echo.EchoConfig",
43
+ "AutoModel": "modeling_echo.EchoModel",
44
+ "AutoModelForCausalLM": "modeling_echo.EchoForCausalLM",
45
+ }
46
+
47
+ # vLLM Advanced Parallelism Plans
48
+ self.base_model_tp_plan = {
49
+ "model.embedding": "rowwise",
50
+ "lm_head": "colwise",
51
+ "model.blocks.*.attn.qkv_proj": "colwise",
52
+ "model.blocks.*.attn.out_proj": "rowwise",
53
+ "model.blocks.*.mlp_up": "colwise",
54
+ "model.blocks.*.mlp_down": "rowwise",
55
+ "model.blocks.*.linear_gate": "colwise",
56
+ "model.blocks.*.linear_memory": "colwise",
57
+ "model.blocks.*.linear_read": "rowwise",
58
+ }
59
+
60
+ self.base_model_pp_plan = {
61
+ "blocks": (["x", "state_prev"], ["x", "h_new_full"]) # Inputs # Outputs
62
+ }
63
+
64
+ super().__init__(**kwargs)
modeling_echo.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import GenerationMixin, PreTrainedModel
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+
9
+ from .configuration_echo import EchoConfig
10
+
11
+ try:
12
+ from vllm.model_executor.models.transformers import ALL_ATTENTION_FUNCTIONS
13
+ except ImportError:
14
+ ALL_ATTENTION_FUNCTIONS = {}
15
+
16
+ try:
17
+ from transformers.cache_utils import Cache
18
+ except ImportError:
19
+
20
+ class Cache:
21
+ pass
22
+
23
+
24
+ class EchoCache(Cache):
25
+ """
26
+ Custom Cache to prevent Hugging Face's DynamicCache from dropping
27
+ the (k_attn, v_attn) elements from the DSRN 4-tuple state.
28
+ """
29
+
30
+ def __init__(self, states=None):
31
+ self.states = states if states is not None else []
32
+ self.layers = self.states # HF expectation
33
+
34
+ @property
35
+ def is_compileable(self):
36
+ return False
37
+
38
+ def get_seq_length(self, layer_idx=0):
39
+ if not self.states or len(self.states) <= layer_idx:
40
+ return 0
41
+ state = self.states[layer_idx]
42
+ if len(state) == 4:
43
+ return state[2].shape[2]
44
+ return 0
45
+
46
+ def get_max_length(self):
47
+ return None
48
+
49
+ def update(
50
+ self,
51
+ key_states: torch.Tensor,
52
+ value_states: torch.Tensor,
53
+ layer_idx: int,
54
+ cache_kwargs: Optional[dict] = None,
55
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
56
+ # EchoModel handles its own cache updates internally within the blocks.
57
+ # This update method is just a shim to satisfy the Cache protocol.
58
+ # k, v are already updated in the state tuple returned by the block.
59
+ if len(self.states) > layer_idx:
60
+ state = self.states[layer_idx]
61
+ if len(state) == 4:
62
+ return state[2], state[3]
63
+ return key_states, value_states
64
+
65
+ def get_usable_length(self, new_seq_length, layer_idx=0):
66
+ return self.get_seq_length(layer_idx)
67
+
68
+ def __getitem__(self, idx):
69
+ return self.states[idx]
70
+
71
+ def __len__(self):
72
+ return len(self.states)
73
+
74
+ def __iter__(self):
75
+ return iter(self.states)
76
+
77
+ def reorder_cache(self, beam_idx: torch.LongTensor):
78
+ reordered_states = []
79
+ for layer_state in self.states:
80
+ reordered_layer_state = tuple(
81
+ tensor.index_select(0, beam_idx.to(tensor.device)) for tensor in layer_state
82
+ )
83
+ reordered_states.append(reordered_layer_state)
84
+ self.states = reordered_states
85
+
86
+
87
+ # --- STANDALONE KERNELS (AUTOMAGICALLY INLINED) ---
88
+ def _sequential_scan(a, b, h):
89
+ """
90
+ Core sequential scan for a batch of sequences.
91
+ Vectorized across all dimensions except time.
92
+ """
93
+ a.shape[:-1]
94
+ a.shape[-1]
95
+ # a, b: (..., T, D)
96
+ # h: (..., D)
97
+ T = a.shape[-2]
98
+
99
+ res = torch.empty_like(b)
100
+ curr_h = h
101
+ for t in range(T):
102
+ curr_h = a[..., t, :] * curr_h + b[..., t, :]
103
+ res[..., t, :] = curr_h
104
+ return res, curr_h
105
+
106
+
107
+ def dsrn_parallel_scan(g_t, m_t, c_0=None, chunk_size=32, use_triton=False):
108
+ """
109
+ Parallel implementation of the DSRN slow-state update:
110
+ c_t = (1 - g_t) * c_{t-1} + g_t * m_t
111
+
112
+ Uses a Hierarchical Chunked Scan for O(T/K + K) speed and stability,
113
+ or a custom Triton kernel for dramatically reduced memory bandwidth.
114
+ """
115
+ # Global Override: Disabling Triton scan while debugging LoRA NaN gradients
116
+ if use_triton and g_t.is_cuda:
117
+ try:
118
+ from .triton_scan import triton_dsrn_parallel_scan
119
+
120
+ return triton_dsrn_parallel_scan(g_t, m_t, c_0)
121
+ except ImportError:
122
+ import warnings
123
+
124
+ warnings.warn("Triton scan unavailable. Falling back to PyTorch scan.", UserWarning)
125
+
126
+ orig_dtype = g_t.dtype
127
+ a = (1.0 - g_t).float()
128
+ b = (g_t * m_t).float()
129
+
130
+ B, T, D = a.shape
131
+ device = a.device
132
+
133
+ # Pad T to be multiple of chunk_size
134
+ pad_len = (chunk_size - (T % chunk_size)) % chunk_size
135
+ if pad_len > 0:
136
+ a = F.pad(a, (0, 0, 0, pad_len), value=1.0)
137
+ b = F.pad(b, (0, 0, 0, pad_len), value=0.0)
138
+
139
+ new_T = T + pad_len
140
+ num_chunks = new_T // chunk_size
141
+
142
+ # 1. Reshape to (B, num_chunks, chunk_size, D)
143
+ a_chunks = a.view(B, num_chunks, chunk_size, D)
144
+ b_chunks = b.view(B, num_chunks, chunk_size, D)
145
+
146
+ # 2. Local scan within each chunk (vectorized across B and num_chunks)
147
+ h_init_local = torch.zeros(B, num_chunks, D, device=device, dtype=torch.float32)
148
+ c_res, c_final = _sequential_scan(a_chunks, b_chunks, h_init_local)
149
+
150
+ # Summary of a for each chunk (product of a)
151
+ a_final = torch.prod(a_chunks, dim=2) # (B, num_chunks, D)
152
+
153
+ # 3. Global scan across chunk summaries
154
+ h_0 = c_0.float() if c_0 is not None else torch.zeros(B, D, device=device, dtype=torch.float32)
155
+
156
+ # h_chunk_outputs[:, j] is the state AFTER chunk j.
157
+ h_chunk_outputs, _ = _sequential_scan(a_final, c_final, h_0)
158
+ # The state BEFORE chunk j is h_chunk_outputs[:, j-1].
159
+ h_starts = torch.cat([h_0.unsqueeze(1), h_chunk_outputs[:, :-1]], dim=1)
160
+
161
+ # 4. Final combine: h_{j, i} = a_prefix_{j, i} * h_starts[j] + c_res[j, i]
162
+ a_prefix = torch.cumprod(a_chunks, dim=2)
163
+ final_h = a_prefix * h_starts.unsqueeze(2) + c_res
164
+
165
+ # Reshape back and crop, then cast back to original dtype
166
+ return final_h.view(B, -1, D)[:, :T].to(orig_dtype)
167
+
168
+
169
+ def rms_norm_fn(hidden_states, weight, eps=1e-6):
170
+ input_dtype = hidden_states.dtype
171
+ hidden_states = hidden_states.contiguous().to(torch.float32)
172
+ variance = (hidden_states * hidden_states).mean(-1, keepdim=True)
173
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
174
+ return weight * hidden_states.to(input_dtype)
175
+
176
+
177
+ def dsrn_parallel_kernel_legacy(
178
+ model_block: nn.Module,
179
+ x: torch.Tensor,
180
+ h_prev: torch.Tensor,
181
+ c_prev: torch.Tensor,
182
+ eos_mask: Optional[torch.Tensor] = None,
183
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
184
+ """
185
+ Legacy DSRN kernel (Fixed LayerNorm, No Surprise Read).
186
+ Identical to the version that passed verification.
187
+ """
188
+ B, T, D = x.shape
189
+
190
+ # 1. Norm and Projections
191
+ x_norm = F.layer_norm(
192
+ x,
193
+ (D,),
194
+ weight=model_block.norm_fast.weight,
195
+ bias=model_block.norm_fast.bias,
196
+ )
197
+
198
+ # Fast State Path (Scan)
199
+ gru_proj = F.linear(x_norm, model_block.gru_cell.weight_ih, model_block.gru_cell.bias_ih)
200
+ z_all = torch.sigmoid(gru_proj[:, :, :D])
201
+ r_all = torch.tanh(gru_proj[:, :, 2 * D :]) # Optimization: slice instead of chunk
202
+
203
+ # --- EOS RESET LOGIC (Fast State) ---
204
+ if eos_mask is not None:
205
+ reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
206
+ reset_mask[:, 0] = (
207
+ 0 # First token reset depends on previous chunk eos, handled by h_prev/c_prev passing 0
208
+ )
209
+
210
+ # Apply strict reset to z_all
211
+ z_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.ones_like(z_all), z_all)
212
+
213
+ # h_t = (1 - z_t) * h_{t-1} + z_t * r_t
214
+ h_all = dsrn_parallel_scan(
215
+ z_all, r_all, h_prev, use_triton=getattr(model_block, "use_triton", False)
216
+ )
217
+ h_new = h_all[:, -1]
218
+
219
+ # 2. Slow State Path
220
+ # CAUSAL SHIFT: Predict x[t] using h[t-1]
221
+ # h_all is [h_1, ..., h_T]. We need [h_0, ..., h_{T-1}]
222
+ # Prepend h_prev to shift
223
+ h_shifted = torch.cat([h_prev.unsqueeze(1), h_all[:, :-1, :]], dim=1)
224
+
225
+ x_pred = model_block.linear_pred(h_shifted)
226
+ diff = x - x_pred
227
+ error = torch.clamp(diff * diff, max=10.0).mean(dim=-1, keepdim=True)
228
+ # Constrain surprise_lambda strictly positive to guarantee error opens the memory gate
229
+ surprise_signal = error * torch.nn.functional.softplus(model_block.surprise_lambda)
230
+
231
+ # Gates
232
+ gate_logits = model_block.linear_gate(h_all) + surprise_signal
233
+ g_all = torch.sigmoid(gate_logits)
234
+ m_all = torch.tanh(model_block.linear_memory(h_all))
235
+
236
+ # --- EOS RESET LOGIC (Slow State) ---
237
+ if eos_mask is not None:
238
+ reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
239
+ reset_mask[:, 0] = 0
240
+
241
+ g_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.zeros_like(g_all), g_all)
242
+
243
+ # c_t
244
+ c_all = dsrn_parallel_scan(
245
+ g_all, m_all, c_prev, use_triton=getattr(model_block, "use_triton", False)
246
+ )
247
+ c_new = c_all[:, -1]
248
+
249
+ # --- Inter-Chunk Reset ---
250
+ # If the LAST token is EOS, then h_new/c_new (which are states FOR NEXT CHUNK) must be 0.
251
+ if eos_mask is not None:
252
+ last_is_eos = eos_mask[:, -1].float() # (B,)
253
+ keep_prob = (1.0 - last_is_eos).unsqueeze(-1) # (B, 1)
254
+ h_new = h_new * keep_prob
255
+ c_new = c_new * keep_prob
256
+ gate_stats = g_all.mean(dim=-1)
257
+
258
+ # 3. Final MLP Path
259
+ h_norm = F.layer_norm(
260
+ h_all, (D,), weight=model_block.norm_ff.weight, bias=model_block.norm_ff.bias
261
+ )
262
+ mlp_out = model_block.mlp_down(model_block.mlp_act(model_block.mlp_up(h_norm)))
263
+
264
+ x_out = x + mlp_out
265
+
266
+ # Continuous Read (Surprise Gate Fix)
267
+ # Enabled on Legacy to fix Disconnected Slow State bug while keeping LayerNorm
268
+ x_out = x_out + model_block.linear_read(c_all)
269
+
270
+ return x_out, h_new, c_new, gate_stats
271
+
272
+
273
+ def dsrn_parallel_kernel_hybrid(
274
+ model_block: nn.Module,
275
+ x: torch.Tensor,
276
+ h_prev: torch.Tensor,
277
+ c_prev: torch.Tensor,
278
+ eos_mask: Optional[torch.Tensor] = None,
279
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
280
+ """
281
+ Hybrid DSRN kernel (RMSNorm + Surprise Read).
282
+ """
283
+ B, T, D = x.shape
284
+
285
+ # 1. Norm (RMSNorm hardcoded for Hybrid path)
286
+ x_norm = rms_norm_fn(x, model_block.norm_fast.weight)
287
+
288
+ # Fast State
289
+ gru_proj = F.linear(x_norm, model_block.gru_cell.weight_ih, model_block.gru_cell.bias_ih)
290
+ z_all = torch.sigmoid(gru_proj[:, :, :D])
291
+ r_all = torch.tanh(gru_proj[:, :, 2 * D :])
292
+
293
+ # --- EOS RESET LOGIC (Fast State) ---
294
+ if eos_mask is not None:
295
+ reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
296
+ reset_mask[:, 0] = 0
297
+ z_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.ones_like(z_all), z_all)
298
+
299
+ h_all = dsrn_parallel_scan(
300
+ z_all, r_all, h_prev, use_triton=getattr(model_block, "use_triton", False)
301
+ )
302
+ h_new = h_all[:, -1]
303
+
304
+ # 2. Slow State
305
+ # CAUSAL SHIFT: Predict x[t] using h[t-1]
306
+ h_shifted = torch.cat([h_prev.unsqueeze(1), h_all[:, :-1, :]], dim=1)
307
+
308
+ x_pred = model_block.linear_pred(h_shifted)
309
+ diff = x - x_pred
310
+ error = torch.clamp(diff * diff, max=10.0).mean(dim=-1, keepdim=True)
311
+ # Constrain surprise_lambda strictly positive to guarantee error opens the memory gate
312
+ surprise_signal = error * torch.nn.functional.softplus(model_block.surprise_lambda)
313
+
314
+ gate_logits = model_block.linear_gate(h_all) + surprise_signal
315
+ g_all = torch.sigmoid(gate_logits)
316
+ m_all = torch.tanh(model_block.linear_memory(h_all))
317
+
318
+ # --- EOS RESET LOGIC (Slow State) ---
319
+ if eos_mask is not None:
320
+ reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
321
+ reset_mask[:, 0] = 0
322
+ g_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.zeros_like(g_all), g_all)
323
+
324
+ c_all = dsrn_parallel_scan(
325
+ g_all, m_all, c_prev, use_triton=getattr(model_block, "use_triton", False)
326
+ )
327
+ c_new = c_all[:, -1]
328
+
329
+ # --- Inter-Chunk Reset ---
330
+ if eos_mask is not None:
331
+ last_is_eos = eos_mask[:, -1].float()
332
+ keep_prob = (1.0 - last_is_eos).unsqueeze(-1)
333
+ h_new = h_new * keep_prob
334
+ c_new = c_new * keep_prob
335
+ gate_stats = g_all.mean(dim=-1)
336
+
337
+ # 3. Final MLP
338
+ h_norm = rms_norm_fn(h_all, model_block.norm_ff.weight)
339
+ mlp_out = model_block.mlp_down(model_block.mlp_act(model_block.mlp_up(h_norm)))
340
+ x_out = x + mlp_out
341
+
342
+ # Continuous Read (Hybrid Feature)
343
+ if model_block.use_hybrid_attention:
344
+ x_out = x_out + model_block.linear_read(c_all)
345
+
346
+ return x_out, h_new, c_new, gate_stats
347
+
348
+
349
+ def dsrn_parallel_kernel(
350
+ model_block: nn.Module,
351
+ x: torch.Tensor,
352
+ h_prev: torch.Tensor,
353
+ c_prev: torch.Tensor,
354
+ eos_mask: Optional[torch.Tensor] = None,
355
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
356
+ """
357
+ Wrapper for backward compatibility. Dispatches based on config.
358
+ """
359
+ if getattr(model_block, "use_rmsnorm", False):
360
+ return dsrn_parallel_kernel_hybrid(model_block, x, h_prev, c_prev, eos_mask=eos_mask)
361
+ return dsrn_parallel_kernel_legacy(model_block, x, h_prev, c_prev, eos_mask=eos_mask)
362
+
363
+
364
+ class HymbaRMSNorm(nn.Module):
365
+ def __init__(self, hidden_size, eps=1e-6):
366
+ """
367
+ HymbaRMSNorm is equivalent to T5LayerNorm
368
+ """
369
+ super().__init__()
370
+ self.weight = nn.Parameter(torch.ones(hidden_size))
371
+ self.variance_epsilon = eps
372
+
373
+ def forward(self, hidden_states):
374
+ input_dtype = hidden_states.dtype
375
+ hidden_states = hidden_states.to(torch.float32)
376
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
377
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
378
+ return self.weight * hidden_states.to(input_dtype)
379
+
380
+
381
+ class EchoRotaryEmbedding(nn.Module):
382
+ def __init__(self, dim, max_position_embeddings=4096, base=10000.0, device=None):
383
+ super().__init__()
384
+ self.dim = dim
385
+ self.max_position_embeddings = max_position_embeddings
386
+ self.base = base
387
+ self.device = device
388
+
389
+ # We NO LONGER use buffers here because they are being corrupted by
390
+ # Hugging Face's weight loading mechanism for this specific model.
391
+ # We will compute and move them on the first forward pass.
392
+ self._cos_cached = None
393
+ self._sin_cached = None
394
+
395
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
396
+ self.max_seq_len_cached = seq_len
397
+ # Compute inv_freq locally
398
+ inv_freq = 1.0 / (
399
+ self.base
400
+ ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)
401
+ )
402
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
403
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
404
+ emb = torch.cat((freqs, freqs), dim=-1)
405
+
406
+ self._cos_cached = emb.cos().to(dtype)
407
+ self._sin_cached = emb.sin().to(dtype)
408
+
409
+ def forward(self, x, seq_len=None):
410
+ if (
411
+ self._cos_cached is None
412
+ or seq_len > self.max_seq_len_cached
413
+ or self._cos_cached.device != x.device
414
+ ):
415
+ self._set_cos_sin_cache(
416
+ seq_len=max(seq_len, self.max_position_embeddings), device=x.device, dtype=x.dtype
417
+ )
418
+
419
+ return (
420
+ self._cos_cached[:seq_len].to(dtype=x.dtype),
421
+ self._sin_cached[:seq_len].to(dtype=x.dtype),
422
+ )
423
+
424
+
425
+ def rotate_half(x):
426
+ """Rotates half the hidden dims of the input."""
427
+ x1 = x[..., : x.shape[-1] // 2]
428
+ x2 = x[..., x.shape[-1] // 2 :]
429
+ return torch.cat((-x2, x1), dim=-1)
430
+
431
+
432
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
433
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim) # (B, 1, T, D)
434
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim) # (B, 1, T, D)
435
+ q_embed = (q * cos) + (rotate_half(q) * sin)
436
+ k_embed = (k * cos) + (rotate_half(k) * sin)
437
+ return q_embed, k_embed
438
+
439
+
440
+ class SlidingWindowAttention(nn.Module):
441
+ def __init__(self, config: EchoConfig):
442
+ super().__init__()
443
+ self.hidden_size = config.hidden_size
444
+ self.num_heads = config.num_heads
445
+ self.head_dim = self.hidden_size // self.num_heads
446
+ self.window_size = getattr(config, "window_size", 128)
447
+
448
+ self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
449
+ self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
450
+
451
+ self.rotary_emb = EchoRotaryEmbedding(
452
+ self.head_dim,
453
+ base=getattr(config, "rope_theta", 10000.0),
454
+ )
455
+
456
+ def forward(
457
+ self,
458
+ x,
459
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
460
+ position_ids: Optional[torch.LongTensor] = None,
461
+ **kwargs,
462
+ ):
463
+ B, T, C = x.shape
464
+ qkv = self.qkv_proj(x)
465
+ q, k, v = qkv.chunk(3, dim=-1)
466
+
467
+ # Reshape for multi-head attention
468
+ q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
469
+ k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
470
+ v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
471
+
472
+ # --- RoPE Injection ---
473
+ if position_ids is None:
474
+ # Fallback if position_ids was not passed
475
+ seq_length_with_past = T
476
+ if past_key_values is not None:
477
+ seq_length_with_past += past_key_values[0].shape[2]
478
+ position_ids = (
479
+ torch.arange(
480
+ seq_length_with_past - T,
481
+ seq_length_with_past,
482
+ dtype=torch.long,
483
+ device=x.device,
484
+ )
485
+ .unsqueeze(0)
486
+ .view(-1, T)
487
+ )
488
+
489
+ kv_seq_len = k.shape[2]
490
+ if past_key_values is not None:
491
+ kv_seq_len += past_key_values[0].shape[2]
492
+
493
+ cos, sin = self.rotary_emb(v, seq_len=kv_seq_len)
494
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
495
+ # ----------------------
496
+
497
+ if past_key_values is not None:
498
+ k_past, v_past = past_key_values
499
+ k = torch.cat([k_past, k], dim=2)
500
+ v = torch.cat([v_past, v], dim=2)
501
+
502
+ # The cache MUST store the full history, do not overwrite it with truncated slices
503
+ current_key_value = (k, v)
504
+
505
+ # Create slices for attention computation
506
+ k_attn = k
507
+ v_attn = v
508
+
509
+ # Enforce Sliding Window (Truncate oldest tokens for attention ONLY)
510
+ if self.window_size is not None and k_attn.shape[2] > self.window_size:
511
+ k_attn = k_attn[:, :, -self.window_size :, :]
512
+ v_attn = v_attn[:, :, -self.window_size :, :]
513
+
514
+ attn_fn = ALL_ATTENTION_FUNCTIONS.get(
515
+ kwargs.get("attn_implementation", "sdpa"), F.scaled_dot_product_attention
516
+ )
517
+
518
+ # Determining causality and windowing:
519
+ # 1. Training (T > 1): Use sliding window causal mask.
520
+ # 2. Decoding (T = 1): Use sliding window and NO CAUSAL MASK
521
+ if T > 1:
522
+ # Training/Prefill: Attend to full k, v but apply band-limited causal mask
523
+ # Build sliding window causal mask (T, kv_seq_len)
524
+ kv_all_seq_len = k.shape[2]
525
+ past_seq_len = kv_all_seq_len - T
526
+
527
+ mask = torch.zeros((T, kv_all_seq_len), device=x.device, dtype=x.dtype)
528
+
529
+ row_idx = torch.arange(T, device=x.device).view(-1, 1)
530
+ col_idx = torch.arange(kv_all_seq_len, device=x.device).view(1, -1)
531
+ abs_pos = row_idx + past_seq_len
532
+
533
+ # Causal upper triangle = -inf
534
+ mask = torch.where(col_idx > abs_pos, float("-inf"), mask)
535
+
536
+ # Keep tokens in range [abs_pos - self.window_size, abs_pos]
537
+ if self.window_size is not None:
538
+ mask = torch.where((abs_pos - col_idx) >= self.window_size, float("-inf"), mask)
539
+
540
+ # Replace -inf with 0 for the permitted window (float mask expected by sdpa)
541
+ mask = torch.where(mask == float("-inf"), mask, torch.zeros_like(mask))
542
+
543
+ y = attn_fn(q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0))
544
+ else:
545
+ # Decoding: Recurrent step, attend only to the last window_size tokens
546
+ y = attn_fn(q, k_attn, v_attn, is_causal=False)
547
+
548
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
549
+ return self.out_proj(y), current_key_value
550
+
551
+
552
+ class DSRNBlock(nn.Module):
553
+ def __init__(self, config: EchoConfig):
554
+ super().__init__()
555
+ self.config = config
556
+ self.hidden_size = config.hidden_size
557
+ self.state_size = config.hidden_size * config.num_heads
558
+ self.use_triton = getattr(config, "use_triton", True)
559
+ self.use_hybrid_attention = getattr(config, "use_hybrid_attention", True)
560
+ self.use_rmsnorm = getattr(config, "use_rmsnorm", True)
561
+
562
+ # Fast State (GRU)
563
+ if self.use_rmsnorm:
564
+ self.norm_fast = HymbaRMSNorm(config.hidden_size)
565
+ else:
566
+ self.norm_fast = nn.LayerNorm(config.hidden_size)
567
+
568
+ self.gru_cell = nn.GRUCell(config.hidden_size, config.hidden_size)
569
+
570
+ # Hybrid Attention
571
+ if self.use_hybrid_attention:
572
+ self.attn = SlidingWindowAttention(config)
573
+
574
+ # Slow State (DSRN)
575
+ self.linear_read = nn.Linear(self.state_size, config.hidden_size, bias=False)
576
+ self.linear_gate = nn.Linear(config.hidden_size, self.state_size)
577
+ self.linear_memory = nn.Linear(config.hidden_size, self.state_size)
578
+
579
+ # -- Surprise Mechanism --
580
+ self.linear_pred = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
581
+ self.surprise_lambda = nn.Parameter(torch.zeros(self.state_size))
582
+
583
+ # Feed-Forward
584
+ if self.use_rmsnorm:
585
+ self.norm_ff = HymbaRMSNorm(config.hidden_size)
586
+ else:
587
+ self.norm_ff = nn.LayerNorm(config.hidden_size)
588
+
589
+ # Simple MLP: Linear -> GELU -> Linear
590
+ # mlp_up / mlp_act / mlp_down are the ONLY registered submodules.
591
+ # No self.mlp alias β€” that caused double-registration and spurious "missing keys".
592
+ intermediate_size = getattr(
593
+ config, "intermediate_size", int(config.hidden_size * getattr(config, "mlp_ratio", 4.0))
594
+ )
595
+ self.mlp_up = nn.Linear(config.hidden_size, intermediate_size)
596
+ self.mlp_act = nn.GELU()
597
+ self.mlp_down = nn.Linear(intermediate_size, config.hidden_size)
598
+
599
+ def forward(
600
+ self, x: torch.Tensor, state_prev: Tuple[torch.Tensor, ...], **kwargs
601
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
602
+
603
+ # Unpack state
604
+ # Supports (h, c) or (h, c, k_attn, v_attn)
605
+ h_prev = state_prev[0]
606
+ c_prev = state_prev[1]
607
+
608
+ if self.use_triton and x.is_cuda:
609
+ # Placeholder for Triton
610
+ pass
611
+
612
+ # Use Parallel Kernel
613
+ x_out, h_new, c_new, gate_stats = dsrn_parallel_kernel(self, x, h_prev, c_prev)
614
+
615
+ if self.use_hybrid_attention:
616
+ # Re-apply norm for attention branch (cleanest for surgical transplant)
617
+ x_norm = self.norm_fast(x)
618
+
619
+ # Extract attention state from tuple if present (h, c, k_attn, v_attn)
620
+ # HF state structure is now: (h, c, k_attn, v_attn)
621
+ # But wait, past_key_values in forward loop is just (h,c) from legacy code.
622
+ # We need to expand the state tuple to include attention KV.
623
+
624
+ attn_kv = None
625
+ if len(state_prev) == 4:
626
+ attn_kv = (state_prev[2], state_prev[3])
627
+
628
+ attn_out, new_attn_kv = self.attn(x_norm, past_key_values=attn_kv, **kwargs)
629
+ x_out = x_out + attn_out
630
+
631
+ # Update state with new KV
632
+ if new_attn_kv is not None:
633
+ h_new_full = (h_new, c_new, new_attn_kv[0], new_attn_kv[1])
634
+ else:
635
+ h_new_full = (h_new, c_new)
636
+ else:
637
+ h_new_full = (h_new, c_new)
638
+
639
+ return x_out, h_new_full, gate_stats
640
+
641
+
642
+ class EchoPreTrainedModel(PreTrainedModel):
643
+ config_class = EchoConfig
644
+ base_model_prefix = "model"
645
+ _no_split_modules = ["DSRNBlock"]
646
+
647
+ # Silently drop legacy mlp.0.*/mlp.1.*/mlp.2.* alias keys if they exist in old
648
+ # local training checkpoints from before the self.mlp aliasing was removed.
649
+ # The canonical names are mlp_up.* / mlp_act.* / mlp_down.* which load fine.
650
+ _keys_to_ignore_on_load_unexpected = [
651
+ r".*\.mlp\.0\..*",
652
+ r".*\.mlp\.1\..*",
653
+ r".*\.mlp\.2\..*",
654
+ ]
655
+
656
+ def _init_weights(self, module):
657
+ if isinstance(module, nn.Linear):
658
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
659
+ if module.bias is not None:
660
+ torch.nn.init.zeros_(module.bias)
661
+ elif isinstance(module, nn.Embedding):
662
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
663
+ elif isinstance(module, nn.LayerNorm):
664
+ torch.nn.init.zeros_(module.bias)
665
+ torch.nn.init.ones_(module.weight)
666
+
667
+
668
+ class EchoModel(EchoPreTrainedModel):
669
+ supports_gradient_checkpointing = True
670
+ _supports_attention_backend = True
671
+
672
+ def __init__(self, config: EchoConfig):
673
+ super().__init__(config)
674
+ self.embed_dim = config.embed_dim
675
+ self.num_layers = config.num_layers
676
+ self.num_heads = config.num_heads
677
+ self.state_dim = config.embed_dim * config.num_heads
678
+
679
+ self.embedding = nn.Embedding(config.vocab_size, config.embed_dim)
680
+ self.blocks = nn.ModuleList([DSRNBlock(config) for _ in range(config.num_layers)])
681
+
682
+ if getattr(config, "use_rmsnorm", False):
683
+ self.final_norm = HymbaRMSNorm(config.hidden_size)
684
+ else:
685
+ self.final_norm = nn.LayerNorm(config.hidden_size)
686
+
687
+ self.gradient_checkpointing = False
688
+
689
+ self.post_init()
690
+
691
+ # --- ZOMBIE GRADIENT PATCH (FIXED) ---
692
+ # Fixed: Now using controlled bias defaults to 1.0 to encourage open gates initially
693
+ bias_val = getattr(config, "gate_bias_init", 1.0)
694
+ for block in self.blocks:
695
+ nn.init.constant_(block.linear_gate.bias, bias_val)
696
+ # Init Surprise
697
+ if (
698
+ block.linear_pred.weight.dtype in (torch.bfloat16, torch.float16)
699
+ and block.linear_pred.weight.is_cuda
700
+ ):
701
+ _device = block.linear_pred.weight.device
702
+ _dtype = block.linear_pred.weight.dtype
703
+ temp_w = torch.empty_like(
704
+ block.linear_pred.weight, dtype=torch.float32, device="cpu"
705
+ )
706
+ nn.init.orthogonal_(temp_w, gain=0.1)
707
+ with torch.no_grad():
708
+ block.linear_pred.weight.copy_(temp_w.to(device=_device, dtype=_dtype))
709
+ else:
710
+ nn.init.orthogonal_(block.linear_pred.weight, gain=0.1)
711
+
712
+ nn.init.zeros_(block.surprise_lambda)
713
+ # CRITICAL: Zero-Init Residual Output (Identity Start)
714
+ nn.init.zeros_(block.mlp_down.weight)
715
+ nn.init.zeros_(block.mlp_down.bias)
716
+
717
+ def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None):
718
+ """Enable/disable gradient checkpointing."""
719
+ self.gradient_checkpointing = enable
720
+
721
+ def get_input_embeddings(self):
722
+ return self.embedding
723
+
724
+ def set_input_embeddings(self, value):
725
+ self.embedding = value
726
+
727
+ def forward(
728
+ self,
729
+ input_ids: Optional[torch.LongTensor] = None,
730
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
731
+ inputs_embeds: Optional[torch.FloatTensor] = None,
732
+ position_ids: Optional[torch.LongTensor] = None,
733
+ output_dsrn_telemetry: Optional[bool] = False,
734
+ **kwargs,
735
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
736
+
737
+ if input_ids is not None and inputs_embeds is not None:
738
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
739
+ elif input_ids is not None:
740
+ batch_size, seq_len = input_ids.shape
741
+ x = self.embedding(input_ids)
742
+ elif inputs_embeds is not None:
743
+ batch_size, seq_len, _ = inputs_embeds.shape
744
+ x = inputs_embeds
745
+ else:
746
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
747
+
748
+ device = x.device
749
+
750
+ # Initialize states if not provided or if it's an empty Cache object
751
+ is_empty_cache = (
752
+ hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0
753
+ )
754
+ if past_key_values is None or is_empty_cache:
755
+ past_key_values = []
756
+ for _ in range(self.num_layers):
757
+ h = torch.zeros(batch_size, self.embed_dim, device=device, dtype=x.dtype)
758
+ c = torch.zeros(batch_size, self.state_dim, device=device, dtype=x.dtype)
759
+ past_key_values.append((h, c))
760
+
761
+ current_states = past_key_values
762
+ next_states = []
763
+
764
+ all_gate_stats = [] if output_dsrn_telemetry else None
765
+ all_c_states = [] if output_dsrn_telemetry else None
766
+
767
+ # Layer-Major Execution
768
+ for i, block in enumerate(self.blocks):
769
+
770
+ # Handle potential DynamicCache structure or list of tuples
771
+ if hasattr(current_states, "__getitem__"):
772
+ state_i = current_states[i]
773
+ else:
774
+ state_i = current_states[i]
775
+
776
+ if len(state_i) == 2:
777
+ # DSRN Only
778
+ pass
779
+ elif len(state_i) == 4:
780
+ # DSRN + Attention State
781
+ pass
782
+ else:
783
+ # Fallback for empty/malformed states
784
+ h_prev = torch.zeros(batch_size, self.embed_dim, device=device)
785
+ c_prev = torch.zeros(batch_size, self.state_dim, device=device)
786
+ state_i = (h_prev, c_prev)
787
+
788
+ # Use gradient checkpointing if enabled
789
+ if self.gradient_checkpointing and self.training:
790
+ # Checkpointing complex states is tricky, usually just pass h/c
791
+ out = torch.utils.checkpoint.checkpoint(block, x, state_i, use_reentrant=False)
792
+ else:
793
+ out = block(x, state_i, **kwargs)
794
+
795
+ x = out[0]
796
+ next_states.append(out[1])
797
+
798
+ if output_dsrn_telemetry:
799
+ all_gate_stats.append(out[2])
800
+ all_c_states.append(out[1][1])
801
+
802
+ x = self.final_norm(x)
803
+
804
+ if isinstance(current_states, EchoCache):
805
+ current_states.states = next_states
806
+ next_states = current_states
807
+ elif EchoCache is not None:
808
+ next_states = EchoCache(next_states)
809
+
810
+ if output_dsrn_telemetry:
811
+ return x, next_states, all_c_states, all_gate_stats
812
+
813
+ return x, next_states
814
+
815
+
816
+ class EchoForCausalLM(EchoPreTrainedModel, GenerationMixin):
817
+ _is_causal = True
818
+ supports_gradient_checkpointing = True
819
+ _supports_cache_class = False
820
+ _supports_static_cache = False
821
+ main_input_name = "input_ids"
822
+
823
+ def __init__(self, config: EchoConfig):
824
+ super().__init__(config)
825
+ self.model = EchoModel(config)
826
+ self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
827
+
828
+ # Initialize weights and apply final processing
829
+ self.post_init()
830
+
831
+ def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None):
832
+ """Enable/disable gradient checkpointing."""
833
+ self.model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
834
+
835
+ def get_output_embeddings(self):
836
+ return self.lm_head
837
+
838
+ def set_output_embeddings(self, new_embeddings):
839
+ self.lm_head = new_embeddings
840
+
841
+ def forward(
842
+ self,
843
+ input_ids: torch.LongTensor,
844
+ attention_mask: Optional[torch.LongTensor] = None,
845
+ position_ids: Optional[torch.LongTensor] = None,
846
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
847
+ inputs_embeds: Optional[torch.FloatTensor] = None,
848
+ labels: Optional[torch.LongTensor] = None,
849
+ use_cache: Optional[bool] = None,
850
+ output_attentions: Optional[bool] = None,
851
+ output_hidden_states: Optional[bool] = None,
852
+ return_dict: Optional[bool] = None,
853
+ output_dsrn_telemetry: Optional[bool] = False,
854
+ **kwargs,
855
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
856
+
857
+ output_attentions = (
858
+ output_attentions
859
+ if output_attentions is not None
860
+ else getattr(self.config, "output_attentions", False)
861
+ )
862
+ output_hidden_states = (
863
+ output_hidden_states
864
+ if output_hidden_states is not None
865
+ else getattr(self.config, "output_hidden_states", False)
866
+ )
867
+ use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", True)
868
+
869
+ return_dict = (
870
+ return_dict
871
+ if return_dict is not None
872
+ else getattr(self.config, "use_return_dict", True)
873
+ )
874
+
875
+ '''
876
+ If kwargs is getting overloaded with extra args HF generate passes,
877
+ we safely extract kwargs here.
878
+ '''
879
+ # Pass position_ids explicitly alongside **kwargs
880
+ kwargs["position_ids"] = position_ids
881
+
882
+ model_out = self.model(
883
+ input_ids=input_ids,
884
+ past_key_values=past_key_values,
885
+ inputs_embeds=inputs_embeds,
886
+ output_dsrn_telemetry=output_dsrn_telemetry,
887
+ **kwargs,
888
+ )
889
+
890
+ hidden_states = model_out[0]
891
+ new_states = model_out[1]
892
+
893
+ if len(model_out) > 2:
894
+ self._latest_c_states = model_out[2]
895
+ self._latest_gate_stats = model_out[3]
896
+
897
+ logits = self.lm_head(hidden_states)
898
+
899
+ loss = None
900
+ if labels is not None:
901
+ # Shift so that tokens < n predict n
902
+ shift_logits = logits[..., :-1, :].contiguous()
903
+ shift_labels = labels[..., 1:].contiguous()
904
+ loss_fct = nn.CrossEntropyLoss()
905
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
906
+
907
+ if not return_dict:
908
+ output = (logits, new_states)
909
+ return ((loss,) + output) if loss is not None else output
910
+
911
+ return CausalLMOutputWithPast(
912
+ loss=loss,
913
+ logits=logits,
914
+ past_key_values=new_states if use_cache else None,
915
+ hidden_states=None, # EchoModel doesn't expose internal states yet
916
+ attentions=None, # EchoModel doesn't expose attention weights yet
917
+ )
918
+
919
+ def prepare_inputs_for_generation(
920
+ self, input_ids, past_key_values=None, attention_mask=None, **kwargs
921
+ ):
922
+ # If past_key_values is a DynamicCache, we need to extract the underlying list of tuples
923
+ # if the custom cache hasn't taken over yet. But actually, HF doesn't know about our 4-tuples.
924
+ # So we should just let EchoModel handle it. If HF gave us a DynamicCache, it might be empty
925
+ # or mangled.
926
+ if (
927
+ past_key_values is not None
928
+ and not isinstance(past_key_values, (list, tuple))
929
+ and not isinstance(past_key_values, EchoCache)
930
+ ):
931
+ # It's a DynamicCache. It's likely from the first generation step.
932
+ # We can't use it directly because it stripped our (h,c).
933
+ # But wait, on the VERY first generation step, past_key_values is None, then EchoModel returns EchoCache.
934
+ # On subsequent steps we get EchoCache.
935
+ # So if we get a DynamicCache, it means someone passed past_key_values explicitly to generate(),
936
+ # or HF auto-created it on step 0 and passed it to step 1 incorrectly.
937
+ pass
938
+
939
+ # In newer transformers, past_key_values could be a DynamicCache.
940
+ # Check if it's effectively empty.
941
+ is_empty = False
942
+ if past_key_values is None:
943
+ is_empty = True
944
+ elif hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0:
945
+ is_empty = True
946
+ elif isinstance(past_key_values, list) and len(past_key_values) == 0:
947
+ is_empty = True
948
+
949
+ # If past_key_values is used, we only need the last token
950
+ if not is_empty:
951
+ input_ids = input_ids[:, -1:]
952
+
953
+ model_inputs = {
954
+ "input_ids": input_ids,
955
+ "past_key_values": past_key_values,
956
+ "attention_mask": attention_mask,
957
+ "use_cache": kwargs.get("use_cache"),
958
+ }
959
+
960
+ # Pass through extra kwargs like output_dsrn_telemetry
961
+ model_inputs.update({k: v for k, v in kwargs.items() if k not in model_inputs})
962
+
963
+ return model_inputs
964
+
965
+ def _reorder_cache(self, past_key_values, beam_idx):
966
+ """
967
+ Reorders cache for beam search or contrastive search.
968
+ past_key_values: List[Tuple(h, c, ...)]
969
+ """
970
+ if past_key_values is None:
971
+ return None
972
+
973
+ reordered_past = []
974
+ for layer_past in past_key_values:
975
+ # Each layer_past is a tuple of tensors (h, c) or (h, c, k, v)
976
+ reordered_layer_past = tuple(
977
+ p.index_select(0, beam_idx.to(p.device)) for p in layer_past
978
+ )
979
+ reordered_past.append(reordered_layer_past)
980
+ return reordered_past
triton_scan.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ # ──────────────────────────────────────────────────────────────
6
+ # FORWARD PASS KERNELS
7
+ # ──────────────────────────────────────────────────────────────
8
+
9
+
10
+ @triton.jit
11
+ def fwd_accumulate_kernel(
12
+ a_ptr,
13
+ b_ptr,
14
+ chunk_a_ptr,
15
+ chunk_c_ptr,
16
+ T,
17
+ D,
18
+ stride_a_b,
19
+ stride_a_t,
20
+ stride_a_d,
21
+ stride_b_b,
22
+ stride_b_t,
23
+ stride_b_d,
24
+ BLOCK_SIZE_D: tl.constexpr,
25
+ BLOCK_SIZE_T: tl.constexpr,
26
+ ):
27
+ pid_b = tl.program_id(0)
28
+ pid_d = tl.program_id(1)
29
+ pid_t = tl.program_id(2)
30
+
31
+ d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
32
+ d_mask = d_offsets < D
33
+
34
+ # Chunk boundaries
35
+ t_start = pid_t * BLOCK_SIZE_T
36
+
37
+ # Initialize local carries
38
+ a_acc = tl.full((BLOCK_SIZE_D,), 1.0, dtype=tl.float32)
39
+ c_acc = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32)
40
+
41
+ a_base = a_ptr + pid_b * stride_a_b + d_offsets * stride_a_d
42
+ b_base = b_ptr + pid_b * stride_b_b + d_offsets * stride_b_d
43
+
44
+ for i in range(BLOCK_SIZE_T):
45
+ t = t_start + i
46
+ if t < T:
47
+ a = tl.load(a_base + t * stride_a_t, mask=d_mask, other=1.0).to(tl.float32)
48
+ b = tl.load(b_base + t * stride_b_t, mask=d_mask, other=0.0).to(tl.float32)
49
+
50
+ # Combine: (a_acc, c_acc) o (a, b) = (a * a_acc, a * c_acc + b)
51
+ c_acc = a * c_acc + b
52
+ a_acc = a * a_acc
53
+
54
+ # Store chunk summaries
55
+ # chunk_ptr: [B, num_chunks, D]
56
+ num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T
57
+ summary_idx = pid_b * (num_chunks * D) + pid_t * D + d_offsets
58
+ tl.store(chunk_a_ptr + summary_idx, a_acc, mask=d_mask)
59
+ tl.store(chunk_c_ptr + summary_idx, c_acc, mask=d_mask)
60
+
61
+
62
+ @triton.jit
63
+ def fwd_global_scan_kernel(
64
+ chunk_a_ptr,
65
+ chunk_c_ptr,
66
+ chunk_carries_ptr,
67
+ c_0_ptr,
68
+ num_chunks,
69
+ D,
70
+ stride_c0_b,
71
+ stride_c0_d,
72
+ HAS_C_0: tl.constexpr,
73
+ BLOCK_SIZE_D: tl.constexpr,
74
+ ):
75
+ pid_b = tl.program_id(0)
76
+ pid_d = tl.program_id(1)
77
+
78
+ d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
79
+ d_mask = d_offsets < D
80
+
81
+ # Initial carry
82
+ carry = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32)
83
+ if HAS_C_0:
84
+ c0_ptrs = c_0_ptr + pid_b * stride_c0_b + d_offsets * stride_c0_d
85
+ carry = tl.load(c0_ptrs, mask=d_mask, other=0.0).to(tl.float32)
86
+
87
+ # Base pointers for chunk summaries
88
+ chunk_base = pid_b * (num_chunks * D) + d_offsets
89
+
90
+ for j in range(num_chunks):
91
+ # Store carry into chunk j (this is c_{j-1})
92
+ tl.store(chunk_carries_ptr + chunk_base + j * D, carry, mask=d_mask)
93
+
94
+ # Load chunk summary
95
+ a_sum = tl.load(chunk_a_ptr + chunk_base + j * D, mask=d_mask, other=1.0).to(tl.float32)
96
+ c_sum = tl.load(chunk_c_ptr + chunk_base + j * D, mask=d_mask, other=0.0).to(tl.float32)
97
+
98
+ # Update carry for chunk j+1
99
+ carry = a_sum * carry + c_sum
100
+
101
+
102
+ @triton.jit
103
+ def fwd_combine_kernel(
104
+ a_ptr,
105
+ b_ptr,
106
+ chunk_carries_ptr,
107
+ c_out_ptr,
108
+ T,
109
+ D,
110
+ stride_a_b,
111
+ stride_a_t,
112
+ stride_a_d,
113
+ stride_b_b,
114
+ stride_b_t,
115
+ stride_b_d,
116
+ stride_c_b,
117
+ stride_c_t,
118
+ stride_c_d,
119
+ BLOCK_SIZE_D: tl.constexpr,
120
+ BLOCK_SIZE_T: tl.constexpr,
121
+ ):
122
+ pid_b = tl.program_id(0)
123
+ pid_d = tl.program_id(1)
124
+ pid_t = tl.program_id(2)
125
+
126
+ d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
127
+ d_mask = d_offsets < D
128
+
129
+ num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T
130
+ t_start = pid_t * BLOCK_SIZE_T
131
+
132
+ # Load initial carry for this chunk
133
+ carry_idx = pid_b * (num_chunks * D) + pid_t * D + d_offsets
134
+ carry = tl.load(chunk_carries_ptr + carry_idx, mask=d_mask, other=0.0).to(tl.float32)
135
+
136
+ a_base = a_ptr + pid_b * stride_a_b + d_offsets * stride_a_d
137
+ b_base = b_ptr + pid_b * stride_b_b + d_offsets * stride_b_d
138
+ c_out_base = c_out_ptr + pid_b * stride_c_b + d_offsets * stride_c_d
139
+
140
+ for i in range(BLOCK_SIZE_T):
141
+ t = t_start + i
142
+ if t < T:
143
+ a = tl.load(a_base + t * stride_a_t, mask=d_mask, other=1.0).to(tl.float32)
144
+ b = tl.load(b_base + t * stride_b_t, mask=d_mask, other=0.0).to(tl.float32)
145
+
146
+ carry = a * carry + b
147
+ tl.store(c_out_base + t * stride_c_t, carry, mask=d_mask)
148
+
149
+
150
+ # ──────────────────────────────────────────────────────────────
151
+ # BACKWARD PASS KERNELS
152
+ # ──────────────────────────────────────────────────────────────
153
+
154
+
155
+ @triton.jit
156
+ def bwd_accumulate_kernel(
157
+ a_ptr,
158
+ grad_c_out_ptr,
159
+ chunk_a_prod_ptr,
160
+ chunk_g_sum_ptr,
161
+ T,
162
+ D,
163
+ stride_a_b,
164
+ stride_a_t,
165
+ stride_a_d,
166
+ stride_g_b,
167
+ stride_g_t,
168
+ stride_g_d,
169
+ BLOCK_SIZE_D: tl.constexpr,
170
+ BLOCK_SIZE_T: tl.constexpr,
171
+ ):
172
+ pid_b = tl.program_id(0)
173
+ pid_d = tl.program_id(1)
174
+ pid_t = tl.program_id(2)
175
+
176
+ d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
177
+ d_mask = d_offsets < D
178
+
179
+ t_start = pid_t * BLOCK_SIZE_T
180
+ t_end = tl.minimum(t_start + BLOCK_SIZE_T, T)
181
+
182
+ a_prod = tl.full((BLOCK_SIZE_D,), 1.0, dtype=tl.float32)
183
+ g_sum = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32)
184
+
185
+ a_base = a_ptr + pid_b * stride_a_b + d_offsets * stride_a_d
186
+ g_base = grad_c_out_ptr + pid_b * stride_g_b + d_offsets * stride_g_d
187
+
188
+ # Reverse sequential accumulation for chunk summary
189
+ # grad_c_start = (g_start + a_start+1*g_start+1 + ...) + (a_start+1*...*a_end) * grad_c_end
190
+ # We iterate from t_end-1 down to t_start
191
+ for i in range(t_end - t_start - 1, -1, -1):
192
+ t = t_start + i
193
+ g = tl.load(g_base + t * stride_g_t, mask=d_mask, other=0.0).to(tl.float32)
194
+
195
+ # Multiplier is a_{t+1}. If t is T-1, multiplier is 1.0 (or 0 if we assume grad_c_T=0)
196
+ # Actually, for the very last token in sequence, grad_c_T is 0.
197
+ a_next = tl.full((BLOCK_SIZE_D,), 1.0, dtype=tl.float32)
198
+ if t + 1 < T:
199
+ a_next = tl.load(a_base + (t + 1) * stride_a_t, mask=d_mask, other=1.0).to(tl.float32)
200
+
201
+ # combine: g_sum = g + a_next * g_sum, a_prod = a_next * a_prod
202
+ g_sum = g + a_next * g_sum
203
+ a_prod = a_next * a_prod
204
+
205
+ num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T
206
+ summary_idx = pid_b * (num_chunks * D) + pid_t * D + d_offsets
207
+ tl.store(chunk_a_prod_ptr + summary_idx, a_prod, mask=d_mask)
208
+ tl.store(chunk_g_sum_ptr + summary_idx, g_sum, mask=d_mask)
209
+
210
+
211
+ @triton.jit
212
+ def bwd_global_scan_kernel(
213
+ chunk_a_prod_ptr,
214
+ chunk_g_sum_ptr,
215
+ chunk_grad_carries_ptr,
216
+ num_chunks,
217
+ D,
218
+ BLOCK_SIZE_D: tl.constexpr,
219
+ ):
220
+ pid_b = tl.program_id(0)
221
+ pid_d = tl.program_id(1)
222
+
223
+ d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
224
+ d_mask = d_offsets < D
225
+
226
+ grad_carry = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32)
227
+ chunk_base = pid_b * (num_chunks * D) + d_offsets
228
+
229
+ # Scan from last chunk to first
230
+ for j in range(num_chunks - 1, -1, -1):
231
+ # Store carry into chunk j (this is grad_c_{chunk_j_end})
232
+ tl.store(chunk_grad_carries_ptr + chunk_base + j * D, grad_carry, mask=d_mask)
233
+
234
+ a_prod = tl.load(chunk_a_prod_ptr + chunk_base + j * D, mask=d_mask, other=1.0).to(
235
+ tl.float32
236
+ )
237
+ g_sum = tl.load(chunk_g_sum_ptr + chunk_base + j * D, mask=d_mask, other=0.0).to(tl.float32)
238
+
239
+ # Update carry for chunk j-1
240
+ # grad_c_{t_start_of_chunk_j} = g_sum_chunk_j + a_prod_chunk_j * grad_c_{t_end_of_chunk_j}
241
+ grad_carry = g_sum + a_prod * grad_carry
242
+
243
+
244
+ @triton.jit
245
+ def bwd_combine_kernel(
246
+ a_ptr,
247
+ c_out_ptr,
248
+ c_0_ptr,
249
+ grad_c_out_ptr,
250
+ chunk_grad_carries_ptr,
251
+ grad_a_ptr,
252
+ grad_b_ptr,
253
+ grad_c_0_ptr,
254
+ T,
255
+ D,
256
+ stride_a_b,
257
+ stride_a_t,
258
+ stride_a_d,
259
+ stride_c_b,
260
+ stride_c_t,
261
+ stride_c_d,
262
+ stride_g_b,
263
+ stride_g_t,
264
+ stride_g_d,
265
+ stride_gb_b,
266
+ stride_gb_t,
267
+ stride_gb_d,
268
+ stride_c0_b,
269
+ stride_c0_d,
270
+ HAS_C_0: tl.constexpr,
271
+ BLOCK_SIZE_D: tl.constexpr,
272
+ BLOCK_SIZE_T: tl.constexpr,
273
+ ):
274
+ pid_b = tl.program_id(0)
275
+ pid_d = tl.program_id(1)
276
+ pid_t = tl.program_id(2)
277
+
278
+ d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
279
+ d_mask = d_offsets < D
280
+
281
+ num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T
282
+ t_start = pid_t * BLOCK_SIZE_T
283
+ t_end = tl.minimum(t_start + BLOCK_SIZE_T, T)
284
+
285
+ # Load initial gradient carry (this is grad_c_{t_end})
286
+ # This was computed as grad_c_end in Pass 2.
287
+ grad_at_tend = tl.load(
288
+ chunk_grad_carries_ptr + pid_b * (num_chunks * D) + pid_t * D + d_offsets,
289
+ mask=d_mask,
290
+ other=0.0,
291
+ ).to(tl.float32)
292
+
293
+ a_base = a_ptr + pid_b * stride_a_b + d_offsets * stride_a_d
294
+ c_out_base = c_out_ptr + pid_b * stride_c_b + d_offsets * stride_c_d
295
+ g_base = grad_c_out_ptr + pid_b * stride_g_b + d_offsets * stride_g_d
296
+ ga_base = grad_a_ptr + pid_b * stride_a_b + d_offsets * stride_a_d
297
+ gb_base = grad_b_ptr + pid_b * stride_gb_b + d_offsets * stride_gb_d
298
+
299
+ # running_grad enters index t as a_{t+1} * grad_c_{t+1}
300
+ # For the very last token in chunk t=t_end-1, we need a_{t_end} * grad_c_{t_end}
301
+ a_tend = tl.full((BLOCK_SIZE_D,), 1.0, dtype=tl.float32)
302
+ if t_end < T:
303
+ a_tend = tl.load(a_base + t_end * stride_a_t, mask=d_mask, other=1.0).to(tl.float32)
304
+
305
+ running_grad = a_tend * grad_at_tend
306
+
307
+ # Reverse scan within chunk
308
+ for i in range(t_end - t_start - 1, -1, -1):
309
+ t = t_start + i
310
+ g_out_t = tl.load(g_base + t * stride_g_t, mask=d_mask, other=0.0).to(tl.float32)
311
+
312
+ # grad_c_t = g_out_t + a_{t+1} * grad_c_{t+1}
313
+ # In our loop, running_grad is always (a_{t+1} * grad_c_{t+1})
314
+ grad_c_t = g_out_t + running_grad
315
+
316
+ # Store results
317
+ # grad_b_t = grad_c_t
318
+ tl.store(gb_base + t * stride_gb_t, grad_c_t, mask=d_mask)
319
+
320
+ # grad_a_t = c_{t-1} * grad_c_t
321
+ c_prev = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32)
322
+ if t > 0:
323
+ c_prev = tl.load(c_out_base + (t - 1) * stride_c_t, mask=d_mask, other=0.0).to(
324
+ tl.float32
325
+ )
326
+ elif HAS_C_0:
327
+ c_prev = tl.load(
328
+ c_0_ptr + pid_b * stride_c0_b + d_offsets * stride_c0_d, mask=d_mask, other=0.0
329
+ ).to(tl.float32)
330
+
331
+ tl.store(ga_base + t * stride_a_t, c_prev * grad_c_t, mask=d_mask)
332
+
333
+ # update running_grad for the next iteration (t-1)
334
+ # new running_grad = a_t * grad_c_t
335
+ a_t = tl.load(a_base + t * stride_a_t, mask=d_mask, other=1.0).to(tl.float32)
336
+ running_grad = a_t * grad_c_t
337
+
338
+ # Final carry for d_c0 if pid_t == 0
339
+ if pid_t == 0 and HAS_C_0:
340
+ # After loop for t=0, running_grad is a_0 * grad_c_0
341
+ tl.store(
342
+ grad_c_0_ptr + pid_b * stride_c0_b + d_offsets * stride_c0_d, running_grad, mask=d_mask
343
+ )
344
+
345
+
346
+ # ──────────────────────────────────────────────────────────────
347
+ # PYTORCH WRAPPER
348
+ # ──────────────────────────────────────────────────────────────
349
+
350
+
351
+ class DSRNScanTriton(torch.autograd.Function):
352
+ @staticmethod
353
+ def forward(ctx, a, b, c_0=None):
354
+ B, T, D = a.shape
355
+ device = a.device
356
+
357
+ a = a.contiguous()
358
+ b = b.contiguous()
359
+ if c_0 is not None:
360
+ c_0 = c_0.contiguous()
361
+
362
+ c_out = torch.empty_like(a)
363
+
364
+ BLOCK_SIZE_T = 64
365
+ BLOCK_SIZE_D = triton.next_power_of_2(min(128, D))
366
+ num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T
367
+
368
+ # Temporary workspace
369
+ chunk_a = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32)
370
+ chunk_c = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32)
371
+ chunk_carries = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32)
372
+
373
+ # Pass 1: Accumulate
374
+ grid1 = (B, triton.cdiv(D, BLOCK_SIZE_D), num_chunks)
375
+ fwd_accumulate_kernel[grid1](
376
+ a,
377
+ b,
378
+ chunk_a,
379
+ chunk_c,
380
+ T,
381
+ D,
382
+ a.stride(0),
383
+ a.stride(1),
384
+ a.stride(2),
385
+ b.stride(0),
386
+ b.stride(1),
387
+ b.stride(2),
388
+ BLOCK_SIZE_D,
389
+ BLOCK_SIZE_T,
390
+ )
391
+
392
+ # Pass 2: Global Scan
393
+ grid2 = (B, triton.cdiv(D, BLOCK_SIZE_D))
394
+ fwd_global_scan_kernel[grid2](
395
+ chunk_a,
396
+ chunk_c,
397
+ chunk_carries,
398
+ c_0,
399
+ num_chunks,
400
+ D,
401
+ c_0.stride(0) if c_0 is not None else 0,
402
+ c_0.stride(1) if c_0 is not None else 0,
403
+ HAS_C_0=(c_0 is not None),
404
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
405
+ )
406
+
407
+ # Pass 3: Combine
408
+ fwd_combine_kernel[grid1](
409
+ a,
410
+ b,
411
+ chunk_carries,
412
+ c_out,
413
+ T,
414
+ D,
415
+ a.stride(0),
416
+ a.stride(1),
417
+ a.stride(2),
418
+ b.stride(0),
419
+ b.stride(1),
420
+ b.stride(2),
421
+ c_out.stride(0),
422
+ c_out.stride(1),
423
+ c_out.stride(2),
424
+ BLOCK_SIZE_D,
425
+ BLOCK_SIZE_T,
426
+ )
427
+
428
+ ctx.save_for_backward(a, c_out, c_0)
429
+ ctx.BLOCK_SIZE_T = BLOCK_SIZE_T
430
+ ctx.BLOCK_SIZE_D = BLOCK_SIZE_D
431
+
432
+ return c_out
433
+
434
+ @staticmethod
435
+ def backward(ctx, grad_c_out):
436
+ a, c_out, c_0 = ctx.saved_tensors
437
+ B, T, D = a.shape
438
+ device = a.device
439
+
440
+ grad_c_out = grad_c_out.contiguous()
441
+ grad_a = torch.empty_like(a)
442
+ grad_b = torch.empty_like(a)
443
+ grad_c_0 = torch.zeros_like(c_0) if c_0 is not None else None
444
+
445
+ BLOCK_SIZE_T = ctx.BLOCK_SIZE_T
446
+ BLOCK_SIZE_D = ctx.BLOCK_SIZE_D
447
+ num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T
448
+
449
+ chunk_grad_a = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32)
450
+ chunk_grad_x = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32)
451
+ chunk_grad_carries = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32)
452
+
453
+ grid1 = (B, triton.cdiv(D, BLOCK_SIZE_D), num_chunks)
454
+
455
+ # Pass 1: Accumulate
456
+ bwd_accumulate_kernel[grid1](
457
+ a,
458
+ grad_c_out,
459
+ chunk_grad_a,
460
+ chunk_grad_x,
461
+ T,
462
+ D,
463
+ a.stride(0),
464
+ a.stride(1),
465
+ a.stride(2),
466
+ grad_c_out.stride(0),
467
+ grad_c_out.stride(1),
468
+ grad_c_out.stride(2),
469
+ BLOCK_SIZE_D,
470
+ BLOCK_SIZE_T,
471
+ )
472
+
473
+ # Pass 2: Global Scan
474
+ grid2 = (B, triton.cdiv(D, BLOCK_SIZE_D))
475
+ bwd_global_scan_kernel[grid2](
476
+ chunk_grad_a, chunk_grad_x, chunk_grad_carries, num_chunks, D, BLOCK_SIZE_D
477
+ )
478
+
479
+ # Pass 3: Combine
480
+ bwd_combine_kernel[grid1](
481
+ a,
482
+ c_out,
483
+ c_0,
484
+ grad_c_out,
485
+ chunk_grad_carries,
486
+ grad_a,
487
+ grad_b,
488
+ grad_c_0,
489
+ T,
490
+ D,
491
+ a.stride(0),
492
+ a.stride(1),
493
+ a.stride(2),
494
+ c_out.stride(0),
495
+ c_out.stride(1),
496
+ c_out.stride(2),
497
+ grad_c_out.stride(0),
498
+ grad_c_out.stride(1),
499
+ grad_c_out.stride(2),
500
+ grad_b.stride(0),
501
+ grad_b.stride(1),
502
+ grad_b.stride(2),
503
+ c_0.stride(0) if c_0 is not None else 0,
504
+ c_0.stride(1) if c_0 is not None else 0,
505
+ HAS_C_0=(c_0 is not None),
506
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
507
+ BLOCK_SIZE_T=BLOCK_SIZE_T,
508
+ )
509
+
510
+ return grad_a, grad_b, grad_c_0
511
+
512
+
513
+ def triton_dsrn_parallel_scan(g_t, m_t, c_0=None):
514
+ orig_dtype = g_t.dtype
515
+ a = (1.0 - g_t).float()
516
+ b = (g_t * m_t).float()
517
+ if c_0 is not None:
518
+ c_0 = c_0.float()
519
+
520
+ out = DSRNScanTriton.apply(a, b, c_0)
521
+ return out.to(orig_dtype)