Text-to-Video
Diffusers
Safetensors
English
MotifVideoPipeline
image-to-video
video-generation
diffusion-transformer
Instructions to use Nishant2414/Motif-Video-2B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use Nishant2414/Motif-Video-2B with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("Nishant2414/Motif-Video-2B", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Commit ·
ef79f27
0
Parent(s):
Duplicate from Motif-Technologies/Motif-Video-2B
Browse filesCo-authored-by: Dongpin oh <odb9402@users.noreply.huggingface.co>
- .gitattributes +42 -0
- .gitignore +29 -0
- README.md +314 -0
- _fm_solvers_unipc.py +759 -0
- assets/architecture.png +3 -0
- assets/banner.png +3 -0
- assets/showcase_i2v.png +3 -0
- assets/showcase_t2v.png +3 -0
- feature_extractor/preprocessor_config.json +23 -0
- inference.py +119 -0
- model_index.json +28 -0
- motif-video-technical-report.pdf +3 -0
- pipeline_motif_video.py +1321 -0
- scheduler/scheduler_config.json +18 -0
- text_encoder/config.json +252 -0
- text_encoder/model.safetensors +3 -0
- tokenizer/tokenizer.json +3 -0
- tokenizer/tokenizer_config.json +26 -0
- transformer/config.json +30 -0
- transformer/diffusion_pytorch_model.safetensors +3 -0
- transformer/transformer_motif_video.py +1350 -0
- vae/config.json +64 -0
- vae/diffusion_pytorch_model.safetensors +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/architecture.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/banner.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/showcase_i2v.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/showcase_t2v.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/i2v_sample.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
motif-video-technical-report.pdf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Claude Code / Codex
|
| 2 |
+
.claude/
|
| 3 |
+
.codex/
|
| 4 |
+
.codex-review-latest.md
|
| 5 |
+
.hook-state
|
| 6 |
+
.hook-state.lock
|
| 7 |
+
.plans/
|
| 8 |
+
.manuals/
|
| 9 |
+
CLAUDE.md
|
| 10 |
+
_old_claude_files/
|
| 11 |
+
|
| 12 |
+
# Internal / test
|
| 13 |
+
test_local.py
|
| 14 |
+
|
| 15 |
+
# Environment
|
| 16 |
+
.env
|
| 17 |
+
tmp/
|
| 18 |
+
|
| 19 |
+
# Python
|
| 20 |
+
*.pyc
|
| 21 |
+
__pycache__/
|
| 22 |
+
|
| 23 |
+
# Experiments
|
| 24 |
+
results/
|
| 25 |
+
experiments/**/outputs/
|
| 26 |
+
experiments/**/checkpoints/
|
| 27 |
+
experiments/**/*.pt
|
| 28 |
+
experiments/**/*.ckpt
|
| 29 |
+
assets/i2v_sample.jpg
|
README.md
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
tags:
|
| 6 |
+
- text-to-video
|
| 7 |
+
- image-to-video
|
| 8 |
+
- video-generation
|
| 9 |
+
- diffusion-transformer
|
| 10 |
+
pipeline_tag: text-to-video
|
| 11 |
+
library_name: diffusers
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
<img src="assets/banner.png" width="100%" alt="Motif-Video 2B teaser"/>
|
| 16 |
+
</p>
|
| 17 |
+
|
| 18 |
+
<p align="center">
|
| 19 |
+
<h1 align="center">Motif-Video 2B</h1>
|
| 20 |
+
</p>
|
| 21 |
+
|
| 22 |
+
<p align="center">
|
| 23 |
+
<b>A micro-budget text-to-video diffusion transformer from Motif Technologies</b>
|
| 24 |
+
</p>
|
| 25 |
+
|
| 26 |
+
<p align="center">
|
| 27 |
+
📑 <a href="https://huggingface.co/Motif-Technologies/Motif-Video-2B/blob/main/motif-video-technical-report.pdf">Technical Report</a> |
|
| 28 |
+
🤗 <a href="">Hugging Face</a> |
|
| 29 |
+
🌐 <a href="https://motiftech.io/videoshowcase">Project Page</a>
|
| 30 |
+
</p>
|
| 31 |
+
|
| 32 |
+
---
|
| 33 |
+
|
| 34 |
+
## 🔥 News
|
| 35 |
+
|
| 36 |
+
- **[2026-04-14]** We release **Motif-Video 2B**, our 2B-parameter text-to-video and image-to-video diffusion transformer, together with the full [technical report](https://huggingface.co/Motif-Technologies/Motif-Video-2B/blob/main/motif-video-technical-report.pdf).
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## 📖 Introduction
|
| 41 |
+
|
| 42 |
+
Training strong video generation models usually requires massive datasets, large parameter counts, and substantial compute. **Motif-Video 2B** asks whether competitive text-to-video quality is reachable at a much smaller budget — fewer than **10M training clips** and under **100,000 H200 GPU hours** — and shows that the answer is yes, provided the model design explicitly separates objectives that scaling would otherwise leave entangled.
|
| 43 |
+
|
| 44 |
+
Our central observation is that prompt alignment, temporal consistency, and fine-detail recovery interfere with one another when handled through the same pathway. Motif-Video 2B addresses this **objective interference** architecturally rather than relying on scale alone, through two contributions:
|
| 45 |
+
|
| 46 |
+
- **Shared Cross-Attention.** A residual cross-attention mechanism that reuses self-attention K/V weights to stabilize text–video alignment under long-context token sparsity, where standard joint attention dilutes text influence as the video token sequence grows.
|
| 47 |
+
- **Three-stage DDT-style backbone.** 12 dual-stream + 16 single-stream + 8 DDT decoder layers, separating early modality fusion, joint representation learning, and high-frequency detail reconstruction into dedicated components. Per-block attention analysis shows that the DDT decoder spontaneously develops inter-frame attention structure absent from the encoder layers.
|
| 48 |
+
|
| 49 |
+
These are paired with a micro-budget training recipe combining **TREAD token routing** and early-phase **REPA** with a frozen **V-JEPA** teacher — to our knowledge, the first time this combination has been applied to text-to-video training.
|
| 50 |
+
|
| 51 |
+
On VBench, Motif-Video 2B reaches **83.76%**, the highest Total Score among open-source models we evaluate, surpassing Wan2.1-14B at **7× fewer parameters** and roughly an order of magnitude less training data.
|
| 52 |
+
|
| 53 |
+
<!--
|
| 54 |
+
Architecture figure — replace with Figure 2 from the technical report
|
| 55 |
+
(the three-stage backbone + Shared Cross-Attention diagram).
|
| 56 |
+
-->
|
| 57 |
+
<p align="center">
|
| 58 |
+
<img src="assets/architecture.png" width="90%" alt="Motif-Video 2B architecture"/>
|
| 59 |
+
</p>
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## ✨ Highlights
|
| 64 |
+
|
| 65 |
+
- **Two tasks, one set of weights.** A single checkpoint handles both **text-to-video (T2V)** and **image-to-video (I2V)** generation, trained jointly without a learnable task-type embedding.
|
| 66 |
+
- **Up to 720p, 121 frames.** The final model generates 720p video at 121 frames under the standard rectified flow-matching sampler.
|
| 67 |
+
- **Architectural specialization over brute-force scale.** Three-stage backbone with role-separated dual-stream / single-stream / DDT decoder layers.
|
| 68 |
+
- **Shared Cross-Attention.** Stabilizes text alignment under long video-token sequences by grounding cross-attention K/V in the self-attention manifold.
|
| 69 |
+
- **Micro-budget recipe.** TREAD token routing (≈27% per-step FLOP reduction) + early-phase REPA with V-JEPA teacher + offline bucket-balanced sampler (≈90% data utilization, up from ≈20% baseline).
|
| 70 |
+
- **Open and reproducible.** Trained on ~64×H200 GPUs with FSDP2, full curriculum and recipe documented in the technical report.
|
| 71 |
+
|
| 72 |
+
---
|
| 73 |
+
|
| 74 |
+
## 🏗️ Architecture
|
| 75 |
+
|
| 76 |
+
Motif-Video 2B is a flow-matching diffusion transformer organized around a single principle: each component is assigned a well-defined responsibility, and components with conflicting objectives are not asked to share capacity.
|
| 77 |
+
|
| 78 |
+
| Component | Choice |
|
| 79 |
+
|---|---|
|
| 80 |
+
| Text encoder | T5Gemma2 (encoder–decoder, UL2-adapted Gemma 3) |
|
| 81 |
+
| Video tokenizer | Wan2.1 VAE (8×8 spatial, 4× temporal compression), 2×2×1 patchify |
|
| 82 |
+
| Backbone | 12 dual-stream + 16 single-stream + 8 DDT decoder layers |
|
| 83 |
+
| Hidden dim / heads | 1536 / 12 heads × 128 |
|
| 84 |
+
| Normalization | QK-normalization throughout |
|
| 85 |
+
| Position encoding | RoPE |
|
| 86 |
+
| Cross-attention | **Shared Cross-Attention** in the single-stream stage |
|
| 87 |
+
| Objective | Rectified flow matching (velocity prediction) |
|
| 88 |
+
| I2V conditioning | First-frame latent + SigLIP image embeddings, with timestep-aware blur |
|
| 89 |
+
|
| 90 |
+
A high-level walkthrough of the role separation:
|
| 91 |
+
|
| 92 |
+
1. **Dual-stream stage (12 layers).** Text and video tokens are processed through separate self-attention pathways, exchanging information via cross-attention. This prevents premature feature entanglement before either modality has formed coherent representations.
|
| 93 |
+
2. **Single-stream stage (16 layers).** Text and video tokens attend freely in a joint sequence. **Shared Cross-Attention** is attached here to repair the text-attention dilution that emerges as the video token sequence grows.
|
| 94 |
+
3. **DDT decoder (8 layers).** A dedicated velocity decoder atop the 28-layer encoder, freeing the encoder from high-frequency detail reconstruction. Per-block attention analysis shows that the DDT decoder develops inter-frame attention structure that single-stream layers do not.
|
| 95 |
+
|
| 96 |
+
For the full derivation of why Shared Cross-Attention shares K/V but not Q, and why this is necessary in addition to standard zero-init of W_O, see Section 3.3 of the [technical report](https://huggingface.co/Motif-Technologies/Motif-Video-2B/blob/main/motif-video-technical-report.pdf).
|
| 97 |
+
|
| 98 |
+
<!--
|
| 99 |
+
Optional: insert Figure 3 (attention heatmaps across the three stages)
|
| 100 |
+
here as a secondary architecture figure. It is the strongest visual
|
| 101 |
+
evidence for the role-separation argument.
|
| 102 |
+
-->
|
| 103 |
+
|
| 104 |
+
---
|
| 105 |
+
|
| 106 |
+
## 🚀 Quickstart / Usage
|
| 107 |
+
|
| 108 |
+
### Requirements
|
| 109 |
+
|
| 110 |
+
- Python 3.10+
|
| 111 |
+
- CUDA-capable GPU with **24GB+ VRAM** (e.g., A100, H100, RTX 4090)
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
pip install "diffusers>=0.35.2" "transformers>=5.0.0" torch accelerate ftfy einops sentencepiece regex Pillow
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
### Text-to-Video (T2V)
|
| 118 |
+
|
| 119 |
+
```python
|
| 120 |
+
import torch
|
| 121 |
+
from diffusers import AdaptiveProjectedGuidance, DiffusionPipeline
|
| 122 |
+
from diffusers.utils import export_to_video
|
| 123 |
+
|
| 124 |
+
guider = AdaptiveProjectedGuidance(
|
| 125 |
+
guidance_scale=8.0,
|
| 126 |
+
adaptive_projected_guidance_rescale=12.0,
|
| 127 |
+
adaptive_projected_guidance_momentum=0.1,
|
| 128 |
+
use_original_formulation=True,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 132 |
+
"Motif-Technologies/Motif-Video-2B",
|
| 133 |
+
custom_pipeline="pipeline_motif_video",
|
| 134 |
+
trust_remote_code=True,
|
| 135 |
+
torch_dtype=torch.bfloat16,
|
| 136 |
+
guider=guider,
|
| 137 |
+
)
|
| 138 |
+
pipe = pipe.to("cuda")
|
| 139 |
+
|
| 140 |
+
output = pipe(
|
| 141 |
+
prompt="A category-five hurricane, viewed from inside the eye, reveals a circular stadium of cloud walls rising to fifty thousand feet with an eerie disk of blue sky directly overhead. Shot from a NOAA reconnaissance aircraft mounted camera, the perspective looks outward toward the eyewall — a near-vertical curtain of rotating cloud and lightning that is simultaneously terrifying and transcendent. The inner surface of the eyewall catches the setting sun, painting it in improbable shades of peach and rose. The camera slowly pans 360 degrees to complete one full revolution, capturing the entire coliseum of the storm. Below, the ocean surface is a white blur of foam and spray. The documentary-style cinematography strips away all artifice to present the storm as an entity of pure elemental power.",
|
| 142 |
+
height=736,
|
| 143 |
+
width=1280,
|
| 144 |
+
num_frames=121,
|
| 145 |
+
num_inference_steps=50,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
export_to_video(output.frames[0], "output.mp4", fps=24)
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
### Image-to-Video (I2V)
|
| 152 |
+
|
| 153 |
+
```python
|
| 154 |
+
import torch
|
| 155 |
+
from diffusers import AdaptiveProjectedGuidance, DiffusionPipeline
|
| 156 |
+
from diffusers.utils import export_to_video, load_image
|
| 157 |
+
|
| 158 |
+
guider = AdaptiveProjectedGuidance(
|
| 159 |
+
guidance_scale=8.0,
|
| 160 |
+
adaptive_projected_guidance_rescale=12.0,
|
| 161 |
+
adaptive_projected_guidance_momentum=0.1,
|
| 162 |
+
use_original_formulation=True,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 166 |
+
"Motif-Technologies/Motif-Video-2B",
|
| 167 |
+
custom_pipeline="pipeline_motif_video",
|
| 168 |
+
trust_remote_code=True,
|
| 169 |
+
torch_dtype=torch.bfloat16,
|
| 170 |
+
guider=guider,
|
| 171 |
+
)
|
| 172 |
+
pipe = pipe.to("cuda")
|
| 173 |
+
|
| 174 |
+
image = load_image("https://huggingface.co/Motif-Technologies/Motif-Video-2B/resolve/main/assets/i2v_sample.jpg")
|
| 175 |
+
|
| 176 |
+
output = pipe(
|
| 177 |
+
prompt="Three friends stride through a sun-bleached meadow as a warm breeze ripples the tall dry grass around their legs. The woman on the left turns her head to share a quiet laugh, the woman in the center pushes a loose curl behind her ear, and the man on the right tilts his face toward the sky. The camera drifts gently alongside them at walking pace, handheld, with soft overcast light.",
|
| 178 |
+
image=image,
|
| 179 |
+
height=736,
|
| 180 |
+
width=1280,
|
| 181 |
+
num_frames=121,
|
| 182 |
+
num_inference_steps=50,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
export_to_video(output.frames[0], "output.mp4", fps=24)
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
### CLI Inference
|
| 189 |
+
|
| 190 |
+
```bash
|
| 191 |
+
# Text-to-Video
|
| 192 |
+
python inference.py \
|
| 193 |
+
--prompt "A time-lapse of a flower blooming in a dark room, dramatic lighting" \
|
| 194 |
+
--output t2v_output.mp4
|
| 195 |
+
|
| 196 |
+
# Image-to-Video
|
| 197 |
+
python inference.py \
|
| 198 |
+
--image assets/i2v_sample.jpg \
|
| 199 |
+
--prompt "Three friends stride through a meadow as a warm breeze ripples the tall grass" \
|
| 200 |
+
--output i2v_output.mp4
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
See `inference.py` for all available options (`--help`).
|
| 204 |
+
|
| 205 |
+
### Recommended Settings
|
| 206 |
+
|
| 207 |
+
| Parameter | Default | Notes |
|
| 208 |
+
|---|---|---|
|
| 209 |
+
| Resolution | 1280x736 | 720p, best quality |
|
| 210 |
+
| Frames | 121 | ~5 seconds at 24fps |
|
| 211 |
+
| Guidance scale | 8.0 | |
|
| 212 |
+
| Scheduler shift | 15.0 | Pre-configured in scheduler config |
|
| 213 |
+
| Inference steps | 50 | |
|
| 214 |
+
| dtype | bfloat16 | Recommended for H100/A100 |
|
| 215 |
+
|
| 216 |
+
---
|
| 217 |
+
|
| 218 |
+
## 📊 Performance
|
| 219 |
+
|
| 220 |
+
### VBench
|
| 221 |
+
|
| 222 |
+
Motif-Video 2B achieves the highest **Total Score** among open-source models we evaluate.
|
| 223 |
+
|
| 224 |
+
| Model | Params | Total | Quality | Semantic |
|
| 225 |
+
|---|---|---|---|---|
|
| 226 |
+
| Wan2.2-T2V (prompt-opt.) | A14B | 84.23 | 85.42 | 79.50 |
|
| 227 |
+
| **Motif-Video 2B (Ours)** | **2B** | **83.76** | **84.59** | **80.44** |
|
| 228 |
+
| SANA-Video | 2B | 83.71 | 84.35 | 81.35 |
|
| 229 |
+
| Wan2.1-T2V | 14B | 83.69 | 85.59 | 76.11 |
|
| 230 |
+
| OpenSora 2.0 (T2I2V) | 11B | 83.60 | 84.40 | 80.30 |
|
| 231 |
+
| Wan2.1-T2V | 1.3B | 83.31 | 85.23 | 75.65 |
|
| 232 |
+
| HunyuanVideo | 13B | 83.24 | 85.09 | 75.82 |
|
| 233 |
+
| CogVideoX1.5-5B (prompt-opt.) | 5B | 82.17 | 82.78 | 79.76 |
|
| 234 |
+
| Step-Video-T2V | 30B | 81.83 | 84.46 | 71.28 |
|
| 235 |
+
| LTX-Video | 2B | 80.00 | 82.30 | 70.79 |
|
| 236 |
+
|
| 237 |
+
Notable per-dimension highlights for Motif-Video 2B (open-source):
|
| 238 |
+
|
| 239 |
+
- **Spatial Relationship: 83.02%** — best among open-source models
|
| 240 |
+
- **Semantic Score: 80.44%** — highest among open-source models reporting per-dimension results
|
| 241 |
+
- **Object Class: 92.93%**, **Multiple Objects: 77.29%**, **Imaging Quality: 70.50%** — second-best in their categories
|
| 242 |
+
|
| 243 |
+
The full 16-dimension breakdown is in Table 3 of the [technical report](https://huggingface.co/Motif-Technologies/Motif-Video-2B/blob/main/motif-video-technical-report.pdf).
|
| 244 |
+
|
| 245 |
+
> **A note on VBench vs. perceptual quality.** Motif-Video 2B leads on VBench Total Score, but in our internal side-by-side comparisons against Wan2.1-T2V-14B we observe a perceptual gap in favor of the larger model on temporal stability and fine human anatomy. We discuss the sources of this gap (uniform dimension weighting, near-correct semantic credit) in Section 7 of the report. We report the gap explicitly rather than smoothing it over.
|
| 246 |
+
|
| 247 |
+
### Human evaluation
|
| 248 |
+
|
| 249 |
+
In a blind pairwise study against six contemporaneous open-source baselines (SANA-Video, LTX-Video 2, Wan2.1-14B, Wan2.1-1.3B, Wan2.2-5B, CogVideoX-5B) on 40 LLM-generated prompts, Motif-Video 2B is preferred over both **SANA-Video** (similar parameter count) and **Wan2.1-1.3B** (similar parameter count, larger training corpus) on prompt-following and video-fidelity axes. Wan2.1-14B remains the preferred model overall, consistent with its 7× larger parameter count and substantially larger training data.
|
| 250 |
+
|
| 251 |
+
---
|
| 252 |
+
|
| 253 |
+
## 🎬 Showcase
|
| 254 |
+
|
| 255 |
+
<!--
|
| 256 |
+
Insert the qualitative grids from the technical report here:
|
| 257 |
+
- Figure 1 / Figure 12: T2V multi-prompt frame strips
|
| 258 |
+
- Figure 13: I2V example (input image + generated frames)
|
| 259 |
+
Use full-width or 2-column layout, matching Wan2.1's "Showcase" section.
|
| 260 |
+
-->
|
| 261 |
+
|
| 262 |
+
### Text-to-Video
|
| 263 |
+
|
| 264 |
+
<p align="center">
|
| 265 |
+
<img src="assets/showcase_t2v.png" width="100%" alt="Motif-Video 2B T2V samples"/>
|
| 266 |
+
</p>
|
| 267 |
+
|
| 268 |
+
### Image-to-Video
|
| 269 |
+
|
| 270 |
+
<p align="center">
|
| 271 |
+
<img src="assets/showcase_i2v.png" width="100%" alt="Motif-Video 2B I2V samples"/>
|
| 272 |
+
</p>
|
| 273 |
+
|
| 274 |
+
---
|
| 275 |
+
|
| 276 |
+
## ⚠️ Limitations
|
| 277 |
+
|
| 278 |
+
We report limitations as the boundary conditions under which the design decisions in this report should be interpreted, not as caveats.
|
| 279 |
+
|
| 280 |
+
- **Micro-scale semantic distortion.** Motif-Video 2B occasionally produces sub-object-level artifacts that leave the category label intact but break perceptual plausibility — distorted hands on close-up human subjects, degraded body structure under high-displacement motion, and attribute leakage between visually similar co-present subjects. We attribute these primarily to data coverage rather than backbone design.
|
| 281 |
+
- **Temporal failures.** Three distinct modes that frame-level metrics do not surface: (i) physically implausible liquid / cloth / collision dynamics, (ii) coherence loss under high scene complexity (multi-agent crowds), and (iii) unintended mid-clip scene transitions in long sequences.
|
| 282 |
+
- **Recipe components are evaluated jointly, not in isolation.** We do not present per-component ablations for Shared Cross-Attention, the DDT decoder, REPA phasing, or TREAD routing at full scale. Readers should interpret our results as evidence that the *composed* recipe works at 2B, not as a marginal-contribution claim about any single component.
|
| 283 |
+
|
| 284 |
+
We view temporal stability and data coverage — not architectural depth — as the primary remaining ceilings on this model. Both are the most natural axes for a future iteration that the current architecture is built to absorb.
|
| 285 |
+
|
| 286 |
+
---
|
| 287 |
+
|
| 288 |
+
## 📚 Citation
|
| 289 |
+
|
| 290 |
+
If you find Motif-Video 2B useful in your research, please cite:
|
| 291 |
+
|
| 292 |
+
```bibtex
|
| 293 |
+
@techreport{motifvideo2b2026,
|
| 294 |
+
title = {Motif-Video 2B: Technical Report},
|
| 295 |
+
author = {Motif Technologies},
|
| 296 |
+
year = {2026},
|
| 297 |
+
institution = {Motif Technologies},
|
| 298 |
+
url = {https://huggingface.co/Motif-Technologies/Motif-Video-2B/blob/main/motif-video-technical-report.pdf}
|
| 299 |
+
}
|
| 300 |
+
```
|
| 301 |
+
|
| 302 |
+
---
|
| 303 |
+
|
| 304 |
+
## 🙏 Acknowledgements
|
| 305 |
+
|
| 306 |
+
We build on a number of excellent open-source projects, including the **Wan2.1 VAE** [Wan Team, 2025], **T5Gemma / Gemma 3** [Google], **TREAD** [Krause et al., 2025], **REPA** with the **V-JEPA** family of visual encoders [Bardes et al.], **DDT** [Wang et al.], and the broader **diffusers** and **Accelerate** ecosystems. Compute was provisioned on Microsoft Azure and orchestrated with **SkyPilot** on Kubernetes.
|
| 307 |
+
|
| 308 |
+
---
|
| 309 |
+
|
| 310 |
+
## 📄 License
|
| 311 |
+
|
| 312 |
+
<!-- TODO: confirm final license — apache-2.0 placeholder above. -->
|
| 313 |
+
|
| 314 |
+
This model is released under the Apache 2.0 License. See `LICENSE` for details.
|
_fm_solvers_unipc.py
ADDED
|
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
|
| 2 |
+
# Convert unipc for flow matching
|
| 3 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 11 |
+
from diffusers.schedulers.scheduling_utils import (
|
| 12 |
+
KarrasDiffusionSchedulers,
|
| 13 |
+
SchedulerMixin,
|
| 14 |
+
SchedulerOutput,
|
| 15 |
+
)
|
| 16 |
+
from diffusers.utils import deprecate, is_scipy_available
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if is_scipy_available():
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
| 24 |
+
"""
|
| 25 |
+
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
|
| 26 |
+
|
| 27 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 28 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 32 |
+
The number of diffusion steps to train the model.
|
| 33 |
+
solver_order (`int`, default `2`):
|
| 34 |
+
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
|
| 35 |
+
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
|
| 36 |
+
unconditional sampling.
|
| 37 |
+
prediction_type (`str`, defaults to "flow_prediction"):
|
| 38 |
+
Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
|
| 39 |
+
the flow of the diffusion process.
|
| 40 |
+
thresholding (`bool`, defaults to `False`):
|
| 41 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
| 42 |
+
as Stable Diffusion.
|
| 43 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 44 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 45 |
+
sample_max_value (`float`, defaults to 1.0):
|
| 46 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
|
| 47 |
+
predict_x0 (`bool`, defaults to `True`):
|
| 48 |
+
Whether to use the updating algorithm on the predicted x0.
|
| 49 |
+
solver_type (`str`, default `bh2`):
|
| 50 |
+
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
|
| 51 |
+
otherwise.
|
| 52 |
+
lower_order_final (`bool`, default `True`):
|
| 53 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
| 54 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
| 55 |
+
disable_corrector (`list`, default `[]`):
|
| 56 |
+
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
|
| 57 |
+
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
|
| 58 |
+
usually disabled during the first few steps.
|
| 59 |
+
solver_p (`SchedulerMixin`, default `None`):
|
| 60 |
+
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
|
| 61 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
| 62 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
| 63 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
| 64 |
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
| 65 |
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
| 66 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 67 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 68 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 69 |
+
steps_offset (`int`, defaults to 0):
|
| 70 |
+
An offset added to the inference steps, as required by some model families.
|
| 71 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
| 72 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 73 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 77 |
+
order = 1
|
| 78 |
+
|
| 79 |
+
@register_to_config
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
num_train_timesteps: int = 1000,
|
| 83 |
+
solver_order: int = 2,
|
| 84 |
+
prediction_type: str = "flow_prediction",
|
| 85 |
+
shift: Optional[float] = 1.0,
|
| 86 |
+
use_dynamic_shifting=False,
|
| 87 |
+
thresholding: bool = False,
|
| 88 |
+
dynamic_thresholding_ratio: float = 0.995,
|
| 89 |
+
sample_max_value: float = 1.0,
|
| 90 |
+
predict_x0: bool = True,
|
| 91 |
+
solver_type: str = "bh2",
|
| 92 |
+
lower_order_final: bool = True,
|
| 93 |
+
disable_corrector: List[int] = [],
|
| 94 |
+
solver_p: Optional[SchedulerMixin] = None,
|
| 95 |
+
timestep_spacing: str = "linspace",
|
| 96 |
+
steps_offset: int = 0,
|
| 97 |
+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
| 98 |
+
):
|
| 99 |
+
if solver_type not in ["bh1", "bh2"]:
|
| 100 |
+
if solver_type in ["midpoint", "heun", "logrho"]:
|
| 101 |
+
self.register_to_config(solver_type="bh2")
|
| 102 |
+
else:
|
| 103 |
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
| 104 |
+
|
| 105 |
+
self.predict_x0 = predict_x0
|
| 106 |
+
# setable values
|
| 107 |
+
self.num_inference_steps = None
|
| 108 |
+
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
|
| 109 |
+
sigmas = 1.0 - alphas
|
| 110 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
|
| 111 |
+
|
| 112 |
+
if not use_dynamic_shifting:
|
| 113 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
| 114 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
|
| 115 |
+
|
| 116 |
+
self.sigmas = sigmas
|
| 117 |
+
self.timesteps = sigmas * num_train_timesteps
|
| 118 |
+
|
| 119 |
+
self.model_outputs = [None] * solver_order
|
| 120 |
+
self.timestep_list = [None] * solver_order
|
| 121 |
+
self.lower_order_nums = 0
|
| 122 |
+
self.disable_corrector = disable_corrector
|
| 123 |
+
self.solver_p = solver_p
|
| 124 |
+
self.last_sample = None
|
| 125 |
+
self._step_index = None
|
| 126 |
+
self._begin_index = None
|
| 127 |
+
|
| 128 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 129 |
+
self.sigma_min = self.sigmas[-1].item()
|
| 130 |
+
self.sigma_max = self.sigmas[0].item()
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def step_index(self):
|
| 134 |
+
"""
|
| 135 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 136 |
+
"""
|
| 137 |
+
return self._step_index
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def begin_index(self):
|
| 141 |
+
"""
|
| 142 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 143 |
+
"""
|
| 144 |
+
return self._begin_index
|
| 145 |
+
|
| 146 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 147 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 148 |
+
"""
|
| 149 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
begin_index (`int`):
|
| 153 |
+
The begin index for the scheduler.
|
| 154 |
+
"""
|
| 155 |
+
self._begin_index = begin_index
|
| 156 |
+
|
| 157 |
+
# Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
|
| 158 |
+
def set_timesteps(
|
| 159 |
+
self,
|
| 160 |
+
num_inference_steps: Union[int, None] = None,
|
| 161 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 162 |
+
sigmas: Optional[List[float]] = None,
|
| 163 |
+
mu: Optional[Union[float, None]] = None,
|
| 164 |
+
shift: Optional[Union[float, None]] = None,
|
| 165 |
+
):
|
| 166 |
+
"""
|
| 167 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 168 |
+
Args:
|
| 169 |
+
num_inference_steps (`int`):
|
| 170 |
+
Total number of the spacing of the time steps.
|
| 171 |
+
device (`str` or `torch.device`, *optional*):
|
| 172 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
if self.config.use_dynamic_shifting and mu is None:
|
| 176 |
+
raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
| 177 |
+
|
| 178 |
+
if sigmas is None:
|
| 179 |
+
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore
|
| 180 |
+
|
| 181 |
+
if self.config.use_dynamic_shifting:
|
| 182 |
+
sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
|
| 183 |
+
else:
|
| 184 |
+
if shift is None:
|
| 185 |
+
shift = self.config.shift
|
| 186 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
|
| 187 |
+
|
| 188 |
+
if self.config.final_sigmas_type == "sigma_min":
|
| 189 |
+
sigma_last = self.config.sigma_min
|
| 190 |
+
elif self.config.final_sigmas_type == "zero":
|
| 191 |
+
sigma_last = 0
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(
|
| 194 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
| 198 |
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
|
| 199 |
+
|
| 200 |
+
self.sigmas = torch.from_numpy(sigmas)
|
| 201 |
+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
| 202 |
+
|
| 203 |
+
self.num_inference_steps = len(timesteps)
|
| 204 |
+
|
| 205 |
+
self.model_outputs = [
|
| 206 |
+
None,
|
| 207 |
+
] * self.config.solver_order
|
| 208 |
+
self.lower_order_nums = 0
|
| 209 |
+
self.last_sample = None
|
| 210 |
+
if self.solver_p:
|
| 211 |
+
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
| 212 |
+
|
| 213 |
+
# add an index counter for schedulers that allow duplicated timesteps
|
| 214 |
+
self._step_index = None
|
| 215 |
+
self._begin_index = None
|
| 216 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 217 |
+
|
| 218 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
| 219 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
| 220 |
+
"""
|
| 221 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
| 222 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 223 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
| 224 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
| 225 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
| 226 |
+
|
| 227 |
+
https://arxiv.org/abs/2205.11487
|
| 228 |
+
"""
|
| 229 |
+
dtype = sample.dtype
|
| 230 |
+
batch_size, channels, *remaining_dims = sample.shape
|
| 231 |
+
|
| 232 |
+
if dtype not in (torch.float32, torch.float64):
|
| 233 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 234 |
+
|
| 235 |
+
# Flatten sample for doing quantile calculation along each image
|
| 236 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
| 237 |
+
|
| 238 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 239 |
+
|
| 240 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 241 |
+
s = torch.clamp(
|
| 242 |
+
s, min=1, max=self.config.sample_max_value
|
| 243 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 244 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 245 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 246 |
+
|
| 247 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 248 |
+
sample = sample.to(dtype)
|
| 249 |
+
|
| 250 |
+
return sample
|
| 251 |
+
|
| 252 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
|
| 253 |
+
def _sigma_to_t(self, sigma):
|
| 254 |
+
return sigma * self.config.num_train_timesteps
|
| 255 |
+
|
| 256 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
| 257 |
+
return 1 - sigma, sigma
|
| 258 |
+
|
| 259 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
|
| 260 |
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
| 261 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 262 |
+
|
| 263 |
+
def convert_model_output(
|
| 264 |
+
self,
|
| 265 |
+
model_output: torch.Tensor,
|
| 266 |
+
*args,
|
| 267 |
+
sample: Optional[torch.Tensor] = None,
|
| 268 |
+
**kwargs,
|
| 269 |
+
) -> torch.Tensor:
|
| 270 |
+
r"""
|
| 271 |
+
Convert the model output to the corresponding type the UniPC algorithm needs.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
model_output (`torch.Tensor`):
|
| 275 |
+
The direct output from the learned diffusion model.
|
| 276 |
+
timestep (`int`):
|
| 277 |
+
The current discrete timestep in the diffusion chain.
|
| 278 |
+
sample (`torch.Tensor`):
|
| 279 |
+
A current instance of a sample created by the diffusion process.
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
`torch.Tensor`:
|
| 283 |
+
The converted model output.
|
| 284 |
+
"""
|
| 285 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 286 |
+
if sample is None:
|
| 287 |
+
if len(args) > 1:
|
| 288 |
+
sample = args[1]
|
| 289 |
+
else:
|
| 290 |
+
raise ValueError("missing `sample` as a required keyward argument")
|
| 291 |
+
if timestep is not None:
|
| 292 |
+
deprecate(
|
| 293 |
+
"timesteps",
|
| 294 |
+
"1.0.0",
|
| 295 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
sigma = self.sigmas[self.step_index]
|
| 299 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 300 |
+
|
| 301 |
+
if self.predict_x0:
|
| 302 |
+
if self.config.prediction_type == "flow_prediction":
|
| 303 |
+
sigma_t = self.sigmas[self.step_index]
|
| 304 |
+
x0_pred = sample - sigma_t * model_output
|
| 305 |
+
else:
|
| 306 |
+
raise ValueError(
|
| 307 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
| 308 |
+
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if self.config.thresholding:
|
| 312 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 313 |
+
|
| 314 |
+
return x0_pred
|
| 315 |
+
else:
|
| 316 |
+
if self.config.prediction_type == "flow_prediction":
|
| 317 |
+
sigma_t = self.sigmas[self.step_index]
|
| 318 |
+
epsilon = sample - (1 - sigma_t) * model_output
|
| 319 |
+
else:
|
| 320 |
+
raise ValueError(
|
| 321 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
| 322 |
+
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
if self.config.thresholding:
|
| 326 |
+
sigma_t = self.sigmas[self.step_index]
|
| 327 |
+
x0_pred = sample - sigma_t * model_output
|
| 328 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 329 |
+
epsilon = model_output + x0_pred
|
| 330 |
+
|
| 331 |
+
return epsilon
|
| 332 |
+
|
| 333 |
+
def multistep_uni_p_bh_update(
|
| 334 |
+
self,
|
| 335 |
+
model_output: torch.Tensor,
|
| 336 |
+
*args,
|
| 337 |
+
sample: Optional[torch.Tensor] = None,
|
| 338 |
+
order: Optional[int] = None,
|
| 339 |
+
**kwargs,
|
| 340 |
+
) -> torch.Tensor:
|
| 341 |
+
"""
|
| 342 |
+
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
model_output (`torch.Tensor`):
|
| 346 |
+
The direct output from the learned diffusion model at the current timestep.
|
| 347 |
+
prev_timestep (`int`):
|
| 348 |
+
The previous discrete timestep in the diffusion chain.
|
| 349 |
+
sample (`torch.Tensor`):
|
| 350 |
+
A current instance of a sample created by the diffusion process.
|
| 351 |
+
order (`int`):
|
| 352 |
+
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
`torch.Tensor`:
|
| 356 |
+
The sample tensor at the previous timestep.
|
| 357 |
+
"""
|
| 358 |
+
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
| 359 |
+
if sample is None:
|
| 360 |
+
if len(args) > 1:
|
| 361 |
+
sample = args[1]
|
| 362 |
+
else:
|
| 363 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
| 364 |
+
if order is None:
|
| 365 |
+
if len(args) > 2:
|
| 366 |
+
order = args[2]
|
| 367 |
+
else:
|
| 368 |
+
raise ValueError(" missing `order` as a required keyward argument")
|
| 369 |
+
if prev_timestep is not None:
|
| 370 |
+
deprecate(
|
| 371 |
+
"prev_timestep",
|
| 372 |
+
"1.0.0",
|
| 373 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 374 |
+
)
|
| 375 |
+
model_output_list = self.model_outputs
|
| 376 |
+
|
| 377 |
+
s0 = self.timestep_list[-1]
|
| 378 |
+
m0 = model_output_list[-1]
|
| 379 |
+
x = sample
|
| 380 |
+
|
| 381 |
+
if self.solver_p:
|
| 382 |
+
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
| 383 |
+
return x_t
|
| 384 |
+
|
| 385 |
+
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore
|
| 386 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 387 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 388 |
+
|
| 389 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 390 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 391 |
+
|
| 392 |
+
h = lambda_t - lambda_s0
|
| 393 |
+
device = sample.device
|
| 394 |
+
|
| 395 |
+
rks = []
|
| 396 |
+
D1s = []
|
| 397 |
+
for i in range(1, order):
|
| 398 |
+
si = self.step_index - i # pyright: ignore
|
| 399 |
+
mi = model_output_list[-(i + 1)]
|
| 400 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
| 401 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
| 402 |
+
rk = (lambda_si - lambda_s0) / h
|
| 403 |
+
rks.append(rk)
|
| 404 |
+
D1s.append((mi - m0) / rk) # pyright: ignore
|
| 405 |
+
|
| 406 |
+
rks.append(1.0)
|
| 407 |
+
rks = torch.tensor(rks, device=device)
|
| 408 |
+
|
| 409 |
+
R = []
|
| 410 |
+
b = []
|
| 411 |
+
|
| 412 |
+
hh = -h if self.predict_x0 else h
|
| 413 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 414 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 415 |
+
|
| 416 |
+
factorial_i = 1
|
| 417 |
+
|
| 418 |
+
if self.config.solver_type == "bh1":
|
| 419 |
+
B_h = hh
|
| 420 |
+
elif self.config.solver_type == "bh2":
|
| 421 |
+
B_h = torch.expm1(hh)
|
| 422 |
+
else:
|
| 423 |
+
raise NotImplementedError()
|
| 424 |
+
|
| 425 |
+
for i in range(1, order + 1):
|
| 426 |
+
R.append(torch.pow(rks, i - 1))
|
| 427 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 428 |
+
factorial_i *= i + 1
|
| 429 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 430 |
+
|
| 431 |
+
R = torch.stack(R)
|
| 432 |
+
b = torch.tensor(b, device=device)
|
| 433 |
+
|
| 434 |
+
if len(D1s) > 0:
|
| 435 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 436 |
+
# for order 2, we use a simplified version
|
| 437 |
+
if order == 2:
|
| 438 |
+
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
| 439 |
+
else:
|
| 440 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
|
| 441 |
+
else:
|
| 442 |
+
D1s = None
|
| 443 |
+
|
| 444 |
+
if self.predict_x0:
|
| 445 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
| 446 |
+
if D1s is not None:
|
| 447 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
|
| 448 |
+
else:
|
| 449 |
+
pred_res = 0
|
| 450 |
+
x_t = x_t_ - alpha_t * B_h * pred_res
|
| 451 |
+
else:
|
| 452 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
| 453 |
+
if D1s is not None:
|
| 454 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
|
| 455 |
+
else:
|
| 456 |
+
pred_res = 0
|
| 457 |
+
x_t = x_t_ - sigma_t * B_h * pred_res
|
| 458 |
+
|
| 459 |
+
x_t = x_t.to(x.dtype)
|
| 460 |
+
return x_t
|
| 461 |
+
|
| 462 |
+
def multistep_uni_c_bh_update(
|
| 463 |
+
self,
|
| 464 |
+
this_model_output: torch.Tensor,
|
| 465 |
+
*args,
|
| 466 |
+
last_sample: Optional[torch.Tensor] = None,
|
| 467 |
+
this_sample: Optional[torch.Tensor] = None,
|
| 468 |
+
order: Optional[int] = None,
|
| 469 |
+
**kwargs,
|
| 470 |
+
) -> torch.Tensor:
|
| 471 |
+
"""
|
| 472 |
+
One step for the UniC (B(h) version).
|
| 473 |
+
|
| 474 |
+
Args:
|
| 475 |
+
this_model_output (`torch.Tensor`):
|
| 476 |
+
The model outputs at `x_t`.
|
| 477 |
+
this_timestep (`int`):
|
| 478 |
+
The current timestep `t`.
|
| 479 |
+
last_sample (`torch.Tensor`):
|
| 480 |
+
The generated sample before the last predictor `x_{t-1}`.
|
| 481 |
+
this_sample (`torch.Tensor`):
|
| 482 |
+
The generated sample after the last predictor `x_{t}`.
|
| 483 |
+
order (`int`):
|
| 484 |
+
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
`torch.Tensor`:
|
| 488 |
+
The corrected sample tensor at the current timestep.
|
| 489 |
+
"""
|
| 490 |
+
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
| 491 |
+
if last_sample is None:
|
| 492 |
+
if len(args) > 1:
|
| 493 |
+
last_sample = args[1]
|
| 494 |
+
else:
|
| 495 |
+
raise ValueError(" missing`last_sample` as a required keyward argument")
|
| 496 |
+
if this_sample is None:
|
| 497 |
+
if len(args) > 2:
|
| 498 |
+
this_sample = args[2]
|
| 499 |
+
else:
|
| 500 |
+
raise ValueError(" missing`this_sample` as a required keyward argument")
|
| 501 |
+
if order is None:
|
| 502 |
+
if len(args) > 3:
|
| 503 |
+
order = args[3]
|
| 504 |
+
else:
|
| 505 |
+
raise ValueError(" missing`order` as a required keyward argument")
|
| 506 |
+
if this_timestep is not None:
|
| 507 |
+
deprecate(
|
| 508 |
+
"this_timestep",
|
| 509 |
+
"1.0.0",
|
| 510 |
+
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
model_output_list = self.model_outputs
|
| 514 |
+
|
| 515 |
+
m0 = model_output_list[-1]
|
| 516 |
+
x = last_sample
|
| 517 |
+
x_t = this_sample
|
| 518 |
+
model_t = this_model_output
|
| 519 |
+
|
| 520 |
+
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore
|
| 521 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 522 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 523 |
+
|
| 524 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 525 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 526 |
+
|
| 527 |
+
h = lambda_t - lambda_s0
|
| 528 |
+
device = this_sample.device
|
| 529 |
+
|
| 530 |
+
rks = []
|
| 531 |
+
D1s = []
|
| 532 |
+
for i in range(1, order):
|
| 533 |
+
si = self.step_index - (i + 1) # pyright: ignore
|
| 534 |
+
mi = model_output_list[-(i + 1)]
|
| 535 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
| 536 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
| 537 |
+
rk = (lambda_si - lambda_s0) / h
|
| 538 |
+
rks.append(rk)
|
| 539 |
+
D1s.append((mi - m0) / rk) # pyright: ignore
|
| 540 |
+
|
| 541 |
+
rks.append(1.0)
|
| 542 |
+
rks = torch.tensor(rks, device=device)
|
| 543 |
+
|
| 544 |
+
R = []
|
| 545 |
+
b = []
|
| 546 |
+
|
| 547 |
+
hh = -h if self.predict_x0 else h
|
| 548 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 549 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 550 |
+
|
| 551 |
+
factorial_i = 1
|
| 552 |
+
|
| 553 |
+
if self.config.solver_type == "bh1":
|
| 554 |
+
B_h = hh
|
| 555 |
+
elif self.config.solver_type == "bh2":
|
| 556 |
+
B_h = torch.expm1(hh)
|
| 557 |
+
else:
|
| 558 |
+
raise NotImplementedError()
|
| 559 |
+
|
| 560 |
+
for i in range(1, order + 1):
|
| 561 |
+
R.append(torch.pow(rks, i - 1))
|
| 562 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 563 |
+
factorial_i *= i + 1
|
| 564 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 565 |
+
|
| 566 |
+
R = torch.stack(R)
|
| 567 |
+
b = torch.tensor(b, device=device)
|
| 568 |
+
|
| 569 |
+
if len(D1s) > 0:
|
| 570 |
+
D1s = torch.stack(D1s, dim=1)
|
| 571 |
+
else:
|
| 572 |
+
D1s = None
|
| 573 |
+
|
| 574 |
+
# for order 1, we use a simplified version
|
| 575 |
+
if order == 1:
|
| 576 |
+
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
| 577 |
+
else:
|
| 578 |
+
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
|
| 579 |
+
|
| 580 |
+
if self.predict_x0:
|
| 581 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
| 582 |
+
if D1s is not None:
|
| 583 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
| 584 |
+
else:
|
| 585 |
+
corr_res = 0
|
| 586 |
+
D1_t = model_t - m0
|
| 587 |
+
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
| 588 |
+
else:
|
| 589 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
| 590 |
+
if D1s is not None:
|
| 591 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
| 592 |
+
else:
|
| 593 |
+
corr_res = 0
|
| 594 |
+
D1_t = model_t - m0
|
| 595 |
+
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
| 596 |
+
x_t = x_t.to(x.dtype)
|
| 597 |
+
return x_t
|
| 598 |
+
|
| 599 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 600 |
+
if schedule_timesteps is None:
|
| 601 |
+
schedule_timesteps = self.timesteps
|
| 602 |
+
|
| 603 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 604 |
+
|
| 605 |
+
# The sigma index that is taken for the **very** first `step`
|
| 606 |
+
# is always the second index (or the last index if there is only 1)
|
| 607 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 608 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 609 |
+
pos = 1 if len(indices) > 1 else 0
|
| 610 |
+
|
| 611 |
+
return indices[pos].item()
|
| 612 |
+
|
| 613 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
| 614 |
+
def _init_step_index(self, timestep):
|
| 615 |
+
"""
|
| 616 |
+
Initialize the step_index counter for the scheduler.
|
| 617 |
+
"""
|
| 618 |
+
|
| 619 |
+
if self.begin_index is None:
|
| 620 |
+
if isinstance(timestep, torch.Tensor):
|
| 621 |
+
timestep = timestep.to(self.timesteps.device)
|
| 622 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 623 |
+
else:
|
| 624 |
+
self._step_index = self._begin_index
|
| 625 |
+
|
| 626 |
+
def step(
|
| 627 |
+
self,
|
| 628 |
+
model_output: torch.Tensor,
|
| 629 |
+
timestep: Union[int, torch.Tensor],
|
| 630 |
+
sample: torch.Tensor,
|
| 631 |
+
return_dict: bool = True,
|
| 632 |
+
generator=None,
|
| 633 |
+
) -> Union[SchedulerOutput, Tuple]:
|
| 634 |
+
"""
|
| 635 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 636 |
+
the multistep UniPC.
|
| 637 |
+
|
| 638 |
+
Args:
|
| 639 |
+
model_output (`torch.Tensor`):
|
| 640 |
+
The direct output from learned diffusion model.
|
| 641 |
+
timestep (`int`):
|
| 642 |
+
The current discrete timestep in the diffusion chain.
|
| 643 |
+
sample (`torch.Tensor`):
|
| 644 |
+
A current instance of a sample created by the diffusion process.
|
| 645 |
+
return_dict (`bool`):
|
| 646 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
| 647 |
+
|
| 648 |
+
Returns:
|
| 649 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 650 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 651 |
+
tuple is returned where the first element is the sample tensor.
|
| 652 |
+
|
| 653 |
+
"""
|
| 654 |
+
if self.num_inference_steps is None:
|
| 655 |
+
raise ValueError(
|
| 656 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
if self.step_index is None:
|
| 660 |
+
self._init_step_index(timestep)
|
| 661 |
+
|
| 662 |
+
use_corrector = (
|
| 663 |
+
self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None # pyright: ignore
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
model_output_convert = self.convert_model_output(model_output, sample=sample)
|
| 667 |
+
if use_corrector:
|
| 668 |
+
sample = self.multistep_uni_c_bh_update(
|
| 669 |
+
this_model_output=model_output_convert,
|
| 670 |
+
last_sample=self.last_sample,
|
| 671 |
+
this_sample=sample,
|
| 672 |
+
order=self.this_order,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
for i in range(self.config.solver_order - 1):
|
| 676 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
| 677 |
+
self.timestep_list[i] = self.timestep_list[i + 1]
|
| 678 |
+
|
| 679 |
+
self.model_outputs[-1] = model_output_convert
|
| 680 |
+
self.timestep_list[-1] = timestep # pyright: ignore
|
| 681 |
+
|
| 682 |
+
if self.config.lower_order_final:
|
| 683 |
+
this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore
|
| 684 |
+
else:
|
| 685 |
+
this_order = self.config.solver_order
|
| 686 |
+
|
| 687 |
+
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
|
| 688 |
+
assert self.this_order > 0
|
| 689 |
+
|
| 690 |
+
self.last_sample = sample
|
| 691 |
+
prev_sample = self.multistep_uni_p_bh_update(
|
| 692 |
+
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
| 693 |
+
sample=sample,
|
| 694 |
+
order=self.this_order,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
if self.lower_order_nums < self.config.solver_order:
|
| 698 |
+
self.lower_order_nums += 1
|
| 699 |
+
|
| 700 |
+
# upon completion increase step index by one
|
| 701 |
+
self._step_index += 1 # pyright: ignore
|
| 702 |
+
|
| 703 |
+
if not return_dict:
|
| 704 |
+
return (prev_sample,)
|
| 705 |
+
|
| 706 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
| 707 |
+
|
| 708 |
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 709 |
+
"""
|
| 710 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 711 |
+
current timestep.
|
| 712 |
+
|
| 713 |
+
Args:
|
| 714 |
+
sample (`torch.Tensor`):
|
| 715 |
+
The input sample.
|
| 716 |
+
|
| 717 |
+
Returns:
|
| 718 |
+
`torch.Tensor`:
|
| 719 |
+
A scaled input sample.
|
| 720 |
+
"""
|
| 721 |
+
return sample
|
| 722 |
+
|
| 723 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
| 724 |
+
def add_noise(
|
| 725 |
+
self,
|
| 726 |
+
original_samples: torch.Tensor,
|
| 727 |
+
noise: torch.Tensor,
|
| 728 |
+
timesteps: torch.IntTensor,
|
| 729 |
+
) -> torch.Tensor:
|
| 730 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 731 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 732 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 733 |
+
# mps does not support float64
|
| 734 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
| 735 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
| 736 |
+
else:
|
| 737 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
| 738 |
+
timesteps = timesteps.to(original_samples.device)
|
| 739 |
+
|
| 740 |
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
| 741 |
+
if self.begin_index is None:
|
| 742 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
| 743 |
+
elif self.step_index is not None:
|
| 744 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 745 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
| 746 |
+
else:
|
| 747 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 748 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
| 749 |
+
|
| 750 |
+
sigma = sigmas[step_indices].flatten()
|
| 751 |
+
while len(sigma.shape) < len(original_samples.shape):
|
| 752 |
+
sigma = sigma.unsqueeze(-1)
|
| 753 |
+
|
| 754 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 755 |
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
| 756 |
+
return noisy_samples
|
| 757 |
+
|
| 758 |
+
def __len__(self):
|
| 759 |
+
return self.config.num_train_timesteps
|
assets/architecture.png
ADDED
|
Git LFS Details
|
assets/banner.png
ADDED
|
Git LFS Details
|
assets/showcase_i2v.png
ADDED
|
Git LFS Details
|
assets/showcase_t2v.png
ADDED
|
Git LFS Details
|
feature_extractor/preprocessor_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_convert_rgb": true,
|
| 3 |
+
"do_normalize": true,
|
| 4 |
+
"do_rescale": false,
|
| 5 |
+
"do_resize": true,
|
| 6 |
+
"image_mean": [
|
| 7 |
+
0.5,
|
| 8 |
+
0.5,
|
| 9 |
+
0.5
|
| 10 |
+
],
|
| 11 |
+
"image_processor_type": "SiglipImageProcessor",
|
| 12 |
+
"image_std": [
|
| 13 |
+
0.5,
|
| 14 |
+
0.5,
|
| 15 |
+
0.5
|
| 16 |
+
],
|
| 17 |
+
"resample": 3,
|
| 18 |
+
"rescale_factor": 0.00392156862745098,
|
| 19 |
+
"size": {
|
| 20 |
+
"height": 896,
|
| 21 |
+
"width": 896
|
| 22 |
+
}
|
| 23 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Motif-Video 2B — Text-to-Video & Image-to-Video inference.
|
| 3 |
+
|
| 4 |
+
GPU requirements: ~24GB VRAM for 720p (1280x736, 121 frames).
|
| 5 |
+
Tested with: torch>=2.0, diffusers>=0.35.2, transformers>=5.0.0
|
| 6 |
+
|
| 7 |
+
Uses Adaptive Projected Guidance (APG) by default for best quality.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from diffusers import AdaptiveProjectedGuidance, DiffusionPipeline
|
| 14 |
+
from diffusers.utils import export_to_video
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parse_args():
|
| 18 |
+
parser = argparse.ArgumentParser(description="Motif-Video 2B Inference (T2V / I2V)")
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--model-path",
|
| 21 |
+
type=str,
|
| 22 |
+
default="Motif-Technologies/Motif-Video-2B",
|
| 23 |
+
help="HuggingFace model ID or local checkpoint path (uses trust_remote_code=True)",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--prompt",
|
| 27 |
+
type=str,
|
| 28 |
+
default="A category-five hurricane, viewed from inside the eye, reveals a circular stadium of cloud walls rising to fifty thousand feet with an eerie disk of blue sky directly overhead. Shot from a NOAA reconnaissance aircraft mounted camera, the perspective looks outward toward the eyewall — a near-vertical curtain of rotating cloud and lightning that is simultaneously terrifying and transcendent. The inner surface of the eyewall catches the setting sun, painting it in improbable shades of peach and rose. The camera slowly pans 360 degrees to complete one full revolution, capturing the entire coliseum of the storm. Below, the ocean surface is a white blur of foam and spray. The documentary-style cinematography strips away all artifice to present the storm as an entity of pure elemental power.",
|
| 29 |
+
help="Text prompt for video generation",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--image",
|
| 33 |
+
type=str,
|
| 34 |
+
default=None,
|
| 35 |
+
help="Path to input image for I2V mode (omit for T2V)",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--negative-prompt",
|
| 39 |
+
type=str,
|
| 40 |
+
default=None,
|
| 41 |
+
help="Negative prompt (default: built-in pipeline default)",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument("--output", type=str, default="output.mp4", help="Output video file path")
|
| 44 |
+
parser.add_argument("--num-frames", type=int, default=121, help="Number of frames to generate (121 = ~5s at 24fps)")
|
| 45 |
+
parser.add_argument("--height", type=int, default=736, help="Video height in pixels")
|
| 46 |
+
parser.add_argument("--width", type=int, default=1280, help="Video width in pixels")
|
| 47 |
+
parser.add_argument("--guidance-scale", type=float, default=8.0, help="Classifier-free guidance scale")
|
| 48 |
+
parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of denoising steps")
|
| 49 |
+
parser.add_argument("--fps", type=int, default=24, help="Output video frame rate")
|
| 50 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--dtype",
|
| 53 |
+
type=str,
|
| 54 |
+
default="bfloat16",
|
| 55 |
+
choices=["float16", "bfloat16", "float32"],
|
| 56 |
+
help="Model dtype",
|
| 57 |
+
)
|
| 58 |
+
return parser.parse_args()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main():
|
| 62 |
+
args = parse_args()
|
| 63 |
+
|
| 64 |
+
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
|
| 65 |
+
torch_dtype = dtype_map[args.dtype]
|
| 66 |
+
|
| 67 |
+
mode = "I2V" if args.image else "T2V"
|
| 68 |
+
print(f"[{mode}] Loading model from: {args.model_path}")
|
| 69 |
+
|
| 70 |
+
guider = AdaptiveProjectedGuidance(
|
| 71 |
+
guidance_scale=args.guidance_scale,
|
| 72 |
+
adaptive_projected_guidance_rescale=12.0,
|
| 73 |
+
adaptive_projected_guidance_momentum=0.1,
|
| 74 |
+
eta=0.0,
|
| 75 |
+
use_original_formulation=True,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 79 |
+
args.model_path,
|
| 80 |
+
custom_pipeline="pipeline_motif_video",
|
| 81 |
+
trust_remote_code=True,
|
| 82 |
+
torch_dtype=torch_dtype,
|
| 83 |
+
guider=guider,
|
| 84 |
+
)
|
| 85 |
+
pipe = pipe.to("cuda")
|
| 86 |
+
|
| 87 |
+
generator = torch.Generator(device="cuda").manual_seed(args.seed)
|
| 88 |
+
|
| 89 |
+
# Load image for I2V mode
|
| 90 |
+
image = None
|
| 91 |
+
if args.image:
|
| 92 |
+
from PIL import Image
|
| 93 |
+
|
| 94 |
+
image = Image.open(args.image).convert("RGB")
|
| 95 |
+
print(f"[I2V] Input image: {args.image} ({image.size[0]}x{image.size[1]})")
|
| 96 |
+
|
| 97 |
+
print(f"Generating video: {args.width}x{args.height}, {args.num_frames} frames, {args.num_inference_steps} steps")
|
| 98 |
+
pipe_kwargs = dict(
|
| 99 |
+
prompt=args.prompt,
|
| 100 |
+
image=image,
|
| 101 |
+
height=args.height,
|
| 102 |
+
width=args.width,
|
| 103 |
+
num_frames=args.num_frames,
|
| 104 |
+
num_inference_steps=args.num_inference_steps,
|
| 105 |
+
generator=generator,
|
| 106 |
+
frame_rate=args.fps,
|
| 107 |
+
)
|
| 108 |
+
if args.negative_prompt is not None:
|
| 109 |
+
pipe_kwargs["negative_prompt"] = args.negative_prompt
|
| 110 |
+
|
| 111 |
+
output = pipe(**pipe_kwargs)
|
| 112 |
+
|
| 113 |
+
video_frames = output.frames[0]
|
| 114 |
+
export_to_video(video_frames, args.output, fps=args.fps)
|
| 115 |
+
print(f"Video saved to: {args.output}")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
main()
|
model_index.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "MotifVideoPipeline",
|
| 3 |
+
"_diffusers_version": "0.35.2",
|
| 4 |
+
"scheduler": [
|
| 5 |
+
"diffusers",
|
| 6 |
+
"FlowMatchEulerDiscreteScheduler"
|
| 7 |
+
],
|
| 8 |
+
"text_encoder": [
|
| 9 |
+
"transformers",
|
| 10 |
+
"T5Gemma2Model"
|
| 11 |
+
],
|
| 12 |
+
"tokenizer": [
|
| 13 |
+
"transformers",
|
| 14 |
+
"GemmaTokenizer"
|
| 15 |
+
],
|
| 16 |
+
"transformer": [
|
| 17 |
+
"transformer_motif_video",
|
| 18 |
+
"MotifVideoTransformer3DModel"
|
| 19 |
+
],
|
| 20 |
+
"vae": [
|
| 21 |
+
"diffusers",
|
| 22 |
+
"AutoencoderKLWan"
|
| 23 |
+
],
|
| 24 |
+
"feature_extractor": [
|
| 25 |
+
"transformers",
|
| 26 |
+
"SiglipImageProcessor"
|
| 27 |
+
]
|
| 28 |
+
}
|
motif-video-technical-report.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f931b222d303bee5b62053b9734a021bb9310388cf575fbab83ffdcceca378ab
|
| 3 |
+
size 17260596
|
pipeline_motif_video.py
ADDED
|
@@ -0,0 +1,1321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Motif Technologies, Inc. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import html
|
| 16 |
+
import inspect
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import ftfy
|
| 21 |
+
import numpy as np
|
| 22 |
+
import regex as re
|
| 23 |
+
import torch
|
| 24 |
+
from diffusers import (
|
| 25 |
+
AdaptiveProjectedGuidance,
|
| 26 |
+
AutoencoderKLWan,
|
| 27 |
+
ClassifierFreeGuidance,
|
| 28 |
+
DiffusionPipeline,
|
| 29 |
+
DPMSolverMultistepScheduler,
|
| 30 |
+
FlowMatchEulerDiscreteScheduler,
|
| 31 |
+
SkipLayerGuidance,
|
| 32 |
+
UniPCMultistepScheduler,
|
| 33 |
+
)
|
| 34 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 35 |
+
from diffusers.utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
|
| 36 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 37 |
+
from diffusers.video_processor import VideoProcessor
|
| 38 |
+
from einops import rearrange
|
| 39 |
+
from PIL import Image
|
| 40 |
+
from torch import Tensor
|
| 41 |
+
|
| 42 |
+
from diffusers.guiders.adaptive_projected_guidance import MomentumBuffer
|
| 43 |
+
from diffusers.guiders.guider_utils import GuiderOutput
|
| 44 |
+
from ._fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 45 |
+
from transformers import BatchEncoding, PreTrainedTokenizerBase, SiglipImageProcessor, T5Gemma2Model
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if is_torch_xla_available():
|
| 49 |
+
import torch_xla.core.xla_model as xm
|
| 50 |
+
|
| 51 |
+
XLA_AVAILABLE = True
|
| 52 |
+
else:
|
| 53 |
+
XLA_AVAILABLE = False
|
| 54 |
+
|
| 55 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 56 |
+
|
| 57 |
+
EXAMPLE_DOC_STRING = """
|
| 58 |
+
Examples:
|
| 59 |
+
```py
|
| 60 |
+
>>> import torch
|
| 61 |
+
>>> from diffusers import MotifVideoPipeline
|
| 62 |
+
>>> from diffusers.utils import export_to_video
|
| 63 |
+
|
| 64 |
+
>>> # Load the Motif Video pipeline
|
| 65 |
+
>>> motif_video_model_id = "MotifTechnologies/Motif-Video"
|
| 66 |
+
>>> pipe = MotifVideoPipeline.from_pretrained(motif_video_model_id, torch_dtype=torch.bfloat16)
|
| 67 |
+
>>> pipe.to("cuda")
|
| 68 |
+
|
| 69 |
+
>>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
| 70 |
+
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
| 71 |
+
|
| 72 |
+
>>> video = pipe(
|
| 73 |
+
... prompt=prompt,
|
| 74 |
+
... negative_prompt=negative_prompt,
|
| 75 |
+
... width=640,
|
| 76 |
+
... height=352,
|
| 77 |
+
... num_frames=65,
|
| 78 |
+
... num_inference_steps=50,
|
| 79 |
+
... ).frames[0]
|
| 80 |
+
>>> export_to_video(video, "output.mp4", fps=16)
|
| 81 |
+
```
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclass
|
| 86 |
+
class MotifVideoPipelineOutput(BaseOutput):
|
| 87 |
+
r"""
|
| 88 |
+
Output class for Motif Video pipelines.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 92 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 93 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 94 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
frames: torch.Tensor
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
"""Video-aware Adaptive Projected Guidance (APG).
|
| 101 |
+
|
| 102 |
+
Standard APG normalizes over all spatial dimensions [C, T, H, W], which collapses
|
| 103 |
+
temporal variation. This module normalizes over [C, H, W] only, preserving
|
| 104 |
+
per-frame independence.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def video_normalized_guidance(
|
| 109 |
+
pred_cond: torch.Tensor,
|
| 110 |
+
pred_uncond: torch.Tensor,
|
| 111 |
+
guidance_scale: float,
|
| 112 |
+
momentum_buffer: MomentumBuffer | None = None,
|
| 113 |
+
eta: float = 1.0,
|
| 114 |
+
norm_threshold: float = 0.0,
|
| 115 |
+
use_original_formulation: bool = False,
|
| 116 |
+
) -> torch.Tensor:
|
| 117 |
+
"""APG with video-aware normalization: normalize over [C, H, W], exclude T.
|
| 118 |
+
|
| 119 |
+
For 5D input [B, C, T, H, W], dim=[-1, -2, -4] normalizes per-frame (W, H, C),
|
| 120 |
+
keeping the T dimension independent. For 4D input [B, C, H, W], falls back to
|
| 121 |
+
standard [-1, -2, -3] behavior.
|
| 122 |
+
"""
|
| 123 |
+
diff = pred_cond - pred_uncond
|
| 124 |
+
|
| 125 |
+
if len(diff.shape) == 5:
|
| 126 |
+
# [B, C, T, H, W] → normalize over W(-1), H(-2), C(-4), skip T(-3)
|
| 127 |
+
dim = [-1, -2, -4]
|
| 128 |
+
else:
|
| 129 |
+
# [B, C, H, W] → standard behavior
|
| 130 |
+
dim = [-i for i in range(1, len(diff.shape))]
|
| 131 |
+
|
| 132 |
+
if momentum_buffer is not None:
|
| 133 |
+
momentum_buffer.update(diff)
|
| 134 |
+
diff = momentum_buffer.running_average
|
| 135 |
+
|
| 136 |
+
if norm_threshold > 0:
|
| 137 |
+
ones = torch.ones_like(diff)
|
| 138 |
+
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
| 139 |
+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
| 140 |
+
diff = diff * scale_factor
|
| 141 |
+
|
| 142 |
+
v0, v1 = diff.double(), pred_cond.double()
|
| 143 |
+
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
| 144 |
+
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
| 145 |
+
v0_orthogonal = v0 - v0_parallel
|
| 146 |
+
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
| 147 |
+
normalized_update = diff_orthogonal + eta * diff_parallel
|
| 148 |
+
|
| 149 |
+
pred = pred_cond if use_original_formulation else pred_uncond
|
| 150 |
+
pred = pred + guidance_scale * normalized_update
|
| 151 |
+
|
| 152 |
+
return pred
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class VideoAdaptiveProjectedGuidance(AdaptiveProjectedGuidance):
|
| 156 |
+
"""APG variant that normalizes over [C, H, W] per frame, excluding the T dimension."""
|
| 157 |
+
|
| 158 |
+
def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
|
| 159 |
+
pred = None
|
| 160 |
+
|
| 161 |
+
if not self._is_apg_enabled():
|
| 162 |
+
pred = pred_cond
|
| 163 |
+
else:
|
| 164 |
+
pred = video_normalized_guidance(
|
| 165 |
+
pred_cond,
|
| 166 |
+
pred_uncond,
|
| 167 |
+
self.guidance_scale,
|
| 168 |
+
self.momentum_buffer,
|
| 169 |
+
self.eta,
|
| 170 |
+
self.adaptive_projected_guidance_rescale,
|
| 171 |
+
self.use_original_formulation,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if self.guidance_rescale > 0.0:
|
| 175 |
+
from diffusers.guiders.classifier_free_guidance import rescale_noise_cfg
|
| 176 |
+
|
| 177 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 178 |
+
|
| 179 |
+
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
| 183 |
+
def calculate_shift(
|
| 184 |
+
image_seq_len,
|
| 185 |
+
base_seq_len: int = 256,
|
| 186 |
+
max_seq_len: int = 4096,
|
| 187 |
+
base_shift: float = 0.5,
|
| 188 |
+
max_shift: float = 1.15,
|
| 189 |
+
):
|
| 190 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 191 |
+
b = base_shift - m * base_seq_len
|
| 192 |
+
mu = image_seq_len * m + b
|
| 193 |
+
return mu
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_linear_quadratic_sigmas(
|
| 197 |
+
num_inference_steps: int,
|
| 198 |
+
linear_quadratic_emulating_steps: int = 250,
|
| 199 |
+
) -> np.ndarray:
|
| 200 |
+
"""
|
| 201 |
+
Compute a linear-quadratic sigma schedule for flow matching.
|
| 202 |
+
|
| 203 |
+
This schedule combines:
|
| 204 |
+
- First half: Linear interpolation from high noise to medium noise (slow denoising)
|
| 205 |
+
- Second half: Quadratic interpolation from medium noise to clean (faster denoising)
|
| 206 |
+
|
| 207 |
+
Convention:
|
| 208 |
+
- sigma=1.0 represents pure noise
|
| 209 |
+
- sigma=0.0 represents clean image
|
| 210 |
+
- Output sigmas are in descending order (1.0 → ~0)
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
num_inference_steps: Total number of denoising steps (must be even).
|
| 214 |
+
linear_quadratic_emulating_steps: Controls the slope of linear interpolation.
|
| 215 |
+
Higher values result in gentler slope in the first half.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
np.ndarray: Array of sigma values with shape (num_inference_steps,).
|
| 219 |
+
The scheduler will append a terminal 0.
|
| 220 |
+
|
| 221 |
+
Raises:
|
| 222 |
+
ValueError: If num_inference_steps is not even.
|
| 223 |
+
|
| 224 |
+
Reference:
|
| 225 |
+
Linear-quadratic timestep schedule for improved flow matching inference.
|
| 226 |
+
"""
|
| 227 |
+
if num_inference_steps % 2 != 0:
|
| 228 |
+
raise ValueError(
|
| 229 |
+
f"num_inference_steps must be even for linear-quadratic schedule, but got {num_inference_steps}"
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
steps = num_inference_steps
|
| 233 |
+
N = linear_quadratic_emulating_steps
|
| 234 |
+
half_steps = steps // 2
|
| 235 |
+
|
| 236 |
+
# First half: linear interpolation from 1 toward 0
|
| 237 |
+
# Takes first half_steps values from linspace(1, 0, N+1)
|
| 238 |
+
linear_part = np.linspace(1.0, 0.0, N + 1)[:half_steps]
|
| 239 |
+
|
| 240 |
+
# Second half: quadratic interpolation
|
| 241 |
+
# Formula: x^2 * (half_steps/N - 1) - (half_steps/N - 1)
|
| 242 |
+
# = (half_steps/N - 1) * (x^2 - 1)
|
| 243 |
+
# This maps x=0 to (half_steps/N - 1) * (-1) = 1 - half_steps/N
|
| 244 |
+
# and maps x=1 to 0
|
| 245 |
+
x = np.linspace(0.0, 1.0, half_steps + 1)
|
| 246 |
+
scale_factor = half_steps / N - 1 # negative value
|
| 247 |
+
quadratic_part = x**2 * scale_factor - scale_factor
|
| 248 |
+
|
| 249 |
+
# Concatenate and exclude the last 0 (scheduler appends terminal 0)
|
| 250 |
+
sigmas = np.concatenate([linear_part, quadratic_part])
|
| 251 |
+
sigmas = sigmas[:-1] # Remove trailing 0, scheduler will append it
|
| 252 |
+
|
| 253 |
+
return sigmas.astype(np.float32)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 257 |
+
def retrieve_timesteps(
|
| 258 |
+
scheduler,
|
| 259 |
+
num_inference_steps: Optional[int] = None,
|
| 260 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 261 |
+
timesteps: Optional[List[int]] = None,
|
| 262 |
+
sigmas: Optional[List[float]] = None,
|
| 263 |
+
use_linear_quadratic_schedule: bool = False,
|
| 264 |
+
linear_quadratic_emulating_steps: int = 250,
|
| 265 |
+
**kwargs,
|
| 266 |
+
):
|
| 267 |
+
"""
|
| 268 |
+
Retrieve timesteps from the scheduler.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
scheduler: The noise scheduler to use.
|
| 272 |
+
num_inference_steps: Number of denoising steps.
|
| 273 |
+
device: Device to place timesteps on.
|
| 274 |
+
timesteps: Custom timestep values (mutually exclusive with sigmas).
|
| 275 |
+
sigmas: Custom sigma values (mutually exclusive with timesteps).
|
| 276 |
+
use_linear_quadratic_schedule: If True, use linear-quadratic sigma schedule.
|
| 277 |
+
This overrides the default linear schedule. Requires num_inference_steps
|
| 278 |
+
to be even.
|
| 279 |
+
linear_quadratic_emulating_steps: Controls the linear portion slope.
|
| 280 |
+
Higher values result in gentler slope in the first half. Default: 250.
|
| 281 |
+
**kwargs: Additional arguments passed to scheduler.set_timesteps().
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
Tuple of (timesteps, num_inference_steps).
|
| 285 |
+
|
| 286 |
+
Raises:
|
| 287 |
+
ValueError: If both timesteps and sigmas are provided, or if
|
| 288 |
+
use_linear_quadratic_schedule is True but num_inference_steps is odd.
|
| 289 |
+
"""
|
| 290 |
+
if timesteps is not None and sigmas is not None:
|
| 291 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 292 |
+
|
| 293 |
+
# Handle linear-quadratic schedule: compute sigmas if flag is set
|
| 294 |
+
if use_linear_quadratic_schedule:
|
| 295 |
+
if sigmas is not None:
|
| 296 |
+
raise ValueError(
|
| 297 |
+
"Cannot use both `sigmas` and `use_linear_quadratic_schedule`. "
|
| 298 |
+
"The linear-quadratic schedule computes sigmas automatically."
|
| 299 |
+
)
|
| 300 |
+
if num_inference_steps is None:
|
| 301 |
+
raise ValueError("`num_inference_steps` must be provided when using `use_linear_quadratic_schedule`.")
|
| 302 |
+
sigmas = get_linear_quadratic_sigmas(
|
| 303 |
+
num_inference_steps=num_inference_steps,
|
| 304 |
+
linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if timesteps is not None:
|
| 308 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 309 |
+
if not accepts_timesteps:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 312 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 313 |
+
)
|
| 314 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 315 |
+
timesteps = scheduler.timesteps
|
| 316 |
+
num_inference_steps = len(timesteps)
|
| 317 |
+
elif sigmas is not None:
|
| 318 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 319 |
+
if not accept_sigmas:
|
| 320 |
+
raise ValueError(
|
| 321 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 322 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 323 |
+
)
|
| 324 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 325 |
+
timesteps = scheduler.timesteps
|
| 326 |
+
num_inference_steps = len(timesteps)
|
| 327 |
+
else:
|
| 328 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 329 |
+
timesteps = scheduler.timesteps
|
| 330 |
+
return timesteps, num_inference_steps
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def basic_clean(text):
|
| 334 |
+
text = ftfy.fix_text(text)
|
| 335 |
+
text = html.unescape(html.unescape(text))
|
| 336 |
+
return text.strip()
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def whitespace_clean(text):
|
| 340 |
+
text = re.sub(r"\s+", " ", text)
|
| 341 |
+
text = text.strip()
|
| 342 |
+
return text
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def prompt_clean(text):
|
| 346 |
+
text = whitespace_clean(basic_clean(text))
|
| 347 |
+
return text
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class MotifVideoPipeline(DiffusionPipeline):
|
| 351 |
+
r"""
|
| 352 |
+
Pipeline for text-to-video generation using MotifVideoTransformer.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
transformer ([`MotifVideoTransformer3DModel`]):
|
| 356 |
+
Conditional Transformer architecture to denoise the encoded video latents.
|
| 357 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 358 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 359 |
+
vae ([`AutoencoderKLWan`]):
|
| 360 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 361 |
+
text_encoder ([`T5Gemma2Model`]):
|
| 362 |
+
Primary text encoder for encoding text prompts into embeddings.
|
| 363 |
+
tokenizer ([`PreTrainedTokenizerBase`]):
|
| 364 |
+
Tokenizer corresponding to the primary text encoder.
|
| 365 |
+
guider ([`ClassifierFreeGuidance`] or [`SkipLayerGuidance`] or [`AdaptiveProjectedGuidance`] or [`VideoAdaptiveProjectedGuidance`], *optional*):
|
| 366 |
+
The guidance method to use. If `None`, it defaults to `ClassifierFreeGuidance()`.
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 370 |
+
_optional_components = ["feature_extractor"]
|
| 371 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 372 |
+
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
scheduler: Union[
|
| 376 |
+
FlowMatchEulerDiscreteScheduler,
|
| 377 |
+
DPMSolverMultistepScheduler,
|
| 378 |
+
UniPCMultistepScheduler,
|
| 379 |
+
FlowUniPCMultistepScheduler,
|
| 380 |
+
],
|
| 381 |
+
vae: AutoencoderKLWan,
|
| 382 |
+
text_encoder: T5Gemma2Model,
|
| 383 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 384 |
+
transformer,
|
| 385 |
+
guider: Optional[
|
| 386 |
+
Union[ClassifierFreeGuidance, SkipLayerGuidance, AdaptiveProjectedGuidance, VideoAdaptiveProjectedGuidance]
|
| 387 |
+
] = None,
|
| 388 |
+
feature_extractor: Optional[SiglipImageProcessor] = None,
|
| 389 |
+
):
|
| 390 |
+
super().__init__()
|
| 391 |
+
|
| 392 |
+
self.guider = ClassifierFreeGuidance() if guider is None else guider
|
| 393 |
+
|
| 394 |
+
self.register_modules(
|
| 395 |
+
vae=vae,
|
| 396 |
+
text_encoder=text_encoder,
|
| 397 |
+
tokenizer=tokenizer,
|
| 398 |
+
transformer=transformer,
|
| 399 |
+
scheduler=scheduler,
|
| 400 |
+
feature_extractor=feature_extractor,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
| 404 |
+
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
| 405 |
+
|
| 406 |
+
self.transformer_spatial_patch_size = (
|
| 407 |
+
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 2
|
| 408 |
+
)
|
| 409 |
+
self.transformer_temporal_patch_size = (
|
| 410 |
+
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 414 |
+
self.tokenizer_max_length = (
|
| 415 |
+
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 512
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
def _get_default_embeds(
|
| 419 |
+
self,
|
| 420 |
+
text_encoder,
|
| 421 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 422 |
+
prompt: Union[str, List[str]],
|
| 423 |
+
max_sequence_length: int = 512,
|
| 424 |
+
device: Optional[torch.device] = None,
|
| 425 |
+
dtype: Optional[torch.dtype] = None,
|
| 426 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 427 |
+
dtype = dtype or text_encoder.dtype
|
| 428 |
+
|
| 429 |
+
text_inputs = tokenizer(
|
| 430 |
+
prompt,
|
| 431 |
+
padding="max_length",
|
| 432 |
+
max_length=max_sequence_length,
|
| 433 |
+
truncation=True,
|
| 434 |
+
add_special_tokens=True,
|
| 435 |
+
return_attention_mask=True,
|
| 436 |
+
return_tensors="pt",
|
| 437 |
+
)
|
| 438 |
+
text_inputs = BatchEncoding(
|
| 439 |
+
{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in text_inputs.items()}
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
prompt_embeds = text_encoder(**text_inputs)[0]
|
| 443 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 444 |
+
|
| 445 |
+
return prompt_embeds, text_inputs.attention_mask
|
| 446 |
+
|
| 447 |
+
def _average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
| 448 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 449 |
+
denom = attention_mask.sum(dim=1, keepdim=True).clamp(min=1) # avoid div by zero
|
| 450 |
+
return last_hidden.sum(dim=1) / denom
|
| 451 |
+
|
| 452 |
+
def _get_prompt_embeds(
|
| 453 |
+
self,
|
| 454 |
+
text_encoder: T5Gemma2Model,
|
| 455 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 456 |
+
prompt: Union[str, List[str]] | None = None,
|
| 457 |
+
num_videos_per_prompt: int = 1,
|
| 458 |
+
max_sequence_length: int = 512,
|
| 459 |
+
device: Optional[torch.device] = None,
|
| 460 |
+
dtype: Optional[torch.dtype] = None,
|
| 461 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 462 |
+
device = device or self._execution_device
|
| 463 |
+
|
| 464 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 465 |
+
|
| 466 |
+
prompt_embeds_kwargs = {
|
| 467 |
+
"text_encoder": text_encoder,
|
| 468 |
+
"tokenizer": tokenizer,
|
| 469 |
+
"prompt": prompt,
|
| 470 |
+
"max_sequence_length": max_sequence_length,
|
| 471 |
+
"device": device,
|
| 472 |
+
"dtype": dtype,
|
| 473 |
+
}
|
| 474 |
+
# T5Gemma2Model bundles encoder and decoder/LM head, while _get_default_embeds expects an encoder-only model
|
| 475 |
+
# (similar to T5EncoderModel/T5GemmaEncoderModel), so we pass the encoder submodule explicitly here.
|
| 476 |
+
if isinstance(text_encoder, T5Gemma2Model):
|
| 477 |
+
prompt_embeds_kwargs["text_encoder"] = text_encoder.encoder
|
| 478 |
+
prompt_embeds, prompt_attention_mask = self._get_default_embeds(**prompt_embeds_kwargs)
|
| 479 |
+
|
| 480 |
+
pooled_prompt_embeds = self._average_pool(prompt_embeds, prompt_attention_mask)
|
| 481 |
+
|
| 482 |
+
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
|
| 483 |
+
|
| 484 |
+
# Keep encode_prompt structure, uses _get_prompt_embeds internally
|
| 485 |
+
def encode_prompt(
|
| 486 |
+
self,
|
| 487 |
+
prompt: Union[str, List[str]],
|
| 488 |
+
num_videos_per_prompt: int = 1,
|
| 489 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 490 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 491 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 492 |
+
max_sequence_length: int = 512,
|
| 493 |
+
device: Optional[torch.device] = None,
|
| 494 |
+
dtype: Optional[torch.dtype] = None,
|
| 495 |
+
) -> Tuple[
|
| 496 |
+
torch.Tensor,
|
| 497 |
+
torch.Tensor,
|
| 498 |
+
torch.Tensor,
|
| 499 |
+
]:
|
| 500 |
+
device = device or self._execution_device
|
| 501 |
+
|
| 502 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 503 |
+
if prompt is not None:
|
| 504 |
+
batch_size = len(prompt)
|
| 505 |
+
else:
|
| 506 |
+
batch_size = prompt_embeds.shape[0]
|
| 507 |
+
|
| 508 |
+
prompt_embeds_kwargs = {
|
| 509 |
+
"device": device,
|
| 510 |
+
"dtype": dtype,
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
if prompt_embeds is None:
|
| 514 |
+
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self._get_prompt_embeds(
|
| 515 |
+
text_encoder=self.text_encoder,
|
| 516 |
+
tokenizer=self.tokenizer,
|
| 517 |
+
prompt=prompt,
|
| 518 |
+
max_sequence_length=max_sequence_length,
|
| 519 |
+
**prompt_embeds_kwargs,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 523 |
+
seq_len = prompt_embeds.shape[1]
|
| 524 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 525 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 526 |
+
|
| 527 |
+
if pooled_prompt_embeds is not None:
|
| 528 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0)
|
| 529 |
+
|
| 530 |
+
# Keep attention mask handling
|
| 531 |
+
prompt_attention_mask = prompt_attention_mask.bool()
|
| 532 |
+
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
| 533 |
+
prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0)
|
| 534 |
+
|
| 535 |
+
return (
|
| 536 |
+
prompt_embeds,
|
| 537 |
+
pooled_prompt_embeds,
|
| 538 |
+
prompt_attention_mask,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
@property
|
| 542 |
+
def vision_encoder(self):
|
| 543 |
+
"""Get the vision encoder from T5Gemma2.
|
| 544 |
+
|
| 545 |
+
T5Gemma2 has vision_tower.vision_model structure.
|
| 546 |
+
Will raise AttributeError if not available.
|
| 547 |
+
"""
|
| 548 |
+
return self.text_encoder.encoder.vision_tower.vision_model
|
| 549 |
+
|
| 550 |
+
def encode_image(
|
| 551 |
+
self,
|
| 552 |
+
image: Image.Image,
|
| 553 |
+
batch_size: int = 1,
|
| 554 |
+
device: Optional[torch.device] = None,
|
| 555 |
+
dtype: Optional[torch.dtype] = None,
|
| 556 |
+
) -> torch.Tensor:
|
| 557 |
+
"""Encode image to embeddings using SigLIP vision encoder."""
|
| 558 |
+
device = device or self._execution_device
|
| 559 |
+
dtype = dtype or self.transformer.dtype
|
| 560 |
+
|
| 561 |
+
image_embeds = self._get_image_embeds(
|
| 562 |
+
image_encoder=self.vision_encoder,
|
| 563 |
+
feature_extractor=self.feature_extractor,
|
| 564 |
+
image=image,
|
| 565 |
+
device=device,
|
| 566 |
+
)
|
| 567 |
+
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
| 568 |
+
return image_embeds.to(device=device, dtype=dtype)
|
| 569 |
+
|
| 570 |
+
@staticmethod
|
| 571 |
+
def _get_image_embeds(
|
| 572 |
+
image_encoder,
|
| 573 |
+
feature_extractor: SiglipImageProcessor,
|
| 574 |
+
image,
|
| 575 |
+
device: torch.device,
|
| 576 |
+
) -> torch.Tensor:
|
| 577 |
+
"""Helper to encode single image with SigLIP.
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
image_encoder: The SigLIP vision encoder model.
|
| 581 |
+
feature_extractor: SiglipImageProcessor for preprocessing.
|
| 582 |
+
image: Can be either:
|
| 583 |
+
- PIL.Image.Image: Will be preprocessed by feature_extractor
|
| 584 |
+
- torch.Tensor: Assumed to be in [0, 1] range, will be normalized and passed to encoder
|
| 585 |
+
device: Device to place tensors on.
|
| 586 |
+
|
| 587 |
+
Returns:
|
| 588 |
+
Image embeddings from the vision encoder.
|
| 589 |
+
"""
|
| 590 |
+
image_encoder_dtype = next(image_encoder.parameters()).dtype
|
| 591 |
+
|
| 592 |
+
if isinstance(image, torch.Tensor):
|
| 593 |
+
image = feature_extractor.preprocess(
|
| 594 |
+
images=image.float(),
|
| 595 |
+
do_resize=True,
|
| 596 |
+
do_rescale=False,
|
| 597 |
+
do_normalize=True,
|
| 598 |
+
do_convert_rgb=True,
|
| 599 |
+
return_tensors="pt",
|
| 600 |
+
)
|
| 601 |
+
else:
|
| 602 |
+
image = feature_extractor.preprocess(
|
| 603 |
+
images=image,
|
| 604 |
+
do_resize=True,
|
| 605 |
+
do_rescale=False,
|
| 606 |
+
do_normalize=True,
|
| 607 |
+
do_convert_rgb=True,
|
| 608 |
+
return_tensors="pt",
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
image = image.to(device, dtype=image_encoder_dtype)
|
| 612 |
+
return image_encoder(**image).last_hidden_state
|
| 613 |
+
|
| 614 |
+
@torch.compiler.disable
|
| 615 |
+
def _prepare_first_frame_conditioning(
|
| 616 |
+
self,
|
| 617 |
+
video: torch.Tensor,
|
| 618 |
+
latents: torch.Tensor,
|
| 619 |
+
use_conditioning: bool,
|
| 620 |
+
generator: Optional[torch.Generator] = None,
|
| 621 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
| 622 |
+
"""Prepare first frame conditioning tensors.
|
| 623 |
+
|
| 624 |
+
This method implements batch-level conditioning where entire
|
| 625 |
+
batches are either I2V (all samples conditioned) or T2V (no conditioning). This
|
| 626 |
+
prevents mode confusion within batches.
|
| 627 |
+
|
| 628 |
+
For I2V mode:
|
| 629 |
+
1. Extract and VAE-encode first frame from video
|
| 630 |
+
2. Create latent_condition by repeating first frame across time (frame 0 only)
|
| 631 |
+
3. Create latent_mask with 1.0 at frame 0
|
| 632 |
+
4. Get image_embeds from vision encoder
|
| 633 |
+
|
| 634 |
+
For T2V mode:
|
| 635 |
+
1. Pad with zeros for latent_condition and latent_mask
|
| 636 |
+
|
| 637 |
+
Args:
|
| 638 |
+
video: Input video tensor [batch_size, frames, channels, height, width] in [-1, 1]
|
| 639 |
+
latents: Latents [batch_size, lantent_channels, latent_num_frames, latent_height, latent_width]
|
| 640 |
+
use_conditioning: Whether to use first-frame conditioning (True for I2V, False for T2V)
|
| 641 |
+
generator: Optional random number generator for reproducibility
|
| 642 |
+
|
| 643 |
+
Returns:
|
| 644 |
+
Tuple of (latent_condition, latent_mask, image_embeds).
|
| 645 |
+
- latent_condition: [B, C, F, H, W] conditioning signal (zeros for T2V)
|
| 646 |
+
- latent_mask: [B, 1, F, H, W] binary mask (zeros for T2V)
|
| 647 |
+
- image_embeds: [B, N, D] image embeddings from vision encoder or None for T2V
|
| 648 |
+
"""
|
| 649 |
+
batch_size, lantent_channels, latent_num_frames, latent_height, latent_width = latents.shape
|
| 650 |
+
device = latents.device
|
| 651 |
+
dtype = latents.dtype
|
| 652 |
+
|
| 653 |
+
# Determine if we should use conditioning
|
| 654 |
+
use_conditioning = use_conditioning and (latent_num_frames > 1)
|
| 655 |
+
|
| 656 |
+
# Initialize conditioning tensors
|
| 657 |
+
latent_condition = torch.zeros(
|
| 658 |
+
batch_size, lantent_channels, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype
|
| 659 |
+
)
|
| 660 |
+
latent_mask = torch.zeros(
|
| 661 |
+
batch_size, 1, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype
|
| 662 |
+
)
|
| 663 |
+
image_embeds = None
|
| 664 |
+
|
| 665 |
+
if use_conditioning:
|
| 666 |
+
with torch.no_grad():
|
| 667 |
+
# Encode first frame for latent_condition
|
| 668 |
+
first_frame_latents = self.vae.encode(
|
| 669 |
+
rearrange(video[:, 0:1], "b f c h w -> b c f h w")
|
| 670 |
+
).latent_dist.sample(generator=generator)
|
| 671 |
+
first_frame_latents = self._normalize_latents(
|
| 672 |
+
latents=first_frame_latents,
|
| 673 |
+
latents_mean=self.vae.config.latents_mean,
|
| 674 |
+
latents_std=self.vae.config.latents_std,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Create latent_condition by repeating first frame across time
|
| 678 |
+
latent_condition = first_frame_latents.repeat(1, 1, latent_num_frames, 1, 1)
|
| 679 |
+
latent_condition[:, :, 1:, :, :] = 0
|
| 680 |
+
|
| 681 |
+
# latent_mask: 1.0 at frame 0, 0.0 elsewhere
|
| 682 |
+
latent_mask[:, :, 0] = 1.0
|
| 683 |
+
|
| 684 |
+
# image_embeds from vision encoder
|
| 685 |
+
first_frame_vision = video[:, 0] # [B, C, H, W]
|
| 686 |
+
first_frame_vision = ((first_frame_vision + 1) / 2).clamp(0, 1)
|
| 687 |
+
|
| 688 |
+
with torch.no_grad():
|
| 689 |
+
image_embeds = self._get_image_embeds(
|
| 690 |
+
image_encoder=self.vision_encoder,
|
| 691 |
+
feature_extractor=self.feature_extractor,
|
| 692 |
+
image=first_frame_vision,
|
| 693 |
+
device=device,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
return latent_condition, latent_mask, image_embeds
|
| 697 |
+
|
| 698 |
+
def check_inputs(
|
| 699 |
+
self,
|
| 700 |
+
prompt,
|
| 701 |
+
negative_prompt,
|
| 702 |
+
height,
|
| 703 |
+
width,
|
| 704 |
+
batch_size,
|
| 705 |
+
callback_on_step_end_tensor_inputs=None,
|
| 706 |
+
prompt_embeds=None,
|
| 707 |
+
negative_prompt_embeds=None,
|
| 708 |
+
prompt_attention_mask=None,
|
| 709 |
+
negative_prompt_attention_mask=None,
|
| 710 |
+
):
|
| 711 |
+
# Resolution must be divisible by VAE scale factor * transformer patch size
|
| 712 |
+
# (e.g. 8 * 2 = 16 for default config) to avoid latent/patch dimension mismatch.
|
| 713 |
+
spatial_divisor = self.vae_scale_factor_spatial * self.transformer_spatial_patch_size
|
| 714 |
+
if height % spatial_divisor != 0 or width % spatial_divisor != 0:
|
| 715 |
+
raise ValueError(
|
| 716 |
+
f"`height` and `width` have to be divisible by {spatial_divisor} "
|
| 717 |
+
f"(vae_scale={self.vae_scale_factor_spatial} * patch_size={self.transformer_spatial_patch_size}) "
|
| 718 |
+
f"but are {height} and {width}."
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 722 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 723 |
+
):
|
| 724 |
+
raise ValueError(
|
| 725 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
if prompt is not None and prompt_embeds is not None:
|
| 729 |
+
raise ValueError(
|
| 730 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 731 |
+
" only forward one of the two."
|
| 732 |
+
)
|
| 733 |
+
elif prompt is None and prompt_embeds is None:
|
| 734 |
+
raise ValueError(
|
| 735 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 736 |
+
)
|
| 737 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 738 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 739 |
+
|
| 740 |
+
# Validate negative_prompt: must be None, str, or list with matching batch_size
|
| 741 |
+
if negative_prompt is not None:
|
| 742 |
+
if not isinstance(negative_prompt, (str, list)):
|
| 743 |
+
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
| 744 |
+
if isinstance(negative_prompt, list) and len(negative_prompt) != batch_size:
|
| 745 |
+
raise ValueError(
|
| 746 |
+
f"`negative_prompt` list length ({len(negative_prompt)}) must match batch_size ({batch_size})."
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
| 750 |
+
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
| 751 |
+
|
| 752 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
| 753 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
| 754 |
+
|
| 755 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 756 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 757 |
+
raise ValueError(
|
| 758 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 759 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 760 |
+
f" {negative_prompt_embeds.shape}."
|
| 761 |
+
)
|
| 762 |
+
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
| 763 |
+
raise ValueError(
|
| 764 |
+
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
| 765 |
+
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
| 766 |
+
f" {negative_prompt_attention_mask.shape}."
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
def _prepare_negative_prompt(
|
| 770 |
+
self,
|
| 771 |
+
negative_prompt: Optional[Union[str, List[str]]],
|
| 772 |
+
batch_size: int,
|
| 773 |
+
) -> List[str]:
|
| 774 |
+
"""
|
| 775 |
+
Prepare negative_prompt to match batch_size.
|
| 776 |
+
|
| 777 |
+
Args:
|
| 778 |
+
negative_prompt: None, a single string, or a list of strings matching batch_size.
|
| 779 |
+
batch_size: The number of prompts in the batch.
|
| 780 |
+
|
| 781 |
+
Returns:
|
| 782 |
+
A list of negative prompts with length equal to batch_size.
|
| 783 |
+
"""
|
| 784 |
+
if negative_prompt is None:
|
| 785 |
+
return [""] * batch_size
|
| 786 |
+
if isinstance(negative_prompt, str):
|
| 787 |
+
return [negative_prompt] * batch_size
|
| 788 |
+
return negative_prompt
|
| 789 |
+
|
| 790 |
+
@staticmethod
|
| 791 |
+
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
| 792 |
+
batch_size, num_channels, num_frames, height, width = latents.shape
|
| 793 |
+
post_patch_num_frames = num_frames // patch_size_t
|
| 794 |
+
post_patch_height = height // patch_size
|
| 795 |
+
post_patch_width = width // patch_size
|
| 796 |
+
latents = latents.reshape(
|
| 797 |
+
batch_size,
|
| 798 |
+
-1,
|
| 799 |
+
post_patch_num_frames,
|
| 800 |
+
patch_size_t,
|
| 801 |
+
post_patch_height,
|
| 802 |
+
patch_size,
|
| 803 |
+
post_patch_width,
|
| 804 |
+
patch_size,
|
| 805 |
+
)
|
| 806 |
+
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
| 807 |
+
return latents
|
| 808 |
+
|
| 809 |
+
@staticmethod
|
| 810 |
+
def _unpack_latents(
|
| 811 |
+
latents: torch.Tensor,
|
| 812 |
+
num_frames: int,
|
| 813 |
+
height: int,
|
| 814 |
+
width: int,
|
| 815 |
+
patch_size: int = 1,
|
| 816 |
+
patch_size_t: int = 1,
|
| 817 |
+
) -> torch.Tensor:
|
| 818 |
+
batch_size = latents.size(0)
|
| 819 |
+
latents = latents.reshape(
|
| 820 |
+
batch_size,
|
| 821 |
+
num_frames,
|
| 822 |
+
height,
|
| 823 |
+
width,
|
| 824 |
+
-1,
|
| 825 |
+
patch_size_t,
|
| 826 |
+
patch_size,
|
| 827 |
+
patch_size,
|
| 828 |
+
)
|
| 829 |
+
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 830 |
+
return latents
|
| 831 |
+
|
| 832 |
+
@staticmethod
|
| 833 |
+
def _normalize_latents(
|
| 834 |
+
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
|
| 835 |
+
) -> torch.Tensor:
|
| 836 |
+
# Normalize latents across the channel dimension [B, C, F, H, W]
|
| 837 |
+
latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
| 838 |
+
latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
| 839 |
+
latents = (latents - latents_mean) / latents_std
|
| 840 |
+
return latents
|
| 841 |
+
|
| 842 |
+
@staticmethod
|
| 843 |
+
def _denormalize_latents(
|
| 844 |
+
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
|
| 845 |
+
) -> torch.Tensor:
|
| 846 |
+
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
| 847 |
+
latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
| 848 |
+
latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
| 849 |
+
latents = latents * latents_std + latents_mean
|
| 850 |
+
return latents
|
| 851 |
+
|
| 852 |
+
def prepare_latents(
|
| 853 |
+
self,
|
| 854 |
+
batch_size: int = 1,
|
| 855 |
+
num_channels_latents: int = 16,
|
| 856 |
+
height: int = 352,
|
| 857 |
+
width: int = 640,
|
| 858 |
+
num_frames: int = 65,
|
| 859 |
+
dtype: Optional[torch.dtype] = None,
|
| 860 |
+
device: Optional[torch.device] = None,
|
| 861 |
+
generator: Optional[torch.Generator] = None,
|
| 862 |
+
latents: Optional[torch.Tensor] = None,
|
| 863 |
+
) -> torch.Tensor:
|
| 864 |
+
if latents is not None:
|
| 865 |
+
return latents.to(device=device, dtype=dtype)
|
| 866 |
+
|
| 867 |
+
shape = (
|
| 868 |
+
batch_size,
|
| 869 |
+
num_channels_latents,
|
| 870 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
| 871 |
+
height // self.vae_scale_factor_spatial,
|
| 872 |
+
width // self.vae_scale_factor_spatial,
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 876 |
+
raise ValueError(
|
| 877 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 878 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 882 |
+
return latents
|
| 883 |
+
|
| 884 |
+
@property
|
| 885 |
+
def num_timesteps(self):
|
| 886 |
+
return self._num_timesteps
|
| 887 |
+
|
| 888 |
+
@property
|
| 889 |
+
def current_timestep(self):
|
| 890 |
+
return self._current_timestep
|
| 891 |
+
|
| 892 |
+
@property
|
| 893 |
+
def attention_kwargs(self):
|
| 894 |
+
return self._attention_kwargs
|
| 895 |
+
|
| 896 |
+
@property
|
| 897 |
+
def interrupt(self):
|
| 898 |
+
return self._interrupt
|
| 899 |
+
|
| 900 |
+
@torch.no_grad()
|
| 901 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 902 |
+
def __call__(
|
| 903 |
+
self,
|
| 904 |
+
prompt: Union[str, List[str]] | None = None,
|
| 905 |
+
image=None,
|
| 906 |
+
negative_prompt: Optional[Union[str, List[str]]] = "text overlay, graphic overlay, watermark, logo, subtitles, timestamp, broadcast graphics, UI elements, random letters, frozen pose, rigid, static expression, jerky motion, mechanical motion, discontinuous motion, flat framing, depthless, dull lighting, monotone, crushed shadows, blown-out highlights, shifting background, fading background, poor continuity, identity drift, deformation, flickering, ghosting, smearing, duplication, mutated proportions, inconsistent clothing, flat colors, desaturated, tonally compressed, poor background separation, exposure shift, uneven brightness, color balance shift",
|
| 907 |
+
height: int = 736,
|
| 908 |
+
width: int = 1280,
|
| 909 |
+
num_frames: int = 121,
|
| 910 |
+
frame_rate: int = 24,
|
| 911 |
+
num_inference_steps: int = 50,
|
| 912 |
+
timesteps: List[int] | None = None,
|
| 913 |
+
use_linear_quadratic_schedule: bool = False,
|
| 914 |
+
linear_quadratic_emulating_steps: int = 250,
|
| 915 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 916 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 917 |
+
latents: Optional[torch.Tensor] = None,
|
| 918 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 919 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 920 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 921 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 922 |
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 923 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 924 |
+
output_type: Optional[str] = "pil",
|
| 925 |
+
return_dict: bool = True,
|
| 926 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 927 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 928 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 929 |
+
max_sequence_length: int = 512,
|
| 930 |
+
use_attention_mask: bool = True,
|
| 931 |
+
vae_batch_size: int | None = None,
|
| 932 |
+
):
|
| 933 |
+
r"""
|
| 934 |
+
Function invoked when calling the pipeline for generation.
|
| 935 |
+
|
| 936 |
+
Args:
|
| 937 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 938 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 939 |
+
instead.
|
| 940 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 941 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 942 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance.
|
| 943 |
+
height (`int`, defaults to `352`):
|
| 944 |
+
The height in pixels of the generated image.
|
| 945 |
+
width (`int`, defaults to `640`):
|
| 946 |
+
The width in pixels of the generated image.
|
| 947 |
+
num_frames (`int`, defaults to `65`):
|
| 948 |
+
The number of video frames to generate
|
| 949 |
+
frame_rate (`int`, defaults to `25`):
|
| 950 |
+
Frame rate for the output video.
|
| 951 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 952 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
| 953 |
+
expense of slower inference.
|
| 954 |
+
timesteps (`List[int]`, *optional*):
|
| 955 |
+
Custom timesteps to use for the denoising process.
|
| 956 |
+
use_linear_quadratic_schedule (`bool`, defaults to `True`):
|
| 957 |
+
Whether to use a linear-quadratic sigma schedule instead of the default linear schedule.
|
| 958 |
+
This schedule combines linear interpolation in the first half (slow denoising at high noise)
|
| 959 |
+
with quadratic interpolation in the second half (faster denoising toward clean image).
|
| 960 |
+
Requires `num_inference_steps` to be even.
|
| 961 |
+
linear_quadratic_emulating_steps (`int`, defaults to `250`):
|
| 962 |
+
Controls the slope of linear interpolation in the first half of the linear-quadratic schedule.
|
| 963 |
+
Higher values result in a gentler slope. Only used when `use_linear_quadratic_schedule=True`.
|
| 964 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 965 |
+
The number of videos to generate per prompt.
|
| 966 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 967 |
+
PyTorch Generator object(s) for deterministic generation.
|
| 968 |
+
latents (`torch.Tensor`, *optional*):
|
| 969 |
+
Pre-generated noisy latents.
|
| 970 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 971 |
+
Pre-generated text embeddings.
|
| 972 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 973 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 974 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 975 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
| 976 |
+
Pre-generated attention mask for text embeddings.
|
| 977 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 978 |
+
Pre-generated negative text embeddings.
|
| 979 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 980 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 981 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 982 |
+
input argument.
|
| 983 |
+
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
| 984 |
+
Pre-generated attention mask for negative text embeddings.
|
| 985 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 986 |
+
The output format ("pil" or "np").
|
| 987 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 988 |
+
Whether to return a `MotifVideoPipelineOutput`.
|
| 989 |
+
attention_kwargs (`dict`, *optional*):
|
| 990 |
+
Arguments passed to the attention processor.
|
| 991 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 992 |
+
Callback function called at the end of each step.
|
| 993 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 994 |
+
Tensors to include in the callback.
|
| 995 |
+
max_sequence_length (`int` defaults to `512`):
|
| 996 |
+
Maximum sequence length for the tokenizer.
|
| 997 |
+
|
| 998 |
+
Examples:
|
| 999 |
+
|
| 1000 |
+
Returns:
|
| 1001 |
+
[`~pipelines.motif_video.MotifVideoPipelineOutput`] or `tuple`:
|
| 1002 |
+
If `return_dict` is `True`, returns [`~pipelines.motif_video.MotifVideoPipelineOutput`],
|
| 1003 |
+
otherwise returns a tuple where the first element is a list of generated video frames.
|
| 1004 |
+
"""
|
| 1005 |
+
|
| 1006 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 1007 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 1008 |
+
|
| 1009 |
+
# 1. Define call parameters (batch_size needed for check_inputs)
|
| 1010 |
+
if prompt is not None and isinstance(prompt, str):
|
| 1011 |
+
batch_size = 1
|
| 1012 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 1013 |
+
batch_size = len(prompt)
|
| 1014 |
+
else:
|
| 1015 |
+
batch_size = prompt_embeds.shape[0]
|
| 1016 |
+
|
| 1017 |
+
# 2. Check inputs. Raise error if not correct
|
| 1018 |
+
self.check_inputs(
|
| 1019 |
+
prompt=prompt,
|
| 1020 |
+
negative_prompt=negative_prompt,
|
| 1021 |
+
height=height,
|
| 1022 |
+
width=width,
|
| 1023 |
+
batch_size=batch_size,
|
| 1024 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 1025 |
+
prompt_embeds=prompt_embeds,
|
| 1026 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1027 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 1028 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
self._attention_kwargs = attention_kwargs
|
| 1032 |
+
self._interrupt = False
|
| 1033 |
+
self._current_timestep = None
|
| 1034 |
+
|
| 1035 |
+
# Auto-upgrade AdaptiveProjectedGuidance to VideoAdaptiveProjectedGuidance
|
| 1036 |
+
# for video generation. Video-aware APG normalizes per-frame [C,H,W] instead
|
| 1037 |
+
# of collapsing the temporal axis, preserving motion quality.
|
| 1038 |
+
if type(self.guider) is AdaptiveProjectedGuidance:
|
| 1039 |
+
self.guider = VideoAdaptiveProjectedGuidance(
|
| 1040 |
+
guidance_scale=self.guider.guidance_scale,
|
| 1041 |
+
adaptive_projected_guidance_rescale=self.guider.adaptive_projected_guidance_rescale,
|
| 1042 |
+
adaptive_projected_guidance_momentum=self.guider.adaptive_projected_guidance_momentum,
|
| 1043 |
+
eta=self.guider.eta,
|
| 1044 |
+
use_original_formulation=self.guider.use_original_formulation,
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
device = self._execution_device
|
| 1048 |
+
|
| 1049 |
+
# 3. Prepare text embeddings
|
| 1050 |
+
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
|
| 1051 |
+
prompt=prompt,
|
| 1052 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 1053 |
+
prompt_embeds=prompt_embeds,
|
| 1054 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1055 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 1056 |
+
max_sequence_length=max_sequence_length,
|
| 1057 |
+
device=device,
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
if self.guider._enabled:
|
| 1061 |
+
negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size)
|
| 1062 |
+
negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
|
| 1063 |
+
prompt=negative_prompt,
|
| 1064 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 1065 |
+
prompt_embeds=negative_prompt_embeds,
|
| 1066 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1067 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
| 1068 |
+
max_sequence_length=max_sequence_length,
|
| 1069 |
+
device=device,
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
num_channels_latents = self.vae.config.z_dim
|
| 1073 |
+
latents = self.prepare_latents(
|
| 1074 |
+
batch_size * num_videos_per_prompt,
|
| 1075 |
+
num_channels_latents,
|
| 1076 |
+
height,
|
| 1077 |
+
width,
|
| 1078 |
+
num_frames,
|
| 1079 |
+
self.transformer.dtype,
|
| 1080 |
+
device,
|
| 1081 |
+
generator,
|
| 1082 |
+
latents,
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
# 4.5 Preprocess image for I2V conditioning
|
| 1086 |
+
if image is not None:
|
| 1087 |
+
from PIL import Image as PILImage
|
| 1088 |
+
|
| 1089 |
+
if isinstance(image, PILImage.Image):
|
| 1090 |
+
image = image.convert("RGB").resize((width, height), PILImage.LANCZOS)
|
| 1091 |
+
image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
|
| 1092 |
+
image = image * 2.0 - 1.0 # [0,1] -> [-1,1]
|
| 1093 |
+
image = image.unsqueeze(0) # [1, C, H, W]
|
| 1094 |
+
# Handle [C, H, W] -> [1, C, H, W]
|
| 1095 |
+
if image.dim() == 3:
|
| 1096 |
+
image = image.unsqueeze(0)
|
| 1097 |
+
# [B, C, H, W] -> [B, 1, C, H, W] for video format
|
| 1098 |
+
if image.dim() == 4:
|
| 1099 |
+
image = image.unsqueeze(1)
|
| 1100 |
+
image = image.to(device=device, dtype=self.vae.dtype)
|
| 1101 |
+
|
| 1102 |
+
# 5. Prepare timesteps (including mu calculation)
|
| 1103 |
+
|
| 1104 |
+
# Recalculate latent dims based on VAE for mu calculation
|
| 1105 |
+
latent_height = height // self.vae_scale_factor_spatial
|
| 1106 |
+
latent_width = width // self.vae_scale_factor_spatial
|
| 1107 |
+
latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 1108 |
+
|
| 1109 |
+
# Calculate sequence length based on *packed* dimensions if transformer uses packing
|
| 1110 |
+
# Packed dims: H/patch, W/patch, F/patch_t
|
| 1111 |
+
packed_latent_height = latent_height // self.transformer_spatial_patch_size
|
| 1112 |
+
packed_latent_width = latent_width // self.transformer_spatial_patch_size
|
| 1113 |
+
packed_latent_num_frames = latent_num_frames // self.transformer_temporal_patch_size
|
| 1114 |
+
video_sequence_length = packed_latent_num_frames * packed_latent_height * packed_latent_width
|
| 1115 |
+
|
| 1116 |
+
# Compute sigmas: use linear-quadratic schedule if enabled, otherwise default linear
|
| 1117 |
+
_is_flow_multistep = isinstance(
|
| 1118 |
+
self.scheduler,
|
| 1119 |
+
(DPMSolverMultistepScheduler, UniPCMultistepScheduler, FlowUniPCMultistepScheduler),
|
| 1120 |
+
)
|
| 1121 |
+
|
| 1122 |
+
# Compute mu once, shared by both branches (required by FlowUniPCMultistepScheduler)
|
| 1123 |
+
mu = calculate_shift(
|
| 1124 |
+
video_sequence_length,
|
| 1125 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 1126 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 1127 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 1128 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
if _is_flow_multistep:
|
| 1132 |
+
# DPMSolver/UniPC manage their own sigma schedule via use_flow_sigmas + flow_shift.
|
| 1133 |
+
# Pass mu for dynamic shifting support (required by FlowUniPCMultistepScheduler).
|
| 1134 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1135 |
+
self.scheduler,
|
| 1136 |
+
num_inference_steps,
|
| 1137 |
+
device,
|
| 1138 |
+
timesteps,
|
| 1139 |
+
mu=mu,
|
| 1140 |
+
)
|
| 1141 |
+
else:
|
| 1142 |
+
if use_linear_quadratic_schedule:
|
| 1143 |
+
# Linear-quadratic schedule computes sigmas internally in retrieve_timesteps
|
| 1144 |
+
sigmas = None
|
| 1145 |
+
else:
|
| 1146 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 1147 |
+
|
| 1148 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1149 |
+
self.scheduler,
|
| 1150 |
+
num_inference_steps,
|
| 1151 |
+
device,
|
| 1152 |
+
timesteps,
|
| 1153 |
+
sigmas=sigmas,
|
| 1154 |
+
use_linear_quadratic_schedule=use_linear_quadratic_schedule,
|
| 1155 |
+
linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
|
| 1156 |
+
mu=mu,
|
| 1157 |
+
)
|
| 1158 |
+
|
| 1159 |
+
# Get conditioning tensors
|
| 1160 |
+
latent_condition, latent_mask, image_embeds = self._prepare_first_frame_conditioning(
|
| 1161 |
+
image,
|
| 1162 |
+
latents,
|
| 1163 |
+
use_conditioning=image is not None,
|
| 1164 |
+
generator=generator,
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1168 |
+
self._num_timesteps = len(timesteps)
|
| 1169 |
+
|
| 1170 |
+
# 6. Denoising loop
|
| 1171 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1172 |
+
for i, t in enumerate(timesteps):
|
| 1173 |
+
if self.interrupt:
|
| 1174 |
+
continue
|
| 1175 |
+
|
| 1176 |
+
self._current_timestep = t
|
| 1177 |
+
|
| 1178 |
+
# Concatenate current latents with conditioning for this timestep
|
| 1179 |
+
# [latents | latent_condition | latent_mask]
|
| 1180 |
+
hidden_states = torch.cat([latents, latent_condition, latent_mask], dim=1)
|
| 1181 |
+
|
| 1182 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1183 |
+
timestep = t.expand(latents.shape[0])
|
| 1184 |
+
|
| 1185 |
+
# Step 1: Collect model inputs needed for the guidance method
|
| 1186 |
+
# conditional inputs should always be first element in the tuple
|
| 1187 |
+
guider_inputs = {
|
| 1188 |
+
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
|
| 1189 |
+
}
|
| 1190 |
+
if use_attention_mask:
|
| 1191 |
+
guider_inputs["encoder_attention_mask"] = (prompt_attention_mask, negative_prompt_attention_mask)
|
| 1192 |
+
if self.transformer.config.pooled_projection_dim is not None:
|
| 1193 |
+
guider_inputs["pooled_projections"] = (pooled_prompt_embeds, negative_pooled_prompt_embeds)
|
| 1194 |
+
if image_embeds is not None:
|
| 1195 |
+
guider_inputs["image_embeds"] = (image_embeds, image_embeds)
|
| 1196 |
+
|
| 1197 |
+
# Step 2: Update guider's internal state for this denoising step
|
| 1198 |
+
self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
| 1199 |
+
# Sigma injection for guiders that support sigma-based gating
|
| 1200 |
+
# (Kynkäänniemi 2024). Must precede `prepare_inputs` because
|
| 1201 |
+
# `num_conditions` → `_is_cfg_enabled()` reads `_current_sigma`.
|
| 1202 |
+
# Duck-typed so diffusers-native guiders are unaffected; guard
|
| 1203 |
+
# on scheduler too since some schedulers don't expose `sigmas`.
|
| 1204 |
+
if hasattr(self.guider, "_current_sigma") and hasattr(self.scheduler, "sigmas"):
|
| 1205 |
+
self.guider._current_sigma = float(self.scheduler.sigmas[i])
|
| 1206 |
+
|
| 1207 |
+
# Step 3: Prepare batched model inputs based on the guidance method
|
| 1208 |
+
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
| 1209 |
+
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
| 1210 |
+
# you will get a guider_state with two batches:
|
| 1211 |
+
# guider_state = [
|
| 1212 |
+
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
| 1213 |
+
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
| 1214 |
+
# ]
|
| 1215 |
+
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
| 1216 |
+
guider_state = self.guider.prepare_inputs(guider_inputs)
|
| 1217 |
+
|
| 1218 |
+
# Step 4: Run the denoiser for each batch
|
| 1219 |
+
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
|
| 1220 |
+
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
|
| 1221 |
+
for guider_state_batch in guider_state:
|
| 1222 |
+
self.guider.prepare_models(self.transformer)
|
| 1223 |
+
|
| 1224 |
+
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
|
| 1225 |
+
cond_kwargs = {
|
| 1226 |
+
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
|
| 1227 |
+
}
|
| 1228 |
+
|
| 1229 |
+
tread_disabled = getattr(self.guider, "_current_tread_disabled", False)
|
| 1230 |
+
|
| 1231 |
+
# Override TREAD selection ratio per batch if the guider provides one
|
| 1232 |
+
selection_ratio = getattr(self.guider, "_current_selection_ratio", None)
|
| 1233 |
+
tread_mixin = getattr(self.transformer, "_inference_tread_mixin", None)
|
| 1234 |
+
if (
|
| 1235 |
+
selection_ratio is not None
|
| 1236 |
+
and tread_mixin is not None
|
| 1237 |
+
and tread_mixin._tread_route is not None
|
| 1238 |
+
):
|
| 1239 |
+
tread_mixin._tread_route["sel"] = selection_ratio
|
| 1240 |
+
|
| 1241 |
+
# e.g. "pred_cond"/"pred_uncond"
|
| 1242 |
+
context_name = getattr(guider_state_batch, self.guider._identifier_key)
|
| 1243 |
+
with self.transformer.cache_context(context_name):
|
| 1244 |
+
# Run denoiser and store noise prediction in this batch
|
| 1245 |
+
|
| 1246 |
+
noise_pred = self.transformer(
|
| 1247 |
+
hidden_states=hidden_states,
|
| 1248 |
+
timestep=timestep,
|
| 1249 |
+
attention_kwargs=self.attention_kwargs,
|
| 1250 |
+
return_dict=False,
|
| 1251 |
+
tread_disabled=tread_disabled,
|
| 1252 |
+
**cond_kwargs,
|
| 1253 |
+
)[0].clone()
|
| 1254 |
+
|
| 1255 |
+
guider_state_batch.noise_pred = noise_pred
|
| 1256 |
+
# Cleanup model (e.g., remove hooks)
|
| 1257 |
+
self.guider.cleanup_models(self.transformer)
|
| 1258 |
+
|
| 1259 |
+
# Step 5: Combine predictions using the guidance method
|
| 1260 |
+
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
|
| 1261 |
+
# Continuing the CFG example, the guider receives:
|
| 1262 |
+
# guider_state = [
|
| 1263 |
+
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
|
| 1264 |
+
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
|
| 1265 |
+
# ]
|
| 1266 |
+
# And extracts predictions using the __guidance_identifier__:
|
| 1267 |
+
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
|
| 1268 |
+
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
|
| 1269 |
+
# Then applies CFG formula:
|
| 1270 |
+
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
| 1271 |
+
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 1272 |
+
noise_pred = self.guider(guider_state)[0]
|
| 1273 |
+
|
| 1274 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1275 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1276 |
+
|
| 1277 |
+
if callback_on_step_end is not None:
|
| 1278 |
+
callback_kwargs = {}
|
| 1279 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1280 |
+
callback_kwargs[k] = locals()[k]
|
| 1281 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1282 |
+
|
| 1283 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1284 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1285 |
+
# Handle negative embeds if needed by callback
|
| 1286 |
+
if "negative_prompt_embeds" in callback_outputs:
|
| 1287 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds")
|
| 1288 |
+
|
| 1289 |
+
# call the callback, if provided
|
| 1290 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1291 |
+
progress_bar.update()
|
| 1292 |
+
|
| 1293 |
+
if XLA_AVAILABLE:
|
| 1294 |
+
xm.mark_step()
|
| 1295 |
+
|
| 1296 |
+
self._current_timestep = None
|
| 1297 |
+
|
| 1298 |
+
if output_type == "latent":
|
| 1299 |
+
video = latents
|
| 1300 |
+
else:
|
| 1301 |
+
latents = latents.to(self.vae.dtype)
|
| 1302 |
+
latents = self._denormalize_latents(latents, self.vae.config.latents_mean, self.vae.config.latents_std)
|
| 1303 |
+
if vae_batch_size is not None and latents.shape[0] > vae_batch_size:
|
| 1304 |
+
video_chunks = []
|
| 1305 |
+
for i in range(0, latents.shape[0], vae_batch_size):
|
| 1306 |
+
chunk = latents[i : i + vae_batch_size]
|
| 1307 |
+
video_chunks.append(self.vae.decode(chunk, return_dict=False)[0])
|
| 1308 |
+
video = torch.cat(video_chunks, dim=0)
|
| 1309 |
+
del video_chunks
|
| 1310 |
+
else:
|
| 1311 |
+
video = self.vae.decode(latents, return_dict=False)[0]
|
| 1312 |
+
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
| 1313 |
+
|
| 1314 |
+
# Offload all models
|
| 1315 |
+
self.maybe_free_model_hooks()
|
| 1316 |
+
|
| 1317 |
+
if not return_dict:
|
| 1318 |
+
return (video,)
|
| 1319 |
+
|
| 1320 |
+
# Return updated output type
|
| 1321 |
+
return MotifVideoPipelineOutput(frames=video)
|
scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"base_image_seq_len": 256,
|
| 5 |
+
"base_shift": 0.5,
|
| 6 |
+
"invert_sigmas": false,
|
| 7 |
+
"max_image_seq_len": 4096,
|
| 8 |
+
"max_shift": 1.15,
|
| 9 |
+
"num_train_timesteps": 1000,
|
| 10 |
+
"shift": 15.0,
|
| 11 |
+
"shift_terminal": null,
|
| 12 |
+
"stochastic_sampling": false,
|
| 13 |
+
"time_shift_type": "exponential",
|
| 14 |
+
"use_beta_sigmas": false,
|
| 15 |
+
"use_dynamic_shifting": false,
|
| 16 |
+
"use_exponential_sigmas": false,
|
| 17 |
+
"use_karras_sigmas": false
|
| 18 |
+
}
|
text_encoder/config.json
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"T5Gemma2Model"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
+
"bos_token_id": 2,
|
| 7 |
+
"classifier_dropout_rate": 0.0,
|
| 8 |
+
"decoder": {
|
| 9 |
+
"_sliding_window_pattern": 6,
|
| 10 |
+
"attention_bias": false,
|
| 11 |
+
"attention_dropout": 0.0,
|
| 12 |
+
"attn_logit_softcapping": null,
|
| 13 |
+
"dropout_rate": 0.0,
|
| 14 |
+
"dtype": "bfloat16",
|
| 15 |
+
"final_logit_softcapping": null,
|
| 16 |
+
"head_dim": 256,
|
| 17 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
| 18 |
+
"hidden_size": 2560,
|
| 19 |
+
"initializer_range": 0.02,
|
| 20 |
+
"intermediate_size": 10240,
|
| 21 |
+
"layer_types": [
|
| 22 |
+
"sliding_attention",
|
| 23 |
+
"sliding_attention",
|
| 24 |
+
"sliding_attention",
|
| 25 |
+
"sliding_attention",
|
| 26 |
+
"sliding_attention",
|
| 27 |
+
"full_attention",
|
| 28 |
+
"sliding_attention",
|
| 29 |
+
"sliding_attention",
|
| 30 |
+
"sliding_attention",
|
| 31 |
+
"sliding_attention",
|
| 32 |
+
"sliding_attention",
|
| 33 |
+
"full_attention",
|
| 34 |
+
"sliding_attention",
|
| 35 |
+
"sliding_attention",
|
| 36 |
+
"sliding_attention",
|
| 37 |
+
"sliding_attention",
|
| 38 |
+
"sliding_attention",
|
| 39 |
+
"full_attention",
|
| 40 |
+
"sliding_attention",
|
| 41 |
+
"sliding_attention",
|
| 42 |
+
"sliding_attention",
|
| 43 |
+
"sliding_attention",
|
| 44 |
+
"sliding_attention",
|
| 45 |
+
"full_attention",
|
| 46 |
+
"sliding_attention",
|
| 47 |
+
"sliding_attention",
|
| 48 |
+
"sliding_attention",
|
| 49 |
+
"sliding_attention",
|
| 50 |
+
"sliding_attention",
|
| 51 |
+
"full_attention",
|
| 52 |
+
"sliding_attention",
|
| 53 |
+
"sliding_attention",
|
| 54 |
+
"sliding_attention",
|
| 55 |
+
"sliding_attention"
|
| 56 |
+
],
|
| 57 |
+
"max_position_embeddings": 131072,
|
| 58 |
+
"model_type": "t5gemma2_decoder",
|
| 59 |
+
"num_attention_heads": 8,
|
| 60 |
+
"num_hidden_layers": 34,
|
| 61 |
+
"num_key_value_heads": 4,
|
| 62 |
+
"query_pre_attn_scalar": 256,
|
| 63 |
+
"rms_norm_eps": 1e-06,
|
| 64 |
+
"rope_parameters": {
|
| 65 |
+
"full_attention": {
|
| 66 |
+
"factor": 8.0,
|
| 67 |
+
"rope_theta": 1000000,
|
| 68 |
+
"rope_type": "linear"
|
| 69 |
+
},
|
| 70 |
+
"sliding_attention": {
|
| 71 |
+
"rope_theta": 10000,
|
| 72 |
+
"rope_type": "default"
|
| 73 |
+
}
|
| 74 |
+
},
|
| 75 |
+
"sliding_window": 1024,
|
| 76 |
+
"use_bidirectional_attention": false,
|
| 77 |
+
"use_cache": true,
|
| 78 |
+
"vocab_size": 262144
|
| 79 |
+
},
|
| 80 |
+
"dropout_rate": 0.0,
|
| 81 |
+
"dtype": "bfloat16",
|
| 82 |
+
"encoder": {
|
| 83 |
+
"attention_dropout": 0.0,
|
| 84 |
+
"boi_token_index": 255999,
|
| 85 |
+
"dropout_rate": 0.0,
|
| 86 |
+
"dtype": "bfloat16",
|
| 87 |
+
"eoi_token_index": 256000,
|
| 88 |
+
"image_token_index": 256001,
|
| 89 |
+
"initializer_range": 0.02,
|
| 90 |
+
"mm_tokens_per_image": 256,
|
| 91 |
+
"model_type": "t5gemma2_encoder",
|
| 92 |
+
"text_config": {
|
| 93 |
+
"_name_or_path": "",
|
| 94 |
+
"_sliding_window_pattern": 6,
|
| 95 |
+
"add_cross_attention": false,
|
| 96 |
+
"architectures": null,
|
| 97 |
+
"attention_bias": false,
|
| 98 |
+
"attention_dropout": 0.0,
|
| 99 |
+
"attn_logit_softcapping": null,
|
| 100 |
+
"bos_token_id": 2,
|
| 101 |
+
"chunk_size_feed_forward": 0,
|
| 102 |
+
"cross_attention_hidden_size": null,
|
| 103 |
+
"decoder_start_token_id": null,
|
| 104 |
+
"dropout_rate": 0.0,
|
| 105 |
+
"dtype": "bfloat16",
|
| 106 |
+
"eos_token_id": 1,
|
| 107 |
+
"final_logit_softcapping": null,
|
| 108 |
+
"finetuning_task": null,
|
| 109 |
+
"head_dim": 256,
|
| 110 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
| 111 |
+
"hidden_size": 2560,
|
| 112 |
+
"id2label": {
|
| 113 |
+
"0": "LABEL_0",
|
| 114 |
+
"1": "LABEL_1"
|
| 115 |
+
},
|
| 116 |
+
"initializer_range": 0.02,
|
| 117 |
+
"intermediate_size": 10240,
|
| 118 |
+
"is_decoder": false,
|
| 119 |
+
"is_encoder_decoder": false,
|
| 120 |
+
"label2id": {
|
| 121 |
+
"LABEL_0": 0,
|
| 122 |
+
"LABEL_1": 1
|
| 123 |
+
},
|
| 124 |
+
"layer_types": [
|
| 125 |
+
"sliding_attention",
|
| 126 |
+
"sliding_attention",
|
| 127 |
+
"sliding_attention",
|
| 128 |
+
"sliding_attention",
|
| 129 |
+
"sliding_attention",
|
| 130 |
+
"full_attention",
|
| 131 |
+
"sliding_attention",
|
| 132 |
+
"sliding_attention",
|
| 133 |
+
"sliding_attention",
|
| 134 |
+
"sliding_attention",
|
| 135 |
+
"sliding_attention",
|
| 136 |
+
"full_attention",
|
| 137 |
+
"sliding_attention",
|
| 138 |
+
"sliding_attention",
|
| 139 |
+
"sliding_attention",
|
| 140 |
+
"sliding_attention",
|
| 141 |
+
"sliding_attention",
|
| 142 |
+
"full_attention",
|
| 143 |
+
"sliding_attention",
|
| 144 |
+
"sliding_attention",
|
| 145 |
+
"sliding_attention",
|
| 146 |
+
"sliding_attention",
|
| 147 |
+
"sliding_attention",
|
| 148 |
+
"full_attention",
|
| 149 |
+
"sliding_attention",
|
| 150 |
+
"sliding_attention",
|
| 151 |
+
"sliding_attention",
|
| 152 |
+
"sliding_attention",
|
| 153 |
+
"sliding_attention",
|
| 154 |
+
"full_attention",
|
| 155 |
+
"sliding_attention",
|
| 156 |
+
"sliding_attention",
|
| 157 |
+
"sliding_attention",
|
| 158 |
+
"sliding_attention"
|
| 159 |
+
],
|
| 160 |
+
"max_position_embeddings": 131072,
|
| 161 |
+
"model_type": "t5gemma2_text",
|
| 162 |
+
"num_attention_heads": 8,
|
| 163 |
+
"num_hidden_layers": 34,
|
| 164 |
+
"num_key_value_heads": 4,
|
| 165 |
+
"output_attentions": false,
|
| 166 |
+
"output_hidden_states": false,
|
| 167 |
+
"pad_token_id": 0,
|
| 168 |
+
"prefix": null,
|
| 169 |
+
"problem_type": null,
|
| 170 |
+
"query_pre_attn_scalar": 256,
|
| 171 |
+
"return_dict": true,
|
| 172 |
+
"rms_norm_eps": 1e-06,
|
| 173 |
+
"rope_parameters": {
|
| 174 |
+
"full_attention": {
|
| 175 |
+
"factor": 8.0,
|
| 176 |
+
"rope_theta": 1000000,
|
| 177 |
+
"rope_type": "linear"
|
| 178 |
+
},
|
| 179 |
+
"sliding_attention": {
|
| 180 |
+
"rope_theta": 10000,
|
| 181 |
+
"rope_type": "default"
|
| 182 |
+
}
|
| 183 |
+
},
|
| 184 |
+
"sep_token_id": null,
|
| 185 |
+
"sliding_window": 1024,
|
| 186 |
+
"task_specific_params": null,
|
| 187 |
+
"tie_encoder_decoder": false,
|
| 188 |
+
"tie_word_embeddings": true,
|
| 189 |
+
"tokenizer_class": null,
|
| 190 |
+
"use_bidirectional_attention": false,
|
| 191 |
+
"use_cache": true,
|
| 192 |
+
"vocab_size": 262144
|
| 193 |
+
},
|
| 194 |
+
"vision_config": {
|
| 195 |
+
"_name_or_path": "",
|
| 196 |
+
"add_cross_attention": false,
|
| 197 |
+
"architectures": null,
|
| 198 |
+
"attention_dropout": 0.0,
|
| 199 |
+
"bos_token_id": null,
|
| 200 |
+
"chunk_size_feed_forward": 0,
|
| 201 |
+
"cross_attention_hidden_size": null,
|
| 202 |
+
"decoder_start_token_id": null,
|
| 203 |
+
"dropout_rate": 0.0,
|
| 204 |
+
"dtype": "bfloat16",
|
| 205 |
+
"eos_token_id": null,
|
| 206 |
+
"finetuning_task": null,
|
| 207 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 208 |
+
"hidden_size": 1152,
|
| 209 |
+
"id2label": {
|
| 210 |
+
"0": "LABEL_0",
|
| 211 |
+
"1": "LABEL_1"
|
| 212 |
+
},
|
| 213 |
+
"image_size": 896,
|
| 214 |
+
"intermediate_size": 4304,
|
| 215 |
+
"is_decoder": false,
|
| 216 |
+
"is_encoder_decoder": false,
|
| 217 |
+
"label2id": {
|
| 218 |
+
"LABEL_0": 0,
|
| 219 |
+
"LABEL_1": 1
|
| 220 |
+
},
|
| 221 |
+
"layer_norm_eps": 1e-06,
|
| 222 |
+
"model_type": "siglip_vision_model",
|
| 223 |
+
"num_attention_heads": 16,
|
| 224 |
+
"num_channels": 3,
|
| 225 |
+
"num_hidden_layers": 27,
|
| 226 |
+
"output_attentions": false,
|
| 227 |
+
"output_hidden_states": false,
|
| 228 |
+
"pad_token_id": null,
|
| 229 |
+
"patch_size": 14,
|
| 230 |
+
"prefix": null,
|
| 231 |
+
"problem_type": null,
|
| 232 |
+
"return_dict": true,
|
| 233 |
+
"sep_token_id": null,
|
| 234 |
+
"task_specific_params": null,
|
| 235 |
+
"tie_encoder_decoder": false,
|
| 236 |
+
"tie_word_embeddings": true,
|
| 237 |
+
"tokenizer_class": null,
|
| 238 |
+
"vision_use_head": false,
|
| 239 |
+
"vocab_size": 262144
|
| 240 |
+
},
|
| 241 |
+
"vocab_size": 262144
|
| 242 |
+
},
|
| 243 |
+
"eoi_token_index": 256000,
|
| 244 |
+
"eos_token_id": 1,
|
| 245 |
+
"image_token_index": 256001,
|
| 246 |
+
"initializer_range": 0.02,
|
| 247 |
+
"is_encoder_decoder": true,
|
| 248 |
+
"model_type": "t5gemma2",
|
| 249 |
+
"pad_token_id": 0,
|
| 250 |
+
"transformers_version": "5.0.0rc1",
|
| 251 |
+
"vocab_size": 262144
|
| 252 |
+
}
|
text_encoder/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c7dd568c34c56a521475124f226983dc191e57aa9b1cac9a22a87dcc753cb57
|
| 3 |
+
size 16360212008
|
tokenizer/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3220c5bec16e78ddf8e59c08fecdede7e8d31820cb5b3e69f17fed6a29a0b30c
|
| 3 |
+
size 33378248
|
tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": null,
|
| 3 |
+
"backend": "tokenizers",
|
| 4 |
+
"boi_token": "<start_of_image>",
|
| 5 |
+
"bos_token": "<bos>",
|
| 6 |
+
"clean_up_tokenization_spaces": false,
|
| 7 |
+
"eoi_token": "<end_of_image>",
|
| 8 |
+
"eos_token": "<eos>",
|
| 9 |
+
"image_token": "<image_soft_token>",
|
| 10 |
+
"is_local": false,
|
| 11 |
+
"mask_token": "<mask>",
|
| 12 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 13 |
+
"model_specific_special_tokens": {
|
| 14 |
+
"boi_token": "<start_of_image>",
|
| 15 |
+
"eoi_token": "<end_of_image>",
|
| 16 |
+
"image_token": "<image_soft_token>"
|
| 17 |
+
},
|
| 18 |
+
"pad_token": "<pad>",
|
| 19 |
+
"padding_side": "right",
|
| 20 |
+
"processor_class": "Gemma3Processor",
|
| 21 |
+
"sp_model_kwargs": null,
|
| 22 |
+
"spaces_between_special_tokens": false,
|
| 23 |
+
"tokenizer_class": "GemmaTokenizer",
|
| 24 |
+
"unk_token": "<unk>",
|
| 25 |
+
"use_default_system_prompt": false
|
| 26 |
+
}
|
transformer/config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "MotifVideoTransformer3DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"_library": "diffusers",
|
| 5 |
+
"attention_head_dim": 128,
|
| 6 |
+
"base_latent_size": null,
|
| 7 |
+
"image_condition_type": null,
|
| 8 |
+
"image_embed_dim": 1152,
|
| 9 |
+
"in_channels": 33,
|
| 10 |
+
"mlp_ratio": 4.0,
|
| 11 |
+
"norm_type": "layer_norm",
|
| 12 |
+
"num_attention_heads": 12,
|
| 13 |
+
"num_decoder_layers": 8,
|
| 14 |
+
"num_layers": 12,
|
| 15 |
+
"num_single_layers": 24,
|
| 16 |
+
"out_channels": 16,
|
| 17 |
+
"patch_size": 2,
|
| 18 |
+
"patch_size_t": 1,
|
| 19 |
+
"pooled_projection_dim": null,
|
| 20 |
+
"qk_norm": "rms_norm",
|
| 21 |
+
"rope_axes_dim": [
|
| 22 |
+
16,
|
| 23 |
+
56,
|
| 24 |
+
56
|
| 25 |
+
],
|
| 26 |
+
"rope_theta": 10000.0,
|
| 27 |
+
"text_embed_dim": 2560,
|
| 28 |
+
"enable_text_cross_attention_dual": false,
|
| 29 |
+
"enable_text_cross_attention_single": true
|
| 30 |
+
}
|
transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3a8b17a188d358d9d0e9b097f5fc58c094f37a82a00cbb8895e54c2bdd73f6ff
|
| 3 |
+
size 7849331048
|
transformer/transformer_motif_video.py
ADDED
|
@@ -0,0 +1,1350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Motif Technologies. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from functools import lru_cache
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
|
| 25 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 26 |
+
from diffusers.models.attention import FeedForward
|
| 27 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
| 28 |
+
from diffusers.models.cache_utils import CacheMixin
|
| 29 |
+
from diffusers.models.embeddings import (
|
| 30 |
+
PixArtAlphaTextProjection,
|
| 31 |
+
TimestepEmbedding,
|
| 32 |
+
Timesteps,
|
| 33 |
+
)
|
| 34 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 35 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 36 |
+
from diffusers.models.normalization import (
|
| 37 |
+
AdaLayerNormContinuous,
|
| 38 |
+
AdaLayerNormZero,
|
| 39 |
+
AdaLayerNormZeroSingle,
|
| 40 |
+
)
|
| 41 |
+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 42 |
+
|
| 43 |
+
# Stub functions for TREAD (Token REduction with Approximated Distillation).
|
| 44 |
+
# These stubs ensure TREAD code paths are never activated during inference
|
| 45 |
+
# without requiring the motif_core package.
|
| 46 |
+
def is_tread_start(block_idx, start, end): return False
|
| 47 |
+
def is_tread_end(block_idx, start, end): return False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 51 |
+
|
| 52 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def apply_rotary_emb(
|
| 56 |
+
x: torch.Tensor,
|
| 57 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
| 58 |
+
use_real: bool = True,
|
| 59 |
+
use_real_unbind_dim: int = -1,
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Apply rotary positional embeddings (RoPE) to input tensors.
|
| 63 |
+
|
| 64 |
+
This implementation supports both standard 2D RoPE tensors [L, Dh] and batched 4D RoPE
|
| 65 |
+
tensors [B, 1, L, Dh] for compatibility with TREAD's token-dropping mechanism where
|
| 66 |
+
different batches may have different token subsets.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
x: Input tensor of shape [B, H, L, Dh].
|
| 70 |
+
freqs_cis: Tuple of (cos, sin) tensors. Supports shapes [L, Dh] or [B, 1, L, Dh].
|
| 71 |
+
use_real: Whether to use real-valued RoPE implementation.
|
| 72 |
+
use_real_unbind_dim: Dimension to unbind when using real-valued RoPE (-1 or -2).
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Tensor with rotary embeddings applied, same shape as input x.
|
| 76 |
+
"""
|
| 77 |
+
if use_real:
|
| 78 |
+
cos, sin = freqs_cis
|
| 79 |
+
if cos.dim() == 2: # [L, Dh] → [1, 1, L, Dh]
|
| 80 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 81 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 82 |
+
if cos.dim() != 4 or sin.dim() != 4:
|
| 83 |
+
raise RuntimeError(f"RoPE must be 2D or 4D, got cos={cos.dim()}D, sin={sin.dim()}D")
|
| 84 |
+
|
| 85 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 86 |
+
|
| 87 |
+
if cos.size(-2) != x.size(-2) or cos.size(-1) != x.size(-1):
|
| 88 |
+
raise RuntimeError(
|
| 89 |
+
f"RoPE shape mismatch: rope[-2:]=({cos.size(-2)},{cos.size(-1)}) vs x[-2:]=({x.size(-2)},{x.size(-1)})"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if use_real_unbind_dim == -1:
|
| 93 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
| 94 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 95 |
+
elif use_real_unbind_dim == -2:
|
| 96 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
|
| 97 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 100 |
+
|
| 101 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 102 |
+
return out
|
| 103 |
+
else:
|
| 104 |
+
x_rot = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 105 |
+
freqs = freqs_cis.unsqueeze(2)
|
| 106 |
+
x_out = torch.view_as_real(x_rot * freqs).flatten(3)
|
| 107 |
+
return x_out.type_as(x)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class MotifVideoAttnProcessor2_0:
|
| 111 |
+
def __init__(self):
|
| 112 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 113 |
+
raise ImportError(
|
| 114 |
+
"MotifVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def __call__(
|
| 118 |
+
self,
|
| 119 |
+
attn: Attention,
|
| 120 |
+
hidden_states: torch.Tensor,
|
| 121 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 122 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 123 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 124 |
+
query_input: Optional[torch.Tensor] = None,
|
| 125 |
+
key_input: Optional[torch.Tensor] = None,
|
| 126 |
+
value_input: Optional[torch.Tensor] = None,
|
| 127 |
+
) -> torch.Tensor:
|
| 128 |
+
# Cross-attention mode: query already projected externally (cross_attn_query_proj + norm),
|
| 129 |
+
# skip to_q and only apply reshape + norm_q + RoPE. K/V use to_k/to_v as normal.
|
| 130 |
+
if query_input is not None:
|
| 131 |
+
query = query_input.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 132 |
+
key = attn.to_k(key_input)
|
| 133 |
+
value = attn.to_v(value_input)
|
| 134 |
+
|
| 135 |
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 136 |
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 137 |
+
|
| 138 |
+
if attn.norm_q is not None:
|
| 139 |
+
query = attn.norm_q(query)
|
| 140 |
+
if attn.norm_k is not None:
|
| 141 |
+
key = attn.norm_k(key)
|
| 142 |
+
|
| 143 |
+
if image_rotary_emb is not None:
|
| 144 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 145 |
+
|
| 146 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 147 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 148 |
+
)
|
| 149 |
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
| 150 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 151 |
+
return hidden_states, None
|
| 152 |
+
|
| 153 |
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 154 |
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 155 |
+
|
| 156 |
+
# 1. QKV projections
|
| 157 |
+
query = attn.to_q(hidden_states)
|
| 158 |
+
key = attn.to_k(hidden_states)
|
| 159 |
+
value = attn.to_v(hidden_states)
|
| 160 |
+
|
| 161 |
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 162 |
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 163 |
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 164 |
+
|
| 165 |
+
# 2. QK normalization
|
| 166 |
+
if attn.norm_q is not None:
|
| 167 |
+
query = attn.norm_q(query)
|
| 168 |
+
if attn.norm_k is not None:
|
| 169 |
+
key = attn.norm_k(key)
|
| 170 |
+
|
| 171 |
+
# 3. Rotational positional embeddings applied to latent stream
|
| 172 |
+
if image_rotary_emb is not None:
|
| 173 |
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 174 |
+
query = torch.cat(
|
| 175 |
+
[
|
| 176 |
+
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 177 |
+
query[:, :, -encoder_hidden_states.shape[1] :],
|
| 178 |
+
],
|
| 179 |
+
dim=2,
|
| 180 |
+
)
|
| 181 |
+
key = torch.cat(
|
| 182 |
+
[
|
| 183 |
+
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 184 |
+
key[:, :, -encoder_hidden_states.shape[1] :],
|
| 185 |
+
],
|
| 186 |
+
dim=2,
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 190 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 191 |
+
|
| 192 |
+
# 4. Encoder condition QKV projection and normalization
|
| 193 |
+
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
| 194 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 195 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 196 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 197 |
+
|
| 198 |
+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 199 |
+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 200 |
+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 201 |
+
|
| 202 |
+
if attn.norm_added_q is not None:
|
| 203 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 204 |
+
if attn.norm_added_k is not None:
|
| 205 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 206 |
+
|
| 207 |
+
query = torch.cat([query, encoder_query], dim=2)
|
| 208 |
+
key = torch.cat([key, encoder_key], dim=2)
|
| 209 |
+
value = torch.cat([value, encoder_value], dim=2)
|
| 210 |
+
|
| 211 |
+
# 5. Attention
|
| 212 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 213 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 214 |
+
)
|
| 215 |
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
| 216 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 217 |
+
|
| 218 |
+
# 6. Output projection
|
| 219 |
+
if encoder_hidden_states is not None:
|
| 220 |
+
hidden_states, encoder_hidden_states = (
|
| 221 |
+
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
| 222 |
+
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if getattr(attn, "to_out", None) is not None:
|
| 226 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 227 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 228 |
+
|
| 229 |
+
if getattr(attn, "to_add_out", None) is not None:
|
| 230 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 231 |
+
|
| 232 |
+
return hidden_states, encoder_hidden_states
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class MotifVideoPatchEmbed(nn.Module):
|
| 236 |
+
def __init__(
|
| 237 |
+
self,
|
| 238 |
+
patch_size: Union[int, Tuple[int, int, int]] = 16,
|
| 239 |
+
in_chans: int = 3,
|
| 240 |
+
embed_dim: int = 768,
|
| 241 |
+
) -> None:
|
| 242 |
+
super().__init__()
|
| 243 |
+
|
| 244 |
+
patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
|
| 245 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 246 |
+
|
| 247 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 248 |
+
hidden_states = self.proj(hidden_states)
|
| 249 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
|
| 250 |
+
return hidden_states
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class MotifVideoAdaNorm(nn.Module):
|
| 254 |
+
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
| 255 |
+
super().__init__()
|
| 256 |
+
|
| 257 |
+
out_features = out_features or 2 * in_features
|
| 258 |
+
self.linear = nn.Linear(in_features, out_features)
|
| 259 |
+
self.nonlinearity = nn.SiLU()
|
| 260 |
+
|
| 261 |
+
def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 262 |
+
temb = self.linear(self.nonlinearity(temb))
|
| 263 |
+
gate_msa, gate_mlp = temb.chunk(2, dim=1)
|
| 264 |
+
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
| 265 |
+
return gate_msa, gate_mlp
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class MotifVideoConditionEmbedding(nn.Module):
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
embedding_dim: int,
|
| 272 |
+
pooled_projection_dim: int | None,
|
| 273 |
+
):
|
| 274 |
+
super().__init__()
|
| 275 |
+
|
| 276 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 277 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 278 |
+
|
| 279 |
+
if isinstance(pooled_projection_dim, int):
|
| 280 |
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
| 281 |
+
|
| 282 |
+
def forward(
|
| 283 |
+
self,
|
| 284 |
+
timestep: torch.Tensor,
|
| 285 |
+
pooled_projection: torch.Tensor | None = None,
|
| 286 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 287 |
+
timesteps_proj = self.time_proj(timestep)
|
| 288 |
+
timestep_embedder_dtype = next(self.timestep_embedder.parameters()).dtype
|
| 289 |
+
conditioning = self.timestep_embedder(timesteps_proj.to(timestep_embedder_dtype)) # (N, D)
|
| 290 |
+
if pooled_projection is not None:
|
| 291 |
+
conditioning = conditioning + self.text_embedder(pooled_projection)
|
| 292 |
+
|
| 293 |
+
token_replace_emb = None
|
| 294 |
+
|
| 295 |
+
return conditioning, token_replace_emb
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L485-L486
|
| 299 |
+
def find_correction_factor(num_rotations, dim, base, max_position_embeddings):
|
| 300 |
+
dtype = num_rotations.dtype if isinstance(num_rotations, torch.Tensor) else torch.float32
|
| 301 |
+
max_pos_tensor = torch.as_tensor(max_position_embeddings, dtype=dtype)
|
| 302 |
+
return (dim * torch.log(max_pos_tensor / (num_rotations * 2 * math.pi))) / (
|
| 303 |
+
2 * math.log(base)
|
| 304 |
+
) # Inverse dim formula to find number of rotations
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L489-L495
|
| 308 |
+
def find_correction_range(low_ratio, high_ratio, dim, base, ori_max_pe_len):
|
| 309 |
+
"""
|
| 310 |
+
Find the correction range for NTK-by-parts interpolation.
|
| 311 |
+
"""
|
| 312 |
+
low = torch.floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len))
|
| 313 |
+
high = torch.ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len))
|
| 314 |
+
low = torch.clamp(low, min=0)
|
| 315 |
+
high = torch.clamp(high, max=dim - 1)
|
| 316 |
+
return low, high # Clamp values just in case
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L498-L504
|
| 320 |
+
def linear_ramp_mask(min_val, max_val, num_dim):
|
| 321 |
+
if isinstance(min_val, torch.Tensor):
|
| 322 |
+
if (min_val == max_val).all():
|
| 323 |
+
max_val = max_val + 0.001
|
| 324 |
+
elif min_val == max_val:
|
| 325 |
+
max_val += 0.001
|
| 326 |
+
|
| 327 |
+
linear_func = (torch.arange(num_dim, dtype=torch.float32) - min_val) / (max_val - min_val)
|
| 328 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 329 |
+
return ramp_func
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L507-L511
|
| 333 |
+
def find_newbase_ntk(dim, base, scale):
|
| 334 |
+
"""
|
| 335 |
+
Calculate the new base for NTK-aware scaling.
|
| 336 |
+
"""
|
| 337 |
+
# Avoid division by zero when dim == 2 (or invalid smaller values).
|
| 338 |
+
# In these degenerate cases, fall back to the original base (no NTK adjustment).
|
| 339 |
+
if dim <= 2:
|
| 340 |
+
return base
|
| 341 |
+
return base * (scale ** (dim / (dim - 2)))
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L514-L652
|
| 345 |
+
def get_1d_rotary_pos_embed(
|
| 346 |
+
dim: int,
|
| 347 |
+
pos: Union[np.ndarray, int],
|
| 348 |
+
theta: float = 10000.0,
|
| 349 |
+
use_real=False,
|
| 350 |
+
linear_factor=1.0,
|
| 351 |
+
ntk_factor=1.0,
|
| 352 |
+
repeat_interleave_real=True,
|
| 353 |
+
freqs_dtype=torch.float32,
|
| 354 |
+
yarn=False,
|
| 355 |
+
max_pe_len=None,
|
| 356 |
+
ori_max_pe_len=64,
|
| 357 |
+
dype=False,
|
| 358 |
+
current_timestep=1.0,
|
| 359 |
+
):
|
| 360 |
+
"""
|
| 361 |
+
Precompute the frequency tensor for complex exponentials with RoPE.
|
| 362 |
+
Supports YARN interpolation for vision transformers.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
dim (`int`):
|
| 366 |
+
Dimension of the frequency tensor.
|
| 367 |
+
pos (`np.ndarray` or `int`):
|
| 368 |
+
Position indices for the frequency tensor. [S] or scalar.
|
| 369 |
+
theta (`float`, *optional*, defaults to 10000.0):
|
| 370 |
+
Scaling factor for frequency computation.
|
| 371 |
+
use_real (`bool`, *optional*, defaults to False):
|
| 372 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| 373 |
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
| 374 |
+
Scaling factor for linear interpolation.
|
| 375 |
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
| 376 |
+
Scaling factor for NTK-Aware RoPE.
|
| 377 |
+
repeat_interleave_real (`bool`, *optional*, defaults to True):
|
| 378 |
+
If True and use_real, real and imaginary parts are interleaved with themselves to reach dim.
|
| 379 |
+
Otherwise, they are concatenated.
|
| 380 |
+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
| 381 |
+
Data type of the frequency tensor.
|
| 382 |
+
yarn (`bool`, *optional*, defaults to False):
|
| 383 |
+
If True, use YARN interpolation combining NTK, linear, and base methods.
|
| 384 |
+
max_pe_len (`int`, *optional*):
|
| 385 |
+
Maximum position encoding length (current patches for vision models).
|
| 386 |
+
ori_max_pe_len (`int`, *optional*, defaults to 64):
|
| 387 |
+
Original maximum position encoding length (base patches for vision models).
|
| 388 |
+
dype (`bool`, *optional*, defaults to False):
|
| 389 |
+
If True, enable DyPE (Dynamic Position Encoding) with timestep-aware scaling.
|
| 390 |
+
current_timestep (`float`, *optional*, defaults to 1.0):
|
| 391 |
+
Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise.
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
| 395 |
+
If use_real=True, returns tuple of (cos, sin) tensors.
|
| 396 |
+
"""
|
| 397 |
+
assert dim % 2 == 0
|
| 398 |
+
|
| 399 |
+
if isinstance(pos, int):
|
| 400 |
+
pos = torch.arange(pos)
|
| 401 |
+
if isinstance(pos, np.ndarray):
|
| 402 |
+
pos = torch.from_numpy(pos)
|
| 403 |
+
|
| 404 |
+
device = pos.device
|
| 405 |
+
|
| 406 |
+
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
| 407 |
+
if not isinstance(max_pe_len, torch.Tensor):
|
| 408 |
+
max_pe_len = torch.tensor(max_pe_len, dtype=freqs_dtype, device=device)
|
| 409 |
+
|
| 410 |
+
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
|
| 411 |
+
|
| 412 |
+
beta_0 = 1.25
|
| 413 |
+
beta_1 = 0.75
|
| 414 |
+
gamma_0 = 16
|
| 415 |
+
gamma_1 = 2
|
| 416 |
+
|
| 417 |
+
freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
|
| 418 |
+
|
| 419 |
+
freqs_linear = 1.0 / torch.einsum(
|
| 420 |
+
"..., f -> ... f",
|
| 421 |
+
scale,
|
| 422 |
+
(theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)),
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
new_base = find_newbase_ntk(dim, theta, scale)
|
| 426 |
+
if new_base.dim() > 0:
|
| 427 |
+
new_base = new_base.view(-1, 1)
|
| 428 |
+
freqs_ntk = 1.0 / torch.pow(new_base, (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
|
| 429 |
+
if freqs_ntk.dim() > 1:
|
| 430 |
+
freqs_ntk = freqs_ntk.squeeze()
|
| 431 |
+
|
| 432 |
+
if dype:
|
| 433 |
+
beta_0 = torch.pow(beta_0, 2.0 * torch.pow(current_timestep, 2.0))
|
| 434 |
+
beta_1 = torch.pow(beta_1, 2.0 * torch.pow(current_timestep, 2.0))
|
| 435 |
+
|
| 436 |
+
low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
|
| 437 |
+
high = torch.clamp(high, max=dim // 2)
|
| 438 |
+
|
| 439 |
+
freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
|
| 440 |
+
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
|
| 441 |
+
|
| 442 |
+
if dype:
|
| 443 |
+
gamma_0 = torch.pow(gamma_0, 2.0 * torch.pow(current_timestep, 2.0))
|
| 444 |
+
gamma_1 = torch.pow(gamma_1, 2.0 * torch.pow(current_timestep, 2.0))
|
| 445 |
+
|
| 446 |
+
low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
|
| 447 |
+
high = torch.clamp(high, max=dim // 2)
|
| 448 |
+
|
| 449 |
+
freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
|
| 450 |
+
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
|
| 451 |
+
|
| 452 |
+
else:
|
| 453 |
+
theta_ntk = theta * ntk_factor
|
| 454 |
+
freqs = 1.0 / (theta_ntk ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)) / linear_factor
|
| 455 |
+
|
| 456 |
+
freqs = torch.outer(pos, freqs)
|
| 457 |
+
|
| 458 |
+
is_npu = freqs.device.type == "npu"
|
| 459 |
+
if is_npu:
|
| 460 |
+
freqs = freqs.float()
|
| 461 |
+
|
| 462 |
+
if use_real and repeat_interleave_real:
|
| 463 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
|
| 464 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
|
| 465 |
+
|
| 466 |
+
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
| 467 |
+
mscale = torch.where(scale <= 1.0, 1.0, 0.1 * torch.log(scale) + 1.0).to(scale)
|
| 468 |
+
freqs_cos = freqs_cos * mscale
|
| 469 |
+
freqs_sin = freqs_sin * mscale
|
| 470 |
+
|
| 471 |
+
return freqs_cos, freqs_sin
|
| 472 |
+
elif use_real:
|
| 473 |
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()
|
| 474 |
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()
|
| 475 |
+
return freqs_cos, freqs_sin
|
| 476 |
+
else:
|
| 477 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 478 |
+
return freqs_cis
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class MotifVideoRotaryPosEmbed(nn.Module):
|
| 482 |
+
def __init__(
|
| 483 |
+
self,
|
| 484 |
+
patch_size: int,
|
| 485 |
+
patch_size_t: int,
|
| 486 |
+
rope_dim: List[int],
|
| 487 |
+
theta: float = 256.0,
|
| 488 |
+
base_latent_size: int | None = None,
|
| 489 |
+
):
|
| 490 |
+
"""
|
| 491 |
+
Rotary Positional Embedding (RoPE) for video latents.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
patch_size (`int`):
|
| 495 |
+
Spatial patch size (e.g., 2).
|
| 496 |
+
patch_size_t (`int`):
|
| 497 |
+
Temporal patch size (e.g., 1).
|
| 498 |
+
rope_dim (`List[int]`):
|
| 499 |
+
Dimensions for RoPE across [Time, Height, Width] axes.
|
| 500 |
+
theta (`float`, *optional*, defaults to 256.0):
|
| 501 |
+
Base frequency for rotary embeddings.
|
| 502 |
+
base_latent_size (`int`, *optional*):
|
| 503 |
+
The maximum spatial dimension (in latent units) seen during training,
|
| 504 |
+
i.e. `training_resolution / vae_scale_factor_spatial`.
|
| 505 |
+
For example, for 1280x1280 training images and a VAE spatial downscale
|
| 506 |
+
(`vae_scale_factor_spatial`) of 8, this would be 160; for a downscale
|
| 507 |
+
of 16, it would be 80.
|
| 508 |
+
"""
|
| 509 |
+
super().__init__()
|
| 510 |
+
|
| 511 |
+
self.patch_size = patch_size
|
| 512 |
+
self.patch_size_t = patch_size_t
|
| 513 |
+
self.rope_dim = rope_dim
|
| 514 |
+
self.theta = theta
|
| 515 |
+
self.base_latent_size = base_latent_size
|
| 516 |
+
|
| 517 |
+
@lru_cache(maxsize=8)
|
| 518 |
+
def _get_base_patch_grid_size(self, base_latent_size: Optional[int], patch_size: int) -> Optional[int]:
|
| 519 |
+
return base_latent_size // patch_size if base_latent_size else None
|
| 520 |
+
|
| 521 |
+
@lru_cache(maxsize=8)
|
| 522 |
+
def _get_dynamic_interpolation_scale(self, h: int, w: int, base_grid_size: int) -> float:
|
| 523 |
+
return math.sqrt(h * w / (base_grid_size**2))
|
| 524 |
+
|
| 525 |
+
def forward(self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 526 |
+
if self.training:
|
| 527 |
+
assert self.base_latent_size is None, (
|
| 528 |
+
"RoPE interpolation/extrapolation logic should only be enabled for inference. "
|
| 529 |
+
f"During training, base_latent_size must be None, but got {self.base_latent_size!r}."
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 533 |
+
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
|
| 534 |
+
|
| 535 |
+
axes_grids = []
|
| 536 |
+
for i in range(3):
|
| 537 |
+
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
|
| 538 |
+
# original implementation creates it on CPU and then moves it to device. This results in numerical
|
| 539 |
+
# differences in layerwise debugging outputs, but visually it is the same.
|
| 540 |
+
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
|
| 541 |
+
axes_grids.append(grid)
|
| 542 |
+
grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
|
| 543 |
+
grid = torch.stack(grid, dim=0) # [3, W, H, T]
|
| 544 |
+
|
| 545 |
+
base_patch_grid_size = self._get_base_patch_grid_size(self.base_latent_size, self.patch_size)
|
| 546 |
+
if base_patch_grid_size is not None:
|
| 547 |
+
if base_patch_grid_size <= 0:
|
| 548 |
+
raise ValueError(f"base_patch_grid_size must be a positive number, got {base_patch_grid_size}.")
|
| 549 |
+
dynamic_interpolation_scale = self._get_dynamic_interpolation_scale(
|
| 550 |
+
rope_sizes[1], rope_sizes[2], base_patch_grid_size
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
normalized_timestep = torch.tensor(1.0)
|
| 554 |
+
if not self.training and timestep is not None:
|
| 555 |
+
normalized_timestep = timestep[0] / NUM_TRAIN_TIMESTEPS
|
| 556 |
+
|
| 557 |
+
freqs = []
|
| 558 |
+
for i in range(3):
|
| 559 |
+
common_kwargs = {
|
| 560 |
+
"dim": self.rope_dim[i],
|
| 561 |
+
"pos": grid[i].reshape(-1),
|
| 562 |
+
"theta": self.theta,
|
| 563 |
+
"use_real": True,
|
| 564 |
+
"freqs_dtype": torch.float64,
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
# Apply scaling only to spatial dimensions (Height and Width, i=1 and i=2)
|
| 568 |
+
if i > 0 and base_patch_grid_size is not None and dynamic_interpolation_scale > 1.0:
|
| 569 |
+
# We project the training base to the current size using the uniform scale factor.
|
| 570 |
+
# max_pe_len tells the RoPE logic the "new" maximum length it's dealing with.
|
| 571 |
+
max_pe_len = torch.tensor(
|
| 572 |
+
base_patch_grid_size * dynamic_interpolation_scale,
|
| 573 |
+
dtype=torch.float64,
|
| 574 |
+
device=hidden_states.device,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
freq = get_1d_rotary_pos_embed(
|
| 578 |
+
**common_kwargs,
|
| 579 |
+
yarn=True, # Enable Yet Another RoPE extensioN (YARN) for extrapolation
|
| 580 |
+
max_pe_len=max_pe_len,
|
| 581 |
+
ori_max_pe_len=base_patch_grid_size, # The original training scale
|
| 582 |
+
dype=True, # Enable Dynamic Position Encoding (time-aware)
|
| 583 |
+
current_timestep=normalized_timestep,
|
| 584 |
+
)
|
| 585 |
+
else:
|
| 586 |
+
# Time dimension OR within training bounds -> Standard RoPE
|
| 587 |
+
freq = get_1d_rotary_pos_embed(**common_kwargs)
|
| 588 |
+
|
| 589 |
+
freqs.append(freq)
|
| 590 |
+
|
| 591 |
+
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
| 592 |
+
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
| 593 |
+
return freqs_cos, freqs_sin
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
class MotifVideoImageProjection(nn.Module):
|
| 597 |
+
def __init__(self, in_features: int, hidden_size: int):
|
| 598 |
+
super().__init__()
|
| 599 |
+
self.norm_in = nn.LayerNorm(in_features)
|
| 600 |
+
self.linear_1 = nn.Linear(in_features, in_features)
|
| 601 |
+
self.act_fn = nn.GELU()
|
| 602 |
+
self.linear_2 = nn.Linear(in_features, hidden_size)
|
| 603 |
+
self.norm_out = nn.LayerNorm(hidden_size)
|
| 604 |
+
|
| 605 |
+
def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
|
| 606 |
+
hidden_states = self.norm_in(image_embeds)
|
| 607 |
+
hidden_states = self.linear_1(hidden_states)
|
| 608 |
+
hidden_states = self.act_fn(hidden_states)
|
| 609 |
+
hidden_states = self.linear_2(hidden_states)
|
| 610 |
+
hidden_states = self.norm_out(hidden_states)
|
| 611 |
+
return hidden_states
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
class MotifVideoSingleTransformerBlock(nn.Module):
|
| 615 |
+
def __init__(
|
| 616 |
+
self,
|
| 617 |
+
num_attention_heads: int,
|
| 618 |
+
attention_head_dim: int,
|
| 619 |
+
mlp_ratio: float = 4.0,
|
| 620 |
+
qk_norm: str = "rms_norm",
|
| 621 |
+
norm_type: str = "layer_norm",
|
| 622 |
+
enable_text_cross_attention: bool = False,
|
| 623 |
+
) -> None:
|
| 624 |
+
super().__init__()
|
| 625 |
+
|
| 626 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 627 |
+
mlp_dim = int(hidden_size * mlp_ratio)
|
| 628 |
+
|
| 629 |
+
self.attn = Attention(
|
| 630 |
+
query_dim=hidden_size,
|
| 631 |
+
cross_attention_dim=None,
|
| 632 |
+
dim_head=attention_head_dim,
|
| 633 |
+
heads=num_attention_heads,
|
| 634 |
+
out_dim=hidden_size,
|
| 635 |
+
bias=True,
|
| 636 |
+
processor=MotifVideoAttnProcessor2_0(),
|
| 637 |
+
qk_norm=qk_norm,
|
| 638 |
+
eps=1e-6,
|
| 639 |
+
pre_only=True,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
self.enable_text_cross_attention = enable_text_cross_attention
|
| 643 |
+
if enable_text_cross_attention:
|
| 644 |
+
self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
|
| 645 |
+
self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
|
| 646 |
+
self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
|
| 647 |
+
nn.init.zeros_(self.cross_attn_out_proj.weight)
|
| 648 |
+
nn.init.zeros_(self.cross_attn_out_proj.bias)
|
| 649 |
+
|
| 650 |
+
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type=norm_type)
|
| 651 |
+
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
| 652 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 653 |
+
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
| 654 |
+
|
| 655 |
+
def forward(
|
| 656 |
+
self,
|
| 657 |
+
hidden_states: torch.Tensor,
|
| 658 |
+
encoder_hidden_states: torch.Tensor,
|
| 659 |
+
temb: torch.Tensor,
|
| 660 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 661 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 662 |
+
token_replace_emb: torch.Tensor | None = None,
|
| 663 |
+
first_frame_num_tokens: int | None = None,
|
| 664 |
+
image_embed_seq_len: int = 0,
|
| 665 |
+
encoder_attention_mask: torch.Tensor | None = None,
|
| 666 |
+
) -> torch.Tensor:
|
| 667 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
| 668 |
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 669 |
+
|
| 670 |
+
residual = hidden_states
|
| 671 |
+
|
| 672 |
+
# 1. Input normalization
|
| 673 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 674 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 675 |
+
|
| 676 |
+
norm_hidden_states, norm_encoder_hidden_states = (
|
| 677 |
+
norm_hidden_states[:, :-text_seq_length, :],
|
| 678 |
+
norm_hidden_states[:, -text_seq_length:, :],
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
# 2. Attention
|
| 682 |
+
attn_output, context_attn_output = self.attn(
|
| 683 |
+
hidden_states=norm_hidden_states,
|
| 684 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 685 |
+
attention_mask=attention_mask,
|
| 686 |
+
image_rotary_emb=image_rotary_emb,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
# Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
|
| 690 |
+
if self.enable_text_cross_attention:
|
| 691 |
+
txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
|
| 692 |
+
text_mask = None
|
| 693 |
+
if encoder_attention_mask is not None:
|
| 694 |
+
text_mask = encoder_attention_mask[:, image_embed_seq_len:]
|
| 695 |
+
text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
|
| 696 |
+
cross_q = self.cross_attn_query_proj(attn_output)
|
| 697 |
+
cross_output, _ = self.attn(
|
| 698 |
+
hidden_states=cross_q,
|
| 699 |
+
query_input=cross_q,
|
| 700 |
+
key_input=txt_kv,
|
| 701 |
+
value_input=txt_kv,
|
| 702 |
+
attention_mask=text_mask,
|
| 703 |
+
image_rotary_emb=image_rotary_emb,
|
| 704 |
+
)
|
| 705 |
+
attn_output = attn_output + self.cross_attn_out_proj(cross_output)
|
| 706 |
+
|
| 707 |
+
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
| 708 |
+
|
| 709 |
+
# 3. Modulation and residual connection
|
| 710 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 711 |
+
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
|
| 712 |
+
hidden_states = hidden_states + residual
|
| 713 |
+
|
| 714 |
+
hidden_states, encoder_hidden_states = (
|
| 715 |
+
hidden_states[:, :-text_seq_length, :],
|
| 716 |
+
hidden_states[:, -text_seq_length:, :],
|
| 717 |
+
)
|
| 718 |
+
return hidden_states, encoder_hidden_states
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
class MotifVideoTransformerBlock(nn.Module):
|
| 722 |
+
def __init__(
|
| 723 |
+
self,
|
| 724 |
+
num_attention_heads: int,
|
| 725 |
+
attention_head_dim: int,
|
| 726 |
+
mlp_ratio: float,
|
| 727 |
+
qk_norm: str = "rms_norm",
|
| 728 |
+
norm_type: str = "layer_norm",
|
| 729 |
+
enable_text_cross_attention: bool = False,
|
| 730 |
+
) -> None:
|
| 731 |
+
super().__init__()
|
| 732 |
+
|
| 733 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 734 |
+
|
| 735 |
+
self.norm1 = AdaLayerNormZero(hidden_size, norm_type=norm_type)
|
| 736 |
+
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type=norm_type)
|
| 737 |
+
|
| 738 |
+
self.attn = Attention(
|
| 739 |
+
query_dim=hidden_size,
|
| 740 |
+
cross_attention_dim=None,
|
| 741 |
+
added_kv_proj_dim=hidden_size,
|
| 742 |
+
dim_head=attention_head_dim,
|
| 743 |
+
heads=num_attention_heads,
|
| 744 |
+
out_dim=hidden_size,
|
| 745 |
+
context_pre_only=False,
|
| 746 |
+
bias=True,
|
| 747 |
+
processor=MotifVideoAttnProcessor2_0(),
|
| 748 |
+
qk_norm=qk_norm,
|
| 749 |
+
eps=1e-6,
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
self.enable_text_cross_attention = enable_text_cross_attention
|
| 753 |
+
if enable_text_cross_attention:
|
| 754 |
+
self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
|
| 755 |
+
self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
|
| 756 |
+
self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
|
| 757 |
+
nn.init.zeros_(self.cross_attn_out_proj.weight)
|
| 758 |
+
nn.init.zeros_(self.cross_attn_out_proj.bias)
|
| 759 |
+
|
| 760 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 761 |
+
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 762 |
+
|
| 763 |
+
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 764 |
+
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 765 |
+
|
| 766 |
+
def forward(
|
| 767 |
+
self,
|
| 768 |
+
hidden_states: torch.Tensor,
|
| 769 |
+
encoder_hidden_states: torch.Tensor,
|
| 770 |
+
temb: torch.Tensor,
|
| 771 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 772 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 773 |
+
token_replace_emb: torch.Tensor | None = None,
|
| 774 |
+
first_frame_num_tokens: int | None = None,
|
| 775 |
+
image_embed_seq_len: int = 0,
|
| 776 |
+
encoder_attention_mask: torch.Tensor | None = None,
|
| 777 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 778 |
+
# 1. Input normalization
|
| 779 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 780 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 781 |
+
encoder_hidden_states, emb=temb
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# 2. Joint attention
|
| 785 |
+
attn_output, context_attn_output = self.attn(
|
| 786 |
+
hidden_states=norm_hidden_states,
|
| 787 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 788 |
+
attention_mask=attention_mask,
|
| 789 |
+
image_rotary_emb=image_rotary_emb,
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
# 3. Modulation and residual connection
|
| 793 |
+
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
| 794 |
+
|
| 795 |
+
# Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
|
| 796 |
+
if self.enable_text_cross_attention:
|
| 797 |
+
txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
|
| 798 |
+
text_mask = None
|
| 799 |
+
if encoder_attention_mask is not None:
|
| 800 |
+
text_mask = encoder_attention_mask[:, image_embed_seq_len:]
|
| 801 |
+
text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
|
| 802 |
+
cross_q = self.cross_attn_query_proj(attn_output)
|
| 803 |
+
cross_output, _ = self.attn(
|
| 804 |
+
hidden_states=cross_q,
|
| 805 |
+
query_input=cross_q,
|
| 806 |
+
key_input=txt_kv,
|
| 807 |
+
value_input=txt_kv,
|
| 808 |
+
attention_mask=text_mask,
|
| 809 |
+
image_rotary_emb=image_rotary_emb,
|
| 810 |
+
)
|
| 811 |
+
hidden_states = hidden_states + self.cross_attn_out_proj(cross_output)
|
| 812 |
+
|
| 813 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
| 814 |
+
|
| 815 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 816 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 817 |
+
|
| 818 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 819 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 820 |
+
|
| 821 |
+
# 4. Feed-forward
|
| 822 |
+
ff_output = self.ff(norm_hidden_states)
|
| 823 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 824 |
+
|
| 825 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
| 826 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 827 |
+
|
| 828 |
+
return hidden_states, encoder_hidden_states
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
TransformerBlockRegistry.register(
|
| 832 |
+
model_class=MotifVideoTransformerBlock,
|
| 833 |
+
metadata=TransformerBlockMetadata(
|
| 834 |
+
return_hidden_states_index=0,
|
| 835 |
+
return_encoder_hidden_states_index=1,
|
| 836 |
+
),
|
| 837 |
+
)
|
| 838 |
+
TransformerBlockRegistry.register(
|
| 839 |
+
model_class=MotifVideoSingleTransformerBlock,
|
| 840 |
+
metadata=TransformerBlockMetadata(
|
| 841 |
+
return_hidden_states_index=0,
|
| 842 |
+
return_encoder_hidden_states_index=1,
|
| 843 |
+
),
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
class MotifVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
| 848 |
+
r"""
|
| 849 |
+
A Transformer model for video-like data used in [MotifVideo](https://huggingface.co/motif/motifvideo).
|
| 850 |
+
|
| 851 |
+
Args:
|
| 852 |
+
in_channels (`int`, defaults to `16`):
|
| 853 |
+
The number of channels in the input.
|
| 854 |
+
out_channels (`int`, defaults to `16`):
|
| 855 |
+
The number of channels in the output.
|
| 856 |
+
num_attention_heads (`int`, defaults to `24`):
|
| 857 |
+
The number of heads to use for multi-head attention.
|
| 858 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 859 |
+
The number of channels in each head.
|
| 860 |
+
num_layers (`int`, defaults to `20`):
|
| 861 |
+
The number of layers of dual-stream blocks to use.
|
| 862 |
+
num_single_layers (`int`, defaults to `40`):
|
| 863 |
+
The number of layers of single-stream blocks to use.
|
| 864 |
+
|
| 865 |
+
mlp_ratio (`float`, defaults to `4.0`):
|
| 866 |
+
The ratio of the hidden layer size to the input size in the feedforward network.
|
| 867 |
+
patch_size (`int`, defaults to `2`):
|
| 868 |
+
The size of the spatial patches to use in the patch embedding layer.
|
| 869 |
+
patch_size_t (`int`, defaults to `1`):
|
| 870 |
+
The size of the temporal patches to use in the patch embedding layer.
|
| 871 |
+
qk_norm (`str`, defaults to `rms_norm`):
|
| 872 |
+
The normalization to use for the query and key projections in the attention layers.
|
| 873 |
+
text_embed_dim (`int`, defaults to `4096`):
|
| 874 |
+
Input dimension of text embeddings from the text encoder.
|
| 875 |
+
rope_theta (`float`, defaults to `256.0`):
|
| 876 |
+
The value of theta to use in the RoPE layer.
|
| 877 |
+
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
| 878 |
+
The dimensions of the axes to use in the RoPE layer.
|
| 879 |
+
base_latent_size (`int`, *optional*):
|
| 880 |
+
The maximum spatial dimension (in latent units) seen during training.
|
| 881 |
+
For example, if trained on 1280x1280 with a VAE downscale of 16, this is 80.
|
| 882 |
+
"""
|
| 883 |
+
|
| 884 |
+
_supports_gradient_checkpointing = True
|
| 885 |
+
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
| 886 |
+
_no_split_modules = [
|
| 887 |
+
"MotifVideoTransformerBlock",
|
| 888 |
+
"MotifVideoSingleTransformerBlock",
|
| 889 |
+
"MotifVideoPatchEmbed",
|
| 890 |
+
]
|
| 891 |
+
|
| 892 |
+
@register_to_config
|
| 893 |
+
def __init__(
|
| 894 |
+
self,
|
| 895 |
+
in_channels: int = 33,
|
| 896 |
+
out_channels: int = 16,
|
| 897 |
+
num_attention_heads: int = 24,
|
| 898 |
+
attention_head_dim: int = 128,
|
| 899 |
+
num_layers: int = 20,
|
| 900 |
+
num_single_layers: int = 40,
|
| 901 |
+
num_decoder_layers: int = 0,
|
| 902 |
+
mlp_ratio: float = 4.0,
|
| 903 |
+
patch_size: int = 2,
|
| 904 |
+
patch_size_t: int = 1,
|
| 905 |
+
qk_norm: str = "rms_norm",
|
| 906 |
+
norm_type: str = "layer_norm",
|
| 907 |
+
text_embed_dim: int = 4096,
|
| 908 |
+
image_embed_dim: int | None = None,
|
| 909 |
+
pooled_projection_dim: int | None = None,
|
| 910 |
+
rope_theta: float = 256.0,
|
| 911 |
+
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
| 912 |
+
base_latent_size: int | None = None,
|
| 913 |
+
enable_text_cross_attention_dual: bool = False,
|
| 914 |
+
enable_text_cross_attention_single: bool = False,
|
| 915 |
+
) -> None:
|
| 916 |
+
super().__init__()
|
| 917 |
+
|
| 918 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 919 |
+
out_channels = out_channels or in_channels
|
| 920 |
+
|
| 921 |
+
# 1. Latent and condition embedders
|
| 922 |
+
self.x_embedder = MotifVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
| 923 |
+
self.context_embedder = PixArtAlphaTextProjection(in_features=text_embed_dim, hidden_size=inner_dim)
|
| 924 |
+
|
| 925 |
+
# First frame conditioning: Image conditioning embedders
|
| 926 |
+
self.image_embed_dim = image_embed_dim
|
| 927 |
+
if image_embed_dim is not None:
|
| 928 |
+
# Project image embeddings from vision encoder to transformer dim
|
| 929 |
+
self.image_embedder = MotifVideoImageProjection(in_features=image_embed_dim, hidden_size=inner_dim)
|
| 930 |
+
|
| 931 |
+
self.time_text_embed = MotifVideoConditionEmbedding(inner_dim, pooled_projection_dim)
|
| 932 |
+
|
| 933 |
+
# 2. RoPE
|
| 934 |
+
self.rope = MotifVideoRotaryPosEmbed(
|
| 935 |
+
patch_size, patch_size_t, rope_axes_dim, rope_theta, base_latent_size=base_latent_size
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
# Cross-attention config
|
| 939 |
+
self.enable_text_cross_attention_dual = enable_text_cross_attention_dual
|
| 940 |
+
self.enable_text_cross_attention_single = enable_text_cross_attention_single
|
| 941 |
+
|
| 942 |
+
# 3. Dual stream transformer blocks
|
| 943 |
+
self.transformer_blocks = nn.ModuleList(
|
| 944 |
+
[
|
| 945 |
+
MotifVideoTransformerBlock(
|
| 946 |
+
num_attention_heads,
|
| 947 |
+
attention_head_dim,
|
| 948 |
+
mlp_ratio=mlp_ratio,
|
| 949 |
+
qk_norm=qk_norm,
|
| 950 |
+
norm_type=norm_type,
|
| 951 |
+
enable_text_cross_attention=enable_text_cross_attention_dual,
|
| 952 |
+
)
|
| 953 |
+
for _ in range(num_layers)
|
| 954 |
+
]
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
# 4. Single stream transformer blocks
|
| 958 |
+
# Encoder blocks get cross-attention; decoder blocks do not (no text stream in decoder)
|
| 959 |
+
num_encoder_single = num_single_layers - num_decoder_layers
|
| 960 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 961 |
+
[
|
| 962 |
+
MotifVideoSingleTransformerBlock(
|
| 963 |
+
num_attention_heads,
|
| 964 |
+
attention_head_dim,
|
| 965 |
+
mlp_ratio=mlp_ratio,
|
| 966 |
+
qk_norm=qk_norm,
|
| 967 |
+
norm_type=norm_type,
|
| 968 |
+
enable_text_cross_attention=enable_text_cross_attention_single
|
| 969 |
+
if i < num_encoder_single
|
| 970 |
+
else False,
|
| 971 |
+
)
|
| 972 |
+
for i in range(num_single_layers)
|
| 973 |
+
]
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
# 5. Output projection
|
| 977 |
+
self.norm_out = AdaLayerNormContinuous(
|
| 978 |
+
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type=norm_type
|
| 979 |
+
)
|
| 980 |
+
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
| 981 |
+
|
| 982 |
+
# Verify cross-attention config matches actual block state.
|
| 983 |
+
# Catches silent misconfiguration (e.g. checkpoint config with renamed keys).
|
| 984 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 985 |
+
if block.enable_text_cross_attention != enable_text_cross_attention_dual:
|
| 986 |
+
raise ValueError(
|
| 987 |
+
f"transformer_blocks[{i}].enable_text_cross_attention="
|
| 988 |
+
f"{block.enable_text_cross_attention}, expected {enable_text_cross_attention_dual}. "
|
| 989 |
+
f"Check checkpoint config.json key names match __init__ parameters."
|
| 990 |
+
)
|
| 991 |
+
num_encoder_single = num_single_layers - num_decoder_layers
|
| 992 |
+
for i, block in enumerate(self.single_transformer_blocks):
|
| 993 |
+
expected = enable_text_cross_attention_single if i < num_encoder_single else False
|
| 994 |
+
if block.enable_text_cross_attention != expected:
|
| 995 |
+
raise ValueError(
|
| 996 |
+
f"single_transformer_blocks[{i}].enable_text_cross_attention="
|
| 997 |
+
f"{block.enable_text_cross_attention}, expected {expected}. "
|
| 998 |
+
f"Check checkpoint config.json key names match __init__ parameters."
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
self.gradient_checkpointing = False
|
| 1002 |
+
self.num_decoder_layers = num_decoder_layers
|
| 1003 |
+
|
| 1004 |
+
@property
|
| 1005 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 1006 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 1007 |
+
r"""
|
| 1008 |
+
Returns:
|
| 1009 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 1010 |
+
indexed by its weight name.
|
| 1011 |
+
"""
|
| 1012 |
+
# set recursively
|
| 1013 |
+
processors = {}
|
| 1014 |
+
|
| 1015 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 1016 |
+
if hasattr(module, "get_processor"):
|
| 1017 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 1018 |
+
|
| 1019 |
+
for sub_name, child in module.named_children():
|
| 1020 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 1021 |
+
|
| 1022 |
+
return processors
|
| 1023 |
+
|
| 1024 |
+
for name, module in self.named_children():
|
| 1025 |
+
fn_recursive_add_processors(name, module, processors)
|
| 1026 |
+
|
| 1027 |
+
return processors
|
| 1028 |
+
|
| 1029 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 1030 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 1031 |
+
r"""
|
| 1032 |
+
Sets the attention processor to use to compute attention.
|
| 1033 |
+
|
| 1034 |
+
Parameters:
|
| 1035 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 1036 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 1037 |
+
for **all** `Attention` layers.
|
| 1038 |
+
|
| 1039 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 1040 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 1041 |
+
|
| 1042 |
+
"""
|
| 1043 |
+
count = len(self.attn_processors.keys())
|
| 1044 |
+
|
| 1045 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 1046 |
+
raise ValueError(
|
| 1047 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 1048 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 1052 |
+
if hasattr(module, "set_processor"):
|
| 1053 |
+
if not isinstance(processor, dict):
|
| 1054 |
+
module.set_processor(processor)
|
| 1055 |
+
else:
|
| 1056 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 1057 |
+
|
| 1058 |
+
for sub_name, child in module.named_children():
|
| 1059 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 1060 |
+
|
| 1061 |
+
for name, module in self.named_children():
|
| 1062 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 1063 |
+
|
| 1064 |
+
def _maybe_gradient_checkpoint_block(self, block, *args):
|
| 1065 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1066 |
+
return self._gradient_checkpointing_func(block, *args)
|
| 1067 |
+
return block(*args)
|
| 1068 |
+
|
| 1069 |
+
def _get_unwrapped_blocks(self, blocks):
|
| 1070 |
+
if hasattr(blocks, "_checkpoint_wrapped_module"):
|
| 1071 |
+
return blocks._checkpoint_wrapped_module
|
| 1072 |
+
elif hasattr(blocks, "module"):
|
| 1073 |
+
return blocks.module
|
| 1074 |
+
return blocks
|
| 1075 |
+
|
| 1076 |
+
def _create_attention_mask(
|
| 1077 |
+
self,
|
| 1078 |
+
hidden_states: torch.Tensor,
|
| 1079 |
+
encoder_attention_mask: torch.Tensor,
|
| 1080 |
+
) -> torch.Tensor:
|
| 1081 |
+
"""
|
| 1082 |
+
Create attention mask of shape [B, 1, 1, N] where N = L + E,
|
| 1083 |
+
based on latent tokens (always valid) and the encoder mask.
|
| 1084 |
+
|
| 1085 |
+
Args:
|
| 1086 |
+
hidden_states: [B, L, D]
|
| 1087 |
+
encoder_attention_mask: [B, E] (required)
|
| 1088 |
+
|
| 1089 |
+
Returns:
|
| 1090 |
+
attention_mask: [B, 1, 1, N]
|
| 1091 |
+
"""
|
| 1092 |
+
attention_mask = F.pad(
|
| 1093 |
+
encoder_attention_mask.to(torch.bool),
|
| 1094 |
+
(hidden_states.shape[1], 0),
|
| 1095 |
+
value=True,
|
| 1096 |
+
)
|
| 1097 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L+E]
|
| 1098 |
+
return attention_mask
|
| 1099 |
+
|
| 1100 |
+
def forward(
|
| 1101 |
+
self,
|
| 1102 |
+
hidden_states: torch.Tensor,
|
| 1103 |
+
timestep: torch.LongTensor,
|
| 1104 |
+
encoder_hidden_states: torch.Tensor,
|
| 1105 |
+
encoder_attention_mask: torch.Tensor | None = None,
|
| 1106 |
+
pooled_projections: torch.Tensor | None = None,
|
| 1107 |
+
image_embeds: torch.Tensor | None = None,
|
| 1108 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1109 |
+
return_dict: bool = True,
|
| 1110 |
+
tread_mixin: Optional[Any] = None,
|
| 1111 |
+
tread_disabled: bool = False,
|
| 1112 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 1113 |
+
"""
|
| 1114 |
+
Forward pass of the MotifVideoTransformer3DModel.
|
| 1115 |
+
|
| 1116 |
+
Args:
|
| 1117 |
+
hidden_states: Input latent tensor [B, C, F, H, W].
|
| 1118 |
+
timestep: Diffusion timesteps [B].
|
| 1119 |
+
encoder_hidden_states: Text conditioning [B, E, D].
|
| 1120 |
+
encoder_attention_mask: Mask for text conditioning [B, E].
|
| 1121 |
+
pooled_projections: Pooled text embeddings [B, D].
|
| 1122 |
+
image_embeds: Optional image embeddings from vision encoder [B, N, D].
|
| 1123 |
+
attention_kwargs: Additional arguments for attention processors.
|
| 1124 |
+
return_dict: Whether to return a Transformer2DModelOutput.
|
| 1125 |
+
tread_mixin: Optional TreadMixin instance for token reduction.
|
| 1126 |
+
tread_disabled: When True, force tread_mixin to None (dense pass).
|
| 1127 |
+
torch.compile specializes on this bool, producing separate graphs
|
| 1128 |
+
for dense vs routed without attribute toggling.
|
| 1129 |
+
|
| 1130 |
+
Returns:
|
| 1131 |
+
Transformer2DModelOutput or tuple containing the predicted samples.
|
| 1132 |
+
"""
|
| 1133 |
+
if tread_disabled:
|
| 1134 |
+
tread_mixin = None
|
| 1135 |
+
elif tread_mixin is None:
|
| 1136 |
+
tread_mixin = getattr(self, "_inference_tread_mixin", None)
|
| 1137 |
+
|
| 1138 |
+
if attention_kwargs is not None:
|
| 1139 |
+
attention_kwargs = attention_kwargs.copy()
|
| 1140 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 1141 |
+
else:
|
| 1142 |
+
lora_scale = 1.0
|
| 1143 |
+
|
| 1144 |
+
if USE_PEFT_BACKEND:
|
| 1145 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 1146 |
+
scale_lora_layers(self, lora_scale)
|
| 1147 |
+
else:
|
| 1148 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 1149 |
+
logger.warning(
|
| 1150 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 1154 |
+
p, p_t = self.config.patch_size, self.config.patch_size_t
|
| 1155 |
+
post_patch_num_frames = num_frames // p_t
|
| 1156 |
+
post_patch_height = height // p
|
| 1157 |
+
post_patch_width = width // p
|
| 1158 |
+
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
|
| 1159 |
+
# 1. RoPE
|
| 1160 |
+
image_rotary_emb = self.rope(hidden_states, timestep=timestep)
|
| 1161 |
+
# 2. Conditional embeddings
|
| 1162 |
+
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections)
|
| 1163 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 1164 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 1165 |
+
|
| 1166 |
+
# First frame conditioning: Image embeddings from vision encoder
|
| 1167 |
+
if image_embeds is not None:
|
| 1168 |
+
# image_embeds: [B, N, D_img] -> [B, N, D]
|
| 1169 |
+
image_embeds = self.image_embedder(image_embeds)
|
| 1170 |
+
encoder_hidden_states = torch.cat([image_embeds, encoder_hidden_states], dim=1)
|
| 1171 |
+
# Extend attention mask for image tokens
|
| 1172 |
+
if encoder_attention_mask is not None:
|
| 1173 |
+
image_mask = torch.ones(
|
| 1174 |
+
image_embeds.shape[0],
|
| 1175 |
+
image_embeds.shape[1],
|
| 1176 |
+
device=encoder_attention_mask.device,
|
| 1177 |
+
dtype=encoder_attention_mask.dtype,
|
| 1178 |
+
)
|
| 1179 |
+
encoder_attention_mask = torch.cat([image_mask, encoder_attention_mask], dim=1)
|
| 1180 |
+
|
| 1181 |
+
# image_embed_seq_len: used by cross-attention blocks to slice text from encoder_hidden_states
|
| 1182 |
+
image_embed_seq_len = image_embeds.shape[1] if image_embeds is not None else 0
|
| 1183 |
+
|
| 1184 |
+
decoder_hidden_states = hidden_states.clone()
|
| 1185 |
+
|
| 1186 |
+
if encoder_attention_mask is not None:
|
| 1187 |
+
attention_mask = self._create_attention_mask(
|
| 1188 |
+
hidden_states=hidden_states,
|
| 1189 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1190 |
+
)
|
| 1191 |
+
else:
|
| 1192 |
+
attention_mask = None
|
| 1193 |
+
|
| 1194 |
+
# TREAD state initialization: manage token reduction manually to support activation checkpointing
|
| 1195 |
+
tread_active = False
|
| 1196 |
+
current_route = None
|
| 1197 |
+
ids_keep = None
|
| 1198 |
+
x_full = None
|
| 1199 |
+
orig_mask = attention_mask
|
| 1200 |
+
orig_rope = image_rotary_emb
|
| 1201 |
+
latent_len = hidden_states.shape[1]
|
| 1202 |
+
|
| 1203 |
+
# 4. Dual stream transformer blocks (Encoder)
|
| 1204 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 1205 |
+
# Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
|
| 1206 |
+
if is_tread_start(tread_mixin, tread_active, i):
|
| 1207 |
+
tread_active = True
|
| 1208 |
+
current_route = tread_mixin._tread_route
|
| 1209 |
+
# Reduce sequence length at the start of a TREAD route
|
| 1210 |
+
ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
|
| 1211 |
+
x_full = hidden_states.contiguous()
|
| 1212 |
+
hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
|
| 1213 |
+
attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
|
| 1214 |
+
image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
|
| 1215 |
+
|
| 1216 |
+
hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
|
| 1217 |
+
block,
|
| 1218 |
+
hidden_states,
|
| 1219 |
+
encoder_hidden_states,
|
| 1220 |
+
temb,
|
| 1221 |
+
attention_mask,
|
| 1222 |
+
image_rotary_emb,
|
| 1223 |
+
token_replace_emb,
|
| 1224 |
+
first_frame_num_tokens,
|
| 1225 |
+
image_embed_seq_len,
|
| 1226 |
+
encoder_attention_mask,
|
| 1227 |
+
)
|
| 1228 |
+
|
| 1229 |
+
if is_tread_end(tread_mixin, tread_active, i):
|
| 1230 |
+
# Restore full sequence length at the end of a TREAD route
|
| 1231 |
+
hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
|
| 1232 |
+
tread_active = False
|
| 1233 |
+
current_route = None
|
| 1234 |
+
ids_keep = None
|
| 1235 |
+
x_full = None
|
| 1236 |
+
attention_mask = orig_mask
|
| 1237 |
+
image_rotary_emb = orig_rope
|
| 1238 |
+
|
| 1239 |
+
# We need to unwrap the blocks because CheckpointWrapper does not support len(),
|
| 1240 |
+
# which is required for slicing the blocks into encoder and decoder parts.
|
| 1241 |
+
single_transformer_blocks = self.single_transformer_blocks
|
| 1242 |
+
|
| 1243 |
+
# 5. Single stream transformer blocks (Encoder)
|
| 1244 |
+
num_dual = len(self.transformer_blocks)
|
| 1245 |
+
for i, block in enumerate(
|
| 1246 |
+
single_transformer_blocks[: len(single_transformer_blocks) - self.num_decoder_layers]
|
| 1247 |
+
):
|
| 1248 |
+
# Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
|
| 1249 |
+
abs_i = num_dual + i
|
| 1250 |
+
if is_tread_start(tread_mixin, tread_active, abs_i):
|
| 1251 |
+
tread_active = True
|
| 1252 |
+
current_route = tread_mixin._tread_route
|
| 1253 |
+
# Reduce sequence length at the start of a TREAD route
|
| 1254 |
+
ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
|
| 1255 |
+
x_full = hidden_states.contiguous()
|
| 1256 |
+
hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
|
| 1257 |
+
attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
|
| 1258 |
+
image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
|
| 1259 |
+
|
| 1260 |
+
hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
|
| 1261 |
+
block,
|
| 1262 |
+
hidden_states,
|
| 1263 |
+
encoder_hidden_states,
|
| 1264 |
+
temb,
|
| 1265 |
+
attention_mask,
|
| 1266 |
+
image_rotary_emb,
|
| 1267 |
+
token_replace_emb,
|
| 1268 |
+
first_frame_num_tokens,
|
| 1269 |
+
image_embed_seq_len,
|
| 1270 |
+
encoder_attention_mask,
|
| 1271 |
+
)
|
| 1272 |
+
|
| 1273 |
+
if is_tread_end(tread_mixin, tread_active, abs_i):
|
| 1274 |
+
# Restore full sequence length at the end of a TREAD route
|
| 1275 |
+
hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
|
| 1276 |
+
tread_active = False
|
| 1277 |
+
current_route = None
|
| 1278 |
+
ids_keep = None
|
| 1279 |
+
x_full = None
|
| 1280 |
+
attention_mask = orig_mask
|
| 1281 |
+
image_rotary_emb = orig_rope
|
| 1282 |
+
|
| 1283 |
+
# 6. Single stream transformer blocks (Decoder)
|
| 1284 |
+
if self.num_decoder_layers > 0:
|
| 1285 |
+
encoder_hidden_states = hidden_states
|
| 1286 |
+
attention_mask = None
|
| 1287 |
+
|
| 1288 |
+
num_single = len(single_transformer_blocks)
|
| 1289 |
+
|
| 1290 |
+
for i, block in enumerate(single_transformer_blocks[-self.num_decoder_layers :]):
|
| 1291 |
+
abs_i = num_dual + (num_single - self.num_decoder_layers) + i
|
| 1292 |
+
if is_tread_start(tread_mixin, tread_active, abs_i):
|
| 1293 |
+
tread_active = True
|
| 1294 |
+
current_route = tread_mixin._tread_route
|
| 1295 |
+
# Reduce sequence length at the start of a TREAD route
|
| 1296 |
+
ids_keep = tread_mixin.keep_indices(decoder_hidden_states, current_route["sel"]).to(
|
| 1297 |
+
decoder_hidden_states.device
|
| 1298 |
+
)
|
| 1299 |
+
x_full = encoder_hidden_states.contiguous()
|
| 1300 |
+
x_t_full = decoder_hidden_states.contiguous()
|
| 1301 |
+
decoder_hidden_states = tread_mixin.gather_tokens(decoder_hidden_states, ids_keep)
|
| 1302 |
+
encoder_hidden_states = tread_mixin.gather_tokens(encoder_hidden_states, ids_keep)
|
| 1303 |
+
attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
|
| 1304 |
+
image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
|
| 1305 |
+
|
| 1306 |
+
decoder_hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
|
| 1307 |
+
block,
|
| 1308 |
+
decoder_hidden_states,
|
| 1309 |
+
encoder_hidden_states,
|
| 1310 |
+
temb,
|
| 1311 |
+
attention_mask,
|
| 1312 |
+
image_rotary_emb,
|
| 1313 |
+
token_replace_emb,
|
| 1314 |
+
first_frame_num_tokens,
|
| 1315 |
+
)
|
| 1316 |
+
|
| 1317 |
+
if is_tread_end(tread_mixin, tread_active, abs_i):
|
| 1318 |
+
# Restore full sequence length at the end of a TREAD route
|
| 1319 |
+
decoder_hidden_states = tread_mixin.scatter_tokens(decoder_hidden_states, ids_keep, x_t_full)
|
| 1320 |
+
encoder_hidden_states = tread_mixin.scatter_tokens(encoder_hidden_states, ids_keep, x_full)
|
| 1321 |
+
tread_active = False
|
| 1322 |
+
current_route = None
|
| 1323 |
+
ids_keep = None
|
| 1324 |
+
x_full = None
|
| 1325 |
+
x_t_full = None
|
| 1326 |
+
attention_mask = orig_mask
|
| 1327 |
+
image_rotary_emb = orig_rope
|
| 1328 |
+
|
| 1329 |
+
hidden_states = decoder_hidden_states
|
| 1330 |
+
|
| 1331 |
+
# 7. Output projection
|
| 1332 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 1333 |
+
hidden_states = self.proj_out(hidden_states)
|
| 1334 |
+
|
| 1335 |
+
hidden_states = hidden_states.reshape(
|
| 1336 |
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
|
| 1337 |
+
)
|
| 1338 |
+
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
| 1339 |
+
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 1340 |
+
|
| 1341 |
+
if USE_PEFT_BACKEND:
|
| 1342 |
+
# remove `lora_scale` from each PEFT layer
|
| 1343 |
+
unscale_lora_layers(self, lora_scale)
|
| 1344 |
+
|
| 1345 |
+
if not return_dict:
|
| 1346 |
+
return (hidden_states,)
|
| 1347 |
+
|
| 1348 |
+
return Transformer2DModelOutput(
|
| 1349 |
+
sample=hidden_states,
|
| 1350 |
+
)
|
vae/config.json
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKLWan",
|
| 3 |
+
"_diffusers_version": "0.35.2",
|
| 4 |
+
"_name_or_path": "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
|
| 5 |
+
"attn_scales": [],
|
| 6 |
+
"base_dim": 96,
|
| 7 |
+
"decoder_base_dim": null,
|
| 8 |
+
"dim_mult": [
|
| 9 |
+
1,
|
| 10 |
+
2,
|
| 11 |
+
4,
|
| 12 |
+
4
|
| 13 |
+
],
|
| 14 |
+
"dropout": 0.0,
|
| 15 |
+
"in_channels": 3,
|
| 16 |
+
"is_residual": false,
|
| 17 |
+
"latents_mean": [
|
| 18 |
+
-0.7571,
|
| 19 |
+
-0.7089,
|
| 20 |
+
-0.9113,
|
| 21 |
+
0.1075,
|
| 22 |
+
-0.1745,
|
| 23 |
+
0.9653,
|
| 24 |
+
-0.1517,
|
| 25 |
+
1.5508,
|
| 26 |
+
0.4134,
|
| 27 |
+
-0.0715,
|
| 28 |
+
0.5517,
|
| 29 |
+
-0.3632,
|
| 30 |
+
-0.1922,
|
| 31 |
+
-0.9497,
|
| 32 |
+
0.2503,
|
| 33 |
+
-0.2921
|
| 34 |
+
],
|
| 35 |
+
"latents_std": [
|
| 36 |
+
2.8184,
|
| 37 |
+
1.4541,
|
| 38 |
+
2.3275,
|
| 39 |
+
2.6558,
|
| 40 |
+
1.2196,
|
| 41 |
+
1.7708,
|
| 42 |
+
2.6052,
|
| 43 |
+
2.0743,
|
| 44 |
+
3.2687,
|
| 45 |
+
2.1526,
|
| 46 |
+
2.8652,
|
| 47 |
+
1.5579,
|
| 48 |
+
1.6382,
|
| 49 |
+
1.1253,
|
| 50 |
+
2.8251,
|
| 51 |
+
1.916
|
| 52 |
+
],
|
| 53 |
+
"num_res_blocks": 2,
|
| 54 |
+
"out_channels": 3,
|
| 55 |
+
"patch_size": null,
|
| 56 |
+
"scale_factor_spatial": 8,
|
| 57 |
+
"scale_factor_temporal": 4,
|
| 58 |
+
"temperal_downsample": [
|
| 59 |
+
false,
|
| 60 |
+
true,
|
| 61 |
+
true
|
| 62 |
+
],
|
| 63 |
+
"z_dim": 16
|
| 64 |
+
}
|
vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d6e524b3fffede1787a74e81b30976dce5400c4439ba64222168e607ed19e793
|
| 3 |
+
size 507591892
|