Spaces:
Running on Zero
Running on Zero
Use precompiled AOTI transformer blocks (ZeroGPU speedup, STG-capable)
Browse filesFuses the LoRA + loads the precompiled AOTI transformer blocks at the **root module level** (per ZeroGPU docs), using the **STG-capable** Group C repo (`ltx-community/LTX-2.3-Transformer-GroupC-STG-sm120-cu130-rb3`): the perturbation path is compiled as always-on tensor math, and a small per-block wrapper feeds a no-op ones mask when STG is off so one graph serves both. Keeps STG on (default stg_scale). Public graph-only repo → no token. Validated end-to-end on a day-to-night duplicate (coherent output, STG passes run clean).
app.py
CHANGED
|
@@ -42,7 +42,21 @@ pipe.to("cuda")
|
|
| 42 |
pipe.vae.enable_tiling()
|
| 43 |
_lora_path = hf_hub_download(LORA_REPO, LORA_FILE, token=HF_TOKEN)
|
| 44 |
pipe.load_lora_weights(load_file(_lora_path), adapter_name="shave")
|
| 45 |
-
pipe.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
def _src_fps(path, default=FPS):
|
|
|
|
| 42 |
pipe.vae.enable_tiling()
|
| 43 |
_lora_path = hf_hub_download(LORA_REPO, LORA_FILE, token=HF_TOKEN)
|
| 44 |
pipe.load_lora_weights(load_file(_lora_path), adapter_name="shave")
|
| 45 |
+
pipe.fuse_lora(lora_scale=LORA_SCALE)
|
| 46 |
+
pipe.unload_lora_weights()
|
| 47 |
+
# AOTI (Group C / STG): load precompiled blocks at ROOT level. The graph always runs the
|
| 48 |
+
# perturbation lerp; the wrapper feeds a no-op ones mask when None (non-STG blocks / main
|
| 49 |
+
# pass) and forces all_perturbed=False. STG (block 28, perturbed pass) still gets the real mask.
|
| 50 |
+
spaces.aoti_load(module=pipe.transformer, repo_id="ltx-community/LTX-2.3-Transformer-GroupC-STG-sm120-cu130-rb3")
|
| 51 |
+
for _blk in pipe.transformer.transformer_blocks:
|
| 52 |
+
_compiled = _blk.forward
|
| 53 |
+
def _fwd(*a, _c=_compiled, **kw):
|
| 54 |
+
if kw.get("perturbation_mask", None) is None:
|
| 55 |
+
_hs = kw["hidden_states"]
|
| 56 |
+
kw["perturbation_mask"] = torch.ones((_hs.shape[0], 1, 1), device=_hs.device, dtype=_hs.dtype)
|
| 57 |
+
kw["all_perturbed"] = False
|
| 58 |
+
return _c(*a, **kw)
|
| 59 |
+
_blk.forward = _fwd
|
| 60 |
|
| 61 |
|
| 62 |
def _src_fps(path, default=FPS):
|