linoyts's picture
linoyts HF Staff
Use precompiled AOTI transformer blocks (ZeroGPU speedup, STG-capable)
dc2750b verified
Raw
History Blame
8.13 kB
import os
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
os.environ.setdefault("TORCHDYNAMO_DISABLE", "1")
import random
import tempfile
import numpy as np
import imageio.v3 as iio
import spaces
import torch
import gradio as gr
from PIL import Image, ImageOps
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from diffusers import LTX2InContextPipeline
from diffusers.pipelines.ltx2.pipeline_ltx2_ic_lora import LTX2ReferenceCondition
from diffusers.utils import load_video, encode_video
# --- Config -----------------------------------------------------------------
# Beard-removal IC-LoRA — non-distilled recipe: 30 steps, guidance 4.0, STG, 25 fps.
BASE_MODEL = "diffusers/LTX-2.3-Diffusers"
LORA_REPO = "linoyts/LTX-2.3-loras"
LORA_FILE = "ltx-2.3-22b-ic-lora-instant-shave-0.9.safetensors"
LORA_SCALE = 1.0
FPS = 25 # the card recommends 25 fps
NUM_STEPS = 30
GUIDANCE = 4.0
STG_BLOCKS = [29]
NEGATIVE = ("beard, mustache, facial hair, stubble, worst quality, "
"inconsistent motion, blurry, jittery, distorted")
MAX_SEED = np.iinfo(np.int32).max
HF_TOKEN = os.environ.get("HF_TOKEN")
RES_PRESETS = {"Fast (768×448)": (768, 448), "Quality (960×544)": (960, 544)}
FRAME_CHOICES = [33, 49, 73, 97, 121]
pipe = LTX2InContextPipeline.from_pretrained(BASE_MODEL, torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.vae.enable_tiling()
_lora_path = hf_hub_download(LORA_REPO, LORA_FILE, token=HF_TOKEN)
pipe.load_lora_weights(load_file(_lora_path), adapter_name="shave")
pipe.fuse_lora(lora_scale=LORA_SCALE)
pipe.unload_lora_weights()
# AOTI (Group C / STG): load precompiled blocks at ROOT level. The graph always runs the
# perturbation lerp; the wrapper feeds a no-op ones mask when None (non-STG blocks / main
# pass) and forces all_perturbed=False. STG (block 28, perturbed pass) still gets the real mask.
spaces.aoti_load(module=pipe.transformer, repo_id="ltx-community/LTX-2.3-Transformer-GroupC-STG-sm120-cu130-rb3")
for _blk in pipe.transformer.transformer_blocks:
_compiled = _blk.forward
def _fwd(*a, _c=_compiled, **kw):
if kw.get("perturbation_mask", None) is None:
_hs = kw["hidden_states"]
kw["perturbation_mask"] = torch.ones((_hs.shape[0], 1, 1), device=_hs.device, dtype=_hs.dtype)
kw["all_perturbed"] = False
return _c(*a, **kw)
_blk.forward = _fwd
def _src_fps(path, default=FPS):
try:
return float(iio.immeta(path, plugin="pyav").get("fps", default)) or default
except Exception:
return default
def _load_frames(path, num_frames, width, height):
frames = load_video(path)
if not frames:
return []
fps = _src_fps(path)
out = []
for i in range(num_frames):
idx = min(int(round(i / FPS * fps)), len(frames) - 1)
out.append(ImageOps.fit(frames[idx].convert("RGB"), (width, height), Image.LANCZOS))
return out
def _pick_resolution(first_frame, preset):
w, h = RES_PRESETS[preset]
if first_frame.height > first_frame.width:
w, h = h, w
return w, h
def _build_prompt(prompt):
desc = prompt.strip() or "the same person, completely clean-shaven"
return (f"REMOVEBEARD {desc}, completely smooth and clean-shaven face, bare skin, "
f"no beard, no stubble, no facial hair; identity, expression, motion, lighting and scene unchanged.")
def _export(video_np, audio, path):
kw = {}
if audio is not None:
kw = dict(audio=audio[0].float().cpu(), audio_sample_rate=pipe.vocoder.config.output_sampling_rate)
encode_video(video_np, fps=FPS, output_path=path, **kw)
def _duration(*args, **kwargs):
preset = next((a for a in args if isinstance(a, str) and a in RES_PRESETS), "Fast")
num_frames = next((a for a in args if isinstance(a, int) and a in FRAME_CHOICES), 49)
per_frame = 4.2 if "Quality" in str(preset) else 3.0 # 30 steps + CFG + STG
return int(120 + int(num_frames) * per_frame)
@spaces.GPU(duration=_duration)
def shave(video, prompt, preset, num_frames, seed, randomize,
progress=gr.Progress()):
if video is None:
raise gr.Error("Please upload a video of a bearded subject.")
if randomize:
seed = random.randint(0, MAX_SEED)
seed = int(seed)
num_frames = int(num_frames)
probe = load_video(video)
if not probe:
raise gr.Error("Could not read any frames from that video.")
width, height = _pick_resolution(probe[0], preset)
ref = _load_frames(video, num_frames, width, height)
full_prompt = _build_prompt(prompt)
def _cb(p, i, t, kw):
progress((i + 1) / NUM_STEPS, desc=f"Removing beard — step {i + 1}/{NUM_STEPS}")
return {}
video_out, audio_out = pipe(
prompt=full_prompt, negative_prompt=NEGATIVE,
reference_conditions=[LTX2ReferenceCondition(frames=ref, strength=1.0)],
reference_downscale_factor=1,
width=width, height=height, num_frames=num_frames, frame_rate=FPS,
num_inference_steps=NUM_STEPS, guidance_scale=GUIDANCE,
spatio_temporal_guidance_blocks=STG_BLOCKS,
generator=torch.Generator(device="cuda").manual_seed(seed),
output_type="np", return_dict=False, callback_on_step_end=_cb,
)
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
_export(video_out[0], audio_out, out_path)
return out_path, seed
with gr.Blocks(title="LTX-2.3 Beard Removal") as demo:
gr.Markdown(
"# 🪒 LTX-2.3 Beard Removal (Instant Shave)\n"
"Remove beard, mustache and stubble from a person in a video while preserving identity, expression and "
"motion. Using [LTX 2.3 Dev](https://huggingface.co/diffusers/LTX-2.3-Diffusers) with the "
"[Beard-Removal IC-LoRA](https://huggingface.co/ltx-community/LTX-2.3-loras), via diffusers 🧨."
)
with gr.Row():
with gr.Column():
video_in = gr.Video(label="Video of a bearded subject")
prompt = gr.Textbox(
label="Prompt — describe the clean-shaven subject/scene and any sounds (optional)", lines=3,
placeholder="a man with a completely smooth clean-shaven face, warm indoor light, laughing; warm hearty laughter and quiet room tone",
)
with gr.Accordion("Settings", open=False):
preset = gr.Dropdown(list(RES_PRESETS), value="Fast (768×448)", label="Resolution")
num_frames = gr.Dropdown(FRAME_CHOICES, value=49, label="Frames (25fps)")
randomize = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
run = gr.Button("Remove beard", variant="primary")
with gr.Column():
video_out = gr.Video(label="Clean-shaven result")
used_seed = gr.Number(label="Seed used", interactive=False)
run.click(shave, inputs=[video_in, prompt, preset, num_frames, seed, randomize],
outputs=[video_out, used_seed])
gr.Examples(
examples=[
["examples/beard_1.mp4",
"a man with a completely smooth clean-shaven face, no beard, no mustache, no stubble, bare skin, neutral expression in soft indoor light; quiet room ambience",
"Fast (768×448)", 73, 42, False],
["examples/beard_2.mp4",
"a man in a hooded jacket with a completely smooth clean-shaven face, no beard or stubble, standing outdoors at night with city lights behind him; cool night ambience and distant traffic",
"Fast (768×448)", 73, 42, False],
["examples/beard_3.mp4",
"a man in profile with a completely smooth clean-shaven face, no beard or stubble, smoking a pipe outdoors in a red jacket; gentle outdoor ambience and a soft breeze",
"Fast (768×448)", 73, 42, False],
],
inputs=[video_in, prompt, preset, num_frames, seed, randomize],
outputs=[video_out, used_seed], fn=shave, cache_examples=True, cache_mode="lazy",
)
if __name__ == "__main__":
demo.launch(show_error=True)