dagloop5 commited on
Commit
db53cec
·
verified ·
1 Parent(s): 83f442f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -27
app.py CHANGED
@@ -6,11 +6,6 @@ import sys
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
9
- subprocess.run(
10
- [sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
11
- check=False,
12
- )
13
-
14
  # Clone LTX-2 repo and install packages
15
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
16
  LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
@@ -83,34 +78,22 @@ from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
83
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
84
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
85
 
86
- from flash_attn import flash_attn_func
87
 
 
88
  try:
89
- from ltx_core.model.transformer import attention as _attn_mod
90
- _attn_mod.memory_efficient_attention = flash_attn_func
91
- except Exception as e:
92
- print(f"[ATTN] FA3 import failed: {e}, using xformers")
93
  from xformers.ops import memory_efficient_attention as _mea
94
- _attn_mod.memory_efficient_attention = _mea
95
 
96
- # Attention kernels: try FlashAttention-3 first, fall back to xformers
97
- from ltx_core.model.transformer import attention as _attn_mod
98
- print(f"[ATTN] Before patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
 
99
 
100
- try:
101
- from kernels import get_kernel
102
- fa3_kernel = get_kernel("kernels-community/vllm-flash-attn3")
103
- _attn_mod.memory_efficient_attention = fa3_kernel
104
- print(f"[ATTN] FA3 kernel applied: {_attn_mod.memory_efficient_attention}")
105
  except Exception as e:
106
- print(f"[ATTN] FA3 not available ({type(e).__name__}: {e}), trying xformers...")
107
-
108
- try:
109
- from xformers.ops import memory_efficient_attention as _mea
110
- _attn_mod.memory_efficient_attention = _mea
111
- print(f"[ATTN] xformers applied: {_attn_mod.memory_efficient_attention}")
112
- except Exception as e2:
113
- print(f"[ATTN] xformers also FAILED: {type(e2).__name__}: {e2}")
114
 
115
  logging.getLogger().setLevel(logging.INFO)
116
 
 
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
 
 
 
 
 
9
  # Clone LTX-2 repo and install packages
10
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
11
  LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
 
78
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
79
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
80
 
81
+ from ltx_core.model.transformer import attention as _attn_mod
82
 
83
+ print(f"[ATTN] Before patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
84
  try:
 
 
 
 
85
  from xformers.ops import memory_efficient_attention as _mea
86
+ from xformers.ops.fmha import cutlass
87
 
88
+ def _cutlass_memory_efficient_attention(*args, **kwargs):
89
+ # Force CUTLASS and avoid FlashAttention paths that are crashing.
90
+ kwargs["op"] = (cutlass.FwOp, cutlass.BwOp)
91
+ return _mea(*args, **kwargs)
92
 
93
+ _attn_mod.memory_efficient_attention = _cutlass_memory_efficient_attention
94
+ print(f"[ATTN] After patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
 
 
 
95
  except Exception as e:
96
+ print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
 
 
 
 
 
 
 
97
 
98
  logging.getLogger().setLevel(logging.INFO)
99