Trying to fix issues with extra arguments to the model

#10
by shmuli - opened
Files changed (1) hide show
  1. modeling_decilm.py +5 -2
modeling_decilm.py CHANGED
@@ -27,7 +27,8 @@ import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers import GenerationConfig
30
- from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
 
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
33
  from transformers.utils import (
@@ -503,6 +504,7 @@ class DeciLMFlashAttention2(DeciLMAttention):
503
  use_cache: bool = False,
504
  cache_position: Optional[torch.LongTensor] = None,
505
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
 
506
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
507
  output_attentions = False
508
 
@@ -810,7 +812,7 @@ class DeciLMPreTrainedModel(PreTrainedModel):
810
  # DeciLM-specific code
811
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
812
  generation_config.cache_implementation = "variable"
813
- NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
814
  return generation_config, model_kwargs
815
 
816
 
@@ -1148,6 +1150,7 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
1148
  output_hidden_states: Optional[bool] = None,
1149
  return_dict: Optional[bool] = None,
1150
  cache_position: Optional[torch.LongTensor] = None,
 
1151
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1152
  r"""
1153
  Args:
 
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers import GenerationConfig
30
+ from transformers.generation.utils import GenerationMixin, GenerateOutput
31
+ from transformers.generation.configuration_utils import ALL_STATIC_CACHE_IMPLEMENTATIONS
32
  from transformers.modeling_utils import PreTrainedModel
33
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
34
  from transformers.utils import (
 
504
  use_cache: bool = False,
505
  cache_position: Optional[torch.LongTensor] = None,
506
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
507
+ **kwargs,
508
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
509
  output_attentions = False
510
 
 
812
  # DeciLM-specific code
813
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
814
  generation_config.cache_implementation = "variable"
815
+ ALL_STATIC_CACHE_IMPLEMENTATIONS["variable"] = VariableCache
816
  return generation_config, model_kwargs
817
 
818
 
 
1150
  output_hidden_states: Optional[bool] = None,
1151
  return_dict: Optional[bool] = None,
1152
  cache_position: Optional[torch.LongTensor] = None,
1153
+ **kwargs,
1154
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1155
  r"""
1156
  Args: