dagloop5 commited on
Commit
8614379
Β·
verified Β·
1 Parent(s): 2327914

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -37
app.py CHANGED
@@ -123,6 +123,9 @@ class LTX23DistilledA2VPipeline:
123
  def __init__(
124
  self,
125
  distilled_checkpoint_path: str,
 
 
 
126
  spatial_upsampler_path: str,
127
  gemma_root: str,
128
  loras: tuple,
@@ -134,16 +137,31 @@ class LTX23DistilledA2VPipeline:
134
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
135
  self.dtype = torch.bfloat16
136
 
137
- self.model_ledger = ModelLedger(
 
 
 
 
 
 
 
 
 
 
 
138
  dtype=self.dtype,
139
- device=self.device,
140
- checkpoint_path=distilled_checkpoint_path,
141
  gemma_root_path=gemma_root,
142
  spatial_upsampler_path=spatial_upsampler_path,
143
- loras=loras,
144
  quantization=quantization,
145
  )
146
 
 
 
 
 
147
  self.pipeline_components = PipelineComponents(
148
  dtype=self.dtype,
149
  device=self.device,
@@ -209,8 +227,8 @@ class LTX23DistilledA2VPipeline:
209
  )
210
  encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
211
 
212
- video_encoder = self.model_ledger.video_encoder()
213
- transformer = self.model_ledger.transformer()
214
 
215
  # Stage 1: Generate sigmas using LTX2Scheduler with user-specified steps
216
  empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(
@@ -292,10 +310,11 @@ class LTX23DistilledA2VPipeline:
292
  cleanup_memory()
293
 
294
  # ── Upscaling ──
 
295
  upscaled_video_latent = upsample_video(
296
  latent=video_state.latent[:1],
297
  video_encoder=video_encoder,
298
- upsampler=self.model_ledger.spatial_upsampler(),
299
  )
300
 
301
  # ── Stage 2: Full resolution ──
@@ -309,6 +328,9 @@ class LTX23DistilledA2VPipeline:
309
  dtype=dtype,
310
  device=self.device,
311
  )
 
 
 
312
  video_state, audio_state = denoise_audio_video(
313
  output_shape=stage_2_output_shape,
314
  conditionings=stage_2_conditionings,
@@ -330,21 +352,22 @@ class LTX23DistilledA2VPipeline:
330
  # ── Decode both video and audio ──
331
  decoded_video = vae_decode_video(
332
  video_state.latent,
333
- self.model_ledger.video_decoder(),
334
  tiling_config,
335
  generator,
336
  )
337
  decoded_audio_output = vae_decode_audio(
338
  audio_state.latent,
339
- self.model_ledger.audio_decoder(),
340
- self.model_ledger.vocoder(),
341
  )
342
 
343
  return decoded_video, decoded_audio_output
344
 
345
  # Model repos
346
- LTX_MODEL_REPO = "SulphurAI/Sulphur-2-base"
347
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
 
348
 
349
  # Download model checkpoints
350
  print("=" * 80)
@@ -364,10 +387,11 @@ weights_dir = Path("weights")
364
  weights_dir.mkdir(exist_ok=True)
365
  checkpoint_path = hf_hub_download(
366
  repo_id=LTX_MODEL_REPO,
367
- filename="sulphur_distil_bf16.safetensors",
368
  local_dir=str(weights_dir),
369
  local_dir_use_symlinks=False,
370
  )
 
371
  spatial_upsampler_path = hf_hub_download(repo_id="Lightricks/LTX-2.3", filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
372
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
373
 
@@ -415,6 +439,9 @@ print(f"[Gemma] Root ready: {gemma_root}")
415
 
416
  pipeline = LTX23DistilledA2VPipeline(
417
  distilled_checkpoint_path=checkpoint_path,
 
 
 
418
  spatial_upsampler_path=spatial_upsampler_path,
419
  gemma_root=gemma_root,
420
  loras=[],
@@ -589,22 +616,31 @@ def apply_prepared_lora_state_to_pipeline():
589
 
590
  # Preload all models for ZeroGPU tensor packing.
591
  print("Preloading all models (including Gemma and audio components)...")
592
- ledger = pipeline.model_ledger
593
-
594
- # Save the original factory methods so we can rebuild individual components later.
595
- # These are bound callables on ledger that will call the builder when invoked.
596
- _orig_transformer_factory = ledger.transformer
597
- _orig_video_encoder_factory = ledger.video_encoder
598
- _orig_video_decoder_factory = ledger.video_decoder
599
- _orig_audio_encoder_factory = ledger.audio_encoder
600
- _orig_audio_decoder_factory = ledger.audio_decoder
601
- _orig_vocoder_factory = ledger.vocoder
602
- _orig_spatial_upsampler_factory = ledger.spatial_upsampler
603
- _orig_text_encoder_factory = ledger.text_encoder
604
- _orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
605
-
606
- # Call the original factories once to create the cached instances we will serve by default.
607
- _transformer = _orig_transformer_factory()
 
 
 
 
 
 
 
 
 
608
  _video_encoder = _orig_video_encoder_factory()
609
  _video_decoder = _orig_video_decoder_factory()
610
  _audio_encoder = _orig_audio_encoder_factory()
@@ -615,18 +651,30 @@ _text_encoder = _orig_text_encoder_factory()
615
  _embeddings_processor = _orig_gemma_embeddings_factory()
616
 
617
  # Replace ledger methods with lightweight lambdas that return the cached instances.
 
618
  # We keep the original factories above so we can call them later to rebuild components.
619
- ledger.transformer = lambda: _transformer
620
- ledger.video_encoder = lambda: _video_encoder
621
- ledger.video_decoder = lambda: _video_decoder
622
- ledger.audio_encoder = lambda: _audio_encoder
623
- ledger.audio_decoder = lambda: _audio_decoder
624
- ledger.vocoder = lambda: _vocoder
625
- ledger.spatial_upsampler = lambda: _spatial_upsampler
626
- ledger.text_encoder = lambda: _text_encoder
627
- ledger.gemma_embeddings_processor = lambda: _embeddings_processor
 
 
 
 
 
 
 
 
 
628
 
629
  print("All models preloaded (including Gemma text encoder and audio encoder)!")
 
 
630
  # ---- REPLACE PRELOAD BLOCK END ----
631
 
632
  print("=" * 80)
 
123
  def __init__(
124
  self,
125
  distilled_checkpoint_path: str,
126
+ istilled_lora_path: str,
127
+ distilled_lora_strength_stage_1: float,
128
+ distilled_lora_strength_stage_2: float,
129
  spatial_upsampler_path: str,
130
  gemma_root: str,
131
  loras: tuple,
 
137
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
138
  self.dtype = torch.bfloat16
139
 
140
+ distilled_lora_stage_1 = LoraPathStrengthAndSDOps(
141
+ path=distilled_lora_path,
142
+ strength=distilled_lora_strength_stage_1,
143
+ sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
144
+ )
145
+ distilled_lora_stage_2 = LoraPathStrengthAndSDOps(
146
+ path=distilled_lora_path,
147
+ strength=distilled_lora_strength_stage_2,
148
+ sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
149
+ )
150
+
151
+ self.stage_1_model_ledger = ModelLedger(
152
  dtype=self.dtype,
153
+ device=device,
154
+ checkpoint_path=checkpoint_path,
155
  gemma_root_path=gemma_root,
156
  spatial_upsampler_path=spatial_upsampler_path,
157
+ loras=(*loras, distilled_lora_stage_1),
158
  quantization=quantization,
159
  )
160
 
161
+ self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras(
162
+ loras=(*loras, distilled_lora_stage_2),
163
+ )
164
+
165
  self.pipeline_components = PipelineComponents(
166
  dtype=self.dtype,
167
  device=self.device,
 
227
  )
228
  encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
229
 
230
+ video_encoder = self.stage_1_model_ledger.video_encoder()
231
+ transformer = self.stage_1_model_ledger.transformer()
232
 
233
  # Stage 1: Generate sigmas using LTX2Scheduler with user-specified steps
234
  empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(
 
310
  cleanup_memory()
311
 
312
  # ── Upscaling ──
313
+ video_encoder = self.stage_1_model_ledger.video_encoder()
314
  upscaled_video_latent = upsample_video(
315
  latent=video_state.latent[:1],
316
  video_encoder=video_encoder,
317
+ upsampler=self.stage_2_model_ledger.spatial_upsampler(),
318
  )
319
 
320
  # ── Stage 2: Full resolution ──
 
328
  dtype=dtype,
329
  device=self.device,
330
  )
331
+
332
+ transformer = self.stage_2_model_ledger.transformer()
333
+
334
  video_state, audio_state = denoise_audio_video(
335
  output_shape=stage_2_output_shape,
336
  conditionings=stage_2_conditionings,
 
352
  # ── Decode both video and audio ──
353
  decoded_video = vae_decode_video(
354
  video_state.latent,
355
+ self.stage_2_model_ledger.video_decoder(),
356
  tiling_config,
357
  generator,
358
  )
359
  decoded_audio_output = vae_decode_audio(
360
  audio_state.latent,
361
+ self.stage_2_model_ledger.audio_decoder(),
362
+ self.stage_2_model_ledger.vocoder(),
363
  )
364
 
365
  return decoded_video, decoded_audio_output
366
 
367
  # Model repos
368
+ LTX_MODEL_REPO = "TenStrip/LTX2.3-10Eros"
369
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
370
+ DISTILLED_LORA_REPO = "TenStrip/LTX2.3_Distilled_Lora_1.1_Experiments"
371
 
372
  # Download model checkpoints
373
  print("=" * 80)
 
387
  weights_dir.mkdir(exist_ok=True)
388
  checkpoint_path = hf_hub_download(
389
  repo_id=LTX_MODEL_REPO,
390
+ filename="10Eros_v1_bf16.safetensors",
391
  local_dir=str(weights_dir),
392
  local_dir_use_symlinks=False,
393
  )
394
+ distilled_lora_path = hf_hub_download(repo_id=DISTILLED_LORA_REPO, filename="ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors")
395
  spatial_upsampler_path = hf_hub_download(repo_id="Lightricks/LTX-2.3", filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
396
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
397
 
 
439
 
440
  pipeline = LTX23DistilledA2VPipeline(
441
  distilled_checkpoint_path=checkpoint_path,
442
+ distilled_lora=distilled_lora_path,
443
+ distilled_lora_strength_stage_1=1.0,
444
+ distilled_lora_strength_stage_2=0.5,
445
  spatial_upsampler_path=spatial_upsampler_path,
446
  gemma_root=gemma_root,
447
  loras=[],
 
616
 
617
  # Preload all models for ZeroGPU tensor packing.
618
  print("Preloading all models (including Gemma and audio components)...")
619
+
620
+ # We now have two ledgers β€” stage 1 (distilled LoRA @ 1.0) and stage 2 (distilled LoRA @ 0.5).
621
+ # Both share the same dev checkpoint and spatial upsampler; only the transformer differs.
622
+ ledger_s1 = pipeline.stage_1_model_ledger
623
+ ledger_s2 = pipeline.stage_2_model_ledger
624
+
625
+ # Save the original factory methods from BOTH ledgers so we can rebuild individual components later.
626
+ _orig_transformer_factory_s1 = ledger_s1.transformer
627
+ _orig_transformer_factory_s2 = ledger_s2.transformer
628
+ _orig_video_encoder_factory = ledger_s1.video_encoder
629
+ _orig_video_decoder_factory = ledger_s1.video_decoder
630
+ _orig_audio_encoder_factory = ledger_s1.audio_encoder
631
+ _orig_audio_decoder_factory = ledger_s1.audio_decoder
632
+ _orig_vocoder_factory = ledger_s1.vocoder
633
+ _orig_spatial_upsampler_factory = ledger_s1.spatial_upsampler
634
+ _orig_text_encoder_factory = ledger_s1.text_encoder
635
+ _orig_gemma_embeddings_factory = ledger_s1.gemma_embeddings_processor
636
+
637
+ # Call the factories to create cached instances.
638
+ # Stage 1 transformer: dev checkpoint + distilled LoRA @ strength 1.0 (baked in at build time)
639
+ _transformer_s1 = _orig_transformer_factory_s1()
640
+ # Stage 2 transformer: dev checkpoint + distilled LoRA @ strength 0.5 (baked in at build time)
641
+ _transformer_s2 = _orig_transformer_factory_s2()
642
+
643
+ # Shared components β€” only need one copy since both ledgers use the same VAE/encoder paths.
644
  _video_encoder = _orig_video_encoder_factory()
645
  _video_decoder = _orig_video_decoder_factory()
646
  _audio_encoder = _orig_audio_encoder_factory()
 
651
  _embeddings_processor = _orig_gemma_embeddings_factory()
652
 
653
  # Replace ledger methods with lightweight lambdas that return the cached instances.
654
+ # Both ledgers point to the same shared model instances (except transformer).
655
  # We keep the original factories above so we can call them later to rebuild components.
656
+ ledger_s1.transformer = lambda: _transformer_s1
657
+ ledger_s2.transformer = lambda: _transformer_s2
658
+ ledger_s1.video_encoder = lambda: _video_encoder
659
+ ledger_s2.video_encoder = lambda: _video_encoder
660
+ ledger_s1.video_decoder = lambda: _video_decoder
661
+ ledger_s2.video_decoder = lambda: _video_decoder
662
+ ledger_s1.audio_encoder = lambda: _audio_encoder
663
+ ledger_s2.audio_encoder = lambda: _audio_encoder
664
+ ledger_s1.audio_decoder = lambda: _audio_decoder
665
+ ledger_s2.audio_decoder = lambda: _audio_decoder
666
+ ledger_s1.vocoder = lambda: _vocoder
667
+ ledger_s2.vocoder = lambda: _vocoder
668
+ ledger_s1.spatial_upsampler = lambda: _spatial_upsampler
669
+ ledger_s2.spatial_upsampler = lambda: _spatial_upsampler
670
+ ledger_s1.text_encoder = lambda: _text_encoder
671
+ ledger_s2.text_encoder = lambda: _text_encoder
672
+ ledger_s1.gemma_embeddings_processor = lambda: _embeddings_processor
673
+ ledger_s2.gemma_embeddings_processor = lambda: _embeddings_processor
674
 
675
  print("All models preloaded (including Gemma text encoder and audio encoder)!")
676
+ print(f" Stage 1 transformer: {_transformer_s1.__class__.__name__} (distilled LoRA @ 1.0 baked)")
677
+ print(f" Stage 2 transformer: {_transformer_s2.__class__.__name__} (distilled LoRA @ 0.5 baked)")
678
  # ---- REPLACE PRELOAD BLOCK END ----
679
 
680
  print("=" * 80)