# Copyright 2026 Motif Technologies, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import html import inspect from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import ftfy import numpy as np import regex as re import torch from diffusers import ( AdaptiveProjectedGuidance, AutoencoderKLWan, ClassifierFreeGuidance, DiffusionPipeline, DPMSolverMultistepScheduler, FlowMatchEulerDiscreteScheduler, SkipLayerGuidance, UniPCMultistepScheduler, ) from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor from einops import rearrange from PIL import Image from torch import Tensor from diffusers.guiders.adaptive_projected_guidance import MomentumBuffer from diffusers.guiders.guider_utils import GuiderOutput from ._fm_solvers_unipc import FlowUniPCMultistepScheduler from transformers import BatchEncoding, PreTrainedTokenizerBase, SiglipImageProcessor, T5Gemma2Model if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import MotifVideoPipeline >>> from diffusers.utils import export_to_video >>> # Load the Motif Video pipeline >>> motif_video_model_id = "MotifTechnologies/Motif-Video" >>> pipe = MotifVideoPipeline.from_pretrained(motif_video_model_id, torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" >>> video = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, ... width=640, ... height=352, ... num_frames=65, ... num_inference_steps=50, ... ).frames[0] >>> export_to_video(video, "output.mp4", fps=16) ``` """ @dataclass class MotifVideoPipelineOutput(BaseOutput): r""" Output class for Motif Video pipelines. Args: frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape `(batch_size, num_frames, channels, height, width)`. """ frames: torch.Tensor """Video-aware Adaptive Projected Guidance (APG). Standard APG normalizes over all spatial dimensions [C, T, H, W], which collapses temporal variation. This module normalizes over [C, H, W] only, preserving per-frame independence. """ def video_normalized_guidance( pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, momentum_buffer: MomentumBuffer | None = None, eta: float = 1.0, norm_threshold: float = 0.0, use_original_formulation: bool = False, ) -> torch.Tensor: """APG with video-aware normalization: normalize over [C, H, W], exclude T. For 5D input [B, C, T, H, W], dim=[-1, -2, -4] normalizes per-frame (W, H, C), keeping the T dimension independent. For 4D input [B, C, H, W], falls back to standard [-1, -2, -3] behavior. """ diff = pred_cond - pred_uncond if len(diff.shape) == 5: # [B, C, T, H, W] → normalize over W(-1), H(-2), C(-4), skip T(-3) dim = [-1, -2, -4] else: # [B, C, H, W] → standard behavior dim = [-i for i in range(1, len(diff.shape))] if momentum_buffer is not None: momentum_buffer.update(diff) diff = momentum_buffer.running_average if norm_threshold > 0: ones = torch.ones_like(diff) diff_norm = diff.norm(p=2, dim=dim, keepdim=True) scale_factor = torch.minimum(ones, norm_threshold / diff_norm) diff = diff * scale_factor v0, v1 = diff.double(), pred_cond.double() v1 = torch.nn.functional.normalize(v1, dim=dim) v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) normalized_update = diff_orthogonal + eta * diff_parallel pred = pred_cond if use_original_formulation else pred_uncond pred = pred + guidance_scale * normalized_update return pred class VideoAdaptiveProjectedGuidance(AdaptiveProjectedGuidance): """APG variant that normalizes over [C, H, W] per frame, excluding the T dimension.""" def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput: pred = None if not self._is_apg_enabled(): pred = pred_cond else: pred = video_normalized_guidance( pred_cond, pred_uncond, self.guidance_scale, self.momentum_buffer, self.eta, self.adaptive_projected_guidance_rescale, self.use_original_formulation, ) if self.guidance_rescale > 0.0: from diffusers.guiders.classifier_free_guidance import rescale_noise_cfg pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu def get_linear_quadratic_sigmas( num_inference_steps: int, linear_quadratic_emulating_steps: int = 250, ) -> np.ndarray: """ Compute a linear-quadratic sigma schedule for flow matching. This schedule combines: - First half: Linear interpolation from high noise to medium noise (slow denoising) - Second half: Quadratic interpolation from medium noise to clean (faster denoising) Convention: - sigma=1.0 represents pure noise - sigma=0.0 represents clean image - Output sigmas are in descending order (1.0 → ~0) Args: num_inference_steps: Total number of denoising steps (must be even). linear_quadratic_emulating_steps: Controls the slope of linear interpolation. Higher values result in gentler slope in the first half. Returns: np.ndarray: Array of sigma values with shape (num_inference_steps,). The scheduler will append a terminal 0. Raises: ValueError: If num_inference_steps is not even. Reference: Linear-quadratic timestep schedule for improved flow matching inference. """ if num_inference_steps % 2 != 0: raise ValueError( f"num_inference_steps must be even for linear-quadratic schedule, but got {num_inference_steps}" ) steps = num_inference_steps N = linear_quadratic_emulating_steps half_steps = steps // 2 # First half: linear interpolation from 1 toward 0 # Takes first half_steps values from linspace(1, 0, N+1) linear_part = np.linspace(1.0, 0.0, N + 1)[:half_steps] # Second half: quadratic interpolation # Formula: x^2 * (half_steps/N - 1) - (half_steps/N - 1) # = (half_steps/N - 1) * (x^2 - 1) # This maps x=0 to (half_steps/N - 1) * (-1) = 1 - half_steps/N # and maps x=1 to 0 x = np.linspace(0.0, 1.0, half_steps + 1) scale_factor = half_steps / N - 1 # negative value quadratic_part = x**2 * scale_factor - scale_factor # Concatenate and exclude the last 0 (scheduler appends terminal 0) sigmas = np.concatenate([linear_part, quadratic_part]) sigmas = sigmas[:-1] # Remove trailing 0, scheduler will append it return sigmas.astype(np.float32) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, use_linear_quadratic_schedule: bool = False, linear_quadratic_emulating_steps: int = 250, **kwargs, ): """ Retrieve timesteps from the scheduler. Args: scheduler: The noise scheduler to use. num_inference_steps: Number of denoising steps. device: Device to place timesteps on. timesteps: Custom timestep values (mutually exclusive with sigmas). sigmas: Custom sigma values (mutually exclusive with timesteps). use_linear_quadratic_schedule: If True, use linear-quadratic sigma schedule. This overrides the default linear schedule. Requires num_inference_steps to be even. linear_quadratic_emulating_steps: Controls the linear portion slope. Higher values result in gentler slope in the first half. Default: 250. **kwargs: Additional arguments passed to scheduler.set_timesteps(). Returns: Tuple of (timesteps, num_inference_steps). Raises: ValueError: If both timesteps and sigmas are provided, or if use_linear_quadratic_schedule is True but num_inference_steps is odd. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") # Handle linear-quadratic schedule: compute sigmas if flag is set if use_linear_quadratic_schedule: if sigmas is not None: raise ValueError( "Cannot use both `sigmas` and `use_linear_quadratic_schedule`. " "The linear-quadratic schedule computes sigmas automatically." ) if num_inference_steps is None: raise ValueError("`num_inference_steps` must be provided when using `use_linear_quadratic_schedule`.") sigmas = get_linear_quadratic_sigmas( num_inference_steps=num_inference_steps, linear_quadratic_emulating_steps=linear_quadratic_emulating_steps, ) if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r"\s+", " ", text) text = text.strip() return text def prompt_clean(text): text = whitespace_clean(basic_clean(text)) return text class MotifVideoPipeline(DiffusionPipeline): r""" Pipeline for text-to-video generation using MotifVideoTransformer. Args: transformer ([`MotifVideoTransformer3DModel`]): Conditional Transformer architecture to denoise the encoded video latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded video latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. text_encoder ([`T5Gemma2Model`]): Primary text encoder for encoding text prompts into embeddings. tokenizer ([`PreTrainedTokenizerBase`]): Tokenizer corresponding to the primary text encoder. guider ([`ClassifierFreeGuidance`] or [`SkipLayerGuidance`] or [`AdaptiveProjectedGuidance`] or [`VideoAdaptiveProjectedGuidance`], *optional*): The guidance method to use. If `None`, it defaults to `ClassifierFreeGuidance()`. """ model_cpu_offload_seq = "text_encoder->transformer->vae" _optional_components = ["feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, scheduler: Union[ FlowMatchEulerDiscreteScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler, FlowUniPCMultistepScheduler, ], vae: AutoencoderKLWan, text_encoder: T5Gemma2Model, tokenizer: PreTrainedTokenizerBase, transformer, guider: Optional[ Union[ClassifierFreeGuidance, SkipLayerGuidance, AdaptiveProjectedGuidance, VideoAdaptiveProjectedGuidance] ] = None, feature_extractor: Optional[SiglipImageProcessor] = None, ): super().__init__() self.guider = ClassifierFreeGuidance() if guider is None else guider self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, feature_extractor=feature_extractor, ) self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.transformer_spatial_patch_size = ( self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 2 ) self.transformer_temporal_patch_size = ( self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.tokenizer_max_length = ( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 512 ) def _get_default_embeds( self, text_encoder, tokenizer: PreTrainedTokenizerBase, prompt: Union[str, List[str]], max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: dtype = dtype or text_encoder.dtype text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, return_tensors="pt", ) text_inputs = BatchEncoding( {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in text_inputs.items()} ) prompt_embeds = text_encoder(**text_inputs)[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds, text_inputs.attention_mask def _average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) denom = attention_mask.sum(dim=1, keepdim=True).clamp(min=1) # avoid div by zero return last_hidden.sum(dim=1) / denom def _get_prompt_embeds( self, text_encoder: T5Gemma2Model, tokenizer: PreTrainedTokenizerBase, prompt: Union[str, List[str]] | None = None, num_videos_per_prompt: int = 1, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt prompt_embeds_kwargs = { "text_encoder": text_encoder, "tokenizer": tokenizer, "prompt": prompt, "max_sequence_length": max_sequence_length, "device": device, "dtype": dtype, } # T5Gemma2Model bundles encoder and decoder/LM head, while _get_default_embeds expects an encoder-only model # (similar to T5EncoderModel/T5GemmaEncoderModel), so we pass the encoder submodule explicitly here. if isinstance(text_encoder, T5Gemma2Model): prompt_embeds_kwargs["text_encoder"] = text_encoder.encoder prompt_embeds, prompt_attention_mask = self._get_default_embeds(**prompt_embeds_kwargs) pooled_prompt_embeds = self._average_pool(prompt_embeds, prompt_attention_mask) return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds # Keep encode_prompt structure, uses _get_prompt_embeds internally def encode_prompt( self, prompt: Union[str, List[str]], num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, ]: device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] prompt_embeds_kwargs = { "device": device, "dtype": dtype, } if prompt_embeds is None: prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self._get_prompt_embeds( text_encoder=self.text_encoder, tokenizer=self.tokenizer, prompt=prompt, max_sequence_length=max_sequence_length, **prompt_embeds_kwargs, ) # duplicate text embeddings for each generation per prompt, using mps friendly method seq_len = prompt_embeds.shape[1] prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) if pooled_prompt_embeds is not None: pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) # Keep attention mask handling prompt_attention_mask = prompt_attention_mask.bool() prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0) return ( prompt_embeds, pooled_prompt_embeds, prompt_attention_mask, ) @property def vision_encoder(self): """Get the vision encoder from T5Gemma2. T5Gemma2 has vision_tower.vision_model structure. Will raise AttributeError if not available. """ return self.text_encoder.encoder.vision_tower.vision_model def encode_image( self, image: Image.Image, batch_size: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: """Encode image to embeddings using SigLIP vision encoder.""" device = device or self._execution_device dtype = dtype or self.transformer.dtype image_embeds = self._get_image_embeds( image_encoder=self.vision_encoder, feature_extractor=self.feature_extractor, image=image, device=device, ) image_embeds = image_embeds.repeat(batch_size, 1, 1) return image_embeds.to(device=device, dtype=dtype) @staticmethod def _get_image_embeds( image_encoder, feature_extractor: SiglipImageProcessor, image, device: torch.device, ) -> torch.Tensor: """Helper to encode single image with SigLIP. Args: image_encoder: The SigLIP vision encoder model. feature_extractor: SiglipImageProcessor for preprocessing. image: Can be either: - PIL.Image.Image: Will be preprocessed by feature_extractor - torch.Tensor: Assumed to be in [0, 1] range, will be normalized and passed to encoder device: Device to place tensors on. Returns: Image embeddings from the vision encoder. """ image_encoder_dtype = next(image_encoder.parameters()).dtype if isinstance(image, torch.Tensor): image = feature_extractor.preprocess( images=image.float(), do_resize=True, do_rescale=False, do_normalize=True, do_convert_rgb=True, return_tensors="pt", ) else: image = feature_extractor.preprocess( images=image, do_resize=True, do_rescale=False, do_normalize=True, do_convert_rgb=True, return_tensors="pt", ) image = image.to(device, dtype=image_encoder_dtype) return image_encoder(**image).last_hidden_state @torch.compiler.disable def _prepare_first_frame_conditioning( self, video: torch.Tensor, latents: torch.Tensor, use_conditioning: bool, generator: Optional[torch.Generator] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """Prepare first frame conditioning tensors. This method implements batch-level conditioning where entire batches are either I2V (all samples conditioned) or T2V (no conditioning). This prevents mode confusion within batches. For I2V mode: 1. Extract and VAE-encode first frame from video 2. Create latent_condition by repeating first frame across time (frame 0 only) 3. Create latent_mask with 1.0 at frame 0 4. Get image_embeds from vision encoder For T2V mode: 1. Pad with zeros for latent_condition and latent_mask Args: video: Input video tensor [batch_size, frames, channels, height, width] in [-1, 1] latents: Latents [batch_size, lantent_channels, latent_num_frames, latent_height, latent_width] use_conditioning: Whether to use first-frame conditioning (True for I2V, False for T2V) generator: Optional random number generator for reproducibility Returns: Tuple of (latent_condition, latent_mask, image_embeds). - latent_condition: [B, C, F, H, W] conditioning signal (zeros for T2V) - latent_mask: [B, 1, F, H, W] binary mask (zeros for T2V) - image_embeds: [B, N, D] image embeddings from vision encoder or None for T2V """ batch_size, lantent_channels, latent_num_frames, latent_height, latent_width = latents.shape device = latents.device dtype = latents.dtype # Determine if we should use conditioning use_conditioning = use_conditioning and (latent_num_frames > 1) # Initialize conditioning tensors latent_condition = torch.zeros( batch_size, lantent_channels, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype ) latent_mask = torch.zeros( batch_size, 1, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype ) image_embeds = None if use_conditioning: with torch.no_grad(): # Encode first frame for latent_condition first_frame_latents = self.vae.encode( rearrange(video[:, 0:1], "b f c h w -> b c f h w") ).latent_dist.sample(generator=generator) first_frame_latents = self._normalize_latents( latents=first_frame_latents, latents_mean=self.vae.config.latents_mean, latents_std=self.vae.config.latents_std, ) # Create latent_condition by repeating first frame across time latent_condition = first_frame_latents.repeat(1, 1, latent_num_frames, 1, 1) latent_condition[:, :, 1:, :, :] = 0 # latent_mask: 1.0 at frame 0, 0.0 elsewhere latent_mask[:, :, 0] = 1.0 # image_embeds from vision encoder first_frame_vision = video[:, 0] # [B, C, H, W] first_frame_vision = ((first_frame_vision + 1) / 2).clamp(0, 1) with torch.no_grad(): image_embeds = self._get_image_embeds( image_encoder=self.vision_encoder, feature_extractor=self.feature_extractor, image=first_frame_vision, device=device, ) return latent_condition, latent_mask, image_embeds def check_inputs( self, prompt, negative_prompt, height, width, batch_size, callback_on_step_end_tensor_inputs=None, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, ): # Resolution must be divisible by VAE scale factor * transformer patch size # (e.g. 8 * 2 = 16 for default config) to avoid latent/patch dimension mismatch. spatial_divisor = self.vae_scale_factor_spatial * self.transformer_spatial_patch_size if height % spatial_divisor != 0 or width % spatial_divisor != 0: raise ValueError( f"`height` and `width` have to be divisible by {spatial_divisor} " f"(vae_scale={self.vae_scale_factor_spatial} * patch_size={self.transformer_spatial_patch_size}) " f"but are {height} and {width}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") # Validate negative_prompt: must be None, str, or list with matching batch_size if negative_prompt is not None: if not isinstance(negative_prompt, (str, list)): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") if isinstance(negative_prompt, list) and len(negative_prompt) != batch_size: raise ValueError( f"`negative_prompt` list length ({len(negative_prompt)}) must match batch_size ({batch_size})." ) if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: raise ValueError( "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" f" {negative_prompt_attention_mask.shape}." ) def _prepare_negative_prompt( self, negative_prompt: Optional[Union[str, List[str]]], batch_size: int, ) -> List[str]: """ Prepare negative_prompt to match batch_size. Args: negative_prompt: None, a single string, or a list of strings matching batch_size. batch_size: The number of prompts in the batch. Returns: A list of negative prompts with length equal to batch_size. """ if negative_prompt is None: return [""] * batch_size if isinstance(negative_prompt, str): return [negative_prompt] * batch_size return negative_prompt @staticmethod def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = latents.shape post_patch_num_frames = num_frames // patch_size_t post_patch_height = height // patch_size post_patch_width = width // patch_size latents = latents.reshape( batch_size, -1, post_patch_num_frames, patch_size_t, post_patch_height, patch_size, post_patch_width, patch_size, ) latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) return latents @staticmethod def _unpack_latents( latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1, ) -> torch.Tensor: batch_size = latents.size(0) latents = latents.reshape( batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size, ) latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) return latents @staticmethod def _normalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor ) -> torch.Tensor: # Normalize latents across the channel dimension [B, C, F, H, W] latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents = (latents - latents_mean) / latents_std return latents @staticmethod def _denormalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor ) -> torch.Tensor: # Denormalize latents across the channel dimension [B, C, F, H, W] latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) latents = latents * latents_std + latents_mean return latents def prepare_latents( self, batch_size: int = 1, num_channels_latents: int = 16, height: int = 352, width: int = 640, num_frames: int = 65, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: return latents.to(device=device, dtype=dtype) shape = ( batch_size, num_channels_latents, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents @property def num_timesteps(self): return self._num_timesteps @property def current_timestep(self): return self._current_timestep @property def attention_kwargs(self): return self._attention_kwargs @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] | None = None, image=None, negative_prompt: Optional[Union[str, List[str]]] = "text overlay, graphic overlay, watermark, logo, subtitles, timestamp, broadcast graphics, UI elements, random letters, frozen pose, rigid, static expression, jerky motion, mechanical motion, discontinuous motion, flat framing, depthless, dull lighting, monotone, crushed shadows, blown-out highlights, shifting background, fading background, poor continuity, identity drift, deformation, flickering, ghosting, smearing, duplication, mutated proportions, inconsistent clothing, flat colors, desaturated, tonally compressed, poor background separation, exposure shift, uneven brightness, color balance shift", height: int = 736, width: int = 1280, num_frames: int = 121, frame_rate: int = 24, num_inference_steps: int = 50, timesteps: List[int] | None = None, use_linear_quadratic_schedule: bool = False, linear_quadratic_emulating_steps: int = 250, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, use_attention_mask: bool = True, vae_batch_size: int | None = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance. height (`int`, defaults to `352`): The height in pixels of the generated image. width (`int`, defaults to `640`): The width in pixels of the generated image. num_frames (`int`, defaults to `65`): The number of video frames to generate frame_rate (`int`, defaults to `25`): Frame rate for the output video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. use_linear_quadratic_schedule (`bool`, defaults to `True`): Whether to use a linear-quadratic sigma schedule instead of the default linear schedule. This schedule combines linear interpolation in the first half (slow denoising at high noise) with quadratic interpolation in the second half (faster denoising toward clean image). Requires `num_inference_steps` to be even. linear_quadratic_emulating_steps (`int`, defaults to `250`): Controls the slope of linear interpolation in the first half of the linear-quadratic schedule. Higher values result in a gentler slope. Only used when `use_linear_quadratic_schedule=True`. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): PyTorch Generator object(s) for deterministic generation. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format ("pil" or "np"). return_dict (`bool`, *optional*, defaults to `True`): Whether to return a `MotifVideoPipelineOutput`. attention_kwargs (`dict`, *optional*): Arguments passed to the attention processor. callback_on_step_end (`Callable`, *optional*): Callback function called at the end of each step. callback_on_step_end_tensor_inputs (`List`, *optional*): Tensors to include in the callback. max_sequence_length (`int` defaults to `512`): Maximum sequence length for the tokenizer. Examples: Returns: [`~pipelines.motif_video.MotifVideoPipelineOutput`] or `tuple`: If `return_dict` is `True`, returns [`~pipelines.motif_video.MotifVideoPipelineOutput`], otherwise returns a tuple where the first element is a list of generated video frames. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Define call parameters (batch_size needed for check_inputs) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # 2. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, batch_size=batch_size, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, ) self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None # Auto-upgrade AdaptiveProjectedGuidance to VideoAdaptiveProjectedGuidance # for video generation. Video-aware APG normalizes per-frame [C,H,W] instead # of collapsing the temporal axis, preserving motion quality. if type(self.guider) is AdaptiveProjectedGuidance: self.guider = VideoAdaptiveProjectedGuidance( guidance_scale=self.guider.guidance_scale, adaptive_projected_guidance_rescale=self.guider.adaptive_projected_guidance_rescale, adaptive_projected_guidance_momentum=self.guider.adaptive_projected_guidance_momentum, eta=self.guider.eta, use_original_formulation=self.guider.use_original_formulation, ) device = self._execution_device # 3. Prepare text embeddings prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, prompt_attention_mask=prompt_attention_mask, max_sequence_length=max_sequence_length, device=device, ) if self.guider._enabled: negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size) negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=negative_pooled_prompt_embeds, prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, device=device, ) num_channels_latents = self.vae.config.z_dim latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, num_frames, self.transformer.dtype, device, generator, latents, ) # 4.5 Preprocess image for I2V conditioning if image is not None: from PIL import Image as PILImage if isinstance(image, PILImage.Image): image = image.convert("RGB").resize((width, height), PILImage.LANCZOS) image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0 image = image * 2.0 - 1.0 # [0,1] -> [-1,1] image = image.unsqueeze(0) # [1, C, H, W] # Handle [C, H, W] -> [1, C, H, W] if image.dim() == 3: image = image.unsqueeze(0) # [B, C, H, W] -> [B, 1, C, H, W] for video format if image.dim() == 4: image = image.unsqueeze(1) image = image.to(device=device, dtype=self.vae.dtype) # 5. Prepare timesteps (including mu calculation) # Recalculate latent dims based on VAE for mu calculation latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 # Calculate sequence length based on *packed* dimensions if transformer uses packing # Packed dims: H/patch, W/patch, F/patch_t packed_latent_height = latent_height // self.transformer_spatial_patch_size packed_latent_width = latent_width // self.transformer_spatial_patch_size packed_latent_num_frames = latent_num_frames // self.transformer_temporal_patch_size video_sequence_length = packed_latent_num_frames * packed_latent_height * packed_latent_width # Compute sigmas: use linear-quadratic schedule if enabled, otherwise default linear _is_flow_multistep = isinstance( self.scheduler, (DPMSolverMultistepScheduler, UniPCMultistepScheduler, FlowUniPCMultistepScheduler), ) # Compute mu once, shared by both branches (required by FlowUniPCMultistepScheduler) mu = calculate_shift( video_sequence_length, self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) if _is_flow_multistep: # DPMSolver/UniPC manage their own sigma schedule via use_flow_sigmas + flow_shift. # Pass mu for dynamic shifting support (required by FlowUniPCMultistepScheduler). timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, mu=mu, ) else: if use_linear_quadratic_schedule: # Linear-quadratic schedule computes sigmas internally in retrieve_timesteps sigmas = None else: sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, use_linear_quadratic_schedule=use_linear_quadratic_schedule, linear_quadratic_emulating_steps=linear_quadratic_emulating_steps, mu=mu, ) # Get conditioning tensors latent_condition, latent_mask, image_embeds = self._prepare_first_frame_conditioning( image, latents, use_conditioning=image is not None, generator=generator, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t # Concatenate current latents with conditioning for this timestep # [latents | latent_condition | latent_mask] hidden_states = torch.cat([latents, latent_condition, latent_mask], dim=1) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) # Step 1: Collect model inputs needed for the guidance method # conditional inputs should always be first element in the tuple guider_inputs = { "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), } if use_attention_mask: guider_inputs["encoder_attention_mask"] = (prompt_attention_mask, negative_prompt_attention_mask) if self.transformer.config.pooled_projection_dim is not None: guider_inputs["pooled_projections"] = (pooled_prompt_embeds, negative_pooled_prompt_embeds) if image_embeds is not None: guider_inputs["image_embeds"] = (image_embeds, image_embeds) # Step 2: Update guider's internal state for this denoising step self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) # Sigma injection for guiders that support sigma-based gating # (Kynkäänniemi 2024). Must precede `prepare_inputs` because # `num_conditions` → `_is_cfg_enabled()` reads `_current_sigma`. # Duck-typed so diffusers-native guiders are unaffected; guard # on scheduler too since some schedulers don't expose `sigmas`. if hasattr(self.guider, "_current_sigma") and hasattr(self.scheduler, "sigmas"): self.guider._current_sigma = float(self.scheduler.sigmas[i]) # Step 3: Prepare batched model inputs based on the guidance method # The guider splits model inputs into separate batches for conditional/unconditional predictions. # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: # you will get a guider_state with two batches: # guider_state = [ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch # ] # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). guider_state = self.guider.prepare_inputs(guider_inputs) # Step 4: Run the denoiser for each batch # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. for guider_state_batch in guider_state: self.guider.prepare_models(self.transformer) # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) cond_kwargs = { input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() } tread_disabled = getattr(self.guider, "_current_tread_disabled", False) # Override TREAD selection ratio per batch if the guider provides one selection_ratio = getattr(self.guider, "_current_selection_ratio", None) tread_mixin = getattr(self.transformer, "_inference_tread_mixin", None) if ( selection_ratio is not None and tread_mixin is not None and tread_mixin._tread_route is not None ): tread_mixin._tread_route["sel"] = selection_ratio # e.g. "pred_cond"/"pred_uncond" context_name = getattr(guider_state_batch, self.guider._identifier_key) with self.transformer.cache_context(context_name): # Run denoiser and store noise prediction in this batch noise_pred = self.transformer( hidden_states=hidden_states, timestep=timestep, attention_kwargs=self.attention_kwargs, return_dict=False, tread_disabled=tread_disabled, **cond_kwargs, )[0].clone() guider_state_batch.noise_pred = noise_pred # Cleanup model (e.g., remove hooks) self.guider.cleanup_models(self.transformer) # Step 5: Combine predictions using the guidance method # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. # Continuing the CFG example, the guider receives: # guider_state = [ # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 # ] # And extracts predictions using the __guidance_identifier__: # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond # Then applies CFG formula: # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) noise_pred = self.guider(guider_state)[0] # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # Handle negative embeds if needed by callback if "negative_prompt_embeds" in callback_outputs: negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds") # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() self._current_timestep = None if output_type == "latent": video = latents else: latents = latents.to(self.vae.dtype) latents = self._denormalize_latents(latents, self.vae.config.latents_mean, self.vae.config.latents_std) if vae_batch_size is not None and latents.shape[0] > vae_batch_size: video_chunks = [] for i in range(0, latents.shape[0], vae_batch_size): chunk = latents[i : i + vae_batch_size] video_chunks.append(self.vae.decode(chunk, return_dict=False)[0]) video = torch.cat(video_chunks, dim=0) del video_chunks else: video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) # Return updated output type return MotifVideoPipelineOutput(frames=video)