jiuhai's picture
Add files using upload-large-folder tool
4e1e978 verified
Raw
History Blame Contribute Delete
8.21 kB
import logging
from collections.abc import Iterator
import torch
from ltx_core.components.diffusion_steps import EulerDiffusionStep
from ltx_core.components.noisers import GaussianNoiser
from ltx_core.components.protocols import DiffusionStepProtocol
from ltx_core.loader import LoraPathStrengthAndSDOps
from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
from ltx_core.model.upsampler import upsample_video
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
from ltx_core.model.video_vae import decode_video as vae_decode_video
from ltx_core.quantization import QuantizationPolicy
from ltx_core.types import Audio, LatentState, VideoPixelShape
from ltx_pipelines.utils import ModelLedger, euler_denoising_loop
from ltx_pipelines.utils.args import (
ImageConditioningInput,
default_2_stage_distilled_arg_parser,
detect_checkpoint_path,
)
from ltx_pipelines.utils.constants import (
DISTILLED_SIGMA_VALUES,
STAGE_2_DISTILLED_SIGMA_VALUES,
detect_params,
)
from ltx_pipelines.utils.helpers import (
assert_resolution,
cleanup_memory,
combined_image_conditionings,
denoise_audio_video,
encode_prompts,
get_device,
simple_denoising_func,
)
from ltx_pipelines.utils.media_io import encode_video
from ltx_pipelines.utils.types import PipelineComponents
device = get_device()
class DistilledPipeline:
"""
Two-stage distilled video generation pipeline.
Stage 1 generates video at half of the target resolution, then Stage 2 upsamples
by 2x and refines with additional denoising steps for higher quality output.
"""
def __init__(
self,
distilled_checkpoint_path: str,
gemma_root: str,
spatial_upsampler_path: str,
loras: list[LoraPathStrengthAndSDOps],
device: torch.device = device,
quantization: QuantizationPolicy | None = None,
):
self.device = device
self.dtype = torch.bfloat16
self.model_ledger = ModelLedger(
dtype=self.dtype,
device=device,
checkpoint_path=distilled_checkpoint_path,
spatial_upsampler_path=spatial_upsampler_path,
gemma_root_path=gemma_root,
loras=loras,
quantization=quantization,
)
self.pipeline_components = PipelineComponents(
dtype=self.dtype,
device=device,
)
def __call__(
self,
prompt: str,
seed: int,
height: int,
width: int,
num_frames: int,
frame_rate: float,
images: list[ImageConditioningInput],
tiling_config: TilingConfig | None = None,
enhance_prompt: bool = False,
) -> tuple[Iterator[torch.Tensor], Audio]:
assert_resolution(height=height, width=width, is_two_stage=True)
generator = torch.Generator(device=self.device).manual_seed(seed)
noiser = GaussianNoiser(generator=generator)
stepper = EulerDiffusionStep()
dtype = torch.bfloat16
(ctx_p,) = encode_prompts(
[prompt],
self.model_ledger,
enhance_first_prompt=enhance_prompt,
enhance_prompt_image=images[0][0] if len(images) > 0 else None,
)
video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
# Stage 1: Initial low resolution video generation.
video_encoder = self.model_ledger.video_encoder()
transformer = self.model_ledger.transformer()
stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
def denoising_loop(
sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
) -> tuple[LatentState, LatentState]:
return euler_denoising_loop(
sigmas=sigmas,
video_state=video_state,
audio_state=audio_state,
stepper=stepper,
denoise_fn=simple_denoising_func(
video_context=video_context,
audio_context=audio_context,
transformer=transformer, # noqa: F821
),
)
stage_1_output_shape = VideoPixelShape(
batch=1,
frames=num_frames,
width=width // 2,
height=height // 2,
fps=frame_rate,
)
stage_1_conditionings = combined_image_conditionings(
images=images,
height=stage_1_output_shape.height,
width=stage_1_output_shape.width,
video_encoder=video_encoder,
dtype=dtype,
device=self.device,
)
video_state, audio_state = denoise_audio_video(
output_shape=stage_1_output_shape,
conditionings=stage_1_conditionings,
noiser=noiser,
sigmas=stage_1_sigmas,
stepper=stepper,
denoising_loop_fn=denoising_loop,
components=self.pipeline_components,
dtype=dtype,
device=self.device,
)
# Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
upscaled_video_latent = upsample_video(
latent=video_state.latent[:1], video_encoder=video_encoder, upsampler=self.model_ledger.spatial_upsampler()
)
torch.cuda.synchronize()
cleanup_memory()
stage_2_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
stage_2_conditionings = combined_image_conditionings(
images=images,
height=stage_2_output_shape.height,
width=stage_2_output_shape.width,
video_encoder=video_encoder,
dtype=dtype,
device=self.device,
)
video_state, audio_state = denoise_audio_video(
output_shape=stage_2_output_shape,
conditionings=stage_2_conditionings,
noiser=noiser,
sigmas=stage_2_sigmas,
stepper=stepper,
denoising_loop_fn=denoising_loop,
components=self.pipeline_components,
dtype=dtype,
device=self.device,
noise_scale=stage_2_sigmas[0],
initial_video_latent=upscaled_video_latent,
initial_audio_latent=audio_state.latent,
)
torch.cuda.synchronize()
del transformer
del video_encoder
cleanup_memory()
decoded_video = vae_decode_video(
video_state.latent, self.model_ledger.video_decoder(), tiling_config, generator
)
decoded_audio = vae_decode_audio(
audio_state.latent, self.model_ledger.audio_decoder(), self.model_ledger.vocoder()
)
return decoded_video, decoded_audio
@torch.inference_mode()
def main() -> None:
logging.getLogger().setLevel(logging.INFO)
checkpoint_path = detect_checkpoint_path(distilled=True)
params = detect_params(checkpoint_path)
parser = default_2_stage_distilled_arg_parser(params=params)
args = parser.parse_args()
pipeline = DistilledPipeline(
distilled_checkpoint_path=args.distilled_checkpoint_path,
spatial_upsampler_path=args.spatial_upsampler_path,
gemma_root=args.gemma_root,
loras=tuple(args.lora) if args.lora else (),
quantization=args.quantization,
)
tiling_config = TilingConfig.default()
video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
video, audio = pipeline(
prompt=args.prompt,
seed=args.seed,
height=args.height,
width=args.width,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
images=args.images,
tiling_config=tiling_config,
enhance_prompt=args.enhance_prompt,
)
encode_video(
video=video,
fps=args.frame_rate,
audio=audio,
output_path=args.output_path,
video_chunks_number=video_chunks_number,
)
if __name__ == "__main__":
main()