Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -123,6 +123,9 @@ class LTX23DistilledA2VPipeline:
|
|
| 123 |
def __init__(
|
| 124 |
self,
|
| 125 |
distilled_checkpoint_path: str,
|
|
|
|
|
|
|
|
|
|
| 126 |
spatial_upsampler_path: str,
|
| 127 |
gemma_root: str,
|
| 128 |
loras: tuple,
|
|
@@ -134,16 +137,31 @@ class LTX23DistilledA2VPipeline:
|
|
| 134 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 135 |
self.dtype = torch.bfloat16
|
| 136 |
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
dtype=self.dtype,
|
| 139 |
-
device=
|
| 140 |
-
checkpoint_path=
|
| 141 |
gemma_root_path=gemma_root,
|
| 142 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 143 |
-
loras=loras,
|
| 144 |
quantization=quantization,
|
| 145 |
)
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
self.pipeline_components = PipelineComponents(
|
| 148 |
dtype=self.dtype,
|
| 149 |
device=self.device,
|
|
@@ -209,8 +227,8 @@ class LTX23DistilledA2VPipeline:
|
|
| 209 |
)
|
| 210 |
encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
|
| 211 |
|
| 212 |
-
video_encoder = self.
|
| 213 |
-
transformer = self.
|
| 214 |
|
| 215 |
# Stage 1: Generate sigmas using LTX2Scheduler with user-specified steps
|
| 216 |
empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(
|
|
@@ -292,10 +310,11 @@ class LTX23DistilledA2VPipeline:
|
|
| 292 |
cleanup_memory()
|
| 293 |
|
| 294 |
# ββ Upscaling ββ
|
|
|
|
| 295 |
upscaled_video_latent = upsample_video(
|
| 296 |
latent=video_state.latent[:1],
|
| 297 |
video_encoder=video_encoder,
|
| 298 |
-
upsampler=self.
|
| 299 |
)
|
| 300 |
|
| 301 |
# ββ Stage 2: Full resolution ββ
|
|
@@ -309,6 +328,9 @@ class LTX23DistilledA2VPipeline:
|
|
| 309 |
dtype=dtype,
|
| 310 |
device=self.device,
|
| 311 |
)
|
|
|
|
|
|
|
|
|
|
| 312 |
video_state, audio_state = denoise_audio_video(
|
| 313 |
output_shape=stage_2_output_shape,
|
| 314 |
conditionings=stage_2_conditionings,
|
|
@@ -330,21 +352,22 @@ class LTX23DistilledA2VPipeline:
|
|
| 330 |
# ββ Decode both video and audio ββ
|
| 331 |
decoded_video = vae_decode_video(
|
| 332 |
video_state.latent,
|
| 333 |
-
self.
|
| 334 |
tiling_config,
|
| 335 |
generator,
|
| 336 |
)
|
| 337 |
decoded_audio_output = vae_decode_audio(
|
| 338 |
audio_state.latent,
|
| 339 |
-
self.
|
| 340 |
-
self.
|
| 341 |
)
|
| 342 |
|
| 343 |
return decoded_video, decoded_audio_output
|
| 344 |
|
| 345 |
# Model repos
|
| 346 |
-
LTX_MODEL_REPO = "
|
| 347 |
GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
|
|
|
| 348 |
|
| 349 |
# Download model checkpoints
|
| 350 |
print("=" * 80)
|
|
@@ -364,10 +387,11 @@ weights_dir = Path("weights")
|
|
| 364 |
weights_dir.mkdir(exist_ok=True)
|
| 365 |
checkpoint_path = hf_hub_download(
|
| 366 |
repo_id=LTX_MODEL_REPO,
|
| 367 |
-
filename="
|
| 368 |
local_dir=str(weights_dir),
|
| 369 |
local_dir_use_symlinks=False,
|
| 370 |
)
|
|
|
|
| 371 |
spatial_upsampler_path = hf_hub_download(repo_id="Lightricks/LTX-2.3", filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
|
| 372 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
| 373 |
|
|
@@ -415,6 +439,9 @@ print(f"[Gemma] Root ready: {gemma_root}")
|
|
| 415 |
|
| 416 |
pipeline = LTX23DistilledA2VPipeline(
|
| 417 |
distilled_checkpoint_path=checkpoint_path,
|
|
|
|
|
|
|
|
|
|
| 418 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 419 |
gemma_root=gemma_root,
|
| 420 |
loras=[],
|
|
@@ -589,22 +616,31 @@ def apply_prepared_lora_state_to_pipeline():
|
|
| 589 |
|
| 590 |
# Preload all models for ZeroGPU tensor packing.
|
| 591 |
print("Preloading all models (including Gemma and audio components)...")
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
#
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
_video_encoder = _orig_video_encoder_factory()
|
| 609 |
_video_decoder = _orig_video_decoder_factory()
|
| 610 |
_audio_encoder = _orig_audio_encoder_factory()
|
|
@@ -615,18 +651,30 @@ _text_encoder = _orig_text_encoder_factory()
|
|
| 615 |
_embeddings_processor = _orig_gemma_embeddings_factory()
|
| 616 |
|
| 617 |
# Replace ledger methods with lightweight lambdas that return the cached instances.
|
|
|
|
| 618 |
# We keep the original factories above so we can call them later to rebuild components.
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
|
| 629 |
print("All models preloaded (including Gemma text encoder and audio encoder)!")
|
|
|
|
|
|
|
| 630 |
# ---- REPLACE PRELOAD BLOCK END ----
|
| 631 |
|
| 632 |
print("=" * 80)
|
|
|
|
| 123 |
def __init__(
|
| 124 |
self,
|
| 125 |
distilled_checkpoint_path: str,
|
| 126 |
+
istilled_lora_path: str,
|
| 127 |
+
distilled_lora_strength_stage_1: float,
|
| 128 |
+
distilled_lora_strength_stage_2: float,
|
| 129 |
spatial_upsampler_path: str,
|
| 130 |
gemma_root: str,
|
| 131 |
loras: tuple,
|
|
|
|
| 137 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 138 |
self.dtype = torch.bfloat16
|
| 139 |
|
| 140 |
+
distilled_lora_stage_1 = LoraPathStrengthAndSDOps(
|
| 141 |
+
path=distilled_lora_path,
|
| 142 |
+
strength=distilled_lora_strength_stage_1,
|
| 143 |
+
sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
|
| 144 |
+
)
|
| 145 |
+
distilled_lora_stage_2 = LoraPathStrengthAndSDOps(
|
| 146 |
+
path=distilled_lora_path,
|
| 147 |
+
strength=distilled_lora_strength_stage_2,
|
| 148 |
+
sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
self.stage_1_model_ledger = ModelLedger(
|
| 152 |
dtype=self.dtype,
|
| 153 |
+
device=device,
|
| 154 |
+
checkpoint_path=checkpoint_path,
|
| 155 |
gemma_root_path=gemma_root,
|
| 156 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 157 |
+
loras=(*loras, distilled_lora_stage_1),
|
| 158 |
quantization=quantization,
|
| 159 |
)
|
| 160 |
|
| 161 |
+
self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras(
|
| 162 |
+
loras=(*loras, distilled_lora_stage_2),
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
self.pipeline_components = PipelineComponents(
|
| 166 |
dtype=self.dtype,
|
| 167 |
device=self.device,
|
|
|
|
| 227 |
)
|
| 228 |
encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
|
| 229 |
|
| 230 |
+
video_encoder = self.stage_1_model_ledger.video_encoder()
|
| 231 |
+
transformer = self.stage_1_model_ledger.transformer()
|
| 232 |
|
| 233 |
# Stage 1: Generate sigmas using LTX2Scheduler with user-specified steps
|
| 234 |
empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(
|
|
|
|
| 310 |
cleanup_memory()
|
| 311 |
|
| 312 |
# ββ Upscaling ββ
|
| 313 |
+
video_encoder = self.stage_1_model_ledger.video_encoder()
|
| 314 |
upscaled_video_latent = upsample_video(
|
| 315 |
latent=video_state.latent[:1],
|
| 316 |
video_encoder=video_encoder,
|
| 317 |
+
upsampler=self.stage_2_model_ledger.spatial_upsampler(),
|
| 318 |
)
|
| 319 |
|
| 320 |
# ββ Stage 2: Full resolution ββ
|
|
|
|
| 328 |
dtype=dtype,
|
| 329 |
device=self.device,
|
| 330 |
)
|
| 331 |
+
|
| 332 |
+
transformer = self.stage_2_model_ledger.transformer()
|
| 333 |
+
|
| 334 |
video_state, audio_state = denoise_audio_video(
|
| 335 |
output_shape=stage_2_output_shape,
|
| 336 |
conditionings=stage_2_conditionings,
|
|
|
|
| 352 |
# ββ Decode both video and audio ββ
|
| 353 |
decoded_video = vae_decode_video(
|
| 354 |
video_state.latent,
|
| 355 |
+
self.stage_2_model_ledger.video_decoder(),
|
| 356 |
tiling_config,
|
| 357 |
generator,
|
| 358 |
)
|
| 359 |
decoded_audio_output = vae_decode_audio(
|
| 360 |
audio_state.latent,
|
| 361 |
+
self.stage_2_model_ledger.audio_decoder(),
|
| 362 |
+
self.stage_2_model_ledger.vocoder(),
|
| 363 |
)
|
| 364 |
|
| 365 |
return decoded_video, decoded_audio_output
|
| 366 |
|
| 367 |
# Model repos
|
| 368 |
+
LTX_MODEL_REPO = "TenStrip/LTX2.3-10Eros"
|
| 369 |
GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 370 |
+
DISTILLED_LORA_REPO = "TenStrip/LTX2.3_Distilled_Lora_1.1_Experiments"
|
| 371 |
|
| 372 |
# Download model checkpoints
|
| 373 |
print("=" * 80)
|
|
|
|
| 387 |
weights_dir.mkdir(exist_ok=True)
|
| 388 |
checkpoint_path = hf_hub_download(
|
| 389 |
repo_id=LTX_MODEL_REPO,
|
| 390 |
+
filename="10Eros_v1_bf16.safetensors",
|
| 391 |
local_dir=str(weights_dir),
|
| 392 |
local_dir_use_symlinks=False,
|
| 393 |
)
|
| 394 |
+
distilled_lora_path = hf_hub_download(repo_id=DISTILLED_LORA_REPO, filename="ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors")
|
| 395 |
spatial_upsampler_path = hf_hub_download(repo_id="Lightricks/LTX-2.3", filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
|
| 396 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
| 397 |
|
|
|
|
| 439 |
|
| 440 |
pipeline = LTX23DistilledA2VPipeline(
|
| 441 |
distilled_checkpoint_path=checkpoint_path,
|
| 442 |
+
distilled_lora=distilled_lora_path,
|
| 443 |
+
distilled_lora_strength_stage_1=1.0,
|
| 444 |
+
distilled_lora_strength_stage_2=0.5,
|
| 445 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 446 |
gemma_root=gemma_root,
|
| 447 |
loras=[],
|
|
|
|
| 616 |
|
| 617 |
# Preload all models for ZeroGPU tensor packing.
|
| 618 |
print("Preloading all models (including Gemma and audio components)...")
|
| 619 |
+
|
| 620 |
+
# We now have two ledgers β stage 1 (distilled LoRA @ 1.0) and stage 2 (distilled LoRA @ 0.5).
|
| 621 |
+
# Both share the same dev checkpoint and spatial upsampler; only the transformer differs.
|
| 622 |
+
ledger_s1 = pipeline.stage_1_model_ledger
|
| 623 |
+
ledger_s2 = pipeline.stage_2_model_ledger
|
| 624 |
+
|
| 625 |
+
# Save the original factory methods from BOTH ledgers so we can rebuild individual components later.
|
| 626 |
+
_orig_transformer_factory_s1 = ledger_s1.transformer
|
| 627 |
+
_orig_transformer_factory_s2 = ledger_s2.transformer
|
| 628 |
+
_orig_video_encoder_factory = ledger_s1.video_encoder
|
| 629 |
+
_orig_video_decoder_factory = ledger_s1.video_decoder
|
| 630 |
+
_orig_audio_encoder_factory = ledger_s1.audio_encoder
|
| 631 |
+
_orig_audio_decoder_factory = ledger_s1.audio_decoder
|
| 632 |
+
_orig_vocoder_factory = ledger_s1.vocoder
|
| 633 |
+
_orig_spatial_upsampler_factory = ledger_s1.spatial_upsampler
|
| 634 |
+
_orig_text_encoder_factory = ledger_s1.text_encoder
|
| 635 |
+
_orig_gemma_embeddings_factory = ledger_s1.gemma_embeddings_processor
|
| 636 |
+
|
| 637 |
+
# Call the factories to create cached instances.
|
| 638 |
+
# Stage 1 transformer: dev checkpoint + distilled LoRA @ strength 1.0 (baked in at build time)
|
| 639 |
+
_transformer_s1 = _orig_transformer_factory_s1()
|
| 640 |
+
# Stage 2 transformer: dev checkpoint + distilled LoRA @ strength 0.5 (baked in at build time)
|
| 641 |
+
_transformer_s2 = _orig_transformer_factory_s2()
|
| 642 |
+
|
| 643 |
+
# Shared components β only need one copy since both ledgers use the same VAE/encoder paths.
|
| 644 |
_video_encoder = _orig_video_encoder_factory()
|
| 645 |
_video_decoder = _orig_video_decoder_factory()
|
| 646 |
_audio_encoder = _orig_audio_encoder_factory()
|
|
|
|
| 651 |
_embeddings_processor = _orig_gemma_embeddings_factory()
|
| 652 |
|
| 653 |
# Replace ledger methods with lightweight lambdas that return the cached instances.
|
| 654 |
+
# Both ledgers point to the same shared model instances (except transformer).
|
| 655 |
# We keep the original factories above so we can call them later to rebuild components.
|
| 656 |
+
ledger_s1.transformer = lambda: _transformer_s1
|
| 657 |
+
ledger_s2.transformer = lambda: _transformer_s2
|
| 658 |
+
ledger_s1.video_encoder = lambda: _video_encoder
|
| 659 |
+
ledger_s2.video_encoder = lambda: _video_encoder
|
| 660 |
+
ledger_s1.video_decoder = lambda: _video_decoder
|
| 661 |
+
ledger_s2.video_decoder = lambda: _video_decoder
|
| 662 |
+
ledger_s1.audio_encoder = lambda: _audio_encoder
|
| 663 |
+
ledger_s2.audio_encoder = lambda: _audio_encoder
|
| 664 |
+
ledger_s1.audio_decoder = lambda: _audio_decoder
|
| 665 |
+
ledger_s2.audio_decoder = lambda: _audio_decoder
|
| 666 |
+
ledger_s1.vocoder = lambda: _vocoder
|
| 667 |
+
ledger_s2.vocoder = lambda: _vocoder
|
| 668 |
+
ledger_s1.spatial_upsampler = lambda: _spatial_upsampler
|
| 669 |
+
ledger_s2.spatial_upsampler = lambda: _spatial_upsampler
|
| 670 |
+
ledger_s1.text_encoder = lambda: _text_encoder
|
| 671 |
+
ledger_s2.text_encoder = lambda: _text_encoder
|
| 672 |
+
ledger_s1.gemma_embeddings_processor = lambda: _embeddings_processor
|
| 673 |
+
ledger_s2.gemma_embeddings_processor = lambda: _embeddings_processor
|
| 674 |
|
| 675 |
print("All models preloaded (including Gemma text encoder and audio encoder)!")
|
| 676 |
+
print(f" Stage 1 transformer: {_transformer_s1.__class__.__name__} (distilled LoRA @ 1.0 baked)")
|
| 677 |
+
print(f" Stage 2 transformer: {_transformer_s2.__class__.__name__} (distilled LoRA @ 0.5 baked)")
|
| 678 |
# ---- REPLACE PRELOAD BLOCK END ----
|
| 679 |
|
| 680 |
print("=" * 80)
|