ltx-2.3-deblur / app.py
linoyts's picture
linoyts HF Staff
add aoti for speed up (#3)
0b57385
Raw
History Blame Contribute Delete
16.3 kB
import os
import subprocess
import sys
# ZeroGPU: torch.compile / dynamo unsupported β€” disable before any torch import.
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"
# (removed runtime xformers install -> would pull torch 2.8 and break the AOTI .pt2; SDPA used)
# --- clone + install the NATIVE LTX-2 codebase at the pinned commit the working ZeroGPU spaces use ---
LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
LTX_COMMIT = "ae855f8538843825f9015a419cf4ba5edaf5eec2"
if not os.path.exists(LTX_REPO_DIR):
subprocess.run(["git", "clone", LTX_REPO_URL, LTX_REPO_DIR], check=True)
subprocess.run(["git", "-C", LTX_REPO_DIR, "checkout", LTX_COMMIT], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps",
"-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-core"),
"-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines")], check=True)
sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
import logging
import random
import tempfile
import numpy as np
import imageio.v3 as iio
from PIL import Image, ImageOps
import torch
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.disable = True
import spaces
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
# Import LTX modules in the proven order β€” importing ltx_core.quantization/loader FIRST hits a
# circular import (fp8_cast <-> loader.fuse_loras). Importing the model modules first forces the
# correct init order (mirrors the working reference Space).
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number, decode_video as _vae_decode_video # noqa: F401
from ltx_core.model.upsampler import upsample_video as _upsample_video # noqa: F401
from ltx_core.model.audio_vae import encode_audio as _vae_encode_audio # noqa: F401
from ltx_core.quantization import QuantizationPolicy
from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
from ltx_pipelines.ic_lora import ICLoraPipeline
from ltx_pipelines.utils.media_io import encode_video
# --- ZeroGPU loader patch -------------------------------------------------------------
# The native loader opens safetensors directly on the CUDA device
# (safe_open(path, device="cuda")), doing the host->device copy in safetensors' own C++
# (cudaMemcpy) β€” bypassing torch.Tensor.to, the call ZeroGPU patches to virtualise + pack
# weights at module scope. Result: "No CUDA GPUs are available" at startup, nothing packs.
# Patch it to open on CPU then move via torch.Tensor.to (ZeroGPU-virtualisable).
import safetensors as _safetensors
import ltx_core.loader.sft_loader as _sft
from ltx_core.loader.primitives import StateDict as _StateDict
def _zerogpu_safe_load(self, path, sd_ops, device=None):
device = device or torch.device("cpu")
sd, size, dtype = {}, 0, set()
model_paths = path if isinstance(path, list) else [path]
for shard_path in model_paths:
with _safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:
for name in f.keys():
expected = name if sd_ops is None else sd_ops.apply_to_key(name)
if expected is None:
continue
value = f.get_tensor(name).to(device=device) # torch path -> ZeroGPU-virtualised
kvs = ((expected, value),)
if sd_ops is not None:
kvs = sd_ops.apply_to_key_value(expected, value)
for k, v in kvs:
size += v.nbytes
dtype.add(v.dtype)
sd[k] = v
return _StateDict(sd=sd, device=device, size=size, dtype=dtype)
_sft.SafetensorsStateDictLoader.load = _zerogpu_safe_load
print("[PATCH] safetensors loader -> CPU-open + torch.to (ZeroGPU-virtualisable)")
# --------------------------------------------------------------------------------------
# --- attention backend patch (FA3 crashes on Blackwell ZeroGPU; use xformers/SDPA) ---
import torch.nn.functional as F
from ltx_core.model.transformer import attention as _attn_mod
def _sdpa_as_mea(query, key, value, attn_bias=None, scale=None, **kwargs):
q, k, v = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
return F.scaled_dot_product_attention(q, k, v, scale=scale).transpose(1, 2)
# IMPORTANT (ZeroGPU): never query CUDA at module scope. SDPA works on every GPU (incl.
# Blackwell ZeroGPU, where FA3 crashes), so patch it unconditionally.
_attn_mod.memory_efficient_attention = _sdpa_as_mea
print("[ATTN] SDPA (patched at module scope, no CUDA query)")
logging.getLogger().setLevel(logging.INFO)
# =========================== PER-LORA CONFIG (colorize) ===========================
TITLE = "LTX-2.3 Deblur (native LTX-2)"
LORA_REPO = "Lightricks/LTX-2.3-22b-IC-LoRA-Deblur"
LORA_FILE = "ltx-2.3-22b-ic-lora-deblur-0.9.safetensors"
LORA_SCALE = 1.0
SKIP_STAGE_2 = True
GRAYSCALE_REF = False
RES_PRESETS = {"960Γ—544 (recommended)": (960, 544), "768Γ—448 (fast)": (768, 448)}
DEFAULT_PRESET = "960Γ—544 (recommended)"
FRAME_CHOICES = [49, 73, 97, 121]
DEFAULT_FRAMES = 121
def build_prompt(p):
return (
"Reference shows the same scene, heavily out of focus with soft defocused blur and no fine detail. "
"Edited shows the same scene in sharp focus with crisp detail and clean edges. "
f"DEBLUR {p.strip()}. "
"Subject identity, framing and background geometry are identical to the reference; only focus and sharpness differ."
)
EXAMPLES = [
["examples/landscape_blur.mp4",
"a pin-sharp misty green mountain landscape mirrored in calm still water, crisp pines and rippling reflections; gentle wind over water and distant birdsong",
"960Γ—544 (recommended)", 121, 42, False],
["examples/man_laughing_blur.mp4",
"a crisp close-up of a man laughing warmly, sharp detail in his skin, hair and eyes; hearty laughter and a quiet room ambience",
"960Γ—544 (recommended)", 121, 42, False],
]
# =================================================================================
FPS = 24.0
MAX_SEED = np.iinfo(np.int32).max
HF_TOKEN = os.environ.get("HF_TOKEN")
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
def _src_fps(path, default=FPS):
try:
return float(iio.immeta(path, plugin="pyav").get("fps", default)) or default
except Exception:
return default
def _prep_reference(path, width, height, num_frames):
"""Resample to 24fps, aspect-fit/crop to WxH, NF frames; (optionally grayscale); write temp mp4."""
vid = iio.imread(path, plugin="pyav")
src_fps = _src_fps(path)
n = len(vid)
out = []
for i in range(num_frames):
idx = min(int(round(i / FPS * src_fps)), n - 1)
im = Image.fromarray(vid[idx]).convert("RGB")
im = ImageOps.fit(im, (width, height), Image.LANCZOS)
if GRAYSCALE_REF:
im = im.convert("L").convert("RGB")
out.append(np.array(im))
tmp = tempfile.mktemp(suffix=".mp4")
iio.imwrite(tmp, np.stack(out), fps=FPS, plugin="pyav", codec="libx264")
return tmp
def _pick_resolution(path, preset):
w, h = RES_PRESETS[preset]
try:
f0 = iio.imread(path, plugin="pyav", index=0)
if f0.shape[0] > f0.shape[1]: # portrait
w, h = h, w
except Exception:
pass
return w, h
# --- Load native pipeline + IC-LoRA once at module scope (ZeroGPU packs weights here) ---
print("Downloading checkpoints…")
checkpoint_path = hf_hub_download(LTX_MODEL_REPO, "ltx-2.3-22b-distilled-1.1.safetensors", token=HF_TOKEN)
spatial_upsampler_path = hf_hub_download(LTX_MODEL_REPO, "ltx-2.3-spatial-upscaler-x2-1.1.safetensors", token=HF_TOKEN)
gemma_root = snapshot_download(GEMMA_REPO, token=HF_TOKEN)
lora_path = hf_hub_download(LORA_REPO, LORA_FILE, token=HF_TOKEN)
print("Building ICLoraPipeline…")
pipeline = ICLoraPipeline(
distilled_checkpoint_path=checkpoint_path,
spatial_upsampler_path=spatial_upsampler_path,
gemma_root=gemma_root,
loras=[LoraPathStrengthAndSDOps(lora_path, LORA_SCALE, LTXV_LORA_COMFY_RENAMING_MAP)],
# bf16 (NOT fp8): the IC-LoRA is fused into the transformer at MODULE SCOPE (the GPU
# worker can't re-open the checkpoint file). fp8_cast()'s fusion runs a custom CUDA kernel
# that can't be ZeroGPU-virtualised; the bf16 fuse rule is pure torch -> virtualisable.
quantization=None,
)
def _preload_pin(ledger, tag):
if ledger is None:
return
for name in ["transformer", "video_encoder", "video_decoder", "audio_encoder",
"audio_decoder", "vocoder", "spatial_upsampler", "text_encoder",
"gemma_embeddings_processor"]:
fn = getattr(ledger, name, None)
if callable(fn):
try:
obj = fn()
setattr(ledger, name, (lambda o=obj: o))
print(f"[preload {tag}] {name} βœ“")
except Exception as e:
print(f"[preload {tag}] {name} skipped: {e}")
# Preload stage 1 always; preload stage 2 only when two-stage is used (skip_stage_2=False).
# Eagerly pinning both ledgers materializes TWO ~46GB transformers β€” too big for the ZeroGPU pack.
_preload_pin(getattr(pipeline, "stage_1_model_ledger", None), "stage1")
if not SKIP_STAGE_2:
_preload_pin(getattr(pipeline, "stage_2_model_ledger", None), "stage2")
print("Pipeline ready.")
# ============================ AOTI (native bf16 transformer graph) ============================
AOTI_REPO = os.environ.get("AOTI_REPO", "linoyts/LTX-2.3-Native-Transformer-GroupA-sm120-cu130-r20")
import types as _types
from dataclasses import replace as _dc_replace
from ltx_core.model.transformer.transformer_args import TransformerArgs as _TA
_TA_FIELDS = list(_TA.__dataclass_fields__.keys())
def _flatten_ta(ta):
out = []
for f in _TA_FIELDS:
v = getattr(ta, f)
if torch.is_tensor(v):
out.append(v)
elif isinstance(v, tuple) and len(v) > 0 and all(torch.is_tensor(x) for x in v):
out.extend(v)
return out
def _install_aoti():
velocity = pipeline.stage_1_model_ledger.transformer().velocity_model
spaces.aoti_load(module=velocity, repo_id=AOTI_REPO)
def _proc(self, video, audio, perturbations):
for blk in self.transformer_blocks:
o = blk(*(_flatten_ta(video) + _flatten_ta(audio)))
video = _dc_replace(video, x=o[0]); audio = _dc_replace(audio, x=o[1])
return video, audio
velocity._process_transformer_blocks = _types.MethodType(_proc, velocity)
print(f"[AOTI] loaded {AOTI_REPO} + patched block loop", flush=True)
print(f"[AOTI] base torch={torch.__version__} cuda={torch.version.cuda}", flush=True)
try:
_install_aoti(); print("[AOTI] OK", flush=True)
except Exception as _e:
import traceback; traceback.print_exc(); print(f"[AOTI] FAILED ({_e!r}) -> EAGER", flush=True)
# ==============================================================================================
def _duration(*args, **kwargs):
nf = next((a for a in args if isinstance(a, int) and a in FRAME_CHOICES), DEFAULT_FRAMES)
return int(60 + nf * 1.2)
@spaces.GPU(duration=_duration)
@torch.inference_mode()
def deblur(video, prompt, preset, num_frames, seed, randomize, progress=gr.Progress(track_tqdm=True)):
if video is None:
raise gr.Error("Please upload a video.")
if not prompt.strip():
raise gr.Error("Describe the result (e.g. 'a brown rabbit on grey rocks, soft birdsong').")
seed = random.randint(0, MAX_SEED) if randomize else int(seed)
num_frames = int(num_frames)
width, height = _pick_resolution(video, preset)
ref_path = _prep_reference(video, width, height, num_frames)
tiling = TilingConfig.default()
# skip_stage_2 outputs at half the passed dims -> pass 2x so output matches the preset.
gen_w, gen_h = (width * 2, height * 2) if SKIP_STAGE_2 else (width, height)
video_out, audio_out = pipeline(
prompt=build_prompt(prompt),
seed=seed, height=gen_h, width=gen_w,
num_frames=num_frames, frame_rate=FPS,
images=[], video_conditioning=[(ref_path, 1.0)],
skip_stage_2=SKIP_STAGE_2, tiling_config=tiling,
)
out_path = tempfile.mktemp(suffix=".mp4")
encode_video(video=video_out, fps=FPS, audio=audio_out, output_path=out_path,
video_chunks_number=get_video_chunks_number(num_frames, tiling))
return out_path, seed
# --- UI config (match the public Space exactly) ---
RES_PRESETS = {"960Γ—544 (recommended)": (960, 544), "1216Γ—704 (high)": (1216, 704), "768Γ—448 (fast)": (768, 448)}
FRAME_CHOICES = [49, 73, 97, 121]
with gr.Blocks(title="LTX-2.3 Deblur") as demo:
gr.Markdown(
"# πŸ”Ž LTX-2.3 Video Deblur\n"
"Restore sharpness to out-of-focus / defocused footage while keeping subject, framing and geometry "
"identity (spatial defocus, not motion blur). Using "
"[LTX 2.3 Distilled](https://huggingface.co/Lightricks/LTX-2.3) with the "
"[Deblur IC-LoRA](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Deblur)."
)
gr.Markdown("⚑ **Accelerated with [AOTI](https://huggingface.co/linoyts/LTX-2.3-Native-Transformer-GroupA-sm120-cu130-r20)** β€” precompiled transformer for faster inference.")
with gr.Row():
with gr.Column():
video_in = gr.Video(label="Out-of-focus video")
prompt = gr.Textbox(
label="Prompt β€” describe the scene and any sounds (optional; the clip does most of the work)", lines=3,
placeholder="a city street at dusk with passing cars and glowing neon signs; city ambience, passing traffic and distant chatter",
)
with gr.Accordion("Settings", open=False):
preset = gr.Dropdown(list(RES_PRESETS), value="960Γ—544 (recommended)", label="Resolution")
num_frames = gr.Dropdown(FRAME_CHOICES, value=121, label="Frames (24fps)")
randomize = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
run = gr.Button("Deblur", variant="primary")
with gr.Column():
video_out = gr.Video(label="Sharpened result")
run.click(deblur, inputs=[video_in, prompt, preset, num_frames, seed, randomize],
outputs=[video_out, seed])
gr.Examples(
examples=[
['examples/man_laughing_blur.mp4', 'a pin-sharp close-up portrait of a middle-aged man laughing warmly, deep smile lines crinkling around his eyes, individual strands of hair and fine stubble crisply resolved, soft natural window light modeling the texture of his skin with bright catchlights in his eyes; hearty, genuine laughter rising and falling, with a quiet intimate room ambience', '960Γ—544 (recommended)', 121, 42, False],
['examples/slicing_veggie_blur.mp4', 'a razor-sharp close-up of hands slicing a fresh green zucchini into thin even rounds on a pale wooden cutting board β€” the glossy green skin, the pale seeded interior, beads of water on the blade and the fine grain of the wood all crisply resolved, a stainless-steel knife edge glinting under warm kitchen light; crisp rhythmic chopping against the board and a gentle kitchen ambience', '960Γ—544 (recommended)', 121, 42, False],
['examples/landscape_blur.mp4', 'a pin-sharp misty green mountain landscape mirrored in calm still water β€” individual pines on the slopes, drifting layers of fog and crisp rippling reflections all resolving into clean detail under soft cool morning light; a gentle wind moving over the water, distant birdsong and the faint lap of ripples', '960Γ—544 (recommended)', 121, 42, False],
],
inputs=[video_in, prompt, preset, num_frames, seed, randomize],
outputs=[video_out, seed], fn=deblur, cache_examples=True, cache_mode="lazy",
)
if __name__ == "__main__":
demo.launch(show_error=True)