ZFTurbo commited on
Commit
bae083d
·
verified ·
1 Parent(s): 7c48f7d

Update modeling_phi4mm.py

Browse files

Fix to work with transformers==4.57.6

Files changed (1) hide show
  1. modeling_phi4mm.py +21 -3
modeling_phi4mm.py CHANGED
@@ -1037,6 +1037,24 @@ class Phi4MMMLP(nn.Module):
1037
  return self.down_proj(up_states)
1038
 
1039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1040
  # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
1041
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
1042
  """
@@ -1134,7 +1152,7 @@ class Phi4MMAttention(nn.Module):
1134
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
1135
  "with a layer index."
1136
  )
1137
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1138
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
1139
 
1140
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -1229,7 +1247,7 @@ class Phi4MMFlashAttention2(Phi4MMAttention):
1229
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
1230
  "with a layer index."
1231
  )
1232
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1233
 
1234
  # Because the input can be padded, the absolute sequence length depends on the max position id.
1235
  rotary_seq_len = (
@@ -1351,7 +1369,7 @@ class Phi4MMSdpaAttention(Phi4MMAttention):
1351
 
1352
  kv_seq_len = key_states.shape[-2]
1353
  if past_key_value is not None:
1354
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1355
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
1356
 
1357
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
1037
  return self.down_proj(up_states)
1038
 
1039
 
1040
+ def _get_usable_past_kv_length(cache: Cache, new_seq_length: int, layer_idx: int = 0) -> int:
1041
+ """Compute the usable past length for the given cache and upcoming new sequence length.
1042
+
1043
+ This mirrors the previous `get_usable_length(new_seq_length, layer_idx)` behavior that existed in
1044
+ Transformers < 4.45, while being compatible with the new Cache API.
1045
+ """
1046
+ try:
1047
+ previous_length = cache.get_seq_length(layer_idx)
1048
+ # Dynamic layers return -1, static layers return an int
1049
+ max_length = cache.get_max_cache_shape(layer_idx)
1050
+ if max_length is not None and max_length != -1 and previous_length + new_seq_length > max_length:
1051
+ return max_length - new_seq_length
1052
+ return previous_length
1053
+ except Exception:
1054
+ # Best-effort fallback
1055
+ return cache.get_seq_length(layer_idx) if hasattr(cache, "get_seq_length") else 0
1056
+
1057
+
1058
  # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
1059
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
1060
  """
 
1152
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
1153
  "with a layer index."
1154
  )
1155
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
1156
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
1157
 
1158
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
1247
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
1248
  "with a layer index."
1249
  )
1250
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
1251
 
1252
  # Because the input can be padded, the absolute sequence length depends on the max position id.
1253
  rotary_seq_len = (
 
1369
 
1370
  kv_seq_len = key_states.shape[-2]
1371
  if past_key_value is not None:
1372
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
1373
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
1374
 
1375
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)