add aoti for speed up

#3
by linoyts HF Staff - opened
Files changed (2) hide show
  1. app.py +34 -2
  2. requirements.txt +5 -4
app.py CHANGED
@@ -6,8 +6,7 @@ import sys
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
9
- # memory-efficient attention
10
- subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
 
12
  # --- clone + install the NATIVE LTX-2 codebase at the pinned commit the working ZeroGPU spaces use ---
13
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
@@ -211,6 +210,38 @@ if not SKIP_STAGE_2:
211
  _preload_pin(getattr(pipeline, "stage_2_model_ledger", None), "stage2")
212
  print("Pipeline ready.")
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  def _duration(*args, **kwargs):
216
  nf = next((a for a in args if isinstance(a, int) and a in FRAME_CHOICES), DEFAULT_FRAMES)
@@ -256,6 +287,7 @@ with gr.Blocks(title="LTX-2.3 Beard Removal") as demo:
256
  "motion. Using [LTX 2.3 Dev](https://huggingface.co/Lightricks/LTX-2.3) with the "
257
  "[Beard-Removal IC-LoRA](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Instant-Shave)."
258
  )
 
259
  with gr.Row():
260
  with gr.Column():
261
  video_in = gr.Video(label="Video of a bearded subject")
 
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
9
+ # (removed runtime xformers install -> would pull torch 2.8 and break the AOTI .pt2; SDPA used)
 
10
 
11
  # --- clone + install the NATIVE LTX-2 codebase at the pinned commit the working ZeroGPU spaces use ---
12
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
 
210
  _preload_pin(getattr(pipeline, "stage_2_model_ledger", None), "stage2")
211
  print("Pipeline ready.")
212
 
213
+ # ============================ AOTI (native bf16 transformer graph) ============================
214
+ AOTI_REPO = os.environ.get("AOTI_REPO", "linoyts/LTX-2.3-Native-Transformer-GroupA-sm120-cu130-r20")
215
+ import types as _types
216
+ from dataclasses import replace as _dc_replace
217
+ from ltx_core.model.transformer.transformer_args import TransformerArgs as _TA
218
+ _TA_FIELDS = list(_TA.__dataclass_fields__.keys())
219
+ def _flatten_ta(ta):
220
+ out = []
221
+ for f in _TA_FIELDS:
222
+ v = getattr(ta, f)
223
+ if torch.is_tensor(v):
224
+ out.append(v)
225
+ elif isinstance(v, tuple) and len(v) > 0 and all(torch.is_tensor(x) for x in v):
226
+ out.extend(v)
227
+ return out
228
+ def _install_aoti():
229
+ velocity = pipeline.stage_1_model_ledger.transformer().velocity_model
230
+ spaces.aoti_load(module=velocity, repo_id=AOTI_REPO)
231
+ def _proc(self, video, audio, perturbations):
232
+ for blk in self.transformer_blocks:
233
+ o = blk(*(_flatten_ta(video) + _flatten_ta(audio)))
234
+ video = _dc_replace(video, x=o[0]); audio = _dc_replace(audio, x=o[1])
235
+ return video, audio
236
+ velocity._process_transformer_blocks = _types.MethodType(_proc, velocity)
237
+ print(f"[AOTI] loaded {AOTI_REPO} + patched block loop", flush=True)
238
+ print(f"[AOTI] base torch={torch.__version__} cuda={torch.version.cuda}", flush=True)
239
+ try:
240
+ _install_aoti(); print("[AOTI] OK", flush=True)
241
+ except Exception as _e:
242
+ import traceback; traceback.print_exc(); print(f"[AOTI] FAILED ({_e!r}) -> EAGER", flush=True)
243
+ # ==============================================================================================
244
+
245
 
246
  def _duration(*args, **kwargs):
247
  nf = next((a for a in args if isinstance(a, int) and a in FRAME_CHOICES), DEFAULT_FRAMES)
 
287
  "motion. Using [LTX 2.3 Dev](https://huggingface.co/Lightricks/LTX-2.3) with the "
288
  "[Beard-Removal IC-LoRA](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Instant-Shave)."
289
  )
290
+ gr.Markdown("⚡ **Accelerated with [AOTI](https://huggingface.co/linoyts/LTX-2.3-Native-Transformer-GroupA-sm120-cu130-r20)** — precompiled transformer for faster inference.")
291
  with gr.Row():
292
  with gr.Column():
293
  video_in = gr.Video(label="Video of a bearded subject")
requirements.txt CHANGED
@@ -1,11 +1,12 @@
 
1
  transformers==4.57.6
2
  accelerate
3
- torch==2.8.0
4
- torchaudio==2.8.0
5
  einops
6
  scipy
7
  av
8
- scikit-image>=0.25.2
9
- flashpack==0.1.2
10
  imageio[ffmpeg]
11
  pillow
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu130
2
  transformers==4.57.6
3
  accelerate
4
+ torchaudio==2.11.0
 
5
  einops
6
  scipy
7
  av
 
 
8
  imageio[ffmpeg]
9
  pillow
10
+ spaces
11
+ sentencepiece
12
+ ftfy