Buckets:
| import logging | |
| from collections.abc import Iterator | |
| import torch | |
| from ltx_core.components.noisers import GaussianNoiser | |
| from ltx_core.loader import LoraPathStrengthAndSDOps | |
| from ltx_core.loader.registry import Registry | |
| from ltx_core.model.transformer.compiling import CompilationConfig | |
| from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number | |
| from ltx_core.quantization import QuantizationPolicy | |
| from ltx_core.types import Audio | |
| from ltx_pipelines.utils.args import ( | |
| ImageConditioningInput, | |
| default_2_stage_distilled_arg_parser, | |
| detect_checkpoint_path, | |
| ) | |
| from ltx_pipelines.utils.blocks import ( | |
| AudioDecoder, | |
| DiffusionStage, | |
| ImageConditioner, | |
| PromptEncoder, | |
| VideoDecoder, | |
| VideoUpsampler, | |
| ) | |
| from ltx_pipelines.utils.constants import ( | |
| DISTILLED_SIGMAS, | |
| STAGE_2_DISTILLED_SIGMAS, | |
| detect_params, | |
| ) | |
| from ltx_pipelines.utils.denoisers import SimpleDenoiser | |
| from ltx_pipelines.utils.helpers import ( | |
| assert_resolution, | |
| combined_image_conditionings, | |
| get_device, | |
| ) | |
| from ltx_pipelines.utils.media_io import encode_video | |
| from ltx_pipelines.utils.types import ModalitySpec, OffloadMode | |
| class DistilledPipeline: | |
| """ | |
| Two-stage distilled video generation pipeline. | |
| Stage 1 generates video at half of the target resolution, then Stage 2 upsamples | |
| by 2x and refines with additional denoising steps for higher quality output. | |
| """ | |
| def __init__( | |
| self, | |
| distilled_checkpoint_path: str, | |
| gemma_root: str, | |
| spatial_upsampler_path: str, | |
| loras: list[LoraPathStrengthAndSDOps], | |
| device: torch.device | None = None, | |
| quantization: QuantizationPolicy | None = None, | |
| registry: Registry | None = None, | |
| compilation_config: CompilationConfig | None = None, | |
| offload_mode: OffloadMode = OffloadMode.NONE, | |
| ): | |
| self.device = device or get_device() | |
| self.dtype = torch.bfloat16 | |
| self.prompt_encoder = PromptEncoder( | |
| distilled_checkpoint_path, | |
| gemma_root, | |
| self.dtype, | |
| self.device, | |
| registry=registry, | |
| offload_mode=offload_mode, | |
| ) | |
| self.image_conditioner = ImageConditioner(distilled_checkpoint_path, self.dtype, self.device, registry=registry) | |
| self.stage = DiffusionStage( | |
| distilled_checkpoint_path, | |
| self.dtype, | |
| self.device, | |
| loras=tuple(loras), | |
| quantization=quantization, | |
| registry=registry, | |
| compilation_config=compilation_config, | |
| offload_mode=offload_mode, | |
| ) | |
| self.upsampler = VideoUpsampler( | |
| distilled_checkpoint_path, spatial_upsampler_path, self.dtype, self.device, registry=registry | |
| ) | |
| self.video_decoder = VideoDecoder(distilled_checkpoint_path, self.dtype, self.device, registry=registry) | |
| self.audio_decoder = AudioDecoder(distilled_checkpoint_path, self.dtype, self.device, registry=registry) | |
| def __call__( # noqa: PLR0913 | |
| self, | |
| prompt: str, | |
| seed: int, | |
| height: int, | |
| width: int, | |
| num_frames: int, | |
| frame_rate: float, | |
| images: list[ImageConditioningInput], | |
| tiling_config: TilingConfig | None = None, | |
| enhance_prompt: bool = False, | |
| stage_1_sigmas: torch.Tensor = DISTILLED_SIGMAS, | |
| stage_2_sigmas: torch.Tensor = STAGE_2_DISTILLED_SIGMAS, | |
| ) -> tuple[Iterator[torch.Tensor], Audio]: | |
| assert_resolution(height=height, width=width, is_two_stage=True) | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| noiser = GaussianNoiser(generator=generator) | |
| dtype = torch.bfloat16 | |
| (ctx_p,) = self.prompt_encoder( | |
| [prompt], | |
| enhance_first_prompt=enhance_prompt, | |
| enhance_prompt_image=images[0][0] if len(images) > 0 else None, | |
| ) | |
| video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding | |
| # Stage 1: Initial low resolution video generation. | |
| stage_1_sigmas = stage_1_sigmas.to(dtype=torch.float32, device=self.device) | |
| stage_1_w, stage_1_h = width // 2, height // 2 | |
| stage_1_conditionings = self.image_conditioner( | |
| lambda enc: combined_image_conditionings( | |
| images=images, | |
| height=stage_1_h, | |
| width=stage_1_w, | |
| video_encoder=enc, | |
| dtype=dtype, | |
| device=self.device, | |
| ) | |
| ) | |
| video_state, audio_state = self.stage( | |
| denoiser=SimpleDenoiser(video_context, audio_context), | |
| sigmas=stage_1_sigmas, | |
| noiser=noiser, | |
| width=stage_1_w, | |
| height=stage_1_h, | |
| frames=num_frames, | |
| fps=frame_rate, | |
| video=ModalitySpec(context=video_context, conditionings=stage_1_conditionings), | |
| audio=ModalitySpec(context=audio_context), | |
| ) | |
| # Stage 2: Upsample and refine the video at higher resolution with distilled LORA. | |
| upscaled_video_latent = self.upsampler(video_state.latent[:1]) | |
| stage_2_sigmas = stage_2_sigmas.to(dtype=torch.float32, device=self.device) | |
| stage_2_conditionings = self.image_conditioner( | |
| lambda enc: combined_image_conditionings( | |
| images=images, | |
| height=height, | |
| width=width, | |
| video_encoder=enc, | |
| dtype=dtype, | |
| device=self.device, | |
| ) | |
| ) | |
| video_state, audio_state = self.stage( | |
| denoiser=SimpleDenoiser(video_context, audio_context), | |
| sigmas=stage_2_sigmas, | |
| noiser=noiser, | |
| width=width, | |
| height=height, | |
| frames=num_frames, | |
| fps=frame_rate, | |
| video=ModalitySpec( | |
| context=video_context, | |
| conditionings=stage_2_conditionings, | |
| noise_scale=stage_2_sigmas[0].item(), | |
| initial_latent=upscaled_video_latent, | |
| ), | |
| audio=ModalitySpec( | |
| context=audio_context, | |
| noise_scale=stage_2_sigmas[0].item(), | |
| initial_latent=audio_state.latent, | |
| ), | |
| ) | |
| decoded_video = self.video_decoder(video_state.latent, tiling_config, generator) | |
| decoded_audio = self.audio_decoder(audio_state.latent) | |
| return decoded_video, decoded_audio | |
| def main() -> None: | |
| logging.basicConfig(level=logging.INFO) | |
| checkpoint_path = detect_checkpoint_path(distilled=True) | |
| params = detect_params(checkpoint_path) | |
| parser = default_2_stage_distilled_arg_parser(params=params) | |
| args = parser.parse_args() | |
| pipeline = DistilledPipeline( | |
| distilled_checkpoint_path=args.distilled_checkpoint_path, | |
| spatial_upsampler_path=args.spatial_upsampler_path, | |
| gemma_root=args.gemma_root, | |
| loras=tuple(args.lora) if args.lora else (), | |
| quantization=args.quantization, | |
| compilation_config=args.compile, | |
| offload_mode=args.offload_mode, | |
| ) | |
| tiling_config = TilingConfig.default() | |
| video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config) | |
| video, audio = pipeline( | |
| prompt=args.prompt, | |
| seed=args.seed, | |
| height=args.height, | |
| width=args.width, | |
| num_frames=args.num_frames, | |
| frame_rate=args.frame_rate, | |
| images=args.images, | |
| tiling_config=tiling_config, | |
| enhance_prompt=args.enhance_prompt, | |
| ) | |
| encode_video( | |
| video=video, | |
| fps=args.frame_rate, | |
| audio=audio, | |
| output_path=args.output_path, | |
| video_chunks_number=video_chunks_number, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |
Xet Storage Details
- Size:
- 7.84 kB
- Xet hash:
- 7376ba9f287116f069c30d4bfa1d5272ada2303ef0700736f66a8a65a6f4f371
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.