add aoti for speed up

#3
by linoyts HF Staff - opened
Files changed (2) hide show
  1. app.py +34 -2
  2. requirements.txt +5 -4
app.py CHANGED
@@ -6,8 +6,7 @@ import sys
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
9
- # memory-efficient attention
10
- subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
 
12
  # --- clone + install the NATIVE LTX-2 codebase at the pinned commit the working ZeroGPU spaces use ---
13
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
@@ -212,6 +211,38 @@ if not SKIP_STAGE_2:
212
  _preload_pin(getattr(pipeline, "stage_2_model_ledger", None), "stage2")
213
  print("Pipeline ready.")
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  def _duration(*args, **kwargs):
217
  nf = next((a for a in args if isinstance(a, int) and a in FRAME_CHOICES), DEFAULT_FRAMES)
@@ -260,6 +291,7 @@ with gr.Blocks(title="LTX-2.3 Water Simulation") as demo:
260
  "Using [LTX 2.3 Distilled](https://huggingface.co/Lightricks/LTX-2.3) with the "
261
  "[Water Simulation IC-LoRA](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Water-Simulation)."
262
  )
 
263
  with gr.Row():
264
  with gr.Column():
265
  video_in = gr.Video(label="Dry input video")
 
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
9
+ # (removed runtime xformers install -> would pull torch 2.8 and break the AOTI .pt2; SDPA used)
 
10
 
11
  # --- clone + install the NATIVE LTX-2 codebase at the pinned commit the working ZeroGPU spaces use ---
12
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
 
211
  _preload_pin(getattr(pipeline, "stage_2_model_ledger", None), "stage2")
212
  print("Pipeline ready.")
213
 
214
+ # ============================ AOTI (native bf16 transformer graph) ============================
215
+ AOTI_REPO = os.environ.get("AOTI_REPO", "linoyts/LTX-2.3-Native-Transformer-GroupA-sm120-cu130-r20")
216
+ import types as _types
217
+ from dataclasses import replace as _dc_replace
218
+ from ltx_core.model.transformer.transformer_args import TransformerArgs as _TA
219
+ _TA_FIELDS = list(_TA.__dataclass_fields__.keys())
220
+ def _flatten_ta(ta):
221
+ out = []
222
+ for f in _TA_FIELDS:
223
+ v = getattr(ta, f)
224
+ if torch.is_tensor(v):
225
+ out.append(v)
226
+ elif isinstance(v, tuple) and len(v) > 0 and all(torch.is_tensor(x) for x in v):
227
+ out.extend(v)
228
+ return out
229
+ def _install_aoti():
230
+ velocity = pipeline.stage_1_model_ledger.transformer().velocity_model
231
+ spaces.aoti_load(module=velocity, repo_id=AOTI_REPO)
232
+ def _proc(self, video, audio, perturbations):
233
+ for blk in self.transformer_blocks:
234
+ o = blk(*(_flatten_ta(video) + _flatten_ta(audio)))
235
+ video = _dc_replace(video, x=o[0]); audio = _dc_replace(audio, x=o[1])
236
+ return video, audio
237
+ velocity._process_transformer_blocks = _types.MethodType(_proc, velocity)
238
+ print(f"[AOTI] loaded {AOTI_REPO} + patched block loop", flush=True)
239
+ print(f"[AOTI] base torch={torch.__version__} cuda={torch.version.cuda}", flush=True)
240
+ try:
241
+ _install_aoti(); print("[AOTI] OK", flush=True)
242
+ except Exception as _e:
243
+ import traceback; traceback.print_exc(); print(f"[AOTI] FAILED ({_e!r}) -> EAGER", flush=True)
244
+ # ==============================================================================================
245
+
246
 
247
  def _duration(*args, **kwargs):
248
  nf = next((a for a in args if isinstance(a, int) and a in FRAME_CHOICES), DEFAULT_FRAMES)
 
291
  "Using [LTX 2.3 Distilled](https://huggingface.co/Lightricks/LTX-2.3) with the "
292
  "[Water Simulation IC-LoRA](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Water-Simulation)."
293
  )
294
+ gr.Markdown("⚡ **Accelerated with [AOTI](https://huggingface.co/linoyts/LTX-2.3-Native-Transformer-GroupA-sm120-cu130-r20)** — precompiled transformer for faster inference.")
295
  with gr.Row():
296
  with gr.Column():
297
  video_in = gr.Video(label="Dry input video")
requirements.txt CHANGED
@@ -1,11 +1,12 @@
 
1
  transformers==4.57.6
2
  accelerate
3
- torch==2.8.0
4
- torchaudio==2.8.0
5
  einops
6
  scipy
7
  av
8
- scikit-image>=0.25.2
9
- flashpack==0.1.2
10
  imageio[ffmpeg]
11
  pillow
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu130
2
  transformers==4.57.6
3
  accelerate
4
+ torchaudio==2.11.0
 
5
  einops
6
  scipy
7
  av
 
 
8
  imageio[ffmpeg]
9
  pillow
10
+ spaces
11
+ sentencepiece
12
+ ftfy