linoyts HF Staff commited on
Commit
0b57385
·
1 Parent(s): 54c9edd

add aoti for speed up (#3)

Browse files

- AOTI: faster transformer (native bf16, ZeroGPU) (0eefdce628b0753e615259ad7ba6b6d13e0ebbfc)
- add aoti for speed up + mention in description (c3afb17beab38be7629916fb6e7bc434f082959d)

Files changed (2) hide show
  1. app.py +34 -2
  2. requirements.txt +5 -4
app.py CHANGED
@@ -6,8 +6,7 @@ import sys
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
9
- # memory-efficient attention
10
- subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
 
12
  # --- clone + install the NATIVE LTX-2 codebase at the pinned commit the working ZeroGPU spaces use ---
13
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
@@ -213,6 +212,38 @@ if not SKIP_STAGE_2:
213
  _preload_pin(getattr(pipeline, "stage_2_model_ledger", None), "stage2")
214
  print("Pipeline ready.")
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  def _duration(*args, **kwargs):
218
  nf = next((a for a in args if isinstance(a, int) and a in FRAME_CHOICES), DEFAULT_FRAMES)
@@ -259,6 +290,7 @@ with gr.Blocks(title="LTX-2.3 Deblur") as demo:
259
  "[LTX 2.3 Distilled](https://huggingface.co/Lightricks/LTX-2.3) with the "
260
  "[Deblur IC-LoRA](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Deblur)."
261
  )
 
262
  with gr.Row():
263
  with gr.Column():
264
  video_in = gr.Video(label="Out-of-focus video")
 
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
9
+ # (removed runtime xformers install -> would pull torch 2.8 and break the AOTI .pt2; SDPA used)
 
10
 
11
  # --- clone + install the NATIVE LTX-2 codebase at the pinned commit the working ZeroGPU spaces use ---
12
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
 
212
  _preload_pin(getattr(pipeline, "stage_2_model_ledger", None), "stage2")
213
  print("Pipeline ready.")
214
 
215
+ # ============================ AOTI (native bf16 transformer graph) ============================
216
+ AOTI_REPO = os.environ.get("AOTI_REPO", "linoyts/LTX-2.3-Native-Transformer-GroupA-sm120-cu130-r20")
217
+ import types as _types
218
+ from dataclasses import replace as _dc_replace
219
+ from ltx_core.model.transformer.transformer_args import TransformerArgs as _TA
220
+ _TA_FIELDS = list(_TA.__dataclass_fields__.keys())
221
+ def _flatten_ta(ta):
222
+ out = []
223
+ for f in _TA_FIELDS:
224
+ v = getattr(ta, f)
225
+ if torch.is_tensor(v):
226
+ out.append(v)
227
+ elif isinstance(v, tuple) and len(v) > 0 and all(torch.is_tensor(x) for x in v):
228
+ out.extend(v)
229
+ return out
230
+ def _install_aoti():
231
+ velocity = pipeline.stage_1_model_ledger.transformer().velocity_model
232
+ spaces.aoti_load(module=velocity, repo_id=AOTI_REPO)
233
+ def _proc(self, video, audio, perturbations):
234
+ for blk in self.transformer_blocks:
235
+ o = blk(*(_flatten_ta(video) + _flatten_ta(audio)))
236
+ video = _dc_replace(video, x=o[0]); audio = _dc_replace(audio, x=o[1])
237
+ return video, audio
238
+ velocity._process_transformer_blocks = _types.MethodType(_proc, velocity)
239
+ print(f"[AOTI] loaded {AOTI_REPO} + patched block loop", flush=True)
240
+ print(f"[AOTI] base torch={torch.__version__} cuda={torch.version.cuda}", flush=True)
241
+ try:
242
+ _install_aoti(); print("[AOTI] OK", flush=True)
243
+ except Exception as _e:
244
+ import traceback; traceback.print_exc(); print(f"[AOTI] FAILED ({_e!r}) -> EAGER", flush=True)
245
+ # ==============================================================================================
246
+
247
 
248
  def _duration(*args, **kwargs):
249
  nf = next((a for a in args if isinstance(a, int) and a in FRAME_CHOICES), DEFAULT_FRAMES)
 
290
  "[LTX 2.3 Distilled](https://huggingface.co/Lightricks/LTX-2.3) with the "
291
  "[Deblur IC-LoRA](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Deblur)."
292
  )
293
+ gr.Markdown("⚡ **Accelerated with [AOTI](https://huggingface.co/linoyts/LTX-2.3-Native-Transformer-GroupA-sm120-cu130-r20)** — precompiled transformer for faster inference.")
294
  with gr.Row():
295
  with gr.Column():
296
  video_in = gr.Video(label="Out-of-focus video")
requirements.txt CHANGED
@@ -1,11 +1,12 @@
 
1
  transformers==4.57.6
2
  accelerate
3
- torch==2.8.0
4
- torchaudio==2.8.0
5
  einops
6
  scipy
7
  av
8
- scikit-image>=0.25.2
9
- flashpack==0.1.2
10
  imageio[ffmpeg]
11
  pillow
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu130
2
  transformers==4.57.6
3
  accelerate
4
+ torchaudio==2.11.0
 
5
  einops
6
  scipy
7
  av
 
 
8
  imageio[ffmpeg]
9
  pillow
10
+ spaces
11
+ sentencepiece
12
+ ftfy