variable_cache.py compatibility for v4.57.2 / python3.12

#12
by NePe - opened
Files changed (2) hide show
  1. modeling_decilm.py +7 -3
  2. variable_cache.py +17 -11
modeling_decilm.py CHANGED
@@ -27,7 +27,7 @@ import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers import GenerationConfig
30
- from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
33
  from transformers.utils import (
@@ -809,8 +809,9 @@ class DeciLMPreTrainedModel(PreTrainedModel):
809
  ) -> tuple[GenerationConfig, dict]:
810
  # DeciLM-specific code
811
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
812
- generation_config.cache_implementation = "variable"
813
- NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
 
814
  return generation_config, model_kwargs
815
 
816
 
@@ -1133,6 +1134,9 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
1133
  def get_decoder(self):
1134
  return self.model
1135
 
 
 
 
1136
  @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
1137
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1138
  def forward(
 
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers import GenerationConfig
30
+ from transformers.generation.utils import GenerationMixin, GenerateOutput
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
33
  from transformers.utils import (
 
809
  ) -> tuple[GenerationConfig, dict]:
810
  # DeciLM-specific code
811
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
812
+ generation_config.disable_compile = True
813
+ #generation_config.cache_implementation = "variable"
814
+ #NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
815
  return generation_config, model_kwargs
816
 
817
 
 
1134
  def get_decoder(self):
1135
  return self.model
1136
 
1137
+ def getVariableCache(self, batch_size=1, max_cache_len=4096, dtype=torch.bfloat16):
1138
+ return VariableCache(config=self.config, batch_size=batch_size, max_batch_size=batch_size, max_cache_len=max_cache_len, dtype=dtype)
1139
+
1140
  @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
1141
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1142
  def forward(
variable_cache.py CHANGED
@@ -31,6 +31,9 @@ class VariableCache(Cache_4_44_2, Cache):
31
  The default implementation for the layer caches is StaticCache.
32
  The cache of each layer is allocated to the same gpu as the layer itself.
33
  """
 
 
 
34
 
35
  def __init__(
36
  self,
@@ -50,7 +53,7 @@ class VariableCache(Cache_4_44_2, Cache):
50
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
51
  self.dtype = dtype
52
 
53
- self.layer_caches: list[Cache_4_44_2 | None] = [None] * config.num_hidden_layers
54
  self.layer_devices: list[torch.device | None] = [None] * config.num_hidden_layers
55
 
56
  def update(
@@ -60,11 +63,11 @@ class VariableCache(Cache_4_44_2, Cache):
60
  layer_idx: int,
61
  cache_kwargs: Optional[Dict[str, Any]] = None,
62
  ) -> Tuple[torch.Tensor, torch.Tensor]:
63
- if self.layer_caches[layer_idx] is None:
64
  self.layer_devices[layer_idx] = key_states.device
65
  self._init_layer_cache(layer_idx)
66
 
67
- layer_cache = self.layer_caches[layer_idx]
68
  assert layer_cache is not None, f"Trying to update the cache of a cache-less layer: {layer_idx=}"
69
 
70
  k_out, v_out = layer_cache.update(key_states=key_states,
@@ -93,37 +96,37 @@ class VariableCache(Cache_4_44_2, Cache):
93
  if attention_config.window_length is not None:
94
  if not attention_config.is_sink:
95
  config.sliding_window = attention_config.window_length
96
- self.layer_caches[layer_idx] = SlidingWindowCache(config=config,
97
  max_batch_size=self.max_batch_size,
98
  max_cache_len=self.max_cache_len,
99
  device=device,
100
  dtype=self.dtype)
101
  return
102
  elif not attention_config.unshifted_sink:
103
- self.layer_caches[layer_idx] = SinkCache(window_length=attention_config.window_length,
104
  num_sink_tokens=attention_config.num_sink_tokens)
105
  return
106
 
107
- self.layer_caches[layer_idx] = StaticCache(config=config,
108
  max_batch_size=self.max_batch_size,
109
  max_cache_len=self.max_cache_len,
110
  device=device,
111
  dtype=self.dtype)
112
 
113
  def _get_first_real_cache(self) -> Cache:
114
- for layer_cache in self.layer_caches:
115
  if layer_cache is not None:
116
  return layer_cache
117
  raise ValueError(f"No real cache found, all layer caches are None.")
118
 
119
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
120
- if layer_idx == 0 and self.layer_caches[0] is None:
121
  try:
122
  layer_cache = self._get_first_real_cache()
123
  except ValueError:
124
  return 0
125
  else:
126
- layer_cache = self.layer_caches[layer_idx]
127
  return layer_cache.get_seq_length()
128
 
129
  def get_max_length(self) -> Optional[int]:
@@ -131,9 +134,12 @@ class VariableCache(Cache_4_44_2, Cache):
131
  return self.max_cache_len
132
 
133
  def reset(self):
134
- for layer_idx in range(len(self.layer_caches)):
135
- layer_cache = self.layer_caches[layer_idx]
136
  if hasattr(layer_cache, "reset"):
137
  layer_cache.reset()
138
  else:
139
  self._init_layer_cache(layer_idx)
 
 
 
 
31
  The default implementation for the layer caches is StaticCache.
32
  The cache of each layer is allocated to the same gpu as the layer itself.
33
  """
34
+
35
+ max_batch_size = None
36
+ max_cache_len = None
37
 
38
  def __init__(
39
  self,
 
53
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
54
  self.dtype = dtype
55
 
56
+ self.layers: list[Cache_4_44_2 | None] = [None] * config.num_hidden_layers
57
  self.layer_devices: list[torch.device | None] = [None] * config.num_hidden_layers
58
 
59
  def update(
 
63
  layer_idx: int,
64
  cache_kwargs: Optional[Dict[str, Any]] = None,
65
  ) -> Tuple[torch.Tensor, torch.Tensor]:
66
+ if self.layers[layer_idx] is None:
67
  self.layer_devices[layer_idx] = key_states.device
68
  self._init_layer_cache(layer_idx)
69
 
70
+ layer_cache = self.layers[layer_idx]
71
  assert layer_cache is not None, f"Trying to update the cache of a cache-less layer: {layer_idx=}"
72
 
73
  k_out, v_out = layer_cache.update(key_states=key_states,
 
96
  if attention_config.window_length is not None:
97
  if not attention_config.is_sink:
98
  config.sliding_window = attention_config.window_length
99
+ self.layers[layer_idx] = SlidingWindowCache(config=config,
100
  max_batch_size=self.max_batch_size,
101
  max_cache_len=self.max_cache_len,
102
  device=device,
103
  dtype=self.dtype)
104
  return
105
  elif not attention_config.unshifted_sink:
106
+ self.layers[layer_idx] = SinkCache(window_length=attention_config.window_length,
107
  num_sink_tokens=attention_config.num_sink_tokens)
108
  return
109
 
110
+ self.layers[layer_idx] = StaticCache(config=config,
111
  max_batch_size=self.max_batch_size,
112
  max_cache_len=self.max_cache_len,
113
  device=device,
114
  dtype=self.dtype)
115
 
116
  def _get_first_real_cache(self) -> Cache:
117
+ for layer_cache in self.layers:
118
  if layer_cache is not None:
119
  return layer_cache
120
  raise ValueError(f"No real cache found, all layer caches are None.")
121
 
122
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
123
+ if layer_idx == 0 and self.layers[0] is None:
124
  try:
125
  layer_cache = self._get_first_real_cache()
126
  except ValueError:
127
  return 0
128
  else:
129
+ layer_cache = self.layers[layer_idx]
130
  return layer_cache.get_seq_length()
131
 
132
  def get_max_length(self) -> Optional[int]:
 
134
  return self.max_cache_len
135
 
136
  def reset(self):
137
+ for layer_idx in range(len(self.layers)):
138
+ layer_cache = self.layers[layer_idx]
139
  if hasattr(layer_cache, "reset"):
140
  layer_cache.reset()
141
  else:
142
  self._init_layer_cache(layer_idx)
143
+
144
+ def is_compileable(self) -> bool:
145
+ return False