Instructions to use Shriramnag/ShivAI-Image-to-Video with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use Shriramnag/ShivAI-Image-to-Video with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("Shriramnag/ShivAI-Image-to-Video", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
File size: 18,469 Bytes
a69cd16 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 | import logging
from collections.abc import Iterator
import torch
from ltx_core.components.noisers import GaussianNoiser
from ltx_core.conditioning import ConditioningItem
from ltx_core.loader import LoraPathStrengthAndSDOps
from ltx_core.loader.registry import Registry
from ltx_core.model.video_vae import TilingConfig, VideoEncoder, get_video_chunks_number
from ltx_core.quantization import QuantizationPolicy
from ltx_core.types import Audio, VideoPixelShape
from ltx_pipelines.iclora_utils import (
append_ic_lora_reference_video_conditionings,
read_lora_reference_downscale_factor,
)
from ltx_pipelines.utils.args import (
ImageConditioningInput,
VideoConditioningAction,
VideoMaskConditioningAction,
default_2_stage_distilled_arg_parser,
detect_checkpoint_path,
)
from ltx_pipelines.utils.blocks import (
AudioDecoder,
DiffusionStage,
ImageConditioner,
PromptEncoder,
VideoDecoder,
VideoUpsampler,
)
from ltx_pipelines.utils.constants import (
DISTILLED_SIGMAS,
STAGE_2_DISTILLED_SIGMAS,
detect_params,
)
from ltx_pipelines.utils.denoisers import SimpleDenoiser
from ltx_pipelines.utils.helpers import assert_resolution, combined_image_conditionings, get_device
from ltx_pipelines.utils.media_io import decode_video_by_frame, encode_video, video_preprocess
from ltx_pipelines.utils.types import ModalitySpec, OffloadMode
class ICLoraPipeline:
"""
Two-stage video generation pipeline with In-Context (IC) LoRA support.
Allows conditioning the generated video on control signals such as depth maps,
human pose, or image edges via the video_conditioning parameter.
The specific IC-LoRA model should be provided via the loras parameter.
Stage 1 generates video at half of the target resolution, then Stage 2 upsamples
by 2x and refines with additional denoising steps for higher quality output.
Both stages use distilled models for efficiency.
"""
def __init__(
self,
distilled_checkpoint_path: str,
spatial_upsampler_path: str,
gemma_root: str,
loras: list[LoraPathStrengthAndSDOps],
device: torch.device | None = None,
quantization: QuantizationPolicy | None = None,
registry: Registry | None = None,
torch_compile: bool = False,
offload_mode: OffloadMode = OffloadMode.NONE,
):
self.device = device or get_device()
self.dtype = torch.bfloat16
self.prompt_encoder = PromptEncoder(
distilled_checkpoint_path,
gemma_root,
self.dtype,
self.device,
registry=registry,
offload_mode=offload_mode,
)
self.image_conditioner = ImageConditioner(distilled_checkpoint_path, self.dtype, self.device, registry=registry)
self.stage_1 = DiffusionStage(
distilled_checkpoint_path,
self.dtype,
self.device,
loras=tuple(loras),
quantization=quantization,
registry=registry,
torch_compile=torch_compile,
offload_mode=offload_mode,
)
self.stage_2 = DiffusionStage(
distilled_checkpoint_path,
self.dtype,
self.device,
loras=(),
quantization=quantization,
registry=registry,
torch_compile=torch_compile,
offload_mode=offload_mode,
)
self.upsampler = VideoUpsampler(
distilled_checkpoint_path, spatial_upsampler_path, self.dtype, self.device, registry=registry
)
self.video_decoder = VideoDecoder(distilled_checkpoint_path, self.dtype, self.device, registry=registry)
self.audio_decoder = AudioDecoder(distilled_checkpoint_path, self.dtype, self.device, registry=registry)
# Read reference downscale factor from LoRA metadata.
# IC-LoRAs trained with low-resolution reference videos store this factor
# so inference can resize reference videos to match training conditions.
self.reference_downscale_factor = 1
for lora in loras:
scale = read_lora_reference_downscale_factor(lora.path)
if scale != 1:
if self.reference_downscale_factor not in (1, scale):
raise ValueError(
f"Conflicting reference_downscale_factor values in LoRAs: "
f"already have {self.reference_downscale_factor}, but {lora.path} "
f"specifies {scale}. Cannot combine LoRAs with different reference scales."
)
self.reference_downscale_factor = scale
def __call__( # noqa: PLR0913
self,
prompt: str,
seed: int,
height: int,
width: int,
num_frames: int,
frame_rate: float,
images: list[ImageConditioningInput],
video_conditioning: list[tuple[str, float]],
enhance_prompt: bool = False,
tiling_config: TilingConfig | None = None,
conditioning_attention_strength: float = 1.0,
skip_stage_2: bool = False,
conditioning_attention_mask: torch.Tensor | None = None,
stage_1_sigmas: torch.Tensor = DISTILLED_SIGMAS,
stage_2_sigmas: torch.Tensor = STAGE_2_DISTILLED_SIGMAS,
) -> tuple[Iterator[torch.Tensor], Audio]:
"""
Generate video with IC-LoRA conditioning.
Args:
prompt: Text prompt for video generation.
seed: Random seed for reproducibility.
height: Output video height in pixels (must be divisible by 64).
width: Output video width in pixels (must be divisible by 64).
num_frames: Number of frames to generate.
frame_rate: Output video frame rate.
images: List of (path, frame_idx, strength) tuples for image conditioning.
video_conditioning: List of (path, strength) tuples for IC-LoRA video conditioning.
enhance_prompt: Whether to enhance the prompt using the text encoder.
tiling_config: Optional tiling configuration for VAE decoding.
conditioning_attention_strength: Scale factor for IC-LoRA conditioning attention.
Controls how strongly the conditioning video influences the output.
0.0 = ignore conditioning, 1.0 = full conditioning influence. Default 1.0.
When conditioning_attention_mask is provided, the mask is multiplied by
this strength before being passed to the conditioning items.
skip_stage_2: If True, skip Stage 2 upsampling and refinement. Output will be
at half resolution (height//2, width//2). Default is False.
conditioning_attention_mask: Optional pixel-space attention mask with the same
spatial-temporal dimensions as the input reference video. Shape should be
(B, 1, F, H, W) or (1, 1, F, H, W) where F, H, W match the reference
video's pixel dimensions. Values in [0, 1].
The mask is downsampled to latent space using VAE scale factors (with
causal temporal handling for the first frame), then multiplied by
conditioning_attention_strength.
When None (default): scalar conditioning_attention_strength is used
directly.
Returns:
Tuple of (video_iterator, audio_tensor).
"""
assert_resolution(height=height, width=width, is_two_stage=True)
if not (0.0 <= conditioning_attention_strength <= 1.0):
raise ValueError(
f"conditioning_attention_strength must be in [0.0, 1.0], got {conditioning_attention_strength}"
)
generator = torch.Generator(device=self.device).manual_seed(seed)
noiser = GaussianNoiser(generator=generator)
(ctx_p,) = self.prompt_encoder(
[prompt],
enhance_first_prompt=enhance_prompt,
enhance_prompt_image=images[0][0] if len(images) > 0 else None,
enhance_prompt_seed=seed,
)
video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
# Stage 1: Initial low resolution video generation.
stage_1_output_shape = VideoPixelShape(
batch=1,
frames=num_frames,
width=width // 2,
height=height // 2,
fps=frame_rate,
)
# Encode conditionings using the video encoder block
stage_1_conditionings = self.image_conditioner(
lambda enc: self._create_conditionings(
images=images,
video_conditioning=video_conditioning,
height=stage_1_output_shape.height,
width=stage_1_output_shape.width,
video_encoder=enc,
num_frames=num_frames,
conditioning_attention_strength=conditioning_attention_strength,
conditioning_attention_mask=conditioning_attention_mask,
)
)
stage_1_sigmas = stage_1_sigmas.to(dtype=torch.float32, device=self.device)
video_state, audio_state = self.stage_1(
denoiser=SimpleDenoiser(video_context, audio_context),
sigmas=stage_1_sigmas,
noiser=noiser,
width=stage_1_output_shape.width,
height=stage_1_output_shape.height,
frames=num_frames,
fps=frame_rate,
video=ModalitySpec(
context=video_context,
conditionings=stage_1_conditionings,
),
audio=ModalitySpec(
context=audio_context,
),
)
if skip_stage_2:
# Skip Stage 2: Decode directly from Stage 1 output at half resolution
logging.info("[IC-LoRA] Skipping Stage 2 (--skip-stage-2 enabled)")
decoded_video = self.video_decoder(video_state.latent, tiling_config, generator)
decoded_audio = self.audio_decoder(audio_state.latent)
return decoded_video, decoded_audio
# Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
upscaled_video_latent = self.upsampler(video_state.latent[:1])
stage_2_sigmas = stage_2_sigmas.to(dtype=torch.float32, device=self.device)
stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
stage_2_conditionings = self.image_conditioner(
lambda enc: combined_image_conditionings(
images=images,
height=stage_2_output_shape.height,
width=stage_2_output_shape.width,
video_encoder=enc,
dtype=self.dtype,
device=self.device,
)
)
video_state, audio_state = self.stage_2(
denoiser=SimpleDenoiser(video_context, audio_context),
sigmas=stage_2_sigmas,
noiser=noiser,
width=width,
height=height,
frames=num_frames,
fps=frame_rate,
video=ModalitySpec(
context=video_context,
conditionings=stage_2_conditionings,
noise_scale=stage_2_sigmas[0].item(),
initial_latent=upscaled_video_latent,
),
audio=ModalitySpec(
context=audio_context,
noise_scale=stage_2_sigmas[0].item(),
initial_latent=audio_state.latent,
),
)
decoded_video = self.video_decoder(video_state.latent, tiling_config, generator)
decoded_audio = self.audio_decoder(audio_state.latent)
return decoded_video, decoded_audio
def _create_conditionings(
self,
images: list[ImageConditioningInput],
video_conditioning: list[tuple[str, float]],
height: int,
width: int,
num_frames: int,
video_encoder: VideoEncoder,
conditioning_attention_strength: float = 1.0,
conditioning_attention_mask: torch.Tensor | None = None,
) -> list[ConditioningItem]:
"""
Create conditioning items for video generation.
Args:
conditioning_attention_strength: Scalar attention weight in [0, 1].
If conditioning_attention_mask is also provided, the downsampled mask
is multiplied by this strength. Otherwise this scalar is passed
directly as the attention mask.
conditioning_attention_mask: Optional pixel-space attention mask with shape
(B, 1, F_pixel, H_pixel, W_pixel) matching the reference video's
pixel dimensions. Downsampled to latent space with causal temporal
handling, then multiplied by conditioning_attention_strength.
Returns:
List of conditioning items. IC-LoRA conditionings are appended last.
"""
conditionings = combined_image_conditionings(
images=images,
height=height,
width=width,
video_encoder=video_encoder,
dtype=self.dtype,
device=self.device,
)
append_ic_lora_reference_video_conditionings(
conditionings,
video_conditioning,
height=height,
width=width,
num_frames=num_frames,
video_encoder=video_encoder,
dtype=self.dtype,
device=self.device,
reference_downscale_factor=self.reference_downscale_factor,
conditioning_attention_strength=conditioning_attention_strength,
conditioning_attention_mask=conditioning_attention_mask,
tiling_config=None,
)
if video_conditioning:
logging.info("[IC-LoRA] Added %d video conditioning(s)", len(video_conditioning))
return conditionings
@torch.inference_mode()
def main() -> None:
logging.getLogger().setLevel(logging.INFO)
checkpoint_path = detect_checkpoint_path(distilled=True)
params = detect_params(checkpoint_path)
parser = default_2_stage_distilled_arg_parser(params=params)
parser.add_argument(
"--video-conditioning",
action=VideoConditioningAction,
nargs=2,
metavar=("PATH", "STRENGTH"),
required=True,
)
parser.add_argument(
"--conditioning-attention-mask",
action=VideoMaskConditioningAction,
nargs=2,
metavar=("MASK_PATH", "STRENGTH"),
default=None,
help=(
"Optional spatial attention mask: path to a grayscale mask video and "
"attention strength. The mask video pixel values in [0,1] control "
"per-region conditioning attention strength. The strength scalar is "
"multiplied with the spatial mask. "
"0.0 = ignore IC-LoRA conditioning, 1.0 = full conditioning influence. "
"When not provided, full conditioning strength (1.0) is used. "
"Example: --conditioning-attention-mask path/to/mask.mp4 0.5"
),
)
parser.add_argument(
"--skip-stage-2",
action="store_true",
help=(
"Skip Stage 2 upsampling and refinement. Output will be at half resolution "
"(height//2, width//2). Useful for faster iteration or when GPU memory is limited."
),
)
args = parser.parse_args()
# Load mask video if provided via --conditioning-attention-mask
conditioning_attention_mask = None
conditioning_attention_strength = 1.0
if args.conditioning_attention_mask is not None:
mask_path, mask_strength = args.conditioning_attention_mask
conditioning_attention_strength = mask_strength
conditioning_attention_mask = _load_mask_video(
mask_path=mask_path,
height=args.height // 2, # Stage 1 operates at half resolution
width=args.width // 2,
num_frames=args.num_frames,
)
pipeline = ICLoraPipeline(
distilled_checkpoint_path=args.distilled_checkpoint_path,
spatial_upsampler_path=args.spatial_upsampler_path,
gemma_root=args.gemma_root,
loras=tuple(args.lora) if args.lora else (),
quantization=args.quantization,
torch_compile=args.compile,
offload_mode=args.offload_mode,
)
tiling_config = TilingConfig.default()
video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
video, audio = pipeline(
prompt=args.prompt,
seed=args.seed,
height=args.height,
width=args.width,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
images=args.images,
video_conditioning=args.video_conditioning,
tiling_config=tiling_config,
conditioning_attention_strength=conditioning_attention_strength,
skip_stage_2=args.skip_stage_2,
conditioning_attention_mask=conditioning_attention_mask,
)
encode_video(
video=video,
fps=args.frame_rate,
audio=audio,
output_path=args.output_path,
video_chunks_number=video_chunks_number,
)
def _load_mask_video(
mask_path: str,
height: int,
width: int,
num_frames: int,
) -> torch.Tensor:
"""Load a mask video and return a pixel-space tensor of shape (1, 1, F, H, W).
The mask video is loaded, resized to (height, width), converted to
grayscale, and normalised to [0, 1].
Args:
mask_path: Path to the mask video file.
height: Target height in pixels.
width: Target width in pixels.
num_frames: Maximum number of frames to load.
Returns:
Tensor of shape ``(1, 1, F, H, W)`` with values in ``[0, 1]``.
"""
device = get_device()
frame_gen = decode_video_by_frame(path=mask_path, frame_cap=num_frames, device=device)
mask_video = video_preprocess(frame_gen, height, width, torch.bfloat16, device)
# mask_video shape: (1, C, F, H, W) — take mean over channels for grayscale
mask = mask_video.mean(dim=1, keepdim=True) # (1, 1, F, H, W)
# Normalise to [0, 1] — video_preprocess applies normalize_latent,
# so undo that: values are in [-1, 1], remap to [0, 1]
mask = (mask + 1.0) / 2.0
return mask.clamp(0.0, 1.0)
if __name__ == "__main__":
main()
|