File size: 15,925 Bytes
8efc0db
462ac67
 
8efc0db
462ac67
 
 
8efc0db
24fae4d
462ac67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8efc0db
 
 
 
 
462ac67
 
8efc0db
462ac67
 
 
 
8efc0db
462ac67
 
 
 
 
 
 
 
 
 
 
 
8efc0db
462ac67
 
 
 
 
 
 
 
 
8efc0db
462ac67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5627944
8efc0db
 
462ac67
 
9369f9f
462ac67
8efc0db
462ac67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8efc0db
462ac67
 
 
 
 
8efc0db
 
 
 
 
 
 
 
 
462ac67
 
 
 
 
8efc0db
 
462ac67
 
 
 
 
 
 
 
 
8efc0db
 
462ac67
8efc0db
462ac67
 
 
 
 
 
8efc0db
 
 
462ac67
 
 
 
 
 
8efc0db
462ac67
 
 
 
 
 
 
 
 
 
 
8efc0db
462ac67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8efc0db
24fae4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8efc0db
 
462ac67
 
8efc0db
 
 
462ac67
 
8efc0db
462ac67
 
 
 
8efc0db
462ac67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f99044
8efc0db
 
462ac67
 
 
 
 
8efc0db
 
 
1d51f06
462ac67
 
8efc0db
24fae4d
8efc0db
 
 
 
 
 
 
 
9369f9f
8efc0db
 
 
 
 
 
 
462ac67
76e191b
8efc0db
 
 
462ac67
 
 
8efc0db
462ac67
76e191b
8efc0db
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import os
import subprocess
import sys

# ZeroGPU: torch.compile / dynamo unsupported — disable before any torch import.
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"

# (removed runtime xformers install -> would pull torch 2.8 and break the AOTI .pt2; SDPA used)

# --- clone + install the NATIVE LTX-2 codebase at the pinned commit the working ZeroGPU spaces use ---
LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
LTX_COMMIT = "ae855f8538843825f9015a419cf4ba5edaf5eec2"
if not os.path.exists(LTX_REPO_DIR):
    subprocess.run(["git", "clone", LTX_REPO_URL, LTX_REPO_DIR], check=True)
    subprocess.run(["git", "-C", LTX_REPO_DIR, "checkout", LTX_COMMIT], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps",
                "-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-core"),
                "-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines")], check=True)
sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))

import logging
import random
import tempfile

import numpy as np
import imageio.v3 as iio
from PIL import Image, ImageOps

import torch
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.disable = True

import spaces
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download

# Import LTX modules in the proven order — importing ltx_core.quantization/loader FIRST hits a
# circular import (fp8_cast <-> loader.fuse_loras). Importing the model modules first forces the
# correct init order (mirrors the working reference Space).
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number, decode_video as _vae_decode_video  # noqa: F401
from ltx_core.model.upsampler import upsample_video as _upsample_video  # noqa: F401
from ltx_core.model.audio_vae import encode_audio as _vae_encode_audio  # noqa: F401
from ltx_core.quantization import QuantizationPolicy
from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
from ltx_pipelines.ic_lora import ICLoraPipeline
from ltx_pipelines.utils.media_io import encode_video

# --- ZeroGPU loader patch -------------------------------------------------------------
# The native loader opens safetensors directly on the CUDA device
# (safe_open(path, device="cuda")), doing the host->device copy in safetensors' own C++
# (cudaMemcpy) — bypassing torch.Tensor.to, the call ZeroGPU patches to virtualise + pack
# weights at module scope. Result: "No CUDA GPUs are available" at startup, nothing packs.
# Patch it to open on CPU then move via torch.Tensor.to (ZeroGPU-virtualisable).
import safetensors as _safetensors
import ltx_core.loader.sft_loader as _sft
from ltx_core.loader.primitives import StateDict as _StateDict

def _zerogpu_safe_load(self, path, sd_ops, device=None):
    device = device or torch.device("cpu")
    sd, size, dtype = {}, 0, set()
    model_paths = path if isinstance(path, list) else [path]
    for shard_path in model_paths:
        with _safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:
            for name in f.keys():
                expected = name if sd_ops is None else sd_ops.apply_to_key(name)
                if expected is None:
                    continue
                value = f.get_tensor(name).to(device=device)  # torch path -> ZeroGPU-virtualised
                kvs = ((expected, value),)
                if sd_ops is not None:
                    kvs = sd_ops.apply_to_key_value(expected, value)
                for k, v in kvs:
                    size += v.nbytes
                    dtype.add(v.dtype)
                    sd[k] = v
    return _StateDict(sd=sd, device=device, size=size, dtype=dtype)

_sft.SafetensorsStateDictLoader.load = _zerogpu_safe_load
print("[PATCH] safetensors loader -> CPU-open + torch.to (ZeroGPU-virtualisable)")
# --------------------------------------------------------------------------------------

# --- attention backend patch (FA3 crashes on Blackwell ZeroGPU; use xformers/SDPA) ---
import torch.nn.functional as F
from ltx_core.model.transformer import attention as _attn_mod

def _sdpa_as_mea(query, key, value, attn_bias=None, scale=None, **kwargs):
    q, k, v = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
    return F.scaled_dot_product_attention(q, k, v, scale=scale).transpose(1, 2)

# IMPORTANT (ZeroGPU): never query CUDA at module scope. SDPA works on every GPU (incl.
# Blackwell ZeroGPU, where FA3 crashes), so patch it unconditionally.
_attn_mod.memory_efficient_attention = _sdpa_as_mea
print("[ATTN] SDPA (patched at module scope, no CUDA query)")

logging.getLogger().setLevel(logging.INFO)

# =========================== PER-LORA CONFIG (colorize) ===========================
TITLE = "LTX-2.3 Beard Removal (native LTX-2)"
LORA_REPO = "Lightricks/LTX-2.3-22b-IC-LoRA-Instant-Shave"
LORA_FILE = "ltx-2.3-22b-ic-lora-instant-shave-0.9.safetensors"
LORA_SCALE = 1.0
SKIP_STAGE_2 = True
GRAYSCALE_REF = False
RES_PRESETS = {"960×544 (recommended)": (960, 544), "768×448 (fast)": (768, 448)}
DEFAULT_PRESET = "960×544 (recommended)"
FRAME_CHOICES = [33, 49, 73, 97, 121]
DEFAULT_FRAMES = 49

def build_prompt(p):
    return (
        f"REMOVEBEARD {p.strip()}, completely smooth and clean-shaven face, bare skin, "
        "no beard, no stubble, no facial hair; identity, expression, motion, lighting and scene unchanged."
    )

EXAMPLES = [
    ["examples/beard_1.mp4",
     "the same man with a completely smooth clean-shaven face, no beard or stubble, bare skin, relaxed expression in soft indoor light; a quiet room ambience",
     "960×544 (recommended)", 49, 42, False],
    ["examples/beard_2.mp4",
     "the same man in a hooded jacket with a completely smooth clean-shaven face, no beard or stubble, outdoors at night with city lights behind him; cool night ambience",
     "960×544 (recommended)", 49, 42, False],
]
# =================================================================================

FPS = 25.0
MAX_SEED = np.iinfo(np.int32).max
HF_TOKEN = os.environ.get("HF_TOKEN")
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"


def _src_fps(path, default=FPS):
    try:
        return float(iio.immeta(path, plugin="pyav").get("fps", default)) or default
    except Exception:
        return default


def _prep_reference(path, width, height, num_frames):
    """Resample to 24fps, aspect-fit/crop to WxH, NF frames; (optionally grayscale); write temp mp4."""
    vid = iio.imread(path, plugin="pyav")
    src_fps = _src_fps(path)
    n = len(vid)
    out = []
    for i in range(num_frames):
        idx = min(int(round(i / FPS * src_fps)), n - 1)
        im = Image.fromarray(vid[idx]).convert("RGB")
        im = ImageOps.fit(im, (width, height), Image.LANCZOS)
        if GRAYSCALE_REF:
            im = im.convert("L").convert("RGB")
        out.append(np.array(im))
    tmp = tempfile.mktemp(suffix=".mp4")
    iio.imwrite(tmp, np.stack(out), fps=FPS, plugin="pyav", codec="libx264")
    return tmp


def _pick_resolution(path, preset):
    w, h = RES_PRESETS[preset]
    try:
        f0 = iio.imread(path, plugin="pyav", index=0)
        if f0.shape[0] > f0.shape[1]:  # portrait
            w, h = h, w
    except Exception:
        pass
    return w, h


# --- Load native pipeline + IC-LoRA once at module scope (ZeroGPU packs weights here) ---
print("Downloading checkpoints…")
checkpoint_path = hf_hub_download(LTX_MODEL_REPO, "ltx-2.3-22b-distilled-1.1.safetensors", token=HF_TOKEN)
spatial_upsampler_path = hf_hub_download(LTX_MODEL_REPO, "ltx-2.3-spatial-upscaler-x2-1.1.safetensors", token=HF_TOKEN)
gemma_root = snapshot_download(GEMMA_REPO, token=HF_TOKEN)
lora_path = hf_hub_download(LORA_REPO, LORA_FILE, token=HF_TOKEN)

print("Building ICLoraPipeline…")
pipeline = ICLoraPipeline(
    distilled_checkpoint_path=checkpoint_path,
    spatial_upsampler_path=spatial_upsampler_path,
    gemma_root=gemma_root,
    loras=[LoraPathStrengthAndSDOps(lora_path, LORA_SCALE, LTXV_LORA_COMFY_RENAMING_MAP)],
    # bf16 (NOT fp8): the IC-LoRA is fused into the transformer at MODULE SCOPE (the GPU
    # worker can't re-open the checkpoint file). fp8_cast()'s fusion runs a custom CUDA kernel
    # that can't be ZeroGPU-virtualised; the bf16 fuse rule is pure torch -> virtualisable.
    quantization=None,
)


def _preload_pin(ledger, tag):
    if ledger is None:
        return
    for name in ["transformer", "video_encoder", "video_decoder", "audio_encoder",
                 "audio_decoder", "vocoder", "spatial_upsampler", "text_encoder",
                 "gemma_embeddings_processor"]:
        fn = getattr(ledger, name, None)
        if callable(fn):
            try:
                obj = fn()
                setattr(ledger, name, (lambda o=obj: o))
                print(f"[preload {tag}] {name} ✓")
            except Exception as e:
                print(f"[preload {tag}] {name} skipped: {e}")

# Preload stage 1 always; preload stage 2 only when two-stage is used (skip_stage_2=False).
# Eagerly pinning both ledgers materializes TWO ~46GB transformers — too big for the ZeroGPU pack.
_preload_pin(getattr(pipeline, "stage_1_model_ledger", None), "stage1")
if not SKIP_STAGE_2:
    _preload_pin(getattr(pipeline, "stage_2_model_ledger", None), "stage2")
print("Pipeline ready.")

# ============================ AOTI (native bf16 transformer graph) ============================
AOTI_REPO = os.environ.get("AOTI_REPO", "linoyts/LTX-2.3-Native-Transformer-GroupA-sm120-cu130-r20")
import types as _types
from dataclasses import replace as _dc_replace
from ltx_core.model.transformer.transformer_args import TransformerArgs as _TA
_TA_FIELDS = list(_TA.__dataclass_fields__.keys())
def _flatten_ta(ta):
    out = []
    for f in _TA_FIELDS:
        v = getattr(ta, f)
        if torch.is_tensor(v):
            out.append(v)
        elif isinstance(v, tuple) and len(v) > 0 and all(torch.is_tensor(x) for x in v):
            out.extend(v)
    return out
def _install_aoti():
    velocity = pipeline.stage_1_model_ledger.transformer().velocity_model
    spaces.aoti_load(module=velocity, repo_id=AOTI_REPO)
    def _proc(self, video, audio, perturbations):
        for blk in self.transformer_blocks:
            o = blk(*(_flatten_ta(video) + _flatten_ta(audio)))
            video = _dc_replace(video, x=o[0]); audio = _dc_replace(audio, x=o[1])
        return video, audio
    velocity._process_transformer_blocks = _types.MethodType(_proc, velocity)
    print(f"[AOTI] loaded {AOTI_REPO} + patched block loop", flush=True)
print(f"[AOTI] base torch={torch.__version__} cuda={torch.version.cuda}", flush=True)
try:
    _install_aoti(); print("[AOTI] OK", flush=True)
except Exception as _e:
    import traceback; traceback.print_exc(); print(f"[AOTI] FAILED ({_e!r}) -> EAGER", flush=True)
# ==============================================================================================


def _duration(*args, **kwargs):
    nf = next((a for a in args if isinstance(a, int) and a in FRAME_CHOICES), DEFAULT_FRAMES)
    return int(60 + nf * 1.2)


@spaces.GPU(duration=_duration)
@torch.inference_mode()
def shave(video, prompt, preset, num_frames, seed, randomize, progress=gr.Progress(track_tqdm=True)):
    if video is None:
        raise gr.Error("Please upload a video.")
    if not prompt.strip():
        raise gr.Error("Describe the result (e.g. 'a brown rabbit on grey rocks, soft birdsong').")
    seed = random.randint(0, MAX_SEED) if randomize else int(seed)
    num_frames = int(num_frames)
    width, height = _pick_resolution(video, preset)
    ref_path = _prep_reference(video, width, height, num_frames)
    tiling = TilingConfig.default()
    # skip_stage_2 outputs at half the passed dims -> pass 2x so output matches the preset.
    gen_w, gen_h = (width * 2, height * 2) if SKIP_STAGE_2 else (width, height)
    video_out, audio_out = pipeline(
        prompt=build_prompt(prompt),
        seed=seed, height=gen_h, width=gen_w,
        num_frames=num_frames, frame_rate=FPS,
        images=[], video_conditioning=[(ref_path, 1.0)],
        skip_stage_2=SKIP_STAGE_2, tiling_config=tiling,
    )
    out_path = tempfile.mktemp(suffix=".mp4")
    encode_video(video=video_out, fps=FPS, audio=audio_out, output_path=out_path,
                 video_chunks_number=get_video_chunks_number(num_frames, tiling))
    return out_path, seed


# --- UI config (match the public Space exactly) ---
RES_PRESETS = {"960×544 (recommended)": (960, 544), "768×448 (fast)": (768, 448)}
FRAME_CHOICES = [33, 49, 73, 97, 121]


with gr.Blocks(title="LTX-2.3 Beard Removal") as demo:
    gr.Markdown(
        "# 🪒 LTX-2.3 Beard Removal (Instant Shave)\n"
        "Remove beard, mustache and stubble from a person in a video while preserving identity, expression and "
        "motion. Using [LTX 2.3 Dev](https://huggingface.co/Lightricks/LTX-2.3) with the "
        "[Beard-Removal IC-LoRA](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Instant-Shave)."
    )
    gr.Markdown("⚡ **Accelerated with [AOTI](https://huggingface.co/linoyts/LTX-2.3-Native-Transformer-GroupA-sm120-cu130-r20)** — precompiled transformer for faster inference.")
    with gr.Row():
        with gr.Column():
            video_in = gr.Video(label="Video of a bearded subject")
            prompt = gr.Textbox(
                label="Prompt — describe the clean-shaven subject/scene and any sounds (optional)", lines=3,
                placeholder="a man with a completely smooth clean-shaven face, warm indoor light, laughing; warm hearty laughter and quiet room tone",
            )
            with gr.Accordion("Settings", open=False):
                preset = gr.Dropdown(list(RES_PRESETS), value="960×544 (recommended)", label="Resolution")
                num_frames = gr.Dropdown(FRAME_CHOICES, value=49, label="Frames (25fps)")
                randomize = gr.Checkbox(True, label="Randomize seed")
                seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
            run = gr.Button("Remove beard", variant="primary")
        with gr.Column():
            video_out = gr.Video(label="Clean-shaven result")

    run.click(shave, inputs=[video_in, prompt, preset, num_frames, seed, randomize],
              outputs=[video_out, seed])

    gr.Examples(
        examples=[
            ['examples/beard_1.mp4', 'the same man with a completely smooth, clean-shaven face — no beard, no mustache, no stubble, bare clear skin revealing his jawline and natural skin texture — a relaxed neutral expression in soft, even indoor light; a quiet, intimate room ambience', '960×544 (recommended)', 49, 42, False],
            ['examples/beard_2.mp4', 'the same man in a hooded jacket with a completely smooth, clean-shaven face — no beard, no stubble, bare skin along his jaw — standing outdoors at night, cool city lights and soft bokeh glowing behind him, the light catching the clean planes of his face; a cool night-time ambience with distant traffic and a soft breeze', '960×544 (recommended)', 49, 42, False],
            ['examples/beard_3.mp4', 'the same man seen in profile with a completely smooth, clean-shaven face — no beard, no stubble, clean bare cheeks and jaw — smoking a pipe outdoors in a red jacket, a thin curl of smoke drifting past him; a gentle outdoor ambience, a soft breeze and faint distant birdsong', '960×544 (recommended)', 49, 42, False],
        ],
        inputs=[video_in, prompt, preset, num_frames, seed, randomize],
        outputs=[video_out, seed], fn=shave, cache_examples=True, cache_mode="lazy",
    )

if __name__ == "__main__":
    demo.launch(show_error=True)