| import logging |
| from collections.abc import Iterator |
|
|
| import torch |
|
|
| from ltx_core.components.diffusion_steps import EulerDiffusionStep |
| from ltx_core.components.noisers import GaussianNoiser |
| from ltx_core.components.protocols import DiffusionStepProtocol |
| from ltx_core.loader import LoraPathStrengthAndSDOps |
| from ltx_core.model.audio_vae import decode_audio as vae_decode_audio |
| from ltx_core.model.upsampler import upsample_video |
| from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number |
| from ltx_core.model.video_vae import decode_video as vae_decode_video |
| from ltx_core.quantization import QuantizationPolicy |
| from ltx_core.types import Audio, LatentState, VideoPixelShape |
| from ltx_pipelines.utils import ModelLedger, euler_denoising_loop |
| from ltx_pipelines.utils.args import ( |
| ImageConditioningInput, |
| default_2_stage_distilled_arg_parser, |
| detect_checkpoint_path, |
| ) |
| from ltx_pipelines.utils.constants import ( |
| DISTILLED_SIGMA_VALUES, |
| STAGE_2_DISTILLED_SIGMA_VALUES, |
| detect_params, |
| ) |
| from ltx_pipelines.utils.helpers import ( |
| assert_resolution, |
| cleanup_memory, |
| combined_image_conditionings, |
| denoise_audio_video, |
| encode_prompts, |
| get_device, |
| simple_denoising_func, |
| ) |
| from ltx_pipelines.utils.media_io import encode_video |
| from ltx_pipelines.utils.types import PipelineComponents |
|
|
| device = get_device() |
|
|
|
|
| class DistilledPipeline: |
| """ |
| Two-stage distilled video generation pipeline. |
| Stage 1 generates video at half of the target resolution, then Stage 2 upsamples |
| by 2x and refines with additional denoising steps for higher quality output. |
| """ |
|
|
| def __init__( |
| self, |
| distilled_checkpoint_path: str, |
| gemma_root: str, |
| spatial_upsampler_path: str, |
| loras: list[LoraPathStrengthAndSDOps], |
| device: torch.device = device, |
| quantization: QuantizationPolicy | None = None, |
| ): |
| self.device = device |
| self.dtype = torch.bfloat16 |
|
|
| self.model_ledger = ModelLedger( |
| dtype=self.dtype, |
| device=device, |
| checkpoint_path=distilled_checkpoint_path, |
| spatial_upsampler_path=spatial_upsampler_path, |
| gemma_root_path=gemma_root, |
| loras=loras, |
| quantization=quantization, |
| ) |
|
|
| self.pipeline_components = PipelineComponents( |
| dtype=self.dtype, |
| device=device, |
| ) |
|
|
| def __call__( |
| self, |
| prompt: str, |
| seed: int, |
| height: int, |
| width: int, |
| num_frames: int, |
| frame_rate: float, |
| images: list[ImageConditioningInput], |
| tiling_config: TilingConfig | None = None, |
| enhance_prompt: bool = False, |
| ) -> tuple[Iterator[torch.Tensor], Audio]: |
| assert_resolution(height=height, width=width, is_two_stage=True) |
|
|
| generator = torch.Generator(device=self.device).manual_seed(seed) |
| noiser = GaussianNoiser(generator=generator) |
| stepper = EulerDiffusionStep() |
| dtype = torch.bfloat16 |
|
|
| (ctx_p,) = encode_prompts( |
| [prompt], |
| self.model_ledger, |
| enhance_first_prompt=enhance_prompt, |
| enhance_prompt_image=images[0][0] if len(images) > 0 else None, |
| ) |
| video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding |
|
|
| |
| video_encoder = self.model_ledger.video_encoder() |
| transformer = self.model_ledger.transformer() |
| stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device) |
|
|
| def denoising_loop( |
| sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol |
| ) -> tuple[LatentState, LatentState]: |
| return euler_denoising_loop( |
| sigmas=sigmas, |
| video_state=video_state, |
| audio_state=audio_state, |
| stepper=stepper, |
| denoise_fn=simple_denoising_func( |
| video_context=video_context, |
| audio_context=audio_context, |
| transformer=transformer, |
| ), |
| ) |
|
|
| stage_1_output_shape = VideoPixelShape( |
| batch=1, |
| frames=num_frames, |
| width=width // 2, |
| height=height // 2, |
| fps=frame_rate, |
| ) |
| stage_1_conditionings = combined_image_conditionings( |
| images=images, |
| height=stage_1_output_shape.height, |
| width=stage_1_output_shape.width, |
| video_encoder=video_encoder, |
| dtype=dtype, |
| device=self.device, |
| ) |
|
|
| video_state, audio_state = denoise_audio_video( |
| output_shape=stage_1_output_shape, |
| conditionings=stage_1_conditionings, |
| noiser=noiser, |
| sigmas=stage_1_sigmas, |
| stepper=stepper, |
| denoising_loop_fn=denoising_loop, |
| components=self.pipeline_components, |
| dtype=dtype, |
| device=self.device, |
| ) |
|
|
| |
| upscaled_video_latent = upsample_video( |
| latent=video_state.latent[:1], video_encoder=video_encoder, upsampler=self.model_ledger.spatial_upsampler() |
| ) |
|
|
| torch.cuda.synchronize() |
| cleanup_memory() |
|
|
| stage_2_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device) |
| stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) |
| stage_2_conditionings = combined_image_conditionings( |
| images=images, |
| height=stage_2_output_shape.height, |
| width=stage_2_output_shape.width, |
| video_encoder=video_encoder, |
| dtype=dtype, |
| device=self.device, |
| ) |
| video_state, audio_state = denoise_audio_video( |
| output_shape=stage_2_output_shape, |
| conditionings=stage_2_conditionings, |
| noiser=noiser, |
| sigmas=stage_2_sigmas, |
| stepper=stepper, |
| denoising_loop_fn=denoising_loop, |
| components=self.pipeline_components, |
| dtype=dtype, |
| device=self.device, |
| noise_scale=stage_2_sigmas[0], |
| initial_video_latent=upscaled_video_latent, |
| initial_audio_latent=audio_state.latent, |
| ) |
|
|
| torch.cuda.synchronize() |
| del transformer |
| del video_encoder |
| cleanup_memory() |
|
|
| decoded_video = vae_decode_video( |
| video_state.latent, self.model_ledger.video_decoder(), tiling_config, generator |
| ) |
| decoded_audio = vae_decode_audio( |
| audio_state.latent, self.model_ledger.audio_decoder(), self.model_ledger.vocoder() |
| ) |
| return decoded_video, decoded_audio |
|
|
|
|
| @torch.inference_mode() |
| def main() -> None: |
| logging.getLogger().setLevel(logging.INFO) |
| checkpoint_path = detect_checkpoint_path(distilled=True) |
| params = detect_params(checkpoint_path) |
| parser = default_2_stage_distilled_arg_parser(params=params) |
| args = parser.parse_args() |
| pipeline = DistilledPipeline( |
| distilled_checkpoint_path=args.distilled_checkpoint_path, |
| spatial_upsampler_path=args.spatial_upsampler_path, |
| gemma_root=args.gemma_root, |
| loras=tuple(args.lora) if args.lora else (), |
| quantization=args.quantization, |
| ) |
| tiling_config = TilingConfig.default() |
| video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config) |
| video, audio = pipeline( |
| prompt=args.prompt, |
| seed=args.seed, |
| height=args.height, |
| width=args.width, |
| num_frames=args.num_frames, |
| frame_rate=args.frame_rate, |
| images=args.images, |
| tiling_config=tiling_config, |
| enhance_prompt=args.enhance_prompt, |
| ) |
|
|
| encode_video( |
| video=video, |
| fps=args.frame_rate, |
| audio=audio, |
| output_path=args.output_path, |
| video_chunks_number=video_chunks_number, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|