blewis-hir commited on
Commit
2b88640
·
verified ·
1 Parent(s): f4a9fe8

Make changes backwards compatible with old `transformers` versions

Browse files
Files changed (1) hide show
  1. modeling_decilm.py +12 -2
modeling_decilm.py CHANGED
@@ -19,8 +19,10 @@
19
  # limitations under the License.
20
 
21
  import math
 
22
  from typing import List, Optional, Tuple, Union
23
 
 
24
  import torch
25
  import torch.nn.functional as F
26
  import torch.utils.checkpoint
@@ -28,7 +30,12 @@ from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers import GenerationConfig
30
  from transformers.generation.utils import GenerationMixin, GenerateOutput
31
- from transformers.generation.configuration_utils import ALL_STATIC_CACHE_IMPLEMENTATIONS
 
 
 
 
 
32
  from transformers.modeling_utils import PreTrainedModel
33
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
34
  from transformers.utils import (
@@ -811,7 +818,10 @@ class DeciLMPreTrainedModel(PreTrainedModel):
811
  # DeciLM-specific code
812
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
813
  generation_config.cache_implementation = "variable"
814
- ALL_STATIC_CACHE_IMPLEMENTATIONS["variable"] = VariableCache
 
 
 
815
  return generation_config, model_kwargs
816
 
817
 
 
19
  # limitations under the License.
20
 
21
  import math
22
+ import importlib.metdata
23
  from typing import List, Optional, Tuple, Union
24
 
25
+ from packaging.version import Version
26
  import torch
27
  import torch.nn.functional as F
28
  import torch.utils.checkpoint
 
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
  from transformers import GenerationConfig
32
  from transformers.generation.utils import GenerationMixin, GenerateOutput
33
+
34
+
35
+ if Version(importlib.metadata.version("transformers")) >= Version("4.56.0.dev0")
36
+ from transformers.generation.configuration_utils import NEED_SETUP_CACHE_CLASSES_MAPPING
37
+ else:
38
+ from transformers.generation.configuration_utils import ALL_STATIC_CACHE_IMPLEMENTATIONS
39
  from transformers.modeling_utils import PreTrainedModel
40
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
41
  from transformers.utils import (
 
818
  # DeciLM-specific code
819
  generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
820
  generation_config.cache_implementation = "variable"
821
+ if transformers_version >= Version("4.56.0.dev0")
822
+ NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
823
+ else:
824
+ ALL_STATIC_CACHE_IMPLEMENTATIONS["variable"] = VariableCache
825
  return generation_config, model_kwargs
826
 
827