import logging from collections.abc import Iterator import torch from ltx_core.components.noisers import GaussianNoiser from ltx_core.conditioning import ConditioningItem from ltx_core.loader import LoraPathStrengthAndSDOps from ltx_core.loader.registry import Registry from ltx_core.model.video_vae import TilingConfig, VideoEncoder, get_video_chunks_number from ltx_core.quantization import QuantizationPolicy from ltx_core.types import Audio, VideoPixelShape from ltx_pipelines.iclora_utils import ( append_ic_lora_reference_video_conditionings, read_lora_reference_downscale_factor, ) from ltx_pipelines.utils.args import ( ImageConditioningInput, VideoConditioningAction, VideoMaskConditioningAction, default_2_stage_distilled_arg_parser, detect_checkpoint_path, ) from ltx_pipelines.utils.blocks import ( AudioDecoder, DiffusionStage, ImageConditioner, PromptEncoder, VideoDecoder, VideoUpsampler, ) from ltx_pipelines.utils.constants import ( DISTILLED_SIGMAS, STAGE_2_DISTILLED_SIGMAS, detect_params, ) from ltx_pipelines.utils.denoisers import SimpleDenoiser from ltx_pipelines.utils.helpers import assert_resolution, combined_image_conditionings, get_device from ltx_pipelines.utils.media_io import decode_video_by_frame, encode_video, video_preprocess from ltx_pipelines.utils.types import ModalitySpec, OffloadMode class ICLoraPipeline: """ Two-stage video generation pipeline with In-Context (IC) LoRA support. Allows conditioning the generated video on control signals such as depth maps, human pose, or image edges via the video_conditioning parameter. The specific IC-LoRA model should be provided via the loras parameter. Stage 1 generates video at half of the target resolution, then Stage 2 upsamples by 2x and refines with additional denoising steps for higher quality output. Both stages use distilled models for efficiency. """ def __init__( self, distilled_checkpoint_path: str, spatial_upsampler_path: str, gemma_root: str, loras: list[LoraPathStrengthAndSDOps], device: torch.device | None = None, quantization: QuantizationPolicy | None = None, registry: Registry | None = None, torch_compile: bool = False, offload_mode: OffloadMode = OffloadMode.NONE, ): self.device = device or get_device() self.dtype = torch.bfloat16 self.prompt_encoder = PromptEncoder( distilled_checkpoint_path, gemma_root, self.dtype, self.device, registry=registry, offload_mode=offload_mode, ) self.image_conditioner = ImageConditioner(distilled_checkpoint_path, self.dtype, self.device, registry=registry) self.stage_1 = DiffusionStage( distilled_checkpoint_path, self.dtype, self.device, loras=tuple(loras), quantization=quantization, registry=registry, torch_compile=torch_compile, offload_mode=offload_mode, ) self.stage_2 = DiffusionStage( distilled_checkpoint_path, self.dtype, self.device, loras=(), quantization=quantization, registry=registry, torch_compile=torch_compile, offload_mode=offload_mode, ) self.upsampler = VideoUpsampler( distilled_checkpoint_path, spatial_upsampler_path, self.dtype, self.device, registry=registry ) self.video_decoder = VideoDecoder(distilled_checkpoint_path, self.dtype, self.device, registry=registry) self.audio_decoder = AudioDecoder(distilled_checkpoint_path, self.dtype, self.device, registry=registry) # Read reference downscale factor from LoRA metadata. # IC-LoRAs trained with low-resolution reference videos store this factor # so inference can resize reference videos to match training conditions. self.reference_downscale_factor = 1 for lora in loras: scale = read_lora_reference_downscale_factor(lora.path) if scale != 1: if self.reference_downscale_factor not in (1, scale): raise ValueError( f"Conflicting reference_downscale_factor values in LoRAs: " f"already have {self.reference_downscale_factor}, but {lora.path} " f"specifies {scale}. Cannot combine LoRAs with different reference scales." ) self.reference_downscale_factor = scale def __call__( # noqa: PLR0913 self, prompt: str, seed: int, height: int, width: int, num_frames: int, frame_rate: float, images: list[ImageConditioningInput], video_conditioning: list[tuple[str, float]], enhance_prompt: bool = False, tiling_config: TilingConfig | None = None, conditioning_attention_strength: float = 1.0, skip_stage_2: bool = False, conditioning_attention_mask: torch.Tensor | None = None, stage_1_sigmas: torch.Tensor = DISTILLED_SIGMAS, stage_2_sigmas: torch.Tensor = STAGE_2_DISTILLED_SIGMAS, ) -> tuple[Iterator[torch.Tensor], Audio]: """ Generate video with IC-LoRA conditioning. Args: prompt: Text prompt for video generation. seed: Random seed for reproducibility. height: Output video height in pixels (must be divisible by 64). width: Output video width in pixels (must be divisible by 64). num_frames: Number of frames to generate. frame_rate: Output video frame rate. images: List of (path, frame_idx, strength) tuples for image conditioning. video_conditioning: List of (path, strength) tuples for IC-LoRA video conditioning. enhance_prompt: Whether to enhance the prompt using the text encoder. tiling_config: Optional tiling configuration for VAE decoding. conditioning_attention_strength: Scale factor for IC-LoRA conditioning attention. Controls how strongly the conditioning video influences the output. 0.0 = ignore conditioning, 1.0 = full conditioning influence. Default 1.0. When conditioning_attention_mask is provided, the mask is multiplied by this strength before being passed to the conditioning items. skip_stage_2: If True, skip Stage 2 upsampling and refinement. Output will be at half resolution (height//2, width//2). Default is False. conditioning_attention_mask: Optional pixel-space attention mask with the same spatial-temporal dimensions as the input reference video. Shape should be (B, 1, F, H, W) or (1, 1, F, H, W) where F, H, W match the reference video's pixel dimensions. Values in [0, 1]. The mask is downsampled to latent space using VAE scale factors (with causal temporal handling for the first frame), then multiplied by conditioning_attention_strength. When None (default): scalar conditioning_attention_strength is used directly. Returns: Tuple of (video_iterator, audio_tensor). """ assert_resolution(height=height, width=width, is_two_stage=True) if not (0.0 <= conditioning_attention_strength <= 1.0): raise ValueError( f"conditioning_attention_strength must be in [0.0, 1.0], got {conditioning_attention_strength}" ) generator = torch.Generator(device=self.device).manual_seed(seed) noiser = GaussianNoiser(generator=generator) (ctx_p,) = self.prompt_encoder( [prompt], enhance_first_prompt=enhance_prompt, enhance_prompt_image=images[0][0] if len(images) > 0 else None, enhance_prompt_seed=seed, ) video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding # Stage 1: Initial low resolution video generation. stage_1_output_shape = VideoPixelShape( batch=1, frames=num_frames, width=width // 2, height=height // 2, fps=frame_rate, ) # Encode conditionings using the video encoder block stage_1_conditionings = self.image_conditioner( lambda enc: self._create_conditionings( images=images, video_conditioning=video_conditioning, height=stage_1_output_shape.height, width=stage_1_output_shape.width, video_encoder=enc, num_frames=num_frames, conditioning_attention_strength=conditioning_attention_strength, conditioning_attention_mask=conditioning_attention_mask, ) ) stage_1_sigmas = stage_1_sigmas.to(dtype=torch.float32, device=self.device) video_state, audio_state = self.stage_1( denoiser=SimpleDenoiser(video_context, audio_context), sigmas=stage_1_sigmas, noiser=noiser, width=stage_1_output_shape.width, height=stage_1_output_shape.height, frames=num_frames, fps=frame_rate, video=ModalitySpec( context=video_context, conditionings=stage_1_conditionings, ), audio=ModalitySpec( context=audio_context, ), ) if skip_stage_2: # Skip Stage 2: Decode directly from Stage 1 output at half resolution logging.info("[IC-LoRA] Skipping Stage 2 (--skip-stage-2 enabled)") decoded_video = self.video_decoder(video_state.latent, tiling_config, generator) decoded_audio = self.audio_decoder(audio_state.latent) return decoded_video, decoded_audio # Stage 2: Upsample and refine the video at higher resolution with distilled LORA. upscaled_video_latent = self.upsampler(video_state.latent[:1]) stage_2_sigmas = stage_2_sigmas.to(dtype=torch.float32, device=self.device) stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) stage_2_conditionings = self.image_conditioner( lambda enc: combined_image_conditionings( images=images, height=stage_2_output_shape.height, width=stage_2_output_shape.width, video_encoder=enc, dtype=self.dtype, device=self.device, ) ) video_state, audio_state = self.stage_2( denoiser=SimpleDenoiser(video_context, audio_context), sigmas=stage_2_sigmas, noiser=noiser, width=width, height=height, frames=num_frames, fps=frame_rate, video=ModalitySpec( context=video_context, conditionings=stage_2_conditionings, noise_scale=stage_2_sigmas[0].item(), initial_latent=upscaled_video_latent, ), audio=ModalitySpec( context=audio_context, noise_scale=stage_2_sigmas[0].item(), initial_latent=audio_state.latent, ), ) decoded_video = self.video_decoder(video_state.latent, tiling_config, generator) decoded_audio = self.audio_decoder(audio_state.latent) return decoded_video, decoded_audio def _create_conditionings( self, images: list[ImageConditioningInput], video_conditioning: list[tuple[str, float]], height: int, width: int, num_frames: int, video_encoder: VideoEncoder, conditioning_attention_strength: float = 1.0, conditioning_attention_mask: torch.Tensor | None = None, ) -> list[ConditioningItem]: """ Create conditioning items for video generation. Args: conditioning_attention_strength: Scalar attention weight in [0, 1]. If conditioning_attention_mask is also provided, the downsampled mask is multiplied by this strength. Otherwise this scalar is passed directly as the attention mask. conditioning_attention_mask: Optional pixel-space attention mask with shape (B, 1, F_pixel, H_pixel, W_pixel) matching the reference video's pixel dimensions. Downsampled to latent space with causal temporal handling, then multiplied by conditioning_attention_strength. Returns: List of conditioning items. IC-LoRA conditionings are appended last. """ conditionings = combined_image_conditionings( images=images, height=height, width=width, video_encoder=video_encoder, dtype=self.dtype, device=self.device, ) append_ic_lora_reference_video_conditionings( conditionings, video_conditioning, height=height, width=width, num_frames=num_frames, video_encoder=video_encoder, dtype=self.dtype, device=self.device, reference_downscale_factor=self.reference_downscale_factor, conditioning_attention_strength=conditioning_attention_strength, conditioning_attention_mask=conditioning_attention_mask, tiling_config=None, ) if video_conditioning: logging.info("[IC-LoRA] Added %d video conditioning(s)", len(video_conditioning)) return conditionings @torch.inference_mode() def main() -> None: logging.getLogger().setLevel(logging.INFO) checkpoint_path = detect_checkpoint_path(distilled=True) params = detect_params(checkpoint_path) parser = default_2_stage_distilled_arg_parser(params=params) parser.add_argument( "--video-conditioning", action=VideoConditioningAction, nargs=2, metavar=("PATH", "STRENGTH"), required=True, ) parser.add_argument( "--conditioning-attention-mask", action=VideoMaskConditioningAction, nargs=2, metavar=("MASK_PATH", "STRENGTH"), default=None, help=( "Optional spatial attention mask: path to a grayscale mask video and " "attention strength. The mask video pixel values in [0,1] control " "per-region conditioning attention strength. The strength scalar is " "multiplied with the spatial mask. " "0.0 = ignore IC-LoRA conditioning, 1.0 = full conditioning influence. " "When not provided, full conditioning strength (1.0) is used. " "Example: --conditioning-attention-mask path/to/mask.mp4 0.5" ), ) parser.add_argument( "--skip-stage-2", action="store_true", help=( "Skip Stage 2 upsampling and refinement. Output will be at half resolution " "(height//2, width//2). Useful for faster iteration or when GPU memory is limited." ), ) args = parser.parse_args() # Load mask video if provided via --conditioning-attention-mask conditioning_attention_mask = None conditioning_attention_strength = 1.0 if args.conditioning_attention_mask is not None: mask_path, mask_strength = args.conditioning_attention_mask conditioning_attention_strength = mask_strength conditioning_attention_mask = _load_mask_video( mask_path=mask_path, height=args.height // 2, # Stage 1 operates at half resolution width=args.width // 2, num_frames=args.num_frames, ) pipeline = ICLoraPipeline( distilled_checkpoint_path=args.distilled_checkpoint_path, spatial_upsampler_path=args.spatial_upsampler_path, gemma_root=args.gemma_root, loras=tuple(args.lora) if args.lora else (), quantization=args.quantization, torch_compile=args.compile, offload_mode=args.offload_mode, ) tiling_config = TilingConfig.default() video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config) video, audio = pipeline( prompt=args.prompt, seed=args.seed, height=args.height, width=args.width, num_frames=args.num_frames, frame_rate=args.frame_rate, images=args.images, video_conditioning=args.video_conditioning, tiling_config=tiling_config, conditioning_attention_strength=conditioning_attention_strength, skip_stage_2=args.skip_stage_2, conditioning_attention_mask=conditioning_attention_mask, ) encode_video( video=video, fps=args.frame_rate, audio=audio, output_path=args.output_path, video_chunks_number=video_chunks_number, ) def _load_mask_video( mask_path: str, height: int, width: int, num_frames: int, ) -> torch.Tensor: """Load a mask video and return a pixel-space tensor of shape (1, 1, F, H, W). The mask video is loaded, resized to (height, width), converted to grayscale, and normalised to [0, 1]. Args: mask_path: Path to the mask video file. height: Target height in pixels. width: Target width in pixels. num_frames: Maximum number of frames to load. Returns: Tensor of shape ``(1, 1, F, H, W)`` with values in ``[0, 1]``. """ device = get_device() frame_gen = decode_video_by_frame(path=mask_path, frame_cap=num_frames, device=device) mask_video = video_preprocess(frame_gen, height, width, torch.bfloat16, device) # mask_video shape: (1, C, F, H, W) — take mean over channels for grayscale mask = mask_video.mean(dim=1, keepdim=True) # (1, 1, F, H, W) # Normalise to [0, 1] — video_preprocess applies normalize_latent, # so undo that: values are in [-1, 1], remap to [0, 1] mask = (mask + 1.0) / 2.0 return mask.clamp(0.0, 1.0) if __name__ == "__main__": main()