linoyts HF Staff commited on
Commit
dc2750b
·
verified ·
1 Parent(s): 4f0b478

Use precompiled AOTI transformer blocks (ZeroGPU speedup, STG-capable)

Browse files

Fuses the LoRA + loads the precompiled AOTI transformer blocks at the **root module level** (per ZeroGPU docs), using the **STG-capable** Group C repo (`ltx-community/LTX-2.3-Transformer-GroupC-STG-sm120-cu130-rb3`): the perturbation path is compiled as always-on tensor math, and a small per-block wrapper feeds a no-op ones mask when STG is off so one graph serves both. Keeps STG on (default stg_scale). Public graph-only repo → no token. Validated end-to-end on a day-to-night duplicate (coherent output, STG passes run clean).

Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -42,7 +42,21 @@ pipe.to("cuda")
42
  pipe.vae.enable_tiling()
43
  _lora_path = hf_hub_download(LORA_REPO, LORA_FILE, token=HF_TOKEN)
44
  pipe.load_lora_weights(load_file(_lora_path), adapter_name="shave")
45
- pipe.set_adapters("shave", LORA_SCALE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  def _src_fps(path, default=FPS):
 
42
  pipe.vae.enable_tiling()
43
  _lora_path = hf_hub_download(LORA_REPO, LORA_FILE, token=HF_TOKEN)
44
  pipe.load_lora_weights(load_file(_lora_path), adapter_name="shave")
45
+ pipe.fuse_lora(lora_scale=LORA_SCALE)
46
+ pipe.unload_lora_weights()
47
+ # AOTI (Group C / STG): load precompiled blocks at ROOT level. The graph always runs the
48
+ # perturbation lerp; the wrapper feeds a no-op ones mask when None (non-STG blocks / main
49
+ # pass) and forces all_perturbed=False. STG (block 28, perturbed pass) still gets the real mask.
50
+ spaces.aoti_load(module=pipe.transformer, repo_id="ltx-community/LTX-2.3-Transformer-GroupC-STG-sm120-cu130-rb3")
51
+ for _blk in pipe.transformer.transformer_blocks:
52
+ _compiled = _blk.forward
53
+ def _fwd(*a, _c=_compiled, **kw):
54
+ if kw.get("perturbation_mask", None) is None:
55
+ _hs = kw["hidden_states"]
56
+ kw["perturbation_mask"] = torch.ones((_hs.shape[0], 1, 1), device=_hs.device, dtype=_hs.dtype)
57
+ kw["all_perturbed"] = False
58
+ return _c(*a, **kw)
59
+ _blk.forward = _fwd
60
 
61
 
62
  def _src_fps(path, default=FPS):