Nishant2414 odb9402 commited on
Commit
ef79f27
·
0 Parent(s):

Duplicate from Motif-Technologies/Motif-Video-2B

Browse files

Co-authored-by: Dongpin oh <odb9402@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/architecture.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/banner.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/showcase_i2v.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/showcase_t2v.png filter=lfs diff=lfs merge=lfs -text
40
+ tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
41
+ assets/i2v_sample.jpg filter=lfs diff=lfs merge=lfs -text
42
+ motif-video-technical-report.pdf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Claude Code / Codex
2
+ .claude/
3
+ .codex/
4
+ .codex-review-latest.md
5
+ .hook-state
6
+ .hook-state.lock
7
+ .plans/
8
+ .manuals/
9
+ CLAUDE.md
10
+ _old_claude_files/
11
+
12
+ # Internal / test
13
+ test_local.py
14
+
15
+ # Environment
16
+ .env
17
+ tmp/
18
+
19
+ # Python
20
+ *.pyc
21
+ __pycache__/
22
+
23
+ # Experiments
24
+ results/
25
+ experiments/**/outputs/
26
+ experiments/**/checkpoints/
27
+ experiments/**/*.pt
28
+ experiments/**/*.ckpt
29
+ assets/i2v_sample.jpg
README.md ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - text-to-video
7
+ - image-to-video
8
+ - video-generation
9
+ - diffusion-transformer
10
+ pipeline_tag: text-to-video
11
+ library_name: diffusers
12
+ ---
13
+
14
+ <p align="center">
15
+ <img src="assets/banner.png" width="100%" alt="Motif-Video 2B teaser"/>
16
+ </p>
17
+
18
+ <p align="center">
19
+ <h1 align="center">Motif-Video 2B</h1>
20
+ </p>
21
+
22
+ <p align="center">
23
+ <b>A micro-budget text-to-video diffusion transformer from Motif Technologies</b>
24
+ </p>
25
+
26
+ <p align="center">
27
+ 📑 <a href="https://huggingface.co/Motif-Technologies/Motif-Video-2B/blob/main/motif-video-technical-report.pdf">Technical Report</a> &nbsp;|&nbsp;
28
+ 🤗 <a href="">Hugging Face</a> &nbsp;|&nbsp;
29
+ 🌐 <a href="https://motiftech.io/videoshowcase">Project Page</a>
30
+ </p>
31
+
32
+ ---
33
+
34
+ ## 🔥 News
35
+
36
+ - **[2026-04-14]** We release **Motif-Video 2B**, our 2B-parameter text-to-video and image-to-video diffusion transformer, together with the full [technical report](https://huggingface.co/Motif-Technologies/Motif-Video-2B/blob/main/motif-video-technical-report.pdf).
37
+
38
+ ---
39
+
40
+ ## 📖 Introduction
41
+
42
+ Training strong video generation models usually requires massive datasets, large parameter counts, and substantial compute. **Motif-Video 2B** asks whether competitive text-to-video quality is reachable at a much smaller budget — fewer than **10M training clips** and under **100,000 H200 GPU hours** — and shows that the answer is yes, provided the model design explicitly separates objectives that scaling would otherwise leave entangled.
43
+
44
+ Our central observation is that prompt alignment, temporal consistency, and fine-detail recovery interfere with one another when handled through the same pathway. Motif-Video 2B addresses this **objective interference** architecturally rather than relying on scale alone, through two contributions:
45
+
46
+ - **Shared Cross-Attention.** A residual cross-attention mechanism that reuses self-attention K/V weights to stabilize text–video alignment under long-context token sparsity, where standard joint attention dilutes text influence as the video token sequence grows.
47
+ - **Three-stage DDT-style backbone.** 12 dual-stream + 16 single-stream + 8 DDT decoder layers, separating early modality fusion, joint representation learning, and high-frequency detail reconstruction into dedicated components. Per-block attention analysis shows that the DDT decoder spontaneously develops inter-frame attention structure absent from the encoder layers.
48
+
49
+ These are paired with a micro-budget training recipe combining **TREAD token routing** and early-phase **REPA** with a frozen **V-JEPA** teacher — to our knowledge, the first time this combination has been applied to text-to-video training.
50
+
51
+ On VBench, Motif-Video 2B reaches **83.76%**, the highest Total Score among open-source models we evaluate, surpassing Wan2.1-14B at **7× fewer parameters** and roughly an order of magnitude less training data.
52
+
53
+ <!--
54
+ Architecture figure — replace with Figure 2 from the technical report
55
+ (the three-stage backbone + Shared Cross-Attention diagram).
56
+ -->
57
+ <p align="center">
58
+ <img src="assets/architecture.png" width="90%" alt="Motif-Video 2B architecture"/>
59
+ </p>
60
+
61
+ ---
62
+
63
+ ## ✨ Highlights
64
+
65
+ - **Two tasks, one set of weights.** A single checkpoint handles both **text-to-video (T2V)** and **image-to-video (I2V)** generation, trained jointly without a learnable task-type embedding.
66
+ - **Up to 720p, 121 frames.** The final model generates 720p video at 121 frames under the standard rectified flow-matching sampler.
67
+ - **Architectural specialization over brute-force scale.** Three-stage backbone with role-separated dual-stream / single-stream / DDT decoder layers.
68
+ - **Shared Cross-Attention.** Stabilizes text alignment under long video-token sequences by grounding cross-attention K/V in the self-attention manifold.
69
+ - **Micro-budget recipe.** TREAD token routing (≈27% per-step FLOP reduction) + early-phase REPA with V-JEPA teacher + offline bucket-balanced sampler (≈90% data utilization, up from ≈20% baseline).
70
+ - **Open and reproducible.** Trained on ~64×H200 GPUs with FSDP2, full curriculum and recipe documented in the technical report.
71
+
72
+ ---
73
+
74
+ ## 🏗️ Architecture
75
+
76
+ Motif-Video 2B is a flow-matching diffusion transformer organized around a single principle: each component is assigned a well-defined responsibility, and components with conflicting objectives are not asked to share capacity.
77
+
78
+ | Component | Choice |
79
+ |---|---|
80
+ | Text encoder | T5Gemma2 (encoder–decoder, UL2-adapted Gemma 3) |
81
+ | Video tokenizer | Wan2.1 VAE (8×8 spatial, 4× temporal compression), 2×2×1 patchify |
82
+ | Backbone | 12 dual-stream + 16 single-stream + 8 DDT decoder layers |
83
+ | Hidden dim / heads | 1536 / 12 heads × 128 |
84
+ | Normalization | QK-normalization throughout |
85
+ | Position encoding | RoPE |
86
+ | Cross-attention | **Shared Cross-Attention** in the single-stream stage |
87
+ | Objective | Rectified flow matching (velocity prediction) |
88
+ | I2V conditioning | First-frame latent + SigLIP image embeddings, with timestep-aware blur |
89
+
90
+ A high-level walkthrough of the role separation:
91
+
92
+ 1. **Dual-stream stage (12 layers).** Text and video tokens are processed through separate self-attention pathways, exchanging information via cross-attention. This prevents premature feature entanglement before either modality has formed coherent representations.
93
+ 2. **Single-stream stage (16 layers).** Text and video tokens attend freely in a joint sequence. **Shared Cross-Attention** is attached here to repair the text-attention dilution that emerges as the video token sequence grows.
94
+ 3. **DDT decoder (8 layers).** A dedicated velocity decoder atop the 28-layer encoder, freeing the encoder from high-frequency detail reconstruction. Per-block attention analysis shows that the DDT decoder develops inter-frame attention structure that single-stream layers do not.
95
+
96
+ For the full derivation of why Shared Cross-Attention shares K/V but not Q, and why this is necessary in addition to standard zero-init of W_O, see Section 3.3 of the [technical report](https://huggingface.co/Motif-Technologies/Motif-Video-2B/blob/main/motif-video-technical-report.pdf).
97
+
98
+ <!--
99
+ Optional: insert Figure 3 (attention heatmaps across the three stages)
100
+ here as a secondary architecture figure. It is the strongest visual
101
+ evidence for the role-separation argument.
102
+ -->
103
+
104
+ ---
105
+
106
+ ## 🚀 Quickstart / Usage
107
+
108
+ ### Requirements
109
+
110
+ - Python 3.10+
111
+ - CUDA-capable GPU with **24GB+ VRAM** (e.g., A100, H100, RTX 4090)
112
+
113
+ ```bash
114
+ pip install "diffusers>=0.35.2" "transformers>=5.0.0" torch accelerate ftfy einops sentencepiece regex Pillow
115
+ ```
116
+
117
+ ### Text-to-Video (T2V)
118
+
119
+ ```python
120
+ import torch
121
+ from diffusers import AdaptiveProjectedGuidance, DiffusionPipeline
122
+ from diffusers.utils import export_to_video
123
+
124
+ guider = AdaptiveProjectedGuidance(
125
+ guidance_scale=8.0,
126
+ adaptive_projected_guidance_rescale=12.0,
127
+ adaptive_projected_guidance_momentum=0.1,
128
+ use_original_formulation=True,
129
+ )
130
+
131
+ pipe = DiffusionPipeline.from_pretrained(
132
+ "Motif-Technologies/Motif-Video-2B",
133
+ custom_pipeline="pipeline_motif_video",
134
+ trust_remote_code=True,
135
+ torch_dtype=torch.bfloat16,
136
+ guider=guider,
137
+ )
138
+ pipe = pipe.to("cuda")
139
+
140
+ output = pipe(
141
+ prompt="A category-five hurricane, viewed from inside the eye, reveals a circular stadium of cloud walls rising to fifty thousand feet with an eerie disk of blue sky directly overhead. Shot from a NOAA reconnaissance aircraft mounted camera, the perspective looks outward toward the eyewall — a near-vertical curtain of rotating cloud and lightning that is simultaneously terrifying and transcendent. The inner surface of the eyewall catches the setting sun, painting it in improbable shades of peach and rose. The camera slowly pans 360 degrees to complete one full revolution, capturing the entire coliseum of the storm. Below, the ocean surface is a white blur of foam and spray. The documentary-style cinematography strips away all artifice to present the storm as an entity of pure elemental power.",
142
+ height=736,
143
+ width=1280,
144
+ num_frames=121,
145
+ num_inference_steps=50,
146
+ )
147
+
148
+ export_to_video(output.frames[0], "output.mp4", fps=24)
149
+ ```
150
+
151
+ ### Image-to-Video (I2V)
152
+
153
+ ```python
154
+ import torch
155
+ from diffusers import AdaptiveProjectedGuidance, DiffusionPipeline
156
+ from diffusers.utils import export_to_video, load_image
157
+
158
+ guider = AdaptiveProjectedGuidance(
159
+ guidance_scale=8.0,
160
+ adaptive_projected_guidance_rescale=12.0,
161
+ adaptive_projected_guidance_momentum=0.1,
162
+ use_original_formulation=True,
163
+ )
164
+
165
+ pipe = DiffusionPipeline.from_pretrained(
166
+ "Motif-Technologies/Motif-Video-2B",
167
+ custom_pipeline="pipeline_motif_video",
168
+ trust_remote_code=True,
169
+ torch_dtype=torch.bfloat16,
170
+ guider=guider,
171
+ )
172
+ pipe = pipe.to("cuda")
173
+
174
+ image = load_image("https://huggingface.co/Motif-Technologies/Motif-Video-2B/resolve/main/assets/i2v_sample.jpg")
175
+
176
+ output = pipe(
177
+ prompt="Three friends stride through a sun-bleached meadow as a warm breeze ripples the tall dry grass around their legs. The woman on the left turns her head to share a quiet laugh, the woman in the center pushes a loose curl behind her ear, and the man on the right tilts his face toward the sky. The camera drifts gently alongside them at walking pace, handheld, with soft overcast light.",
178
+ image=image,
179
+ height=736,
180
+ width=1280,
181
+ num_frames=121,
182
+ num_inference_steps=50,
183
+ )
184
+
185
+ export_to_video(output.frames[0], "output.mp4", fps=24)
186
+ ```
187
+
188
+ ### CLI Inference
189
+
190
+ ```bash
191
+ # Text-to-Video
192
+ python inference.py \
193
+ --prompt "A time-lapse of a flower blooming in a dark room, dramatic lighting" \
194
+ --output t2v_output.mp4
195
+
196
+ # Image-to-Video
197
+ python inference.py \
198
+ --image assets/i2v_sample.jpg \
199
+ --prompt "Three friends stride through a meadow as a warm breeze ripples the tall grass" \
200
+ --output i2v_output.mp4
201
+ ```
202
+
203
+ See `inference.py` for all available options (`--help`).
204
+
205
+ ### Recommended Settings
206
+
207
+ | Parameter | Default | Notes |
208
+ |---|---|---|
209
+ | Resolution | 1280x736 | 720p, best quality |
210
+ | Frames | 121 | ~5 seconds at 24fps |
211
+ | Guidance scale | 8.0 | |
212
+ | Scheduler shift | 15.0 | Pre-configured in scheduler config |
213
+ | Inference steps | 50 | |
214
+ | dtype | bfloat16 | Recommended for H100/A100 |
215
+
216
+ ---
217
+
218
+ ## 📊 Performance
219
+
220
+ ### VBench
221
+
222
+ Motif-Video 2B achieves the highest **Total Score** among open-source models we evaluate.
223
+
224
+ | Model | Params | Total | Quality | Semantic |
225
+ |---|---|---|---|---|
226
+ | Wan2.2-T2V (prompt-opt.) | A14B | 84.23 | 85.42 | 79.50 |
227
+ | **Motif-Video 2B (Ours)** | **2B** | **83.76** | **84.59** | **80.44** |
228
+ | SANA-Video | 2B | 83.71 | 84.35 | 81.35 |
229
+ | Wan2.1-T2V | 14B | 83.69 | 85.59 | 76.11 |
230
+ | OpenSora 2.0 (T2I2V) | 11B | 83.60 | 84.40 | 80.30 |
231
+ | Wan2.1-T2V | 1.3B | 83.31 | 85.23 | 75.65 |
232
+ | HunyuanVideo | 13B | 83.24 | 85.09 | 75.82 |
233
+ | CogVideoX1.5-5B (prompt-opt.) | 5B | 82.17 | 82.78 | 79.76 |
234
+ | Step-Video-T2V | 30B | 81.83 | 84.46 | 71.28 |
235
+ | LTX-Video | 2B | 80.00 | 82.30 | 70.79 |
236
+
237
+ Notable per-dimension highlights for Motif-Video 2B (open-source):
238
+
239
+ - **Spatial Relationship: 83.02%** — best among open-source models
240
+ - **Semantic Score: 80.44%** — highest among open-source models reporting per-dimension results
241
+ - **Object Class: 92.93%**, **Multiple Objects: 77.29%**, **Imaging Quality: 70.50%** — second-best in their categories
242
+
243
+ The full 16-dimension breakdown is in Table 3 of the [technical report](https://huggingface.co/Motif-Technologies/Motif-Video-2B/blob/main/motif-video-technical-report.pdf).
244
+
245
+ > **A note on VBench vs. perceptual quality.** Motif-Video 2B leads on VBench Total Score, but in our internal side-by-side comparisons against Wan2.1-T2V-14B we observe a perceptual gap in favor of the larger model on temporal stability and fine human anatomy. We discuss the sources of this gap (uniform dimension weighting, near-correct semantic credit) in Section 7 of the report. We report the gap explicitly rather than smoothing it over.
246
+
247
+ ### Human evaluation
248
+
249
+ In a blind pairwise study against six contemporaneous open-source baselines (SANA-Video, LTX-Video 2, Wan2.1-14B, Wan2.1-1.3B, Wan2.2-5B, CogVideoX-5B) on 40 LLM-generated prompts, Motif-Video 2B is preferred over both **SANA-Video** (similar parameter count) and **Wan2.1-1.3B** (similar parameter count, larger training corpus) on prompt-following and video-fidelity axes. Wan2.1-14B remains the preferred model overall, consistent with its 7× larger parameter count and substantially larger training data.
250
+
251
+ ---
252
+
253
+ ## 🎬 Showcase
254
+
255
+ <!--
256
+ Insert the qualitative grids from the technical report here:
257
+ - Figure 1 / Figure 12: T2V multi-prompt frame strips
258
+ - Figure 13: I2V example (input image + generated frames)
259
+ Use full-width or 2-column layout, matching Wan2.1's "Showcase" section.
260
+ -->
261
+
262
+ ### Text-to-Video
263
+
264
+ <p align="center">
265
+ <img src="assets/showcase_t2v.png" width="100%" alt="Motif-Video 2B T2V samples"/>
266
+ </p>
267
+
268
+ ### Image-to-Video
269
+
270
+ <p align="center">
271
+ <img src="assets/showcase_i2v.png" width="100%" alt="Motif-Video 2B I2V samples"/>
272
+ </p>
273
+
274
+ ---
275
+
276
+ ## ⚠️ Limitations
277
+
278
+ We report limitations as the boundary conditions under which the design decisions in this report should be interpreted, not as caveats.
279
+
280
+ - **Micro-scale semantic distortion.** Motif-Video 2B occasionally produces sub-object-level artifacts that leave the category label intact but break perceptual plausibility — distorted hands on close-up human subjects, degraded body structure under high-displacement motion, and attribute leakage between visually similar co-present subjects. We attribute these primarily to data coverage rather than backbone design.
281
+ - **Temporal failures.** Three distinct modes that frame-level metrics do not surface: (i) physically implausible liquid / cloth / collision dynamics, (ii) coherence loss under high scene complexity (multi-agent crowds), and (iii) unintended mid-clip scene transitions in long sequences.
282
+ - **Recipe components are evaluated jointly, not in isolation.** We do not present per-component ablations for Shared Cross-Attention, the DDT decoder, REPA phasing, or TREAD routing at full scale. Readers should interpret our results as evidence that the *composed* recipe works at 2B, not as a marginal-contribution claim about any single component.
283
+
284
+ We view temporal stability and data coverage — not architectural depth — as the primary remaining ceilings on this model. Both are the most natural axes for a future iteration that the current architecture is built to absorb.
285
+
286
+ ---
287
+
288
+ ## 📚 Citation
289
+
290
+ If you find Motif-Video 2B useful in your research, please cite:
291
+
292
+ ```bibtex
293
+ @techreport{motifvideo2b2026,
294
+ title = {Motif-Video 2B: Technical Report},
295
+ author = {Motif Technologies},
296
+ year = {2026},
297
+ institution = {Motif Technologies},
298
+ url = {https://huggingface.co/Motif-Technologies/Motif-Video-2B/blob/main/motif-video-technical-report.pdf}
299
+ }
300
+ ```
301
+
302
+ ---
303
+
304
+ ## 🙏 Acknowledgements
305
+
306
+ We build on a number of excellent open-source projects, including the **Wan2.1 VAE** [Wan Team, 2025], **T5Gemma / Gemma 3** [Google], **TREAD** [Krause et al., 2025], **REPA** with the **V-JEPA** family of visual encoders [Bardes et al.], **DDT** [Wang et al.], and the broader **diffusers** and **Accelerate** ecosystems. Compute was provisioned on Microsoft Azure and orchestrated with **SkyPilot** on Kubernetes.
307
+
308
+ ---
309
+
310
+ ## 📄 License
311
+
312
+ <!-- TODO: confirm final license — apache-2.0 placeholder above. -->
313
+
314
+ This model is released under the Apache 2.0 License. See `LICENSE` for details.
_fm_solvers_unipc.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
2
+ # Convert unipc for flow matching
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+
5
+ import math
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.schedulers.scheduling_utils import (
12
+ KarrasDiffusionSchedulers,
13
+ SchedulerMixin,
14
+ SchedulerOutput,
15
+ )
16
+ from diffusers.utils import deprecate, is_scipy_available
17
+
18
+
19
+ if is_scipy_available():
20
+ pass
21
+
22
+
23
+ class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
24
+ """
25
+ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
26
+
27
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
28
+ methods the library implements for all schedulers such as loading and saving.
29
+
30
+ Args:
31
+ num_train_timesteps (`int`, defaults to 1000):
32
+ The number of diffusion steps to train the model.
33
+ solver_order (`int`, default `2`):
34
+ The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
35
+ due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
36
+ unconditional sampling.
37
+ prediction_type (`str`, defaults to "flow_prediction"):
38
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
39
+ the flow of the diffusion process.
40
+ thresholding (`bool`, defaults to `False`):
41
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
42
+ as Stable Diffusion.
43
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
44
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
45
+ sample_max_value (`float`, defaults to 1.0):
46
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
47
+ predict_x0 (`bool`, defaults to `True`):
48
+ Whether to use the updating algorithm on the predicted x0.
49
+ solver_type (`str`, default `bh2`):
50
+ Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
51
+ otherwise.
52
+ lower_order_final (`bool`, default `True`):
53
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
54
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
55
+ disable_corrector (`list`, default `[]`):
56
+ Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
57
+ and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
58
+ usually disabled during the first few steps.
59
+ solver_p (`SchedulerMixin`, default `None`):
60
+ Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
61
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
62
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
63
+ the sigmas are determined according to a sequence of noise levels {σi}.
64
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
65
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
66
+ timestep_spacing (`str`, defaults to `"linspace"`):
67
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
68
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
69
+ steps_offset (`int`, defaults to 0):
70
+ An offset added to the inference steps, as required by some model families.
71
+ final_sigmas_type (`str`, defaults to `"zero"`):
72
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
73
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
74
+ """
75
+
76
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
77
+ order = 1
78
+
79
+ @register_to_config
80
+ def __init__(
81
+ self,
82
+ num_train_timesteps: int = 1000,
83
+ solver_order: int = 2,
84
+ prediction_type: str = "flow_prediction",
85
+ shift: Optional[float] = 1.0,
86
+ use_dynamic_shifting=False,
87
+ thresholding: bool = False,
88
+ dynamic_thresholding_ratio: float = 0.995,
89
+ sample_max_value: float = 1.0,
90
+ predict_x0: bool = True,
91
+ solver_type: str = "bh2",
92
+ lower_order_final: bool = True,
93
+ disable_corrector: List[int] = [],
94
+ solver_p: Optional[SchedulerMixin] = None,
95
+ timestep_spacing: str = "linspace",
96
+ steps_offset: int = 0,
97
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
98
+ ):
99
+ if solver_type not in ["bh1", "bh2"]:
100
+ if solver_type in ["midpoint", "heun", "logrho"]:
101
+ self.register_to_config(solver_type="bh2")
102
+ else:
103
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
104
+
105
+ self.predict_x0 = predict_x0
106
+ # setable values
107
+ self.num_inference_steps = None
108
+ alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
109
+ sigmas = 1.0 - alphas
110
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
111
+
112
+ if not use_dynamic_shifting:
113
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
114
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
115
+
116
+ self.sigmas = sigmas
117
+ self.timesteps = sigmas * num_train_timesteps
118
+
119
+ self.model_outputs = [None] * solver_order
120
+ self.timestep_list = [None] * solver_order
121
+ self.lower_order_nums = 0
122
+ self.disable_corrector = disable_corrector
123
+ self.solver_p = solver_p
124
+ self.last_sample = None
125
+ self._step_index = None
126
+ self._begin_index = None
127
+
128
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
129
+ self.sigma_min = self.sigmas[-1].item()
130
+ self.sigma_max = self.sigmas[0].item()
131
+
132
+ @property
133
+ def step_index(self):
134
+ """
135
+ The index counter for current timestep. It will increase 1 after each scheduler step.
136
+ """
137
+ return self._step_index
138
+
139
+ @property
140
+ def begin_index(self):
141
+ """
142
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
143
+ """
144
+ return self._begin_index
145
+
146
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
147
+ def set_begin_index(self, begin_index: int = 0):
148
+ """
149
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
150
+
151
+ Args:
152
+ begin_index (`int`):
153
+ The begin index for the scheduler.
154
+ """
155
+ self._begin_index = begin_index
156
+
157
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
158
+ def set_timesteps(
159
+ self,
160
+ num_inference_steps: Union[int, None] = None,
161
+ device: Optional[Union[str, torch.device]] = None,
162
+ sigmas: Optional[List[float]] = None,
163
+ mu: Optional[Union[float, None]] = None,
164
+ shift: Optional[Union[float, None]] = None,
165
+ ):
166
+ """
167
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
168
+ Args:
169
+ num_inference_steps (`int`):
170
+ Total number of the spacing of the time steps.
171
+ device (`str` or `torch.device`, *optional*):
172
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
173
+ """
174
+
175
+ if self.config.use_dynamic_shifting and mu is None:
176
+ raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
177
+
178
+ if sigmas is None:
179
+ sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore
180
+
181
+ if self.config.use_dynamic_shifting:
182
+ sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
183
+ else:
184
+ if shift is None:
185
+ shift = self.config.shift
186
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
187
+
188
+ if self.config.final_sigmas_type == "sigma_min":
189
+ sigma_last = self.config.sigma_min
190
+ elif self.config.final_sigmas_type == "zero":
191
+ sigma_last = 0
192
+ else:
193
+ raise ValueError(
194
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
195
+ )
196
+
197
+ timesteps = sigmas * self.config.num_train_timesteps
198
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
199
+
200
+ self.sigmas = torch.from_numpy(sigmas)
201
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
202
+
203
+ self.num_inference_steps = len(timesteps)
204
+
205
+ self.model_outputs = [
206
+ None,
207
+ ] * self.config.solver_order
208
+ self.lower_order_nums = 0
209
+ self.last_sample = None
210
+ if self.solver_p:
211
+ self.solver_p.set_timesteps(self.num_inference_steps, device=device)
212
+
213
+ # add an index counter for schedulers that allow duplicated timesteps
214
+ self._step_index = None
215
+ self._begin_index = None
216
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
217
+
218
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
219
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
220
+ """
221
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
222
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
223
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
224
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
225
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
226
+
227
+ https://arxiv.org/abs/2205.11487
228
+ """
229
+ dtype = sample.dtype
230
+ batch_size, channels, *remaining_dims = sample.shape
231
+
232
+ if dtype not in (torch.float32, torch.float64):
233
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
234
+
235
+ # Flatten sample for doing quantile calculation along each image
236
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
237
+
238
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
239
+
240
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
241
+ s = torch.clamp(
242
+ s, min=1, max=self.config.sample_max_value
243
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
244
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
245
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
246
+
247
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
248
+ sample = sample.to(dtype)
249
+
250
+ return sample
251
+
252
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
253
+ def _sigma_to_t(self, sigma):
254
+ return sigma * self.config.num_train_timesteps
255
+
256
+ def _sigma_to_alpha_sigma_t(self, sigma):
257
+ return 1 - sigma, sigma
258
+
259
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
260
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
261
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
262
+
263
+ def convert_model_output(
264
+ self,
265
+ model_output: torch.Tensor,
266
+ *args,
267
+ sample: Optional[torch.Tensor] = None,
268
+ **kwargs,
269
+ ) -> torch.Tensor:
270
+ r"""
271
+ Convert the model output to the corresponding type the UniPC algorithm needs.
272
+
273
+ Args:
274
+ model_output (`torch.Tensor`):
275
+ The direct output from the learned diffusion model.
276
+ timestep (`int`):
277
+ The current discrete timestep in the diffusion chain.
278
+ sample (`torch.Tensor`):
279
+ A current instance of a sample created by the diffusion process.
280
+
281
+ Returns:
282
+ `torch.Tensor`:
283
+ The converted model output.
284
+ """
285
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
286
+ if sample is None:
287
+ if len(args) > 1:
288
+ sample = args[1]
289
+ else:
290
+ raise ValueError("missing `sample` as a required keyward argument")
291
+ if timestep is not None:
292
+ deprecate(
293
+ "timesteps",
294
+ "1.0.0",
295
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
296
+ )
297
+
298
+ sigma = self.sigmas[self.step_index]
299
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
300
+
301
+ if self.predict_x0:
302
+ if self.config.prediction_type == "flow_prediction":
303
+ sigma_t = self.sigmas[self.step_index]
304
+ x0_pred = sample - sigma_t * model_output
305
+ else:
306
+ raise ValueError(
307
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
308
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
309
+ )
310
+
311
+ if self.config.thresholding:
312
+ x0_pred = self._threshold_sample(x0_pred)
313
+
314
+ return x0_pred
315
+ else:
316
+ if self.config.prediction_type == "flow_prediction":
317
+ sigma_t = self.sigmas[self.step_index]
318
+ epsilon = sample - (1 - sigma_t) * model_output
319
+ else:
320
+ raise ValueError(
321
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
322
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
323
+ )
324
+
325
+ if self.config.thresholding:
326
+ sigma_t = self.sigmas[self.step_index]
327
+ x0_pred = sample - sigma_t * model_output
328
+ x0_pred = self._threshold_sample(x0_pred)
329
+ epsilon = model_output + x0_pred
330
+
331
+ return epsilon
332
+
333
+ def multistep_uni_p_bh_update(
334
+ self,
335
+ model_output: torch.Tensor,
336
+ *args,
337
+ sample: Optional[torch.Tensor] = None,
338
+ order: Optional[int] = None,
339
+ **kwargs,
340
+ ) -> torch.Tensor:
341
+ """
342
+ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
343
+
344
+ Args:
345
+ model_output (`torch.Tensor`):
346
+ The direct output from the learned diffusion model at the current timestep.
347
+ prev_timestep (`int`):
348
+ The previous discrete timestep in the diffusion chain.
349
+ sample (`torch.Tensor`):
350
+ A current instance of a sample created by the diffusion process.
351
+ order (`int`):
352
+ The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
353
+
354
+ Returns:
355
+ `torch.Tensor`:
356
+ The sample tensor at the previous timestep.
357
+ """
358
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
359
+ if sample is None:
360
+ if len(args) > 1:
361
+ sample = args[1]
362
+ else:
363
+ raise ValueError(" missing `sample` as a required keyward argument")
364
+ if order is None:
365
+ if len(args) > 2:
366
+ order = args[2]
367
+ else:
368
+ raise ValueError(" missing `order` as a required keyward argument")
369
+ if prev_timestep is not None:
370
+ deprecate(
371
+ "prev_timestep",
372
+ "1.0.0",
373
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
374
+ )
375
+ model_output_list = self.model_outputs
376
+
377
+ s0 = self.timestep_list[-1]
378
+ m0 = model_output_list[-1]
379
+ x = sample
380
+
381
+ if self.solver_p:
382
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
383
+ return x_t
384
+
385
+ sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore
386
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
387
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
388
+
389
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
390
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
391
+
392
+ h = lambda_t - lambda_s0
393
+ device = sample.device
394
+
395
+ rks = []
396
+ D1s = []
397
+ for i in range(1, order):
398
+ si = self.step_index - i # pyright: ignore
399
+ mi = model_output_list[-(i + 1)]
400
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
401
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
402
+ rk = (lambda_si - lambda_s0) / h
403
+ rks.append(rk)
404
+ D1s.append((mi - m0) / rk) # pyright: ignore
405
+
406
+ rks.append(1.0)
407
+ rks = torch.tensor(rks, device=device)
408
+
409
+ R = []
410
+ b = []
411
+
412
+ hh = -h if self.predict_x0 else h
413
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
414
+ h_phi_k = h_phi_1 / hh - 1
415
+
416
+ factorial_i = 1
417
+
418
+ if self.config.solver_type == "bh1":
419
+ B_h = hh
420
+ elif self.config.solver_type == "bh2":
421
+ B_h = torch.expm1(hh)
422
+ else:
423
+ raise NotImplementedError()
424
+
425
+ for i in range(1, order + 1):
426
+ R.append(torch.pow(rks, i - 1))
427
+ b.append(h_phi_k * factorial_i / B_h)
428
+ factorial_i *= i + 1
429
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
430
+
431
+ R = torch.stack(R)
432
+ b = torch.tensor(b, device=device)
433
+
434
+ if len(D1s) > 0:
435
+ D1s = torch.stack(D1s, dim=1) # (B, K)
436
+ # for order 2, we use a simplified version
437
+ if order == 2:
438
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
439
+ else:
440
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
441
+ else:
442
+ D1s = None
443
+
444
+ if self.predict_x0:
445
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
446
+ if D1s is not None:
447
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
448
+ else:
449
+ pred_res = 0
450
+ x_t = x_t_ - alpha_t * B_h * pred_res
451
+ else:
452
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
453
+ if D1s is not None:
454
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
455
+ else:
456
+ pred_res = 0
457
+ x_t = x_t_ - sigma_t * B_h * pred_res
458
+
459
+ x_t = x_t.to(x.dtype)
460
+ return x_t
461
+
462
+ def multistep_uni_c_bh_update(
463
+ self,
464
+ this_model_output: torch.Tensor,
465
+ *args,
466
+ last_sample: Optional[torch.Tensor] = None,
467
+ this_sample: Optional[torch.Tensor] = None,
468
+ order: Optional[int] = None,
469
+ **kwargs,
470
+ ) -> torch.Tensor:
471
+ """
472
+ One step for the UniC (B(h) version).
473
+
474
+ Args:
475
+ this_model_output (`torch.Tensor`):
476
+ The model outputs at `x_t`.
477
+ this_timestep (`int`):
478
+ The current timestep `t`.
479
+ last_sample (`torch.Tensor`):
480
+ The generated sample before the last predictor `x_{t-1}`.
481
+ this_sample (`torch.Tensor`):
482
+ The generated sample after the last predictor `x_{t}`.
483
+ order (`int`):
484
+ The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
485
+
486
+ Returns:
487
+ `torch.Tensor`:
488
+ The corrected sample tensor at the current timestep.
489
+ """
490
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
491
+ if last_sample is None:
492
+ if len(args) > 1:
493
+ last_sample = args[1]
494
+ else:
495
+ raise ValueError(" missing`last_sample` as a required keyward argument")
496
+ if this_sample is None:
497
+ if len(args) > 2:
498
+ this_sample = args[2]
499
+ else:
500
+ raise ValueError(" missing`this_sample` as a required keyward argument")
501
+ if order is None:
502
+ if len(args) > 3:
503
+ order = args[3]
504
+ else:
505
+ raise ValueError(" missing`order` as a required keyward argument")
506
+ if this_timestep is not None:
507
+ deprecate(
508
+ "this_timestep",
509
+ "1.0.0",
510
+ "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
511
+ )
512
+
513
+ model_output_list = self.model_outputs
514
+
515
+ m0 = model_output_list[-1]
516
+ x = last_sample
517
+ x_t = this_sample
518
+ model_t = this_model_output
519
+
520
+ sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore
521
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
522
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
523
+
524
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
525
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
526
+
527
+ h = lambda_t - lambda_s0
528
+ device = this_sample.device
529
+
530
+ rks = []
531
+ D1s = []
532
+ for i in range(1, order):
533
+ si = self.step_index - (i + 1) # pyright: ignore
534
+ mi = model_output_list[-(i + 1)]
535
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
536
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
537
+ rk = (lambda_si - lambda_s0) / h
538
+ rks.append(rk)
539
+ D1s.append((mi - m0) / rk) # pyright: ignore
540
+
541
+ rks.append(1.0)
542
+ rks = torch.tensor(rks, device=device)
543
+
544
+ R = []
545
+ b = []
546
+
547
+ hh = -h if self.predict_x0 else h
548
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
549
+ h_phi_k = h_phi_1 / hh - 1
550
+
551
+ factorial_i = 1
552
+
553
+ if self.config.solver_type == "bh1":
554
+ B_h = hh
555
+ elif self.config.solver_type == "bh2":
556
+ B_h = torch.expm1(hh)
557
+ else:
558
+ raise NotImplementedError()
559
+
560
+ for i in range(1, order + 1):
561
+ R.append(torch.pow(rks, i - 1))
562
+ b.append(h_phi_k * factorial_i / B_h)
563
+ factorial_i *= i + 1
564
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
565
+
566
+ R = torch.stack(R)
567
+ b = torch.tensor(b, device=device)
568
+
569
+ if len(D1s) > 0:
570
+ D1s = torch.stack(D1s, dim=1)
571
+ else:
572
+ D1s = None
573
+
574
+ # for order 1, we use a simplified version
575
+ if order == 1:
576
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
577
+ else:
578
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
579
+
580
+ if self.predict_x0:
581
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
582
+ if D1s is not None:
583
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
584
+ else:
585
+ corr_res = 0
586
+ D1_t = model_t - m0
587
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
588
+ else:
589
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
590
+ if D1s is not None:
591
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
592
+ else:
593
+ corr_res = 0
594
+ D1_t = model_t - m0
595
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
596
+ x_t = x_t.to(x.dtype)
597
+ return x_t
598
+
599
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
600
+ if schedule_timesteps is None:
601
+ schedule_timesteps = self.timesteps
602
+
603
+ indices = (schedule_timesteps == timestep).nonzero()
604
+
605
+ # The sigma index that is taken for the **very** first `step`
606
+ # is always the second index (or the last index if there is only 1)
607
+ # This way we can ensure we don't accidentally skip a sigma in
608
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
609
+ pos = 1 if len(indices) > 1 else 0
610
+
611
+ return indices[pos].item()
612
+
613
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
614
+ def _init_step_index(self, timestep):
615
+ """
616
+ Initialize the step_index counter for the scheduler.
617
+ """
618
+
619
+ if self.begin_index is None:
620
+ if isinstance(timestep, torch.Tensor):
621
+ timestep = timestep.to(self.timesteps.device)
622
+ self._step_index = self.index_for_timestep(timestep)
623
+ else:
624
+ self._step_index = self._begin_index
625
+
626
+ def step(
627
+ self,
628
+ model_output: torch.Tensor,
629
+ timestep: Union[int, torch.Tensor],
630
+ sample: torch.Tensor,
631
+ return_dict: bool = True,
632
+ generator=None,
633
+ ) -> Union[SchedulerOutput, Tuple]:
634
+ """
635
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
636
+ the multistep UniPC.
637
+
638
+ Args:
639
+ model_output (`torch.Tensor`):
640
+ The direct output from learned diffusion model.
641
+ timestep (`int`):
642
+ The current discrete timestep in the diffusion chain.
643
+ sample (`torch.Tensor`):
644
+ A current instance of a sample created by the diffusion process.
645
+ return_dict (`bool`):
646
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
647
+
648
+ Returns:
649
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
650
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
651
+ tuple is returned where the first element is the sample tensor.
652
+
653
+ """
654
+ if self.num_inference_steps is None:
655
+ raise ValueError(
656
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
657
+ )
658
+
659
+ if self.step_index is None:
660
+ self._init_step_index(timestep)
661
+
662
+ use_corrector = (
663
+ self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None # pyright: ignore
664
+ )
665
+
666
+ model_output_convert = self.convert_model_output(model_output, sample=sample)
667
+ if use_corrector:
668
+ sample = self.multistep_uni_c_bh_update(
669
+ this_model_output=model_output_convert,
670
+ last_sample=self.last_sample,
671
+ this_sample=sample,
672
+ order=self.this_order,
673
+ )
674
+
675
+ for i in range(self.config.solver_order - 1):
676
+ self.model_outputs[i] = self.model_outputs[i + 1]
677
+ self.timestep_list[i] = self.timestep_list[i + 1]
678
+
679
+ self.model_outputs[-1] = model_output_convert
680
+ self.timestep_list[-1] = timestep # pyright: ignore
681
+
682
+ if self.config.lower_order_final:
683
+ this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore
684
+ else:
685
+ this_order = self.config.solver_order
686
+
687
+ self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
688
+ assert self.this_order > 0
689
+
690
+ self.last_sample = sample
691
+ prev_sample = self.multistep_uni_p_bh_update(
692
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
693
+ sample=sample,
694
+ order=self.this_order,
695
+ )
696
+
697
+ if self.lower_order_nums < self.config.solver_order:
698
+ self.lower_order_nums += 1
699
+
700
+ # upon completion increase step index by one
701
+ self._step_index += 1 # pyright: ignore
702
+
703
+ if not return_dict:
704
+ return (prev_sample,)
705
+
706
+ return SchedulerOutput(prev_sample=prev_sample)
707
+
708
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
709
+ """
710
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
711
+ current timestep.
712
+
713
+ Args:
714
+ sample (`torch.Tensor`):
715
+ The input sample.
716
+
717
+ Returns:
718
+ `torch.Tensor`:
719
+ A scaled input sample.
720
+ """
721
+ return sample
722
+
723
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
724
+ def add_noise(
725
+ self,
726
+ original_samples: torch.Tensor,
727
+ noise: torch.Tensor,
728
+ timesteps: torch.IntTensor,
729
+ ) -> torch.Tensor:
730
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
731
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
732
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
733
+ # mps does not support float64
734
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
735
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
736
+ else:
737
+ schedule_timesteps = self.timesteps.to(original_samples.device)
738
+ timesteps = timesteps.to(original_samples.device)
739
+
740
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
741
+ if self.begin_index is None:
742
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
743
+ elif self.step_index is not None:
744
+ # add_noise is called after first denoising step (for inpainting)
745
+ step_indices = [self.step_index] * timesteps.shape[0]
746
+ else:
747
+ # add noise is called before first denoising step to create initial latent(img2img)
748
+ step_indices = [self.begin_index] * timesteps.shape[0]
749
+
750
+ sigma = sigmas[step_indices].flatten()
751
+ while len(sigma.shape) < len(original_samples.shape):
752
+ sigma = sigma.unsqueeze(-1)
753
+
754
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
755
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
756
+ return noisy_samples
757
+
758
+ def __len__(self):
759
+ return self.config.num_train_timesteps
assets/architecture.png ADDED

Git LFS Details

  • SHA256: 33f619ed4c78c185e5e40fec1b774dee5573e3f8f3a405785ebe9552b2a02c33
  • Pointer size: 131 Bytes
  • Size of remote file: 260 kB
assets/banner.png ADDED

Git LFS Details

  • SHA256: c01efdcc5579a31fb8926717fbc8ef24c317c9e2d852b4c084132b60da1ca602
  • Pointer size: 132 Bytes
  • Size of remote file: 7.57 MB
assets/showcase_i2v.png ADDED

Git LFS Details

  • SHA256: 2b605ed63797b53532df7a0a52e1168c7847b3f2e6c5c4a4dfad0b901648c2f7
  • Pointer size: 133 Bytes
  • Size of remote file: 11.4 MB
assets/showcase_t2v.png ADDED

Git LFS Details

  • SHA256: c78f3dd58ad8562a082275a955bbbfbfc85cb7e48ce3c28ca55fb1fb625ef139
  • Pointer size: 132 Bytes
  • Size of remote file: 3.85 MB
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": true,
3
+ "do_normalize": true,
4
+ "do_rescale": false,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "SiglipImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "resample": 3,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 896,
21
+ "width": 896
22
+ }
23
+ }
inference.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Motif-Video 2B — Text-to-Video & Image-to-Video inference.
3
+
4
+ GPU requirements: ~24GB VRAM for 720p (1280x736, 121 frames).
5
+ Tested with: torch>=2.0, diffusers>=0.35.2, transformers>=5.0.0
6
+
7
+ Uses Adaptive Projected Guidance (APG) by default for best quality.
8
+ """
9
+
10
+ import argparse
11
+
12
+ import torch
13
+ from diffusers import AdaptiveProjectedGuidance, DiffusionPipeline
14
+ from diffusers.utils import export_to_video
15
+
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(description="Motif-Video 2B Inference (T2V / I2V)")
19
+ parser.add_argument(
20
+ "--model-path",
21
+ type=str,
22
+ default="Motif-Technologies/Motif-Video-2B",
23
+ help="HuggingFace model ID or local checkpoint path (uses trust_remote_code=True)",
24
+ )
25
+ parser.add_argument(
26
+ "--prompt",
27
+ type=str,
28
+ default="A category-five hurricane, viewed from inside the eye, reveals a circular stadium of cloud walls rising to fifty thousand feet with an eerie disk of blue sky directly overhead. Shot from a NOAA reconnaissance aircraft mounted camera, the perspective looks outward toward the eyewall — a near-vertical curtain of rotating cloud and lightning that is simultaneously terrifying and transcendent. The inner surface of the eyewall catches the setting sun, painting it in improbable shades of peach and rose. The camera slowly pans 360 degrees to complete one full revolution, capturing the entire coliseum of the storm. Below, the ocean surface is a white blur of foam and spray. The documentary-style cinematography strips away all artifice to present the storm as an entity of pure elemental power.",
29
+ help="Text prompt for video generation",
30
+ )
31
+ parser.add_argument(
32
+ "--image",
33
+ type=str,
34
+ default=None,
35
+ help="Path to input image for I2V mode (omit for T2V)",
36
+ )
37
+ parser.add_argument(
38
+ "--negative-prompt",
39
+ type=str,
40
+ default=None,
41
+ help="Negative prompt (default: built-in pipeline default)",
42
+ )
43
+ parser.add_argument("--output", type=str, default="output.mp4", help="Output video file path")
44
+ parser.add_argument("--num-frames", type=int, default=121, help="Number of frames to generate (121 = ~5s at 24fps)")
45
+ parser.add_argument("--height", type=int, default=736, help="Video height in pixels")
46
+ parser.add_argument("--width", type=int, default=1280, help="Video width in pixels")
47
+ parser.add_argument("--guidance-scale", type=float, default=8.0, help="Classifier-free guidance scale")
48
+ parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of denoising steps")
49
+ parser.add_argument("--fps", type=int, default=24, help="Output video frame rate")
50
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
51
+ parser.add_argument(
52
+ "--dtype",
53
+ type=str,
54
+ default="bfloat16",
55
+ choices=["float16", "bfloat16", "float32"],
56
+ help="Model dtype",
57
+ )
58
+ return parser.parse_args()
59
+
60
+
61
+ def main():
62
+ args = parse_args()
63
+
64
+ dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
65
+ torch_dtype = dtype_map[args.dtype]
66
+
67
+ mode = "I2V" if args.image else "T2V"
68
+ print(f"[{mode}] Loading model from: {args.model_path}")
69
+
70
+ guider = AdaptiveProjectedGuidance(
71
+ guidance_scale=args.guidance_scale,
72
+ adaptive_projected_guidance_rescale=12.0,
73
+ adaptive_projected_guidance_momentum=0.1,
74
+ eta=0.0,
75
+ use_original_formulation=True,
76
+ )
77
+
78
+ pipe = DiffusionPipeline.from_pretrained(
79
+ args.model_path,
80
+ custom_pipeline="pipeline_motif_video",
81
+ trust_remote_code=True,
82
+ torch_dtype=torch_dtype,
83
+ guider=guider,
84
+ )
85
+ pipe = pipe.to("cuda")
86
+
87
+ generator = torch.Generator(device="cuda").manual_seed(args.seed)
88
+
89
+ # Load image for I2V mode
90
+ image = None
91
+ if args.image:
92
+ from PIL import Image
93
+
94
+ image = Image.open(args.image).convert("RGB")
95
+ print(f"[I2V] Input image: {args.image} ({image.size[0]}x{image.size[1]})")
96
+
97
+ print(f"Generating video: {args.width}x{args.height}, {args.num_frames} frames, {args.num_inference_steps} steps")
98
+ pipe_kwargs = dict(
99
+ prompt=args.prompt,
100
+ image=image,
101
+ height=args.height,
102
+ width=args.width,
103
+ num_frames=args.num_frames,
104
+ num_inference_steps=args.num_inference_steps,
105
+ generator=generator,
106
+ frame_rate=args.fps,
107
+ )
108
+ if args.negative_prompt is not None:
109
+ pipe_kwargs["negative_prompt"] = args.negative_prompt
110
+
111
+ output = pipe(**pipe_kwargs)
112
+
113
+ video_frames = output.frames[0]
114
+ export_to_video(video_frames, args.output, fps=args.fps)
115
+ print(f"Video saved to: {args.output}")
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
model_index.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MotifVideoPipeline",
3
+ "_diffusers_version": "0.35.2",
4
+ "scheduler": [
5
+ "diffusers",
6
+ "FlowMatchEulerDiscreteScheduler"
7
+ ],
8
+ "text_encoder": [
9
+ "transformers",
10
+ "T5Gemma2Model"
11
+ ],
12
+ "tokenizer": [
13
+ "transformers",
14
+ "GemmaTokenizer"
15
+ ],
16
+ "transformer": [
17
+ "transformer_motif_video",
18
+ "MotifVideoTransformer3DModel"
19
+ ],
20
+ "vae": [
21
+ "diffusers",
22
+ "AutoencoderKLWan"
23
+ ],
24
+ "feature_extractor": [
25
+ "transformers",
26
+ "SiglipImageProcessor"
27
+ ]
28
+ }
motif-video-technical-report.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f931b222d303bee5b62053b9734a021bb9310388cf575fbab83ffdcceca378ab
3
+ size 17260596
pipeline_motif_video.py ADDED
@@ -0,0 +1,1321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Motif Technologies, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ import ftfy
21
+ import numpy as np
22
+ import regex as re
23
+ import torch
24
+ from diffusers import (
25
+ AdaptiveProjectedGuidance,
26
+ AutoencoderKLWan,
27
+ ClassifierFreeGuidance,
28
+ DiffusionPipeline,
29
+ DPMSolverMultistepScheduler,
30
+ FlowMatchEulerDiscreteScheduler,
31
+ SkipLayerGuidance,
32
+ UniPCMultistepScheduler,
33
+ )
34
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
35
+ from diffusers.utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
36
+ from diffusers.utils.torch_utils import randn_tensor
37
+ from diffusers.video_processor import VideoProcessor
38
+ from einops import rearrange
39
+ from PIL import Image
40
+ from torch import Tensor
41
+
42
+ from diffusers.guiders.adaptive_projected_guidance import MomentumBuffer
43
+ from diffusers.guiders.guider_utils import GuiderOutput
44
+ from ._fm_solvers_unipc import FlowUniPCMultistepScheduler
45
+ from transformers import BatchEncoding, PreTrainedTokenizerBase, SiglipImageProcessor, T5Gemma2Model
46
+
47
+
48
+ if is_torch_xla_available():
49
+ import torch_xla.core.xla_model as xm
50
+
51
+ XLA_AVAILABLE = True
52
+ else:
53
+ XLA_AVAILABLE = False
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> import torch
61
+ >>> from diffusers import MotifVideoPipeline
62
+ >>> from diffusers.utils import export_to_video
63
+
64
+ >>> # Load the Motif Video pipeline
65
+ >>> motif_video_model_id = "MotifTechnologies/Motif-Video"
66
+ >>> pipe = MotifVideoPipeline.from_pretrained(motif_video_model_id, torch_dtype=torch.bfloat16)
67
+ >>> pipe.to("cuda")
68
+
69
+ >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
70
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
71
+
72
+ >>> video = pipe(
73
+ ... prompt=prompt,
74
+ ... negative_prompt=negative_prompt,
75
+ ... width=640,
76
+ ... height=352,
77
+ ... num_frames=65,
78
+ ... num_inference_steps=50,
79
+ ... ).frames[0]
80
+ >>> export_to_video(video, "output.mp4", fps=16)
81
+ ```
82
+ """
83
+
84
+
85
+ @dataclass
86
+ class MotifVideoPipelineOutput(BaseOutput):
87
+ r"""
88
+ Output class for Motif Video pipelines.
89
+
90
+ Args:
91
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
92
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
93
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
94
+ `(batch_size, num_frames, channels, height, width)`.
95
+ """
96
+
97
+ frames: torch.Tensor
98
+
99
+
100
+ """Video-aware Adaptive Projected Guidance (APG).
101
+
102
+ Standard APG normalizes over all spatial dimensions [C, T, H, W], which collapses
103
+ temporal variation. This module normalizes over [C, H, W] only, preserving
104
+ per-frame independence.
105
+ """
106
+
107
+
108
+ def video_normalized_guidance(
109
+ pred_cond: torch.Tensor,
110
+ pred_uncond: torch.Tensor,
111
+ guidance_scale: float,
112
+ momentum_buffer: MomentumBuffer | None = None,
113
+ eta: float = 1.0,
114
+ norm_threshold: float = 0.0,
115
+ use_original_formulation: bool = False,
116
+ ) -> torch.Tensor:
117
+ """APG with video-aware normalization: normalize over [C, H, W], exclude T.
118
+
119
+ For 5D input [B, C, T, H, W], dim=[-1, -2, -4] normalizes per-frame (W, H, C),
120
+ keeping the T dimension independent. For 4D input [B, C, H, W], falls back to
121
+ standard [-1, -2, -3] behavior.
122
+ """
123
+ diff = pred_cond - pred_uncond
124
+
125
+ if len(diff.shape) == 5:
126
+ # [B, C, T, H, W] → normalize over W(-1), H(-2), C(-4), skip T(-3)
127
+ dim = [-1, -2, -4]
128
+ else:
129
+ # [B, C, H, W] → standard behavior
130
+ dim = [-i for i in range(1, len(diff.shape))]
131
+
132
+ if momentum_buffer is not None:
133
+ momentum_buffer.update(diff)
134
+ diff = momentum_buffer.running_average
135
+
136
+ if norm_threshold > 0:
137
+ ones = torch.ones_like(diff)
138
+ diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
139
+ scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
140
+ diff = diff * scale_factor
141
+
142
+ v0, v1 = diff.double(), pred_cond.double()
143
+ v1 = torch.nn.functional.normalize(v1, dim=dim)
144
+ v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
145
+ v0_orthogonal = v0 - v0_parallel
146
+ diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
147
+ normalized_update = diff_orthogonal + eta * diff_parallel
148
+
149
+ pred = pred_cond if use_original_formulation else pred_uncond
150
+ pred = pred + guidance_scale * normalized_update
151
+
152
+ return pred
153
+
154
+
155
+ class VideoAdaptiveProjectedGuidance(AdaptiveProjectedGuidance):
156
+ """APG variant that normalizes over [C, H, W] per frame, excluding the T dimension."""
157
+
158
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
159
+ pred = None
160
+
161
+ if not self._is_apg_enabled():
162
+ pred = pred_cond
163
+ else:
164
+ pred = video_normalized_guidance(
165
+ pred_cond,
166
+ pred_uncond,
167
+ self.guidance_scale,
168
+ self.momentum_buffer,
169
+ self.eta,
170
+ self.adaptive_projected_guidance_rescale,
171
+ self.use_original_formulation,
172
+ )
173
+
174
+ if self.guidance_rescale > 0.0:
175
+ from diffusers.guiders.classifier_free_guidance import rescale_noise_cfg
176
+
177
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
178
+
179
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
180
+
181
+
182
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
183
+ def calculate_shift(
184
+ image_seq_len,
185
+ base_seq_len: int = 256,
186
+ max_seq_len: int = 4096,
187
+ base_shift: float = 0.5,
188
+ max_shift: float = 1.15,
189
+ ):
190
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
191
+ b = base_shift - m * base_seq_len
192
+ mu = image_seq_len * m + b
193
+ return mu
194
+
195
+
196
+ def get_linear_quadratic_sigmas(
197
+ num_inference_steps: int,
198
+ linear_quadratic_emulating_steps: int = 250,
199
+ ) -> np.ndarray:
200
+ """
201
+ Compute a linear-quadratic sigma schedule for flow matching.
202
+
203
+ This schedule combines:
204
+ - First half: Linear interpolation from high noise to medium noise (slow denoising)
205
+ - Second half: Quadratic interpolation from medium noise to clean (faster denoising)
206
+
207
+ Convention:
208
+ - sigma=1.0 represents pure noise
209
+ - sigma=0.0 represents clean image
210
+ - Output sigmas are in descending order (1.0 → ~0)
211
+
212
+ Args:
213
+ num_inference_steps: Total number of denoising steps (must be even).
214
+ linear_quadratic_emulating_steps: Controls the slope of linear interpolation.
215
+ Higher values result in gentler slope in the first half.
216
+
217
+ Returns:
218
+ np.ndarray: Array of sigma values with shape (num_inference_steps,).
219
+ The scheduler will append a terminal 0.
220
+
221
+ Raises:
222
+ ValueError: If num_inference_steps is not even.
223
+
224
+ Reference:
225
+ Linear-quadratic timestep schedule for improved flow matching inference.
226
+ """
227
+ if num_inference_steps % 2 != 0:
228
+ raise ValueError(
229
+ f"num_inference_steps must be even for linear-quadratic schedule, but got {num_inference_steps}"
230
+ )
231
+
232
+ steps = num_inference_steps
233
+ N = linear_quadratic_emulating_steps
234
+ half_steps = steps // 2
235
+
236
+ # First half: linear interpolation from 1 toward 0
237
+ # Takes first half_steps values from linspace(1, 0, N+1)
238
+ linear_part = np.linspace(1.0, 0.0, N + 1)[:half_steps]
239
+
240
+ # Second half: quadratic interpolation
241
+ # Formula: x^2 * (half_steps/N - 1) - (half_steps/N - 1)
242
+ # = (half_steps/N - 1) * (x^2 - 1)
243
+ # This maps x=0 to (half_steps/N - 1) * (-1) = 1 - half_steps/N
244
+ # and maps x=1 to 0
245
+ x = np.linspace(0.0, 1.0, half_steps + 1)
246
+ scale_factor = half_steps / N - 1 # negative value
247
+ quadratic_part = x**2 * scale_factor - scale_factor
248
+
249
+ # Concatenate and exclude the last 0 (scheduler appends terminal 0)
250
+ sigmas = np.concatenate([linear_part, quadratic_part])
251
+ sigmas = sigmas[:-1] # Remove trailing 0, scheduler will append it
252
+
253
+ return sigmas.astype(np.float32)
254
+
255
+
256
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
257
+ def retrieve_timesteps(
258
+ scheduler,
259
+ num_inference_steps: Optional[int] = None,
260
+ device: Optional[Union[str, torch.device]] = None,
261
+ timesteps: Optional[List[int]] = None,
262
+ sigmas: Optional[List[float]] = None,
263
+ use_linear_quadratic_schedule: bool = False,
264
+ linear_quadratic_emulating_steps: int = 250,
265
+ **kwargs,
266
+ ):
267
+ """
268
+ Retrieve timesteps from the scheduler.
269
+
270
+ Args:
271
+ scheduler: The noise scheduler to use.
272
+ num_inference_steps: Number of denoising steps.
273
+ device: Device to place timesteps on.
274
+ timesteps: Custom timestep values (mutually exclusive with sigmas).
275
+ sigmas: Custom sigma values (mutually exclusive with timesteps).
276
+ use_linear_quadratic_schedule: If True, use linear-quadratic sigma schedule.
277
+ This overrides the default linear schedule. Requires num_inference_steps
278
+ to be even.
279
+ linear_quadratic_emulating_steps: Controls the linear portion slope.
280
+ Higher values result in gentler slope in the first half. Default: 250.
281
+ **kwargs: Additional arguments passed to scheduler.set_timesteps().
282
+
283
+ Returns:
284
+ Tuple of (timesteps, num_inference_steps).
285
+
286
+ Raises:
287
+ ValueError: If both timesteps and sigmas are provided, or if
288
+ use_linear_quadratic_schedule is True but num_inference_steps is odd.
289
+ """
290
+ if timesteps is not None and sigmas is not None:
291
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
292
+
293
+ # Handle linear-quadratic schedule: compute sigmas if flag is set
294
+ if use_linear_quadratic_schedule:
295
+ if sigmas is not None:
296
+ raise ValueError(
297
+ "Cannot use both `sigmas` and `use_linear_quadratic_schedule`. "
298
+ "The linear-quadratic schedule computes sigmas automatically."
299
+ )
300
+ if num_inference_steps is None:
301
+ raise ValueError("`num_inference_steps` must be provided when using `use_linear_quadratic_schedule`.")
302
+ sigmas = get_linear_quadratic_sigmas(
303
+ num_inference_steps=num_inference_steps,
304
+ linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
305
+ )
306
+
307
+ if timesteps is not None:
308
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
309
+ if not accepts_timesteps:
310
+ raise ValueError(
311
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
312
+ f" timestep schedules. Please check whether you are using the correct scheduler."
313
+ )
314
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
315
+ timesteps = scheduler.timesteps
316
+ num_inference_steps = len(timesteps)
317
+ elif sigmas is not None:
318
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
319
+ if not accept_sigmas:
320
+ raise ValueError(
321
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
322
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
323
+ )
324
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
325
+ timesteps = scheduler.timesteps
326
+ num_inference_steps = len(timesteps)
327
+ else:
328
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
329
+ timesteps = scheduler.timesteps
330
+ return timesteps, num_inference_steps
331
+
332
+
333
+ def basic_clean(text):
334
+ text = ftfy.fix_text(text)
335
+ text = html.unescape(html.unescape(text))
336
+ return text.strip()
337
+
338
+
339
+ def whitespace_clean(text):
340
+ text = re.sub(r"\s+", " ", text)
341
+ text = text.strip()
342
+ return text
343
+
344
+
345
+ def prompt_clean(text):
346
+ text = whitespace_clean(basic_clean(text))
347
+ return text
348
+
349
+
350
+ class MotifVideoPipeline(DiffusionPipeline):
351
+ r"""
352
+ Pipeline for text-to-video generation using MotifVideoTransformer.
353
+
354
+ Args:
355
+ transformer ([`MotifVideoTransformer3DModel`]):
356
+ Conditional Transformer architecture to denoise the encoded video latents.
357
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
358
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
359
+ vae ([`AutoencoderKLWan`]):
360
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
361
+ text_encoder ([`T5Gemma2Model`]):
362
+ Primary text encoder for encoding text prompts into embeddings.
363
+ tokenizer ([`PreTrainedTokenizerBase`]):
364
+ Tokenizer corresponding to the primary text encoder.
365
+ guider ([`ClassifierFreeGuidance`] or [`SkipLayerGuidance`] or [`AdaptiveProjectedGuidance`] or [`VideoAdaptiveProjectedGuidance`], *optional*):
366
+ The guidance method to use. If `None`, it defaults to `ClassifierFreeGuidance()`.
367
+ """
368
+
369
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
370
+ _optional_components = ["feature_extractor"]
371
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
372
+
373
+ def __init__(
374
+ self,
375
+ scheduler: Union[
376
+ FlowMatchEulerDiscreteScheduler,
377
+ DPMSolverMultistepScheduler,
378
+ UniPCMultistepScheduler,
379
+ FlowUniPCMultistepScheduler,
380
+ ],
381
+ vae: AutoencoderKLWan,
382
+ text_encoder: T5Gemma2Model,
383
+ tokenizer: PreTrainedTokenizerBase,
384
+ transformer,
385
+ guider: Optional[
386
+ Union[ClassifierFreeGuidance, SkipLayerGuidance, AdaptiveProjectedGuidance, VideoAdaptiveProjectedGuidance]
387
+ ] = None,
388
+ feature_extractor: Optional[SiglipImageProcessor] = None,
389
+ ):
390
+ super().__init__()
391
+
392
+ self.guider = ClassifierFreeGuidance() if guider is None else guider
393
+
394
+ self.register_modules(
395
+ vae=vae,
396
+ text_encoder=text_encoder,
397
+ tokenizer=tokenizer,
398
+ transformer=transformer,
399
+ scheduler=scheduler,
400
+ feature_extractor=feature_extractor,
401
+ )
402
+
403
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
404
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
405
+
406
+ self.transformer_spatial_patch_size = (
407
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 2
408
+ )
409
+ self.transformer_temporal_patch_size = (
410
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
411
+ )
412
+
413
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
414
+ self.tokenizer_max_length = (
415
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 512
416
+ )
417
+
418
+ def _get_default_embeds(
419
+ self,
420
+ text_encoder,
421
+ tokenizer: PreTrainedTokenizerBase,
422
+ prompt: Union[str, List[str]],
423
+ max_sequence_length: int = 512,
424
+ device: Optional[torch.device] = None,
425
+ dtype: Optional[torch.dtype] = None,
426
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
427
+ dtype = dtype or text_encoder.dtype
428
+
429
+ text_inputs = tokenizer(
430
+ prompt,
431
+ padding="max_length",
432
+ max_length=max_sequence_length,
433
+ truncation=True,
434
+ add_special_tokens=True,
435
+ return_attention_mask=True,
436
+ return_tensors="pt",
437
+ )
438
+ text_inputs = BatchEncoding(
439
+ {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in text_inputs.items()}
440
+ )
441
+
442
+ prompt_embeds = text_encoder(**text_inputs)[0]
443
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
444
+
445
+ return prompt_embeds, text_inputs.attention_mask
446
+
447
+ def _average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
448
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
449
+ denom = attention_mask.sum(dim=1, keepdim=True).clamp(min=1) # avoid div by zero
450
+ return last_hidden.sum(dim=1) / denom
451
+
452
+ def _get_prompt_embeds(
453
+ self,
454
+ text_encoder: T5Gemma2Model,
455
+ tokenizer: PreTrainedTokenizerBase,
456
+ prompt: Union[str, List[str]] | None = None,
457
+ num_videos_per_prompt: int = 1,
458
+ max_sequence_length: int = 512,
459
+ device: Optional[torch.device] = None,
460
+ dtype: Optional[torch.dtype] = None,
461
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
462
+ device = device or self._execution_device
463
+
464
+ prompt = [prompt] if isinstance(prompt, str) else prompt
465
+
466
+ prompt_embeds_kwargs = {
467
+ "text_encoder": text_encoder,
468
+ "tokenizer": tokenizer,
469
+ "prompt": prompt,
470
+ "max_sequence_length": max_sequence_length,
471
+ "device": device,
472
+ "dtype": dtype,
473
+ }
474
+ # T5Gemma2Model bundles encoder and decoder/LM head, while _get_default_embeds expects an encoder-only model
475
+ # (similar to T5EncoderModel/T5GemmaEncoderModel), so we pass the encoder submodule explicitly here.
476
+ if isinstance(text_encoder, T5Gemma2Model):
477
+ prompt_embeds_kwargs["text_encoder"] = text_encoder.encoder
478
+ prompt_embeds, prompt_attention_mask = self._get_default_embeds(**prompt_embeds_kwargs)
479
+
480
+ pooled_prompt_embeds = self._average_pool(prompt_embeds, prompt_attention_mask)
481
+
482
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
483
+
484
+ # Keep encode_prompt structure, uses _get_prompt_embeds internally
485
+ def encode_prompt(
486
+ self,
487
+ prompt: Union[str, List[str]],
488
+ num_videos_per_prompt: int = 1,
489
+ prompt_embeds: Optional[torch.Tensor] = None,
490
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
491
+ prompt_attention_mask: Optional[torch.Tensor] = None,
492
+ max_sequence_length: int = 512,
493
+ device: Optional[torch.device] = None,
494
+ dtype: Optional[torch.dtype] = None,
495
+ ) -> Tuple[
496
+ torch.Tensor,
497
+ torch.Tensor,
498
+ torch.Tensor,
499
+ ]:
500
+ device = device or self._execution_device
501
+
502
+ prompt = [prompt] if isinstance(prompt, str) else prompt
503
+ if prompt is not None:
504
+ batch_size = len(prompt)
505
+ else:
506
+ batch_size = prompt_embeds.shape[0]
507
+
508
+ prompt_embeds_kwargs = {
509
+ "device": device,
510
+ "dtype": dtype,
511
+ }
512
+
513
+ if prompt_embeds is None:
514
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self._get_prompt_embeds(
515
+ text_encoder=self.text_encoder,
516
+ tokenizer=self.tokenizer,
517
+ prompt=prompt,
518
+ max_sequence_length=max_sequence_length,
519
+ **prompt_embeds_kwargs,
520
+ )
521
+
522
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
523
+ seq_len = prompt_embeds.shape[1]
524
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
525
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
526
+
527
+ if pooled_prompt_embeds is not None:
528
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0)
529
+
530
+ # Keep attention mask handling
531
+ prompt_attention_mask = prompt_attention_mask.bool()
532
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
533
+ prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0)
534
+
535
+ return (
536
+ prompt_embeds,
537
+ pooled_prompt_embeds,
538
+ prompt_attention_mask,
539
+ )
540
+
541
+ @property
542
+ def vision_encoder(self):
543
+ """Get the vision encoder from T5Gemma2.
544
+
545
+ T5Gemma2 has vision_tower.vision_model structure.
546
+ Will raise AttributeError if not available.
547
+ """
548
+ return self.text_encoder.encoder.vision_tower.vision_model
549
+
550
+ def encode_image(
551
+ self,
552
+ image: Image.Image,
553
+ batch_size: int = 1,
554
+ device: Optional[torch.device] = None,
555
+ dtype: Optional[torch.dtype] = None,
556
+ ) -> torch.Tensor:
557
+ """Encode image to embeddings using SigLIP vision encoder."""
558
+ device = device or self._execution_device
559
+ dtype = dtype or self.transformer.dtype
560
+
561
+ image_embeds = self._get_image_embeds(
562
+ image_encoder=self.vision_encoder,
563
+ feature_extractor=self.feature_extractor,
564
+ image=image,
565
+ device=device,
566
+ )
567
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
568
+ return image_embeds.to(device=device, dtype=dtype)
569
+
570
+ @staticmethod
571
+ def _get_image_embeds(
572
+ image_encoder,
573
+ feature_extractor: SiglipImageProcessor,
574
+ image,
575
+ device: torch.device,
576
+ ) -> torch.Tensor:
577
+ """Helper to encode single image with SigLIP.
578
+
579
+ Args:
580
+ image_encoder: The SigLIP vision encoder model.
581
+ feature_extractor: SiglipImageProcessor for preprocessing.
582
+ image: Can be either:
583
+ - PIL.Image.Image: Will be preprocessed by feature_extractor
584
+ - torch.Tensor: Assumed to be in [0, 1] range, will be normalized and passed to encoder
585
+ device: Device to place tensors on.
586
+
587
+ Returns:
588
+ Image embeddings from the vision encoder.
589
+ """
590
+ image_encoder_dtype = next(image_encoder.parameters()).dtype
591
+
592
+ if isinstance(image, torch.Tensor):
593
+ image = feature_extractor.preprocess(
594
+ images=image.float(),
595
+ do_resize=True,
596
+ do_rescale=False,
597
+ do_normalize=True,
598
+ do_convert_rgb=True,
599
+ return_tensors="pt",
600
+ )
601
+ else:
602
+ image = feature_extractor.preprocess(
603
+ images=image,
604
+ do_resize=True,
605
+ do_rescale=False,
606
+ do_normalize=True,
607
+ do_convert_rgb=True,
608
+ return_tensors="pt",
609
+ )
610
+
611
+ image = image.to(device, dtype=image_encoder_dtype)
612
+ return image_encoder(**image).last_hidden_state
613
+
614
+ @torch.compiler.disable
615
+ def _prepare_first_frame_conditioning(
616
+ self,
617
+ video: torch.Tensor,
618
+ latents: torch.Tensor,
619
+ use_conditioning: bool,
620
+ generator: Optional[torch.Generator] = None,
621
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
622
+ """Prepare first frame conditioning tensors.
623
+
624
+ This method implements batch-level conditioning where entire
625
+ batches are either I2V (all samples conditioned) or T2V (no conditioning). This
626
+ prevents mode confusion within batches.
627
+
628
+ For I2V mode:
629
+ 1. Extract and VAE-encode first frame from video
630
+ 2. Create latent_condition by repeating first frame across time (frame 0 only)
631
+ 3. Create latent_mask with 1.0 at frame 0
632
+ 4. Get image_embeds from vision encoder
633
+
634
+ For T2V mode:
635
+ 1. Pad with zeros for latent_condition and latent_mask
636
+
637
+ Args:
638
+ video: Input video tensor [batch_size, frames, channels, height, width] in [-1, 1]
639
+ latents: Latents [batch_size, lantent_channels, latent_num_frames, latent_height, latent_width]
640
+ use_conditioning: Whether to use first-frame conditioning (True for I2V, False for T2V)
641
+ generator: Optional random number generator for reproducibility
642
+
643
+ Returns:
644
+ Tuple of (latent_condition, latent_mask, image_embeds).
645
+ - latent_condition: [B, C, F, H, W] conditioning signal (zeros for T2V)
646
+ - latent_mask: [B, 1, F, H, W] binary mask (zeros for T2V)
647
+ - image_embeds: [B, N, D] image embeddings from vision encoder or None for T2V
648
+ """
649
+ batch_size, lantent_channels, latent_num_frames, latent_height, latent_width = latents.shape
650
+ device = latents.device
651
+ dtype = latents.dtype
652
+
653
+ # Determine if we should use conditioning
654
+ use_conditioning = use_conditioning and (latent_num_frames > 1)
655
+
656
+ # Initialize conditioning tensors
657
+ latent_condition = torch.zeros(
658
+ batch_size, lantent_channels, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype
659
+ )
660
+ latent_mask = torch.zeros(
661
+ batch_size, 1, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype
662
+ )
663
+ image_embeds = None
664
+
665
+ if use_conditioning:
666
+ with torch.no_grad():
667
+ # Encode first frame for latent_condition
668
+ first_frame_latents = self.vae.encode(
669
+ rearrange(video[:, 0:1], "b f c h w -> b c f h w")
670
+ ).latent_dist.sample(generator=generator)
671
+ first_frame_latents = self._normalize_latents(
672
+ latents=first_frame_latents,
673
+ latents_mean=self.vae.config.latents_mean,
674
+ latents_std=self.vae.config.latents_std,
675
+ )
676
+
677
+ # Create latent_condition by repeating first frame across time
678
+ latent_condition = first_frame_latents.repeat(1, 1, latent_num_frames, 1, 1)
679
+ latent_condition[:, :, 1:, :, :] = 0
680
+
681
+ # latent_mask: 1.0 at frame 0, 0.0 elsewhere
682
+ latent_mask[:, :, 0] = 1.0
683
+
684
+ # image_embeds from vision encoder
685
+ first_frame_vision = video[:, 0] # [B, C, H, W]
686
+ first_frame_vision = ((first_frame_vision + 1) / 2).clamp(0, 1)
687
+
688
+ with torch.no_grad():
689
+ image_embeds = self._get_image_embeds(
690
+ image_encoder=self.vision_encoder,
691
+ feature_extractor=self.feature_extractor,
692
+ image=first_frame_vision,
693
+ device=device,
694
+ )
695
+
696
+ return latent_condition, latent_mask, image_embeds
697
+
698
+ def check_inputs(
699
+ self,
700
+ prompt,
701
+ negative_prompt,
702
+ height,
703
+ width,
704
+ batch_size,
705
+ callback_on_step_end_tensor_inputs=None,
706
+ prompt_embeds=None,
707
+ negative_prompt_embeds=None,
708
+ prompt_attention_mask=None,
709
+ negative_prompt_attention_mask=None,
710
+ ):
711
+ # Resolution must be divisible by VAE scale factor * transformer patch size
712
+ # (e.g. 8 * 2 = 16 for default config) to avoid latent/patch dimension mismatch.
713
+ spatial_divisor = self.vae_scale_factor_spatial * self.transformer_spatial_patch_size
714
+ if height % spatial_divisor != 0 or width % spatial_divisor != 0:
715
+ raise ValueError(
716
+ f"`height` and `width` have to be divisible by {spatial_divisor} "
717
+ f"(vae_scale={self.vae_scale_factor_spatial} * patch_size={self.transformer_spatial_patch_size}) "
718
+ f"but are {height} and {width}."
719
+ )
720
+
721
+ if callback_on_step_end_tensor_inputs is not None and not all(
722
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
723
+ ):
724
+ raise ValueError(
725
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
726
+ )
727
+
728
+ if prompt is not None and prompt_embeds is not None:
729
+ raise ValueError(
730
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
731
+ " only forward one of the two."
732
+ )
733
+ elif prompt is None and prompt_embeds is None:
734
+ raise ValueError(
735
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
736
+ )
737
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
738
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
739
+
740
+ # Validate negative_prompt: must be None, str, or list with matching batch_size
741
+ if negative_prompt is not None:
742
+ if not isinstance(negative_prompt, (str, list)):
743
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
744
+ if isinstance(negative_prompt, list) and len(negative_prompt) != batch_size:
745
+ raise ValueError(
746
+ f"`negative_prompt` list length ({len(negative_prompt)}) must match batch_size ({batch_size})."
747
+ )
748
+
749
+ if prompt_embeds is not None and prompt_attention_mask is None:
750
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
751
+
752
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
753
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
754
+
755
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
756
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
757
+ raise ValueError(
758
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
759
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
760
+ f" {negative_prompt_embeds.shape}."
761
+ )
762
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
763
+ raise ValueError(
764
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
765
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
766
+ f" {negative_prompt_attention_mask.shape}."
767
+ )
768
+
769
+ def _prepare_negative_prompt(
770
+ self,
771
+ negative_prompt: Optional[Union[str, List[str]]],
772
+ batch_size: int,
773
+ ) -> List[str]:
774
+ """
775
+ Prepare negative_prompt to match batch_size.
776
+
777
+ Args:
778
+ negative_prompt: None, a single string, or a list of strings matching batch_size.
779
+ batch_size: The number of prompts in the batch.
780
+
781
+ Returns:
782
+ A list of negative prompts with length equal to batch_size.
783
+ """
784
+ if negative_prompt is None:
785
+ return [""] * batch_size
786
+ if isinstance(negative_prompt, str):
787
+ return [negative_prompt] * batch_size
788
+ return negative_prompt
789
+
790
+ @staticmethod
791
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
792
+ batch_size, num_channels, num_frames, height, width = latents.shape
793
+ post_patch_num_frames = num_frames // patch_size_t
794
+ post_patch_height = height // patch_size
795
+ post_patch_width = width // patch_size
796
+ latents = latents.reshape(
797
+ batch_size,
798
+ -1,
799
+ post_patch_num_frames,
800
+ patch_size_t,
801
+ post_patch_height,
802
+ patch_size,
803
+ post_patch_width,
804
+ patch_size,
805
+ )
806
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
807
+ return latents
808
+
809
+ @staticmethod
810
+ def _unpack_latents(
811
+ latents: torch.Tensor,
812
+ num_frames: int,
813
+ height: int,
814
+ width: int,
815
+ patch_size: int = 1,
816
+ patch_size_t: int = 1,
817
+ ) -> torch.Tensor:
818
+ batch_size = latents.size(0)
819
+ latents = latents.reshape(
820
+ batch_size,
821
+ num_frames,
822
+ height,
823
+ width,
824
+ -1,
825
+ patch_size_t,
826
+ patch_size,
827
+ patch_size,
828
+ )
829
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
830
+ return latents
831
+
832
+ @staticmethod
833
+ def _normalize_latents(
834
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
835
+ ) -> torch.Tensor:
836
+ # Normalize latents across the channel dimension [B, C, F, H, W]
837
+ latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
838
+ latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
839
+ latents = (latents - latents_mean) / latents_std
840
+ return latents
841
+
842
+ @staticmethod
843
+ def _denormalize_latents(
844
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
845
+ ) -> torch.Tensor:
846
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
847
+ latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
848
+ latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
849
+ latents = latents * latents_std + latents_mean
850
+ return latents
851
+
852
+ def prepare_latents(
853
+ self,
854
+ batch_size: int = 1,
855
+ num_channels_latents: int = 16,
856
+ height: int = 352,
857
+ width: int = 640,
858
+ num_frames: int = 65,
859
+ dtype: Optional[torch.dtype] = None,
860
+ device: Optional[torch.device] = None,
861
+ generator: Optional[torch.Generator] = None,
862
+ latents: Optional[torch.Tensor] = None,
863
+ ) -> torch.Tensor:
864
+ if latents is not None:
865
+ return latents.to(device=device, dtype=dtype)
866
+
867
+ shape = (
868
+ batch_size,
869
+ num_channels_latents,
870
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
871
+ height // self.vae_scale_factor_spatial,
872
+ width // self.vae_scale_factor_spatial,
873
+ )
874
+
875
+ if isinstance(generator, list) and len(generator) != batch_size:
876
+ raise ValueError(
877
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
878
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
879
+ )
880
+
881
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
882
+ return latents
883
+
884
+ @property
885
+ def num_timesteps(self):
886
+ return self._num_timesteps
887
+
888
+ @property
889
+ def current_timestep(self):
890
+ return self._current_timestep
891
+
892
+ @property
893
+ def attention_kwargs(self):
894
+ return self._attention_kwargs
895
+
896
+ @property
897
+ def interrupt(self):
898
+ return self._interrupt
899
+
900
+ @torch.no_grad()
901
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
902
+ def __call__(
903
+ self,
904
+ prompt: Union[str, List[str]] | None = None,
905
+ image=None,
906
+ negative_prompt: Optional[Union[str, List[str]]] = "text overlay, graphic overlay, watermark, logo, subtitles, timestamp, broadcast graphics, UI elements, random letters, frozen pose, rigid, static expression, jerky motion, mechanical motion, discontinuous motion, flat framing, depthless, dull lighting, monotone, crushed shadows, blown-out highlights, shifting background, fading background, poor continuity, identity drift, deformation, flickering, ghosting, smearing, duplication, mutated proportions, inconsistent clothing, flat colors, desaturated, tonally compressed, poor background separation, exposure shift, uneven brightness, color balance shift",
907
+ height: int = 736,
908
+ width: int = 1280,
909
+ num_frames: int = 121,
910
+ frame_rate: int = 24,
911
+ num_inference_steps: int = 50,
912
+ timesteps: List[int] | None = None,
913
+ use_linear_quadratic_schedule: bool = False,
914
+ linear_quadratic_emulating_steps: int = 250,
915
+ num_videos_per_prompt: Optional[int] = 1,
916
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
917
+ latents: Optional[torch.Tensor] = None,
918
+ prompt_embeds: Optional[torch.Tensor] = None,
919
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
920
+ prompt_attention_mask: Optional[torch.Tensor] = None,
921
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
922
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
923
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
924
+ output_type: Optional[str] = "pil",
925
+ return_dict: bool = True,
926
+ attention_kwargs: Optional[Dict[str, Any]] = None,
927
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
928
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
929
+ max_sequence_length: int = 512,
930
+ use_attention_mask: bool = True,
931
+ vae_batch_size: int | None = None,
932
+ ):
933
+ r"""
934
+ Function invoked when calling the pipeline for generation.
935
+
936
+ Args:
937
+ prompt (`str` or `List[str]`, *optional*):
938
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
939
+ instead.
940
+ negative_prompt (`str` or `List[str]`, *optional*):
941
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
942
+ `negative_prompt_embeds` instead. Ignored when not using guidance.
943
+ height (`int`, defaults to `352`):
944
+ The height in pixels of the generated image.
945
+ width (`int`, defaults to `640`):
946
+ The width in pixels of the generated image.
947
+ num_frames (`int`, defaults to `65`):
948
+ The number of video frames to generate
949
+ frame_rate (`int`, defaults to `25`):
950
+ Frame rate for the output video.
951
+ num_inference_steps (`int`, *optional*, defaults to 50):
952
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
953
+ expense of slower inference.
954
+ timesteps (`List[int]`, *optional*):
955
+ Custom timesteps to use for the denoising process.
956
+ use_linear_quadratic_schedule (`bool`, defaults to `True`):
957
+ Whether to use a linear-quadratic sigma schedule instead of the default linear schedule.
958
+ This schedule combines linear interpolation in the first half (slow denoising at high noise)
959
+ with quadratic interpolation in the second half (faster denoising toward clean image).
960
+ Requires `num_inference_steps` to be even.
961
+ linear_quadratic_emulating_steps (`int`, defaults to `250`):
962
+ Controls the slope of linear interpolation in the first half of the linear-quadratic schedule.
963
+ Higher values result in a gentler slope. Only used when `use_linear_quadratic_schedule=True`.
964
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
965
+ The number of videos to generate per prompt.
966
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
967
+ PyTorch Generator object(s) for deterministic generation.
968
+ latents (`torch.Tensor`, *optional*):
969
+ Pre-generated noisy latents.
970
+ prompt_embeds (`torch.Tensor`, *optional*):
971
+ Pre-generated text embeddings.
972
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
973
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
974
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
975
+ prompt_attention_mask (`torch.Tensor`, *optional*):
976
+ Pre-generated attention mask for text embeddings.
977
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
978
+ Pre-generated negative text embeddings.
979
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
980
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
981
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
982
+ input argument.
983
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
984
+ Pre-generated attention mask for negative text embeddings.
985
+ output_type (`str`, *optional*, defaults to `"pil"`):
986
+ The output format ("pil" or "np").
987
+ return_dict (`bool`, *optional*, defaults to `True`):
988
+ Whether to return a `MotifVideoPipelineOutput`.
989
+ attention_kwargs (`dict`, *optional*):
990
+ Arguments passed to the attention processor.
991
+ callback_on_step_end (`Callable`, *optional*):
992
+ Callback function called at the end of each step.
993
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
994
+ Tensors to include in the callback.
995
+ max_sequence_length (`int` defaults to `512`):
996
+ Maximum sequence length for the tokenizer.
997
+
998
+ Examples:
999
+
1000
+ Returns:
1001
+ [`~pipelines.motif_video.MotifVideoPipelineOutput`] or `tuple`:
1002
+ If `return_dict` is `True`, returns [`~pipelines.motif_video.MotifVideoPipelineOutput`],
1003
+ otherwise returns a tuple where the first element is a list of generated video frames.
1004
+ """
1005
+
1006
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1007
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1008
+
1009
+ # 1. Define call parameters (batch_size needed for check_inputs)
1010
+ if prompt is not None and isinstance(prompt, str):
1011
+ batch_size = 1
1012
+ elif prompt is not None and isinstance(prompt, list):
1013
+ batch_size = len(prompt)
1014
+ else:
1015
+ batch_size = prompt_embeds.shape[0]
1016
+
1017
+ # 2. Check inputs. Raise error if not correct
1018
+ self.check_inputs(
1019
+ prompt=prompt,
1020
+ negative_prompt=negative_prompt,
1021
+ height=height,
1022
+ width=width,
1023
+ batch_size=batch_size,
1024
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1025
+ prompt_embeds=prompt_embeds,
1026
+ negative_prompt_embeds=negative_prompt_embeds,
1027
+ prompt_attention_mask=prompt_attention_mask,
1028
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
1029
+ )
1030
+
1031
+ self._attention_kwargs = attention_kwargs
1032
+ self._interrupt = False
1033
+ self._current_timestep = None
1034
+
1035
+ # Auto-upgrade AdaptiveProjectedGuidance to VideoAdaptiveProjectedGuidance
1036
+ # for video generation. Video-aware APG normalizes per-frame [C,H,W] instead
1037
+ # of collapsing the temporal axis, preserving motion quality.
1038
+ if type(self.guider) is AdaptiveProjectedGuidance:
1039
+ self.guider = VideoAdaptiveProjectedGuidance(
1040
+ guidance_scale=self.guider.guidance_scale,
1041
+ adaptive_projected_guidance_rescale=self.guider.adaptive_projected_guidance_rescale,
1042
+ adaptive_projected_guidance_momentum=self.guider.adaptive_projected_guidance_momentum,
1043
+ eta=self.guider.eta,
1044
+ use_original_formulation=self.guider.use_original_formulation,
1045
+ )
1046
+
1047
+ device = self._execution_device
1048
+
1049
+ # 3. Prepare text embeddings
1050
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
1051
+ prompt=prompt,
1052
+ num_videos_per_prompt=num_videos_per_prompt,
1053
+ prompt_embeds=prompt_embeds,
1054
+ pooled_prompt_embeds=pooled_prompt_embeds,
1055
+ prompt_attention_mask=prompt_attention_mask,
1056
+ max_sequence_length=max_sequence_length,
1057
+ device=device,
1058
+ )
1059
+
1060
+ if self.guider._enabled:
1061
+ negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size)
1062
+ negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
1063
+ prompt=negative_prompt,
1064
+ num_videos_per_prompt=num_videos_per_prompt,
1065
+ prompt_embeds=negative_prompt_embeds,
1066
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
1067
+ prompt_attention_mask=negative_prompt_attention_mask,
1068
+ max_sequence_length=max_sequence_length,
1069
+ device=device,
1070
+ )
1071
+
1072
+ num_channels_latents = self.vae.config.z_dim
1073
+ latents = self.prepare_latents(
1074
+ batch_size * num_videos_per_prompt,
1075
+ num_channels_latents,
1076
+ height,
1077
+ width,
1078
+ num_frames,
1079
+ self.transformer.dtype,
1080
+ device,
1081
+ generator,
1082
+ latents,
1083
+ )
1084
+
1085
+ # 4.5 Preprocess image for I2V conditioning
1086
+ if image is not None:
1087
+ from PIL import Image as PILImage
1088
+
1089
+ if isinstance(image, PILImage.Image):
1090
+ image = image.convert("RGB").resize((width, height), PILImage.LANCZOS)
1091
+ image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
1092
+ image = image * 2.0 - 1.0 # [0,1] -> [-1,1]
1093
+ image = image.unsqueeze(0) # [1, C, H, W]
1094
+ # Handle [C, H, W] -> [1, C, H, W]
1095
+ if image.dim() == 3:
1096
+ image = image.unsqueeze(0)
1097
+ # [B, C, H, W] -> [B, 1, C, H, W] for video format
1098
+ if image.dim() == 4:
1099
+ image = image.unsqueeze(1)
1100
+ image = image.to(device=device, dtype=self.vae.dtype)
1101
+
1102
+ # 5. Prepare timesteps (including mu calculation)
1103
+
1104
+ # Recalculate latent dims based on VAE for mu calculation
1105
+ latent_height = height // self.vae_scale_factor_spatial
1106
+ latent_width = width // self.vae_scale_factor_spatial
1107
+ latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
1108
+
1109
+ # Calculate sequence length based on *packed* dimensions if transformer uses packing
1110
+ # Packed dims: H/patch, W/patch, F/patch_t
1111
+ packed_latent_height = latent_height // self.transformer_spatial_patch_size
1112
+ packed_latent_width = latent_width // self.transformer_spatial_patch_size
1113
+ packed_latent_num_frames = latent_num_frames // self.transformer_temporal_patch_size
1114
+ video_sequence_length = packed_latent_num_frames * packed_latent_height * packed_latent_width
1115
+
1116
+ # Compute sigmas: use linear-quadratic schedule if enabled, otherwise default linear
1117
+ _is_flow_multistep = isinstance(
1118
+ self.scheduler,
1119
+ (DPMSolverMultistepScheduler, UniPCMultistepScheduler, FlowUniPCMultistepScheduler),
1120
+ )
1121
+
1122
+ # Compute mu once, shared by both branches (required by FlowUniPCMultistepScheduler)
1123
+ mu = calculate_shift(
1124
+ video_sequence_length,
1125
+ self.scheduler.config.get("base_image_seq_len", 256),
1126
+ self.scheduler.config.get("max_image_seq_len", 4096),
1127
+ self.scheduler.config.get("base_shift", 0.5),
1128
+ self.scheduler.config.get("max_shift", 1.15),
1129
+ )
1130
+
1131
+ if _is_flow_multistep:
1132
+ # DPMSolver/UniPC manage their own sigma schedule via use_flow_sigmas + flow_shift.
1133
+ # Pass mu for dynamic shifting support (required by FlowUniPCMultistepScheduler).
1134
+ timesteps, num_inference_steps = retrieve_timesteps(
1135
+ self.scheduler,
1136
+ num_inference_steps,
1137
+ device,
1138
+ timesteps,
1139
+ mu=mu,
1140
+ )
1141
+ else:
1142
+ if use_linear_quadratic_schedule:
1143
+ # Linear-quadratic schedule computes sigmas internally in retrieve_timesteps
1144
+ sigmas = None
1145
+ else:
1146
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1147
+
1148
+ timesteps, num_inference_steps = retrieve_timesteps(
1149
+ self.scheduler,
1150
+ num_inference_steps,
1151
+ device,
1152
+ timesteps,
1153
+ sigmas=sigmas,
1154
+ use_linear_quadratic_schedule=use_linear_quadratic_schedule,
1155
+ linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
1156
+ mu=mu,
1157
+ )
1158
+
1159
+ # Get conditioning tensors
1160
+ latent_condition, latent_mask, image_embeds = self._prepare_first_frame_conditioning(
1161
+ image,
1162
+ latents,
1163
+ use_conditioning=image is not None,
1164
+ generator=generator,
1165
+ )
1166
+
1167
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1168
+ self._num_timesteps = len(timesteps)
1169
+
1170
+ # 6. Denoising loop
1171
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1172
+ for i, t in enumerate(timesteps):
1173
+ if self.interrupt:
1174
+ continue
1175
+
1176
+ self._current_timestep = t
1177
+
1178
+ # Concatenate current latents with conditioning for this timestep
1179
+ # [latents | latent_condition | latent_mask]
1180
+ hidden_states = torch.cat([latents, latent_condition, latent_mask], dim=1)
1181
+
1182
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1183
+ timestep = t.expand(latents.shape[0])
1184
+
1185
+ # Step 1: Collect model inputs needed for the guidance method
1186
+ # conditional inputs should always be first element in the tuple
1187
+ guider_inputs = {
1188
+ "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
1189
+ }
1190
+ if use_attention_mask:
1191
+ guider_inputs["encoder_attention_mask"] = (prompt_attention_mask, negative_prompt_attention_mask)
1192
+ if self.transformer.config.pooled_projection_dim is not None:
1193
+ guider_inputs["pooled_projections"] = (pooled_prompt_embeds, negative_pooled_prompt_embeds)
1194
+ if image_embeds is not None:
1195
+ guider_inputs["image_embeds"] = (image_embeds, image_embeds)
1196
+
1197
+ # Step 2: Update guider's internal state for this denoising step
1198
+ self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
1199
+ # Sigma injection for guiders that support sigma-based gating
1200
+ # (Kynkäänniemi 2024). Must precede `prepare_inputs` because
1201
+ # `num_conditions` → `_is_cfg_enabled()` reads `_current_sigma`.
1202
+ # Duck-typed so diffusers-native guiders are unaffected; guard
1203
+ # on scheduler too since some schedulers don't expose `sigmas`.
1204
+ if hasattr(self.guider, "_current_sigma") and hasattr(self.scheduler, "sigmas"):
1205
+ self.guider._current_sigma = float(self.scheduler.sigmas[i])
1206
+
1207
+ # Step 3: Prepare batched model inputs based on the guidance method
1208
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
1209
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
1210
+ # you will get a guider_state with two batches:
1211
+ # guider_state = [
1212
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
1213
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
1214
+ # ]
1215
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
1216
+ guider_state = self.guider.prepare_inputs(guider_inputs)
1217
+
1218
+ # Step 4: Run the denoiser for each batch
1219
+ # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
1220
+ # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
1221
+ for guider_state_batch in guider_state:
1222
+ self.guider.prepare_models(self.transformer)
1223
+
1224
+ # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
1225
+ cond_kwargs = {
1226
+ input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
1227
+ }
1228
+
1229
+ tread_disabled = getattr(self.guider, "_current_tread_disabled", False)
1230
+
1231
+ # Override TREAD selection ratio per batch if the guider provides one
1232
+ selection_ratio = getattr(self.guider, "_current_selection_ratio", None)
1233
+ tread_mixin = getattr(self.transformer, "_inference_tread_mixin", None)
1234
+ if (
1235
+ selection_ratio is not None
1236
+ and tread_mixin is not None
1237
+ and tread_mixin._tread_route is not None
1238
+ ):
1239
+ tread_mixin._tread_route["sel"] = selection_ratio
1240
+
1241
+ # e.g. "pred_cond"/"pred_uncond"
1242
+ context_name = getattr(guider_state_batch, self.guider._identifier_key)
1243
+ with self.transformer.cache_context(context_name):
1244
+ # Run denoiser and store noise prediction in this batch
1245
+
1246
+ noise_pred = self.transformer(
1247
+ hidden_states=hidden_states,
1248
+ timestep=timestep,
1249
+ attention_kwargs=self.attention_kwargs,
1250
+ return_dict=False,
1251
+ tread_disabled=tread_disabled,
1252
+ **cond_kwargs,
1253
+ )[0].clone()
1254
+
1255
+ guider_state_batch.noise_pred = noise_pred
1256
+ # Cleanup model (e.g., remove hooks)
1257
+ self.guider.cleanup_models(self.transformer)
1258
+
1259
+ # Step 5: Combine predictions using the guidance method
1260
+ # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
1261
+ # Continuing the CFG example, the guider receives:
1262
+ # guider_state = [
1263
+ # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
1264
+ # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
1265
+ # ]
1266
+ # And extracts predictions using the __guidance_identifier__:
1267
+ # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
1268
+ # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
1269
+ # Then applies CFG formula:
1270
+ # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
1271
+ # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
1272
+ noise_pred = self.guider(guider_state)[0]
1273
+
1274
+ # compute the previous noisy sample x_t -> x_t-1
1275
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1276
+
1277
+ if callback_on_step_end is not None:
1278
+ callback_kwargs = {}
1279
+ for k in callback_on_step_end_tensor_inputs:
1280
+ callback_kwargs[k] = locals()[k]
1281
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1282
+
1283
+ latents = callback_outputs.pop("latents", latents)
1284
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1285
+ # Handle negative embeds if needed by callback
1286
+ if "negative_prompt_embeds" in callback_outputs:
1287
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds")
1288
+
1289
+ # call the callback, if provided
1290
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1291
+ progress_bar.update()
1292
+
1293
+ if XLA_AVAILABLE:
1294
+ xm.mark_step()
1295
+
1296
+ self._current_timestep = None
1297
+
1298
+ if output_type == "latent":
1299
+ video = latents
1300
+ else:
1301
+ latents = latents.to(self.vae.dtype)
1302
+ latents = self._denormalize_latents(latents, self.vae.config.latents_mean, self.vae.config.latents_std)
1303
+ if vae_batch_size is not None and latents.shape[0] > vae_batch_size:
1304
+ video_chunks = []
1305
+ for i in range(0, latents.shape[0], vae_batch_size):
1306
+ chunk = latents[i : i + vae_batch_size]
1307
+ video_chunks.append(self.vae.decode(chunk, return_dict=False)[0])
1308
+ video = torch.cat(video_chunks, dim=0)
1309
+ del video_chunks
1310
+ else:
1311
+ video = self.vae.decode(latents, return_dict=False)[0]
1312
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
1313
+
1314
+ # Offload all models
1315
+ self.maybe_free_model_hooks()
1316
+
1317
+ if not return_dict:
1318
+ return (video,)
1319
+
1320
+ # Return updated output type
1321
+ return MotifVideoPipelineOutput(frames=video)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "base_image_seq_len": 256,
5
+ "base_shift": 0.5,
6
+ "invert_sigmas": false,
7
+ "max_image_seq_len": 4096,
8
+ "max_shift": 1.15,
9
+ "num_train_timesteps": 1000,
10
+ "shift": 15.0,
11
+ "shift_terminal": null,
12
+ "stochastic_sampling": false,
13
+ "time_shift_type": "exponential",
14
+ "use_beta_sigmas": false,
15
+ "use_dynamic_shifting": false,
16
+ "use_exponential_sigmas": false,
17
+ "use_karras_sigmas": false
18
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5Gemma2Model"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 2,
7
+ "classifier_dropout_rate": 0.0,
8
+ "decoder": {
9
+ "_sliding_window_pattern": 6,
10
+ "attention_bias": false,
11
+ "attention_dropout": 0.0,
12
+ "attn_logit_softcapping": null,
13
+ "dropout_rate": 0.0,
14
+ "dtype": "bfloat16",
15
+ "final_logit_softcapping": null,
16
+ "head_dim": 256,
17
+ "hidden_activation": "gelu_pytorch_tanh",
18
+ "hidden_size": 2560,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 10240,
21
+ "layer_types": [
22
+ "sliding_attention",
23
+ "sliding_attention",
24
+ "sliding_attention",
25
+ "sliding_attention",
26
+ "sliding_attention",
27
+ "full_attention",
28
+ "sliding_attention",
29
+ "sliding_attention",
30
+ "sliding_attention",
31
+ "sliding_attention",
32
+ "sliding_attention",
33
+ "full_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "sliding_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "full_attention",
40
+ "sliding_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "sliding_attention",
44
+ "sliding_attention",
45
+ "full_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "sliding_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "full_attention",
52
+ "sliding_attention",
53
+ "sliding_attention",
54
+ "sliding_attention",
55
+ "sliding_attention"
56
+ ],
57
+ "max_position_embeddings": 131072,
58
+ "model_type": "t5gemma2_decoder",
59
+ "num_attention_heads": 8,
60
+ "num_hidden_layers": 34,
61
+ "num_key_value_heads": 4,
62
+ "query_pre_attn_scalar": 256,
63
+ "rms_norm_eps": 1e-06,
64
+ "rope_parameters": {
65
+ "full_attention": {
66
+ "factor": 8.0,
67
+ "rope_theta": 1000000,
68
+ "rope_type": "linear"
69
+ },
70
+ "sliding_attention": {
71
+ "rope_theta": 10000,
72
+ "rope_type": "default"
73
+ }
74
+ },
75
+ "sliding_window": 1024,
76
+ "use_bidirectional_attention": false,
77
+ "use_cache": true,
78
+ "vocab_size": 262144
79
+ },
80
+ "dropout_rate": 0.0,
81
+ "dtype": "bfloat16",
82
+ "encoder": {
83
+ "attention_dropout": 0.0,
84
+ "boi_token_index": 255999,
85
+ "dropout_rate": 0.0,
86
+ "dtype": "bfloat16",
87
+ "eoi_token_index": 256000,
88
+ "image_token_index": 256001,
89
+ "initializer_range": 0.02,
90
+ "mm_tokens_per_image": 256,
91
+ "model_type": "t5gemma2_encoder",
92
+ "text_config": {
93
+ "_name_or_path": "",
94
+ "_sliding_window_pattern": 6,
95
+ "add_cross_attention": false,
96
+ "architectures": null,
97
+ "attention_bias": false,
98
+ "attention_dropout": 0.0,
99
+ "attn_logit_softcapping": null,
100
+ "bos_token_id": 2,
101
+ "chunk_size_feed_forward": 0,
102
+ "cross_attention_hidden_size": null,
103
+ "decoder_start_token_id": null,
104
+ "dropout_rate": 0.0,
105
+ "dtype": "bfloat16",
106
+ "eos_token_id": 1,
107
+ "final_logit_softcapping": null,
108
+ "finetuning_task": null,
109
+ "head_dim": 256,
110
+ "hidden_activation": "gelu_pytorch_tanh",
111
+ "hidden_size": 2560,
112
+ "id2label": {
113
+ "0": "LABEL_0",
114
+ "1": "LABEL_1"
115
+ },
116
+ "initializer_range": 0.02,
117
+ "intermediate_size": 10240,
118
+ "is_decoder": false,
119
+ "is_encoder_decoder": false,
120
+ "label2id": {
121
+ "LABEL_0": 0,
122
+ "LABEL_1": 1
123
+ },
124
+ "layer_types": [
125
+ "sliding_attention",
126
+ "sliding_attention",
127
+ "sliding_attention",
128
+ "sliding_attention",
129
+ "sliding_attention",
130
+ "full_attention",
131
+ "sliding_attention",
132
+ "sliding_attention",
133
+ "sliding_attention",
134
+ "sliding_attention",
135
+ "sliding_attention",
136
+ "full_attention",
137
+ "sliding_attention",
138
+ "sliding_attention",
139
+ "sliding_attention",
140
+ "sliding_attention",
141
+ "sliding_attention",
142
+ "full_attention",
143
+ "sliding_attention",
144
+ "sliding_attention",
145
+ "sliding_attention",
146
+ "sliding_attention",
147
+ "sliding_attention",
148
+ "full_attention",
149
+ "sliding_attention",
150
+ "sliding_attention",
151
+ "sliding_attention",
152
+ "sliding_attention",
153
+ "sliding_attention",
154
+ "full_attention",
155
+ "sliding_attention",
156
+ "sliding_attention",
157
+ "sliding_attention",
158
+ "sliding_attention"
159
+ ],
160
+ "max_position_embeddings": 131072,
161
+ "model_type": "t5gemma2_text",
162
+ "num_attention_heads": 8,
163
+ "num_hidden_layers": 34,
164
+ "num_key_value_heads": 4,
165
+ "output_attentions": false,
166
+ "output_hidden_states": false,
167
+ "pad_token_id": 0,
168
+ "prefix": null,
169
+ "problem_type": null,
170
+ "query_pre_attn_scalar": 256,
171
+ "return_dict": true,
172
+ "rms_norm_eps": 1e-06,
173
+ "rope_parameters": {
174
+ "full_attention": {
175
+ "factor": 8.0,
176
+ "rope_theta": 1000000,
177
+ "rope_type": "linear"
178
+ },
179
+ "sliding_attention": {
180
+ "rope_theta": 10000,
181
+ "rope_type": "default"
182
+ }
183
+ },
184
+ "sep_token_id": null,
185
+ "sliding_window": 1024,
186
+ "task_specific_params": null,
187
+ "tie_encoder_decoder": false,
188
+ "tie_word_embeddings": true,
189
+ "tokenizer_class": null,
190
+ "use_bidirectional_attention": false,
191
+ "use_cache": true,
192
+ "vocab_size": 262144
193
+ },
194
+ "vision_config": {
195
+ "_name_or_path": "",
196
+ "add_cross_attention": false,
197
+ "architectures": null,
198
+ "attention_dropout": 0.0,
199
+ "bos_token_id": null,
200
+ "chunk_size_feed_forward": 0,
201
+ "cross_attention_hidden_size": null,
202
+ "decoder_start_token_id": null,
203
+ "dropout_rate": 0.0,
204
+ "dtype": "bfloat16",
205
+ "eos_token_id": null,
206
+ "finetuning_task": null,
207
+ "hidden_act": "gelu_pytorch_tanh",
208
+ "hidden_size": 1152,
209
+ "id2label": {
210
+ "0": "LABEL_0",
211
+ "1": "LABEL_1"
212
+ },
213
+ "image_size": 896,
214
+ "intermediate_size": 4304,
215
+ "is_decoder": false,
216
+ "is_encoder_decoder": false,
217
+ "label2id": {
218
+ "LABEL_0": 0,
219
+ "LABEL_1": 1
220
+ },
221
+ "layer_norm_eps": 1e-06,
222
+ "model_type": "siglip_vision_model",
223
+ "num_attention_heads": 16,
224
+ "num_channels": 3,
225
+ "num_hidden_layers": 27,
226
+ "output_attentions": false,
227
+ "output_hidden_states": false,
228
+ "pad_token_id": null,
229
+ "patch_size": 14,
230
+ "prefix": null,
231
+ "problem_type": null,
232
+ "return_dict": true,
233
+ "sep_token_id": null,
234
+ "task_specific_params": null,
235
+ "tie_encoder_decoder": false,
236
+ "tie_word_embeddings": true,
237
+ "tokenizer_class": null,
238
+ "vision_use_head": false,
239
+ "vocab_size": 262144
240
+ },
241
+ "vocab_size": 262144
242
+ },
243
+ "eoi_token_index": 256000,
244
+ "eos_token_id": 1,
245
+ "image_token_index": 256001,
246
+ "initializer_range": 0.02,
247
+ "is_encoder_decoder": true,
248
+ "model_type": "t5gemma2",
249
+ "pad_token_id": 0,
250
+ "transformers_version": "5.0.0rc1",
251
+ "vocab_size": 262144
252
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c7dd568c34c56a521475124f226983dc191e57aa9b1cac9a22a87dcc753cb57
3
+ size 16360212008
tokenizer/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3220c5bec16e78ddf8e59c08fecdede7e8d31820cb5b3e69f17fed6a29a0b30c
3
+ size 33378248
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": null,
3
+ "backend": "tokenizers",
4
+ "boi_token": "<start_of_image>",
5
+ "bos_token": "<bos>",
6
+ "clean_up_tokenization_spaces": false,
7
+ "eoi_token": "<end_of_image>",
8
+ "eos_token": "<eos>",
9
+ "image_token": "<image_soft_token>",
10
+ "is_local": false,
11
+ "mask_token": "<mask>",
12
+ "model_max_length": 1000000000000000019884624838656,
13
+ "model_specific_special_tokens": {
14
+ "boi_token": "<start_of_image>",
15
+ "eoi_token": "<end_of_image>",
16
+ "image_token": "<image_soft_token>"
17
+ },
18
+ "pad_token": "<pad>",
19
+ "padding_side": "right",
20
+ "processor_class": "Gemma3Processor",
21
+ "sp_model_kwargs": null,
22
+ "spaces_between_special_tokens": false,
23
+ "tokenizer_class": "GemmaTokenizer",
24
+ "unk_token": "<unk>",
25
+ "use_default_system_prompt": false
26
+ }
transformer/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MotifVideoTransformer3DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "_library": "diffusers",
5
+ "attention_head_dim": 128,
6
+ "base_latent_size": null,
7
+ "image_condition_type": null,
8
+ "image_embed_dim": 1152,
9
+ "in_channels": 33,
10
+ "mlp_ratio": 4.0,
11
+ "norm_type": "layer_norm",
12
+ "num_attention_heads": 12,
13
+ "num_decoder_layers": 8,
14
+ "num_layers": 12,
15
+ "num_single_layers": 24,
16
+ "out_channels": 16,
17
+ "patch_size": 2,
18
+ "patch_size_t": 1,
19
+ "pooled_projection_dim": null,
20
+ "qk_norm": "rms_norm",
21
+ "rope_axes_dim": [
22
+ 16,
23
+ 56,
24
+ 56
25
+ ],
26
+ "rope_theta": 10000.0,
27
+ "text_embed_dim": 2560,
28
+ "enable_text_cross_attention_dual": false,
29
+ "enable_text_cross_attention_single": true
30
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a8b17a188d358d9d0e9b097f5fc58c094f37a82a00cbb8895e54c2bdd73f6ff
3
+ size 7849331048
transformer/transformer_motif_video.py ADDED
@@ -0,0 +1,1350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Motif Technologies. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from functools import lru_cache
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
25
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
26
+ from diffusers.models.attention import FeedForward
27
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
28
+ from diffusers.models.cache_utils import CacheMixin
29
+ from diffusers.models.embeddings import (
30
+ PixArtAlphaTextProjection,
31
+ TimestepEmbedding,
32
+ Timesteps,
33
+ )
34
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
35
+ from diffusers.models.modeling_utils import ModelMixin
36
+ from diffusers.models.normalization import (
37
+ AdaLayerNormContinuous,
38
+ AdaLayerNormZero,
39
+ AdaLayerNormZeroSingle,
40
+ )
41
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
42
+
43
+ # Stub functions for TREAD (Token REduction with Approximated Distillation).
44
+ # These stubs ensure TREAD code paths are never activated during inference
45
+ # without requiring the motif_core package.
46
+ def is_tread_start(block_idx, start, end): return False
47
+ def is_tread_end(block_idx, start, end): return False
48
+
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+ NUM_TRAIN_TIMESTEPS = 1000
53
+
54
+
55
+ def apply_rotary_emb(
56
+ x: torch.Tensor,
57
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor],
58
+ use_real: bool = True,
59
+ use_real_unbind_dim: int = -1,
60
+ ) -> torch.Tensor:
61
+ """
62
+ Apply rotary positional embeddings (RoPE) to input tensors.
63
+
64
+ This implementation supports both standard 2D RoPE tensors [L, Dh] and batched 4D RoPE
65
+ tensors [B, 1, L, Dh] for compatibility with TREAD's token-dropping mechanism where
66
+ different batches may have different token subsets.
67
+
68
+ Args:
69
+ x: Input tensor of shape [B, H, L, Dh].
70
+ freqs_cis: Tuple of (cos, sin) tensors. Supports shapes [L, Dh] or [B, 1, L, Dh].
71
+ use_real: Whether to use real-valued RoPE implementation.
72
+ use_real_unbind_dim: Dimension to unbind when using real-valued RoPE (-1 or -2).
73
+
74
+ Returns:
75
+ Tensor with rotary embeddings applied, same shape as input x.
76
+ """
77
+ if use_real:
78
+ cos, sin = freqs_cis
79
+ if cos.dim() == 2: # [L, Dh] → [1, 1, L, Dh]
80
+ cos = cos.unsqueeze(0).unsqueeze(0)
81
+ sin = sin.unsqueeze(0).unsqueeze(0)
82
+ if cos.dim() != 4 or sin.dim() != 4:
83
+ raise RuntimeError(f"RoPE must be 2D or 4D, got cos={cos.dim()}D, sin={sin.dim()}D")
84
+
85
+ cos, sin = cos.to(x.device), sin.to(x.device)
86
+
87
+ if cos.size(-2) != x.size(-2) or cos.size(-1) != x.size(-1):
88
+ raise RuntimeError(
89
+ f"RoPE shape mismatch: rope[-2:]=({cos.size(-2)},{cos.size(-1)}) vs x[-2:]=({x.size(-2)},{x.size(-1)})"
90
+ )
91
+
92
+ if use_real_unbind_dim == -1:
93
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
94
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
95
+ elif use_real_unbind_dim == -2:
96
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
97
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
98
+ else:
99
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
100
+
101
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
102
+ return out
103
+ else:
104
+ x_rot = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
105
+ freqs = freqs_cis.unsqueeze(2)
106
+ x_out = torch.view_as_real(x_rot * freqs).flatten(3)
107
+ return x_out.type_as(x)
108
+
109
+
110
+ class MotifVideoAttnProcessor2_0:
111
+ def __init__(self):
112
+ if not hasattr(F, "scaled_dot_product_attention"):
113
+ raise ImportError(
114
+ "MotifVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
115
+ )
116
+
117
+ def __call__(
118
+ self,
119
+ attn: Attention,
120
+ hidden_states: torch.Tensor,
121
+ encoder_hidden_states: Optional[torch.Tensor] = None,
122
+ attention_mask: Optional[torch.Tensor] = None,
123
+ image_rotary_emb: Optional[torch.Tensor] = None,
124
+ query_input: Optional[torch.Tensor] = None,
125
+ key_input: Optional[torch.Tensor] = None,
126
+ value_input: Optional[torch.Tensor] = None,
127
+ ) -> torch.Tensor:
128
+ # Cross-attention mode: query already projected externally (cross_attn_query_proj + norm),
129
+ # skip to_q and only apply reshape + norm_q + RoPE. K/V use to_k/to_v as normal.
130
+ if query_input is not None:
131
+ query = query_input.unflatten(2, (attn.heads, -1)).transpose(1, 2)
132
+ key = attn.to_k(key_input)
133
+ value = attn.to_v(value_input)
134
+
135
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
136
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
137
+
138
+ if attn.norm_q is not None:
139
+ query = attn.norm_q(query)
140
+ if attn.norm_k is not None:
141
+ key = attn.norm_k(key)
142
+
143
+ if image_rotary_emb is not None:
144
+ query = apply_rotary_emb(query, image_rotary_emb)
145
+
146
+ hidden_states = F.scaled_dot_product_attention(
147
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
148
+ )
149
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
150
+ hidden_states = hidden_states.to(query.dtype)
151
+ return hidden_states, None
152
+
153
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
154
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
155
+
156
+ # 1. QKV projections
157
+ query = attn.to_q(hidden_states)
158
+ key = attn.to_k(hidden_states)
159
+ value = attn.to_v(hidden_states)
160
+
161
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
162
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
163
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
164
+
165
+ # 2. QK normalization
166
+ if attn.norm_q is not None:
167
+ query = attn.norm_q(query)
168
+ if attn.norm_k is not None:
169
+ key = attn.norm_k(key)
170
+
171
+ # 3. Rotational positional embeddings applied to latent stream
172
+ if image_rotary_emb is not None:
173
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
174
+ query = torch.cat(
175
+ [
176
+ apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
177
+ query[:, :, -encoder_hidden_states.shape[1] :],
178
+ ],
179
+ dim=2,
180
+ )
181
+ key = torch.cat(
182
+ [
183
+ apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
184
+ key[:, :, -encoder_hidden_states.shape[1] :],
185
+ ],
186
+ dim=2,
187
+ )
188
+ else:
189
+ query = apply_rotary_emb(query, image_rotary_emb)
190
+ key = apply_rotary_emb(key, image_rotary_emb)
191
+
192
+ # 4. Encoder condition QKV projection and normalization
193
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
194
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
195
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
196
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
197
+
198
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
199
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
200
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
201
+
202
+ if attn.norm_added_q is not None:
203
+ encoder_query = attn.norm_added_q(encoder_query)
204
+ if attn.norm_added_k is not None:
205
+ encoder_key = attn.norm_added_k(encoder_key)
206
+
207
+ query = torch.cat([query, encoder_query], dim=2)
208
+ key = torch.cat([key, encoder_key], dim=2)
209
+ value = torch.cat([value, encoder_value], dim=2)
210
+
211
+ # 5. Attention
212
+ hidden_states = F.scaled_dot_product_attention(
213
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
214
+ )
215
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
216
+ hidden_states = hidden_states.to(query.dtype)
217
+
218
+ # 6. Output projection
219
+ if encoder_hidden_states is not None:
220
+ hidden_states, encoder_hidden_states = (
221
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
222
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
223
+ )
224
+
225
+ if getattr(attn, "to_out", None) is not None:
226
+ hidden_states = attn.to_out[0](hidden_states)
227
+ hidden_states = attn.to_out[1](hidden_states)
228
+
229
+ if getattr(attn, "to_add_out", None) is not None:
230
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
231
+
232
+ return hidden_states, encoder_hidden_states
233
+
234
+
235
+ class MotifVideoPatchEmbed(nn.Module):
236
+ def __init__(
237
+ self,
238
+ patch_size: Union[int, Tuple[int, int, int]] = 16,
239
+ in_chans: int = 3,
240
+ embed_dim: int = 768,
241
+ ) -> None:
242
+ super().__init__()
243
+
244
+ patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
245
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
246
+
247
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
248
+ hidden_states = self.proj(hidden_states)
249
+ hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
250
+ return hidden_states
251
+
252
+
253
+ class MotifVideoAdaNorm(nn.Module):
254
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
255
+ super().__init__()
256
+
257
+ out_features = out_features or 2 * in_features
258
+ self.linear = nn.Linear(in_features, out_features)
259
+ self.nonlinearity = nn.SiLU()
260
+
261
+ def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
262
+ temb = self.linear(self.nonlinearity(temb))
263
+ gate_msa, gate_mlp = temb.chunk(2, dim=1)
264
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
265
+ return gate_msa, gate_mlp
266
+
267
+
268
+ class MotifVideoConditionEmbedding(nn.Module):
269
+ def __init__(
270
+ self,
271
+ embedding_dim: int,
272
+ pooled_projection_dim: int | None,
273
+ ):
274
+ super().__init__()
275
+
276
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
277
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
278
+
279
+ if isinstance(pooled_projection_dim, int):
280
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
281
+
282
+ def forward(
283
+ self,
284
+ timestep: torch.Tensor,
285
+ pooled_projection: torch.Tensor | None = None,
286
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
287
+ timesteps_proj = self.time_proj(timestep)
288
+ timestep_embedder_dtype = next(self.timestep_embedder.parameters()).dtype
289
+ conditioning = self.timestep_embedder(timesteps_proj.to(timestep_embedder_dtype)) # (N, D)
290
+ if pooled_projection is not None:
291
+ conditioning = conditioning + self.text_embedder(pooled_projection)
292
+
293
+ token_replace_emb = None
294
+
295
+ return conditioning, token_replace_emb
296
+
297
+
298
+ # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L485-L486
299
+ def find_correction_factor(num_rotations, dim, base, max_position_embeddings):
300
+ dtype = num_rotations.dtype if isinstance(num_rotations, torch.Tensor) else torch.float32
301
+ max_pos_tensor = torch.as_tensor(max_position_embeddings, dtype=dtype)
302
+ return (dim * torch.log(max_pos_tensor / (num_rotations * 2 * math.pi))) / (
303
+ 2 * math.log(base)
304
+ ) # Inverse dim formula to find number of rotations
305
+
306
+
307
+ # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L489-L495
308
+ def find_correction_range(low_ratio, high_ratio, dim, base, ori_max_pe_len):
309
+ """
310
+ Find the correction range for NTK-by-parts interpolation.
311
+ """
312
+ low = torch.floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len))
313
+ high = torch.ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len))
314
+ low = torch.clamp(low, min=0)
315
+ high = torch.clamp(high, max=dim - 1)
316
+ return low, high # Clamp values just in case
317
+
318
+
319
+ # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L498-L504
320
+ def linear_ramp_mask(min_val, max_val, num_dim):
321
+ if isinstance(min_val, torch.Tensor):
322
+ if (min_val == max_val).all():
323
+ max_val = max_val + 0.001
324
+ elif min_val == max_val:
325
+ max_val += 0.001
326
+
327
+ linear_func = (torch.arange(num_dim, dtype=torch.float32) - min_val) / (max_val - min_val)
328
+ ramp_func = torch.clamp(linear_func, 0, 1)
329
+ return ramp_func
330
+
331
+
332
+ # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L507-L511
333
+ def find_newbase_ntk(dim, base, scale):
334
+ """
335
+ Calculate the new base for NTK-aware scaling.
336
+ """
337
+ # Avoid division by zero when dim == 2 (or invalid smaller values).
338
+ # In these degenerate cases, fall back to the original base (no NTK adjustment).
339
+ if dim <= 2:
340
+ return base
341
+ return base * (scale ** (dim / (dim - 2)))
342
+
343
+
344
+ # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L514-L652
345
+ def get_1d_rotary_pos_embed(
346
+ dim: int,
347
+ pos: Union[np.ndarray, int],
348
+ theta: float = 10000.0,
349
+ use_real=False,
350
+ linear_factor=1.0,
351
+ ntk_factor=1.0,
352
+ repeat_interleave_real=True,
353
+ freqs_dtype=torch.float32,
354
+ yarn=False,
355
+ max_pe_len=None,
356
+ ori_max_pe_len=64,
357
+ dype=False,
358
+ current_timestep=1.0,
359
+ ):
360
+ """
361
+ Precompute the frequency tensor for complex exponentials with RoPE.
362
+ Supports YARN interpolation for vision transformers.
363
+
364
+ Args:
365
+ dim (`int`):
366
+ Dimension of the frequency tensor.
367
+ pos (`np.ndarray` or `int`):
368
+ Position indices for the frequency tensor. [S] or scalar.
369
+ theta (`float`, *optional*, defaults to 10000.0):
370
+ Scaling factor for frequency computation.
371
+ use_real (`bool`, *optional*, defaults to False):
372
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
373
+ linear_factor (`float`, *optional*, defaults to 1.0):
374
+ Scaling factor for linear interpolation.
375
+ ntk_factor (`float`, *optional*, defaults to 1.0):
376
+ Scaling factor for NTK-Aware RoPE.
377
+ repeat_interleave_real (`bool`, *optional*, defaults to True):
378
+ If True and use_real, real and imaginary parts are interleaved with themselves to reach dim.
379
+ Otherwise, they are concatenated.
380
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
381
+ Data type of the frequency tensor.
382
+ yarn (`bool`, *optional*, defaults to False):
383
+ If True, use YARN interpolation combining NTK, linear, and base methods.
384
+ max_pe_len (`int`, *optional*):
385
+ Maximum position encoding length (current patches for vision models).
386
+ ori_max_pe_len (`int`, *optional*, defaults to 64):
387
+ Original maximum position encoding length (base patches for vision models).
388
+ dype (`bool`, *optional*, defaults to False):
389
+ If True, enable DyPE (Dynamic Position Encoding) with timestep-aware scaling.
390
+ current_timestep (`float`, *optional*, defaults to 1.0):
391
+ Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise.
392
+
393
+ Returns:
394
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
395
+ If use_real=True, returns tuple of (cos, sin) tensors.
396
+ """
397
+ assert dim % 2 == 0
398
+
399
+ if isinstance(pos, int):
400
+ pos = torch.arange(pos)
401
+ if isinstance(pos, np.ndarray):
402
+ pos = torch.from_numpy(pos)
403
+
404
+ device = pos.device
405
+
406
+ if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
407
+ if not isinstance(max_pe_len, torch.Tensor):
408
+ max_pe_len = torch.tensor(max_pe_len, dtype=freqs_dtype, device=device)
409
+
410
+ scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
411
+
412
+ beta_0 = 1.25
413
+ beta_1 = 0.75
414
+ gamma_0 = 16
415
+ gamma_1 = 2
416
+
417
+ freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
418
+
419
+ freqs_linear = 1.0 / torch.einsum(
420
+ "..., f -> ... f",
421
+ scale,
422
+ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)),
423
+ )
424
+
425
+ new_base = find_newbase_ntk(dim, theta, scale)
426
+ if new_base.dim() > 0:
427
+ new_base = new_base.view(-1, 1)
428
+ freqs_ntk = 1.0 / torch.pow(new_base, (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
429
+ if freqs_ntk.dim() > 1:
430
+ freqs_ntk = freqs_ntk.squeeze()
431
+
432
+ if dype:
433
+ beta_0 = torch.pow(beta_0, 2.0 * torch.pow(current_timestep, 2.0))
434
+ beta_1 = torch.pow(beta_1, 2.0 * torch.pow(current_timestep, 2.0))
435
+
436
+ low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
437
+ high = torch.clamp(high, max=dim // 2)
438
+
439
+ freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
440
+ freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
441
+
442
+ if dype:
443
+ gamma_0 = torch.pow(gamma_0, 2.0 * torch.pow(current_timestep, 2.0))
444
+ gamma_1 = torch.pow(gamma_1, 2.0 * torch.pow(current_timestep, 2.0))
445
+
446
+ low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
447
+ high = torch.clamp(high, max=dim // 2)
448
+
449
+ freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
450
+ freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
451
+
452
+ else:
453
+ theta_ntk = theta * ntk_factor
454
+ freqs = 1.0 / (theta_ntk ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)) / linear_factor
455
+
456
+ freqs = torch.outer(pos, freqs)
457
+
458
+ is_npu = freqs.device.type == "npu"
459
+ if is_npu:
460
+ freqs = freqs.float()
461
+
462
+ if use_real and repeat_interleave_real:
463
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
464
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
465
+
466
+ if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
467
+ mscale = torch.where(scale <= 1.0, 1.0, 0.1 * torch.log(scale) + 1.0).to(scale)
468
+ freqs_cos = freqs_cos * mscale
469
+ freqs_sin = freqs_sin * mscale
470
+
471
+ return freqs_cos, freqs_sin
472
+ elif use_real:
473
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()
474
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()
475
+ return freqs_cos, freqs_sin
476
+ else:
477
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
478
+ return freqs_cis
479
+
480
+
481
+ class MotifVideoRotaryPosEmbed(nn.Module):
482
+ def __init__(
483
+ self,
484
+ patch_size: int,
485
+ patch_size_t: int,
486
+ rope_dim: List[int],
487
+ theta: float = 256.0,
488
+ base_latent_size: int | None = None,
489
+ ):
490
+ """
491
+ Rotary Positional Embedding (RoPE) for video latents.
492
+
493
+ Args:
494
+ patch_size (`int`):
495
+ Spatial patch size (e.g., 2).
496
+ patch_size_t (`int`):
497
+ Temporal patch size (e.g., 1).
498
+ rope_dim (`List[int]`):
499
+ Dimensions for RoPE across [Time, Height, Width] axes.
500
+ theta (`float`, *optional*, defaults to 256.0):
501
+ Base frequency for rotary embeddings.
502
+ base_latent_size (`int`, *optional*):
503
+ The maximum spatial dimension (in latent units) seen during training,
504
+ i.e. `training_resolution / vae_scale_factor_spatial`.
505
+ For example, for 1280x1280 training images and a VAE spatial downscale
506
+ (`vae_scale_factor_spatial`) of 8, this would be 160; for a downscale
507
+ of 16, it would be 80.
508
+ """
509
+ super().__init__()
510
+
511
+ self.patch_size = patch_size
512
+ self.patch_size_t = patch_size_t
513
+ self.rope_dim = rope_dim
514
+ self.theta = theta
515
+ self.base_latent_size = base_latent_size
516
+
517
+ @lru_cache(maxsize=8)
518
+ def _get_base_patch_grid_size(self, base_latent_size: Optional[int], patch_size: int) -> Optional[int]:
519
+ return base_latent_size // patch_size if base_latent_size else None
520
+
521
+ @lru_cache(maxsize=8)
522
+ def _get_dynamic_interpolation_scale(self, h: int, w: int, base_grid_size: int) -> float:
523
+ return math.sqrt(h * w / (base_grid_size**2))
524
+
525
+ def forward(self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
526
+ if self.training:
527
+ assert self.base_latent_size is None, (
528
+ "RoPE interpolation/extrapolation logic should only be enabled for inference. "
529
+ f"During training, base_latent_size must be None, but got {self.base_latent_size!r}."
530
+ )
531
+
532
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
533
+ rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
534
+
535
+ axes_grids = []
536
+ for i in range(3):
537
+ # Note: The following line diverges from original behaviour. We create the grid on the device, whereas
538
+ # original implementation creates it on CPU and then moves it to device. This results in numerical
539
+ # differences in layerwise debugging outputs, but visually it is the same.
540
+ grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
541
+ axes_grids.append(grid)
542
+ grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
543
+ grid = torch.stack(grid, dim=0) # [3, W, H, T]
544
+
545
+ base_patch_grid_size = self._get_base_patch_grid_size(self.base_latent_size, self.patch_size)
546
+ if base_patch_grid_size is not None:
547
+ if base_patch_grid_size <= 0:
548
+ raise ValueError(f"base_patch_grid_size must be a positive number, got {base_patch_grid_size}.")
549
+ dynamic_interpolation_scale = self._get_dynamic_interpolation_scale(
550
+ rope_sizes[1], rope_sizes[2], base_patch_grid_size
551
+ )
552
+
553
+ normalized_timestep = torch.tensor(1.0)
554
+ if not self.training and timestep is not None:
555
+ normalized_timestep = timestep[0] / NUM_TRAIN_TIMESTEPS
556
+
557
+ freqs = []
558
+ for i in range(3):
559
+ common_kwargs = {
560
+ "dim": self.rope_dim[i],
561
+ "pos": grid[i].reshape(-1),
562
+ "theta": self.theta,
563
+ "use_real": True,
564
+ "freqs_dtype": torch.float64,
565
+ }
566
+
567
+ # Apply scaling only to spatial dimensions (Height and Width, i=1 and i=2)
568
+ if i > 0 and base_patch_grid_size is not None and dynamic_interpolation_scale > 1.0:
569
+ # We project the training base to the current size using the uniform scale factor.
570
+ # max_pe_len tells the RoPE logic the "new" maximum length it's dealing with.
571
+ max_pe_len = torch.tensor(
572
+ base_patch_grid_size * dynamic_interpolation_scale,
573
+ dtype=torch.float64,
574
+ device=hidden_states.device,
575
+ )
576
+
577
+ freq = get_1d_rotary_pos_embed(
578
+ **common_kwargs,
579
+ yarn=True, # Enable Yet Another RoPE extensioN (YARN) for extrapolation
580
+ max_pe_len=max_pe_len,
581
+ ori_max_pe_len=base_patch_grid_size, # The original training scale
582
+ dype=True, # Enable Dynamic Position Encoding (time-aware)
583
+ current_timestep=normalized_timestep,
584
+ )
585
+ else:
586
+ # Time dimension OR within training bounds -> Standard RoPE
587
+ freq = get_1d_rotary_pos_embed(**common_kwargs)
588
+
589
+ freqs.append(freq)
590
+
591
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
592
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
593
+ return freqs_cos, freqs_sin
594
+
595
+
596
+ class MotifVideoImageProjection(nn.Module):
597
+ def __init__(self, in_features: int, hidden_size: int):
598
+ super().__init__()
599
+ self.norm_in = nn.LayerNorm(in_features)
600
+ self.linear_1 = nn.Linear(in_features, in_features)
601
+ self.act_fn = nn.GELU()
602
+ self.linear_2 = nn.Linear(in_features, hidden_size)
603
+ self.norm_out = nn.LayerNorm(hidden_size)
604
+
605
+ def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
606
+ hidden_states = self.norm_in(image_embeds)
607
+ hidden_states = self.linear_1(hidden_states)
608
+ hidden_states = self.act_fn(hidden_states)
609
+ hidden_states = self.linear_2(hidden_states)
610
+ hidden_states = self.norm_out(hidden_states)
611
+ return hidden_states
612
+
613
+
614
+ class MotifVideoSingleTransformerBlock(nn.Module):
615
+ def __init__(
616
+ self,
617
+ num_attention_heads: int,
618
+ attention_head_dim: int,
619
+ mlp_ratio: float = 4.0,
620
+ qk_norm: str = "rms_norm",
621
+ norm_type: str = "layer_norm",
622
+ enable_text_cross_attention: bool = False,
623
+ ) -> None:
624
+ super().__init__()
625
+
626
+ hidden_size = num_attention_heads * attention_head_dim
627
+ mlp_dim = int(hidden_size * mlp_ratio)
628
+
629
+ self.attn = Attention(
630
+ query_dim=hidden_size,
631
+ cross_attention_dim=None,
632
+ dim_head=attention_head_dim,
633
+ heads=num_attention_heads,
634
+ out_dim=hidden_size,
635
+ bias=True,
636
+ processor=MotifVideoAttnProcessor2_0(),
637
+ qk_norm=qk_norm,
638
+ eps=1e-6,
639
+ pre_only=True,
640
+ )
641
+
642
+ self.enable_text_cross_attention = enable_text_cross_attention
643
+ if enable_text_cross_attention:
644
+ self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
645
+ self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
646
+ self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
647
+ nn.init.zeros_(self.cross_attn_out_proj.weight)
648
+ nn.init.zeros_(self.cross_attn_out_proj.bias)
649
+
650
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type=norm_type)
651
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
652
+ self.act_mlp = nn.GELU(approximate="tanh")
653
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
654
+
655
+ def forward(
656
+ self,
657
+ hidden_states: torch.Tensor,
658
+ encoder_hidden_states: torch.Tensor,
659
+ temb: torch.Tensor,
660
+ attention_mask: Optional[torch.Tensor] = None,
661
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
662
+ token_replace_emb: torch.Tensor | None = None,
663
+ first_frame_num_tokens: int | None = None,
664
+ image_embed_seq_len: int = 0,
665
+ encoder_attention_mask: torch.Tensor | None = None,
666
+ ) -> torch.Tensor:
667
+ text_seq_length = encoder_hidden_states.shape[1]
668
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
669
+
670
+ residual = hidden_states
671
+
672
+ # 1. Input normalization
673
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
674
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
675
+
676
+ norm_hidden_states, norm_encoder_hidden_states = (
677
+ norm_hidden_states[:, :-text_seq_length, :],
678
+ norm_hidden_states[:, -text_seq_length:, :],
679
+ )
680
+
681
+ # 2. Attention
682
+ attn_output, context_attn_output = self.attn(
683
+ hidden_states=norm_hidden_states,
684
+ encoder_hidden_states=norm_encoder_hidden_states,
685
+ attention_mask=attention_mask,
686
+ image_rotary_emb=image_rotary_emb,
687
+ )
688
+
689
+ # Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
690
+ if self.enable_text_cross_attention:
691
+ txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
692
+ text_mask = None
693
+ if encoder_attention_mask is not None:
694
+ text_mask = encoder_attention_mask[:, image_embed_seq_len:]
695
+ text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
696
+ cross_q = self.cross_attn_query_proj(attn_output)
697
+ cross_output, _ = self.attn(
698
+ hidden_states=cross_q,
699
+ query_input=cross_q,
700
+ key_input=txt_kv,
701
+ value_input=txt_kv,
702
+ attention_mask=text_mask,
703
+ image_rotary_emb=image_rotary_emb,
704
+ )
705
+ attn_output = attn_output + self.cross_attn_out_proj(cross_output)
706
+
707
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
708
+
709
+ # 3. Modulation and residual connection
710
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
711
+ hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
712
+ hidden_states = hidden_states + residual
713
+
714
+ hidden_states, encoder_hidden_states = (
715
+ hidden_states[:, :-text_seq_length, :],
716
+ hidden_states[:, -text_seq_length:, :],
717
+ )
718
+ return hidden_states, encoder_hidden_states
719
+
720
+
721
+ class MotifVideoTransformerBlock(nn.Module):
722
+ def __init__(
723
+ self,
724
+ num_attention_heads: int,
725
+ attention_head_dim: int,
726
+ mlp_ratio: float,
727
+ qk_norm: str = "rms_norm",
728
+ norm_type: str = "layer_norm",
729
+ enable_text_cross_attention: bool = False,
730
+ ) -> None:
731
+ super().__init__()
732
+
733
+ hidden_size = num_attention_heads * attention_head_dim
734
+
735
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type=norm_type)
736
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type=norm_type)
737
+
738
+ self.attn = Attention(
739
+ query_dim=hidden_size,
740
+ cross_attention_dim=None,
741
+ added_kv_proj_dim=hidden_size,
742
+ dim_head=attention_head_dim,
743
+ heads=num_attention_heads,
744
+ out_dim=hidden_size,
745
+ context_pre_only=False,
746
+ bias=True,
747
+ processor=MotifVideoAttnProcessor2_0(),
748
+ qk_norm=qk_norm,
749
+ eps=1e-6,
750
+ )
751
+
752
+ self.enable_text_cross_attention = enable_text_cross_attention
753
+ if enable_text_cross_attention:
754
+ self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
755
+ self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
756
+ self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
757
+ nn.init.zeros_(self.cross_attn_out_proj.weight)
758
+ nn.init.zeros_(self.cross_attn_out_proj.bias)
759
+
760
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
761
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
762
+
763
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
764
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
765
+
766
+ def forward(
767
+ self,
768
+ hidden_states: torch.Tensor,
769
+ encoder_hidden_states: torch.Tensor,
770
+ temb: torch.Tensor,
771
+ attention_mask: Optional[torch.Tensor] = None,
772
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
773
+ token_replace_emb: torch.Tensor | None = None,
774
+ first_frame_num_tokens: int | None = None,
775
+ image_embed_seq_len: int = 0,
776
+ encoder_attention_mask: torch.Tensor | None = None,
777
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
778
+ # 1. Input normalization
779
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
780
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
781
+ encoder_hidden_states, emb=temb
782
+ )
783
+
784
+ # 2. Joint attention
785
+ attn_output, context_attn_output = self.attn(
786
+ hidden_states=norm_hidden_states,
787
+ encoder_hidden_states=norm_encoder_hidden_states,
788
+ attention_mask=attention_mask,
789
+ image_rotary_emb=image_rotary_emb,
790
+ )
791
+
792
+ # 3. Modulation and residual connection
793
+ hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
794
+
795
+ # Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
796
+ if self.enable_text_cross_attention:
797
+ txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
798
+ text_mask = None
799
+ if encoder_attention_mask is not None:
800
+ text_mask = encoder_attention_mask[:, image_embed_seq_len:]
801
+ text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
802
+ cross_q = self.cross_attn_query_proj(attn_output)
803
+ cross_output, _ = self.attn(
804
+ hidden_states=cross_q,
805
+ query_input=cross_q,
806
+ key_input=txt_kv,
807
+ value_input=txt_kv,
808
+ attention_mask=text_mask,
809
+ image_rotary_emb=image_rotary_emb,
810
+ )
811
+ hidden_states = hidden_states + self.cross_attn_out_proj(cross_output)
812
+
813
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
814
+
815
+ norm_hidden_states = self.norm2(hidden_states)
816
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
817
+
818
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
819
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
820
+
821
+ # 4. Feed-forward
822
+ ff_output = self.ff(norm_hidden_states)
823
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
824
+
825
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
826
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
827
+
828
+ return hidden_states, encoder_hidden_states
829
+
830
+
831
+ TransformerBlockRegistry.register(
832
+ model_class=MotifVideoTransformerBlock,
833
+ metadata=TransformerBlockMetadata(
834
+ return_hidden_states_index=0,
835
+ return_encoder_hidden_states_index=1,
836
+ ),
837
+ )
838
+ TransformerBlockRegistry.register(
839
+ model_class=MotifVideoSingleTransformerBlock,
840
+ metadata=TransformerBlockMetadata(
841
+ return_hidden_states_index=0,
842
+ return_encoder_hidden_states_index=1,
843
+ ),
844
+ )
845
+
846
+
847
+ class MotifVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
848
+ r"""
849
+ A Transformer model for video-like data used in [MotifVideo](https://huggingface.co/motif/motifvideo).
850
+
851
+ Args:
852
+ in_channels (`int`, defaults to `16`):
853
+ The number of channels in the input.
854
+ out_channels (`int`, defaults to `16`):
855
+ The number of channels in the output.
856
+ num_attention_heads (`int`, defaults to `24`):
857
+ The number of heads to use for multi-head attention.
858
+ attention_head_dim (`int`, defaults to `128`):
859
+ The number of channels in each head.
860
+ num_layers (`int`, defaults to `20`):
861
+ The number of layers of dual-stream blocks to use.
862
+ num_single_layers (`int`, defaults to `40`):
863
+ The number of layers of single-stream blocks to use.
864
+
865
+ mlp_ratio (`float`, defaults to `4.0`):
866
+ The ratio of the hidden layer size to the input size in the feedforward network.
867
+ patch_size (`int`, defaults to `2`):
868
+ The size of the spatial patches to use in the patch embedding layer.
869
+ patch_size_t (`int`, defaults to `1`):
870
+ The size of the temporal patches to use in the patch embedding layer.
871
+ qk_norm (`str`, defaults to `rms_norm`):
872
+ The normalization to use for the query and key projections in the attention layers.
873
+ text_embed_dim (`int`, defaults to `4096`):
874
+ Input dimension of text embeddings from the text encoder.
875
+ rope_theta (`float`, defaults to `256.0`):
876
+ The value of theta to use in the RoPE layer.
877
+ rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
878
+ The dimensions of the axes to use in the RoPE layer.
879
+ base_latent_size (`int`, *optional*):
880
+ The maximum spatial dimension (in latent units) seen during training.
881
+ For example, if trained on 1280x1280 with a VAE downscale of 16, this is 80.
882
+ """
883
+
884
+ _supports_gradient_checkpointing = True
885
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
886
+ _no_split_modules = [
887
+ "MotifVideoTransformerBlock",
888
+ "MotifVideoSingleTransformerBlock",
889
+ "MotifVideoPatchEmbed",
890
+ ]
891
+
892
+ @register_to_config
893
+ def __init__(
894
+ self,
895
+ in_channels: int = 33,
896
+ out_channels: int = 16,
897
+ num_attention_heads: int = 24,
898
+ attention_head_dim: int = 128,
899
+ num_layers: int = 20,
900
+ num_single_layers: int = 40,
901
+ num_decoder_layers: int = 0,
902
+ mlp_ratio: float = 4.0,
903
+ patch_size: int = 2,
904
+ patch_size_t: int = 1,
905
+ qk_norm: str = "rms_norm",
906
+ norm_type: str = "layer_norm",
907
+ text_embed_dim: int = 4096,
908
+ image_embed_dim: int | None = None,
909
+ pooled_projection_dim: int | None = None,
910
+ rope_theta: float = 256.0,
911
+ rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
912
+ base_latent_size: int | None = None,
913
+ enable_text_cross_attention_dual: bool = False,
914
+ enable_text_cross_attention_single: bool = False,
915
+ ) -> None:
916
+ super().__init__()
917
+
918
+ inner_dim = num_attention_heads * attention_head_dim
919
+ out_channels = out_channels or in_channels
920
+
921
+ # 1. Latent and condition embedders
922
+ self.x_embedder = MotifVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
923
+ self.context_embedder = PixArtAlphaTextProjection(in_features=text_embed_dim, hidden_size=inner_dim)
924
+
925
+ # First frame conditioning: Image conditioning embedders
926
+ self.image_embed_dim = image_embed_dim
927
+ if image_embed_dim is not None:
928
+ # Project image embeddings from vision encoder to transformer dim
929
+ self.image_embedder = MotifVideoImageProjection(in_features=image_embed_dim, hidden_size=inner_dim)
930
+
931
+ self.time_text_embed = MotifVideoConditionEmbedding(inner_dim, pooled_projection_dim)
932
+
933
+ # 2. RoPE
934
+ self.rope = MotifVideoRotaryPosEmbed(
935
+ patch_size, patch_size_t, rope_axes_dim, rope_theta, base_latent_size=base_latent_size
936
+ )
937
+
938
+ # Cross-attention config
939
+ self.enable_text_cross_attention_dual = enable_text_cross_attention_dual
940
+ self.enable_text_cross_attention_single = enable_text_cross_attention_single
941
+
942
+ # 3. Dual stream transformer blocks
943
+ self.transformer_blocks = nn.ModuleList(
944
+ [
945
+ MotifVideoTransformerBlock(
946
+ num_attention_heads,
947
+ attention_head_dim,
948
+ mlp_ratio=mlp_ratio,
949
+ qk_norm=qk_norm,
950
+ norm_type=norm_type,
951
+ enable_text_cross_attention=enable_text_cross_attention_dual,
952
+ )
953
+ for _ in range(num_layers)
954
+ ]
955
+ )
956
+
957
+ # 4. Single stream transformer blocks
958
+ # Encoder blocks get cross-attention; decoder blocks do not (no text stream in decoder)
959
+ num_encoder_single = num_single_layers - num_decoder_layers
960
+ self.single_transformer_blocks = nn.ModuleList(
961
+ [
962
+ MotifVideoSingleTransformerBlock(
963
+ num_attention_heads,
964
+ attention_head_dim,
965
+ mlp_ratio=mlp_ratio,
966
+ qk_norm=qk_norm,
967
+ norm_type=norm_type,
968
+ enable_text_cross_attention=enable_text_cross_attention_single
969
+ if i < num_encoder_single
970
+ else False,
971
+ )
972
+ for i in range(num_single_layers)
973
+ ]
974
+ )
975
+
976
+ # 5. Output projection
977
+ self.norm_out = AdaLayerNormContinuous(
978
+ inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type=norm_type
979
+ )
980
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
981
+
982
+ # Verify cross-attention config matches actual block state.
983
+ # Catches silent misconfiguration (e.g. checkpoint config with renamed keys).
984
+ for i, block in enumerate(self.transformer_blocks):
985
+ if block.enable_text_cross_attention != enable_text_cross_attention_dual:
986
+ raise ValueError(
987
+ f"transformer_blocks[{i}].enable_text_cross_attention="
988
+ f"{block.enable_text_cross_attention}, expected {enable_text_cross_attention_dual}. "
989
+ f"Check checkpoint config.json key names match __init__ parameters."
990
+ )
991
+ num_encoder_single = num_single_layers - num_decoder_layers
992
+ for i, block in enumerate(self.single_transformer_blocks):
993
+ expected = enable_text_cross_attention_single if i < num_encoder_single else False
994
+ if block.enable_text_cross_attention != expected:
995
+ raise ValueError(
996
+ f"single_transformer_blocks[{i}].enable_text_cross_attention="
997
+ f"{block.enable_text_cross_attention}, expected {expected}. "
998
+ f"Check checkpoint config.json key names match __init__ parameters."
999
+ )
1000
+
1001
+ self.gradient_checkpointing = False
1002
+ self.num_decoder_layers = num_decoder_layers
1003
+
1004
+ @property
1005
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
1006
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
1007
+ r"""
1008
+ Returns:
1009
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
1010
+ indexed by its weight name.
1011
+ """
1012
+ # set recursively
1013
+ processors = {}
1014
+
1015
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
1016
+ if hasattr(module, "get_processor"):
1017
+ processors[f"{name}.processor"] = module.get_processor()
1018
+
1019
+ for sub_name, child in module.named_children():
1020
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
1021
+
1022
+ return processors
1023
+
1024
+ for name, module in self.named_children():
1025
+ fn_recursive_add_processors(name, module, processors)
1026
+
1027
+ return processors
1028
+
1029
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
1030
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
1031
+ r"""
1032
+ Sets the attention processor to use to compute attention.
1033
+
1034
+ Parameters:
1035
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
1036
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
1037
+ for **all** `Attention` layers.
1038
+
1039
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
1040
+ processor. This is strongly recommended when setting trainable attention processors.
1041
+
1042
+ """
1043
+ count = len(self.attn_processors.keys())
1044
+
1045
+ if isinstance(processor, dict) and len(processor) != count:
1046
+ raise ValueError(
1047
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
1048
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
1049
+ )
1050
+
1051
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
1052
+ if hasattr(module, "set_processor"):
1053
+ if not isinstance(processor, dict):
1054
+ module.set_processor(processor)
1055
+ else:
1056
+ module.set_processor(processor.pop(f"{name}.processor"))
1057
+
1058
+ for sub_name, child in module.named_children():
1059
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
1060
+
1061
+ for name, module in self.named_children():
1062
+ fn_recursive_attn_processor(name, module, processor)
1063
+
1064
+ def _maybe_gradient_checkpoint_block(self, block, *args):
1065
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1066
+ return self._gradient_checkpointing_func(block, *args)
1067
+ return block(*args)
1068
+
1069
+ def _get_unwrapped_blocks(self, blocks):
1070
+ if hasattr(blocks, "_checkpoint_wrapped_module"):
1071
+ return blocks._checkpoint_wrapped_module
1072
+ elif hasattr(blocks, "module"):
1073
+ return blocks.module
1074
+ return blocks
1075
+
1076
+ def _create_attention_mask(
1077
+ self,
1078
+ hidden_states: torch.Tensor,
1079
+ encoder_attention_mask: torch.Tensor,
1080
+ ) -> torch.Tensor:
1081
+ """
1082
+ Create attention mask of shape [B, 1, 1, N] where N = L + E,
1083
+ based on latent tokens (always valid) and the encoder mask.
1084
+
1085
+ Args:
1086
+ hidden_states: [B, L, D]
1087
+ encoder_attention_mask: [B, E] (required)
1088
+
1089
+ Returns:
1090
+ attention_mask: [B, 1, 1, N]
1091
+ """
1092
+ attention_mask = F.pad(
1093
+ encoder_attention_mask.to(torch.bool),
1094
+ (hidden_states.shape[1], 0),
1095
+ value=True,
1096
+ )
1097
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L+E]
1098
+ return attention_mask
1099
+
1100
+ def forward(
1101
+ self,
1102
+ hidden_states: torch.Tensor,
1103
+ timestep: torch.LongTensor,
1104
+ encoder_hidden_states: torch.Tensor,
1105
+ encoder_attention_mask: torch.Tensor | None = None,
1106
+ pooled_projections: torch.Tensor | None = None,
1107
+ image_embeds: torch.Tensor | None = None,
1108
+ attention_kwargs: Optional[Dict[str, Any]] = None,
1109
+ return_dict: bool = True,
1110
+ tread_mixin: Optional[Any] = None,
1111
+ tread_disabled: bool = False,
1112
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
1113
+ """
1114
+ Forward pass of the MotifVideoTransformer3DModel.
1115
+
1116
+ Args:
1117
+ hidden_states: Input latent tensor [B, C, F, H, W].
1118
+ timestep: Diffusion timesteps [B].
1119
+ encoder_hidden_states: Text conditioning [B, E, D].
1120
+ encoder_attention_mask: Mask for text conditioning [B, E].
1121
+ pooled_projections: Pooled text embeddings [B, D].
1122
+ image_embeds: Optional image embeddings from vision encoder [B, N, D].
1123
+ attention_kwargs: Additional arguments for attention processors.
1124
+ return_dict: Whether to return a Transformer2DModelOutput.
1125
+ tread_mixin: Optional TreadMixin instance for token reduction.
1126
+ tread_disabled: When True, force tread_mixin to None (dense pass).
1127
+ torch.compile specializes on this bool, producing separate graphs
1128
+ for dense vs routed without attribute toggling.
1129
+
1130
+ Returns:
1131
+ Transformer2DModelOutput or tuple containing the predicted samples.
1132
+ """
1133
+ if tread_disabled:
1134
+ tread_mixin = None
1135
+ elif tread_mixin is None:
1136
+ tread_mixin = getattr(self, "_inference_tread_mixin", None)
1137
+
1138
+ if attention_kwargs is not None:
1139
+ attention_kwargs = attention_kwargs.copy()
1140
+ lora_scale = attention_kwargs.pop("scale", 1.0)
1141
+ else:
1142
+ lora_scale = 1.0
1143
+
1144
+ if USE_PEFT_BACKEND:
1145
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1146
+ scale_lora_layers(self, lora_scale)
1147
+ else:
1148
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
1149
+ logger.warning(
1150
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
1151
+ )
1152
+
1153
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
1154
+ p, p_t = self.config.patch_size, self.config.patch_size_t
1155
+ post_patch_num_frames = num_frames // p_t
1156
+ post_patch_height = height // p
1157
+ post_patch_width = width // p
1158
+ first_frame_num_tokens = 1 * post_patch_height * post_patch_width
1159
+ # 1. RoPE
1160
+ image_rotary_emb = self.rope(hidden_states, timestep=timestep)
1161
+ # 2. Conditional embeddings
1162
+ temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections)
1163
+ hidden_states = self.x_embedder(hidden_states)
1164
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1165
+
1166
+ # First frame conditioning: Image embeddings from vision encoder
1167
+ if image_embeds is not None:
1168
+ # image_embeds: [B, N, D_img] -> [B, N, D]
1169
+ image_embeds = self.image_embedder(image_embeds)
1170
+ encoder_hidden_states = torch.cat([image_embeds, encoder_hidden_states], dim=1)
1171
+ # Extend attention mask for image tokens
1172
+ if encoder_attention_mask is not None:
1173
+ image_mask = torch.ones(
1174
+ image_embeds.shape[0],
1175
+ image_embeds.shape[1],
1176
+ device=encoder_attention_mask.device,
1177
+ dtype=encoder_attention_mask.dtype,
1178
+ )
1179
+ encoder_attention_mask = torch.cat([image_mask, encoder_attention_mask], dim=1)
1180
+
1181
+ # image_embed_seq_len: used by cross-attention blocks to slice text from encoder_hidden_states
1182
+ image_embed_seq_len = image_embeds.shape[1] if image_embeds is not None else 0
1183
+
1184
+ decoder_hidden_states = hidden_states.clone()
1185
+
1186
+ if encoder_attention_mask is not None:
1187
+ attention_mask = self._create_attention_mask(
1188
+ hidden_states=hidden_states,
1189
+ encoder_attention_mask=encoder_attention_mask,
1190
+ )
1191
+ else:
1192
+ attention_mask = None
1193
+
1194
+ # TREAD state initialization: manage token reduction manually to support activation checkpointing
1195
+ tread_active = False
1196
+ current_route = None
1197
+ ids_keep = None
1198
+ x_full = None
1199
+ orig_mask = attention_mask
1200
+ orig_rope = image_rotary_emb
1201
+ latent_len = hidden_states.shape[1]
1202
+
1203
+ # 4. Dual stream transformer blocks (Encoder)
1204
+ for i, block in enumerate(self.transformer_blocks):
1205
+ # Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
1206
+ if is_tread_start(tread_mixin, tread_active, i):
1207
+ tread_active = True
1208
+ current_route = tread_mixin._tread_route
1209
+ # Reduce sequence length at the start of a TREAD route
1210
+ ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
1211
+ x_full = hidden_states.contiguous()
1212
+ hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
1213
+ attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
1214
+ image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
1215
+
1216
+ hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
1217
+ block,
1218
+ hidden_states,
1219
+ encoder_hidden_states,
1220
+ temb,
1221
+ attention_mask,
1222
+ image_rotary_emb,
1223
+ token_replace_emb,
1224
+ first_frame_num_tokens,
1225
+ image_embed_seq_len,
1226
+ encoder_attention_mask,
1227
+ )
1228
+
1229
+ if is_tread_end(tread_mixin, tread_active, i):
1230
+ # Restore full sequence length at the end of a TREAD route
1231
+ hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
1232
+ tread_active = False
1233
+ current_route = None
1234
+ ids_keep = None
1235
+ x_full = None
1236
+ attention_mask = orig_mask
1237
+ image_rotary_emb = orig_rope
1238
+
1239
+ # We need to unwrap the blocks because CheckpointWrapper does not support len(),
1240
+ # which is required for slicing the blocks into encoder and decoder parts.
1241
+ single_transformer_blocks = self.single_transformer_blocks
1242
+
1243
+ # 5. Single stream transformer blocks (Encoder)
1244
+ num_dual = len(self.transformer_blocks)
1245
+ for i, block in enumerate(
1246
+ single_transformer_blocks[: len(single_transformer_blocks) - self.num_decoder_layers]
1247
+ ):
1248
+ # Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
1249
+ abs_i = num_dual + i
1250
+ if is_tread_start(tread_mixin, tread_active, abs_i):
1251
+ tread_active = True
1252
+ current_route = tread_mixin._tread_route
1253
+ # Reduce sequence length at the start of a TREAD route
1254
+ ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
1255
+ x_full = hidden_states.contiguous()
1256
+ hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
1257
+ attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
1258
+ image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
1259
+
1260
+ hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
1261
+ block,
1262
+ hidden_states,
1263
+ encoder_hidden_states,
1264
+ temb,
1265
+ attention_mask,
1266
+ image_rotary_emb,
1267
+ token_replace_emb,
1268
+ first_frame_num_tokens,
1269
+ image_embed_seq_len,
1270
+ encoder_attention_mask,
1271
+ )
1272
+
1273
+ if is_tread_end(tread_mixin, tread_active, abs_i):
1274
+ # Restore full sequence length at the end of a TREAD route
1275
+ hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
1276
+ tread_active = False
1277
+ current_route = None
1278
+ ids_keep = None
1279
+ x_full = None
1280
+ attention_mask = orig_mask
1281
+ image_rotary_emb = orig_rope
1282
+
1283
+ # 6. Single stream transformer blocks (Decoder)
1284
+ if self.num_decoder_layers > 0:
1285
+ encoder_hidden_states = hidden_states
1286
+ attention_mask = None
1287
+
1288
+ num_single = len(single_transformer_blocks)
1289
+
1290
+ for i, block in enumerate(single_transformer_blocks[-self.num_decoder_layers :]):
1291
+ abs_i = num_dual + (num_single - self.num_decoder_layers) + i
1292
+ if is_tread_start(tread_mixin, tread_active, abs_i):
1293
+ tread_active = True
1294
+ current_route = tread_mixin._tread_route
1295
+ # Reduce sequence length at the start of a TREAD route
1296
+ ids_keep = tread_mixin.keep_indices(decoder_hidden_states, current_route["sel"]).to(
1297
+ decoder_hidden_states.device
1298
+ )
1299
+ x_full = encoder_hidden_states.contiguous()
1300
+ x_t_full = decoder_hidden_states.contiguous()
1301
+ decoder_hidden_states = tread_mixin.gather_tokens(decoder_hidden_states, ids_keep)
1302
+ encoder_hidden_states = tread_mixin.gather_tokens(encoder_hidden_states, ids_keep)
1303
+ attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
1304
+ image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
1305
+
1306
+ decoder_hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
1307
+ block,
1308
+ decoder_hidden_states,
1309
+ encoder_hidden_states,
1310
+ temb,
1311
+ attention_mask,
1312
+ image_rotary_emb,
1313
+ token_replace_emb,
1314
+ first_frame_num_tokens,
1315
+ )
1316
+
1317
+ if is_tread_end(tread_mixin, tread_active, abs_i):
1318
+ # Restore full sequence length at the end of a TREAD route
1319
+ decoder_hidden_states = tread_mixin.scatter_tokens(decoder_hidden_states, ids_keep, x_t_full)
1320
+ encoder_hidden_states = tread_mixin.scatter_tokens(encoder_hidden_states, ids_keep, x_full)
1321
+ tread_active = False
1322
+ current_route = None
1323
+ ids_keep = None
1324
+ x_full = None
1325
+ x_t_full = None
1326
+ attention_mask = orig_mask
1327
+ image_rotary_emb = orig_rope
1328
+
1329
+ hidden_states = decoder_hidden_states
1330
+
1331
+ # 7. Output projection
1332
+ hidden_states = self.norm_out(hidden_states, temb)
1333
+ hidden_states = self.proj_out(hidden_states)
1334
+
1335
+ hidden_states = hidden_states.reshape(
1336
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
1337
+ )
1338
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
1339
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
1340
+
1341
+ if USE_PEFT_BACKEND:
1342
+ # remove `lora_scale` from each PEFT layer
1343
+ unscale_lora_layers(self, lora_scale)
1344
+
1345
+ if not return_dict:
1346
+ return (hidden_states,)
1347
+
1348
+ return Transformer2DModelOutput(
1349
+ sample=hidden_states,
1350
+ )
vae/config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKLWan",
3
+ "_diffusers_version": "0.35.2",
4
+ "_name_or_path": "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
5
+ "attn_scales": [],
6
+ "base_dim": 96,
7
+ "decoder_base_dim": null,
8
+ "dim_mult": [
9
+ 1,
10
+ 2,
11
+ 4,
12
+ 4
13
+ ],
14
+ "dropout": 0.0,
15
+ "in_channels": 3,
16
+ "is_residual": false,
17
+ "latents_mean": [
18
+ -0.7571,
19
+ -0.7089,
20
+ -0.9113,
21
+ 0.1075,
22
+ -0.1745,
23
+ 0.9653,
24
+ -0.1517,
25
+ 1.5508,
26
+ 0.4134,
27
+ -0.0715,
28
+ 0.5517,
29
+ -0.3632,
30
+ -0.1922,
31
+ -0.9497,
32
+ 0.2503,
33
+ -0.2921
34
+ ],
35
+ "latents_std": [
36
+ 2.8184,
37
+ 1.4541,
38
+ 2.3275,
39
+ 2.6558,
40
+ 1.2196,
41
+ 1.7708,
42
+ 2.6052,
43
+ 2.0743,
44
+ 3.2687,
45
+ 2.1526,
46
+ 2.8652,
47
+ 1.5579,
48
+ 1.6382,
49
+ 1.1253,
50
+ 2.8251,
51
+ 1.916
52
+ ],
53
+ "num_res_blocks": 2,
54
+ "out_channels": 3,
55
+ "patch_size": null,
56
+ "scale_factor_spatial": 8,
57
+ "scale_factor_temporal": 4,
58
+ "temperal_downsample": [
59
+ false,
60
+ true,
61
+ true
62
+ ],
63
+ "z_dim": 16
64
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6e524b3fffede1787a74e81b30976dce5400c4439ba64222168e607ed19e793
3
+ size 507591892