mrs83 commited on
Commit
6b3235f
·
verified ·
1 Parent(s): 4737c54

Upload modeling_hybrid.py

Browse files
Files changed (1) hide show
  1. modeling_hybrid.py +6 -1
modeling_hybrid.py CHANGED
@@ -215,7 +215,12 @@ class HybridEchoModel(Qwen2PreTrainedModel):
215
  if past_key_values is None:
216
  past_key_values = HybridEchoCache(config=self.config) if use_dsrn_cache else None
217
 
218
- self._dsrn_input_states = dsrn_states # ALWAYS run injectors regardless of cache mode.
 
 
 
 
 
219
  # use_dsrn_cache only gates whether output states are *returned* in the
220
  # HybridEchoCache. Setting this to [] when use_cache=False (training mode)
221
  # caused the hook to exit early → injectors bypassed → grad_norm=0.
 
215
  if past_key_values is None:
216
  past_key_values = HybridEchoCache(config=self.config) if use_dsrn_cache else None
217
 
218
+ # Detach DSRN states before carrying forward to the next step.
219
+ # Without .detach(), _dsrn_input_states holds tensors with grad_fn
220
+ # from the previous forward pass. Step N+1's graph then includes
221
+ # step N's graph as a parent, preventing step N's activation tensors
222
+ # from being freed after backward — accumulating ~11.78 GiB per step.
223
+ self._dsrn_input_states = [(h.detach(), c.detach()) for h, c in dsrn_states]
224
  # use_dsrn_cache only gates whether output states are *returned* in the
225
  # HybridEchoCache. Setting this to [] when use_cache=False (training mode)
226
  # caused the hook to exit early → injectors bypassed → grad_norm=0.