NePe commited on
Commit
2be40ac
·
verified ·
1 Parent(s): 36f62ed

Remove incompatible code + method to get VariableCache

Browse files

Usage:
past_key_values = model.getVariableCache(batch_size=1, max_cache_len=4096)
model.generate(... ,past_key_values=past_key_values)

Files changed (1) hide show
  1. modeling_decilm.py +7 -3
modeling_decilm.py CHANGED
@@ -27,7 +27,7 @@ import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers import GenerationConfig
30
- from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
33
  from transformers.utils import (
@@ -809,8 +809,9 @@ class DeciLMPreTrainedModel(PreTrainedModel):
809
  ) -> tuple[GenerationConfig, dict]:
810
  # DeciLM-specific code
811
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
812
- generation_config.cache_implementation = "variable"
813
- NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
 
814
  return generation_config, model_kwargs
815
 
816
 
@@ -1133,6 +1134,9 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
1133
  def get_decoder(self):
1134
  return self.model
1135
 
 
 
 
1136
  @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
1137
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1138
  def forward(
 
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers import GenerationConfig
30
+ from transformers.generation.utils import GenerationMixin, GenerateOutput
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
33
  from transformers.utils import (
 
809
  ) -> tuple[GenerationConfig, dict]:
810
  # DeciLM-specific code
811
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
812
+ generation_config.disable_compile = True
813
+ #generation_config.cache_implementation = "variable"
814
+ #NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
815
  return generation_config, model_kwargs
816
 
817
 
 
1134
  def get_decoder(self):
1135
  return self.model
1136
 
1137
+ def getVariableCache(self, batch_size=1, max_cache_len=4096, dtype=torch.bfloat16):
1138
+ return VariableCache(config=self.config, batch_size=batch_size, max_batch_size=batch_size, max_cache_len=max_cache_len, dtype=dtype)
1139
+
1140
  @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
1141
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1142
  def forward(