dagloop5 commited on
Commit
c963051
·
verified ·
1 Parent(s): 1c287c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py CHANGED
@@ -95,6 +95,73 @@ try:
95
  except Exception as e:
96
  print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  logging.getLogger().setLevel(logging.INFO)
99
 
100
  MAX_SEED = np.iinfo(np.int32).max
 
95
  except Exception as e:
96
  print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
97
 
98
+ # Add this patch after imports in app.py
99
+
100
+ def _patch_attention_for_kv_cache():
101
+ """Patch Attention.forward to accept pre-projected K/V."""
102
+ from ltx_core.model.transformer.attention import Attention
103
+
104
+ _original_forward = Attention.forward
105
+
106
+ def patched_forward(
107
+ self,
108
+ x: torch.Tensor,
109
+ context: torch.Tensor | None = None,
110
+ mask: torch.Tensor | None = None,
111
+ pe: torch.Tensor | None = None,
112
+ k_pe: torch.Tensor | None = None,
113
+ perturbation_mask: torch.Tensor | None = None,
114
+ all_perturbed: bool = False,
115
+ # NEW: pre-computed KV for cross-attention
116
+ cached_k: torch.Tensor | None = None,
117
+ cached_v: torch.Tensor | None = None,
118
+ ) -> torch.Tensor:
119
+ context = x if context is None else context
120
+ use_attention = not all_perturbed
121
+
122
+ v = cached_v if cached_v is not None else self.to_v(context)
123
+
124
+ if not use_attention:
125
+ out = v
126
+ else:
127
+ if cached_k is not None:
128
+ q = self.to_q(x)
129
+ q = self.q_norm(q)
130
+ k = cached_k
131
+ if pe is not None:
132
+ q = apply_rotary_emb(q, pe, self.rope_type)
133
+ k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)
134
+ else:
135
+ q = self.to_q(x)
136
+ k = self.to_k(context)
137
+
138
+ q = self.q_norm(q)
139
+ k = self.k_norm(k)
140
+
141
+ if pe is not None:
142
+ q = apply_rotary_emb(q, pe, self.rope_type)
143
+ k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)
144
+
145
+ out = self.attention_function(q, k, v, self.heads, mask)
146
+
147
+ if perturbation_mask is not None:
148
+ out = out * perturbation_mask + v * (1 - perturbation_mask)
149
+
150
+ # Gating logic remains the same
151
+ if self.to_gate_logits is not None:
152
+ gate_logits = self.to_gate_logits(x)
153
+ b, t, _ = out.shape
154
+ out = out.view(b, t, self.heads, self.dim_head)
155
+ gates = 2.0 * torch.sigmoid(gate_logits)
156
+ out = out * gates.unsqueeze(-1)
157
+ out = out.view(b, t, self.heads * self.dim_head)
158
+
159
+ return self.to_out(out)
160
+
161
+ Attention.forward = patched_forward
162
+
163
+ _patch_attention_for_kv_cache()
164
+
165
  logging.getLogger().setLevel(logging.INFO)
166
 
167
  MAX_SEED = np.iinfo(np.int32).max