Spaces:
Running on Zero
Running on Zero
| import os | |
| os.environ.setdefault("TORCH_COMPILE_DISABLE", "1") | |
| os.environ.setdefault("TORCHDYNAMO_DISABLE", "1") | |
| import random | |
| import tempfile | |
| import numpy as np | |
| import imageio.v3 as iio | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| from PIL import Image, ImageOps | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from diffusers import LTX2InContextPipeline | |
| from diffusers.pipelines.ltx2.pipeline_ltx2_ic_lora import LTX2ReferenceCondition | |
| from diffusers.utils import load_video, encode_video | |
| # --- Config ----------------------------------------------------------------- | |
| # Beard-removal IC-LoRA — non-distilled recipe: 30 steps, guidance 4.0, STG, 25 fps. | |
| BASE_MODEL = "diffusers/LTX-2.3-Diffusers" | |
| LORA_REPO = "linoyts/LTX-2.3-loras" | |
| LORA_FILE = "ltx-2.3-22b-ic-lora-instant-shave-0.9.safetensors" | |
| LORA_SCALE = 1.0 | |
| FPS = 25 # the card recommends 25 fps | |
| NUM_STEPS = 30 | |
| GUIDANCE = 4.0 | |
| STG_BLOCKS = [29] | |
| NEGATIVE = ("beard, mustache, facial hair, stubble, worst quality, " | |
| "inconsistent motion, blurry, jittery, distorted") | |
| MAX_SEED = np.iinfo(np.int32).max | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| RES_PRESETS = {"Fast (768×448)": (768, 448), "Quality (960×544)": (960, 544)} | |
| FRAME_CHOICES = [33, 49, 73, 97, 121] | |
| pipe = LTX2InContextPipeline.from_pretrained(BASE_MODEL, torch_dtype=torch.bfloat16) | |
| pipe.to("cuda") | |
| pipe.vae.enable_tiling() | |
| _lora_path = hf_hub_download(LORA_REPO, LORA_FILE, token=HF_TOKEN) | |
| pipe.load_lora_weights(load_file(_lora_path), adapter_name="shave") | |
| pipe.fuse_lora(lora_scale=LORA_SCALE) | |
| pipe.unload_lora_weights() | |
| # AOTI (Group C / STG): load precompiled blocks at ROOT level. The graph always runs the | |
| # perturbation lerp; the wrapper feeds a no-op ones mask when None (non-STG blocks / main | |
| # pass) and forces all_perturbed=False. STG (block 28, perturbed pass) still gets the real mask. | |
| spaces.aoti_load(module=pipe.transformer, repo_id="ltx-community/LTX-2.3-Transformer-GroupC-STG-sm120-cu130-rb3") | |
| for _blk in pipe.transformer.transformer_blocks: | |
| _compiled = _blk.forward | |
| def _fwd(*a, _c=_compiled, **kw): | |
| if kw.get("perturbation_mask", None) is None: | |
| _hs = kw["hidden_states"] | |
| kw["perturbation_mask"] = torch.ones((_hs.shape[0], 1, 1), device=_hs.device, dtype=_hs.dtype) | |
| kw["all_perturbed"] = False | |
| return _c(*a, **kw) | |
| _blk.forward = _fwd | |
| def _src_fps(path, default=FPS): | |
| try: | |
| return float(iio.immeta(path, plugin="pyav").get("fps", default)) or default | |
| except Exception: | |
| return default | |
| def _load_frames(path, num_frames, width, height): | |
| frames = load_video(path) | |
| if not frames: | |
| return [] | |
| fps = _src_fps(path) | |
| out = [] | |
| for i in range(num_frames): | |
| idx = min(int(round(i / FPS * fps)), len(frames) - 1) | |
| out.append(ImageOps.fit(frames[idx].convert("RGB"), (width, height), Image.LANCZOS)) | |
| return out | |
| def _pick_resolution(first_frame, preset): | |
| w, h = RES_PRESETS[preset] | |
| if first_frame.height > first_frame.width: | |
| w, h = h, w | |
| return w, h | |
| def _build_prompt(prompt): | |
| desc = prompt.strip() or "the same person, completely clean-shaven" | |
| return (f"REMOVEBEARD {desc}, completely smooth and clean-shaven face, bare skin, " | |
| f"no beard, no stubble, no facial hair; identity, expression, motion, lighting and scene unchanged.") | |
| def _export(video_np, audio, path): | |
| kw = {} | |
| if audio is not None: | |
| kw = dict(audio=audio[0].float().cpu(), audio_sample_rate=pipe.vocoder.config.output_sampling_rate) | |
| encode_video(video_np, fps=FPS, output_path=path, **kw) | |
| def _duration(*args, **kwargs): | |
| preset = next((a for a in args if isinstance(a, str) and a in RES_PRESETS), "Fast") | |
| num_frames = next((a for a in args if isinstance(a, int) and a in FRAME_CHOICES), 49) | |
| per_frame = 4.2 if "Quality" in str(preset) else 3.0 # 30 steps + CFG + STG | |
| return int(120 + int(num_frames) * per_frame) | |
| def shave(video, prompt, preset, num_frames, seed, randomize, | |
| progress=gr.Progress()): | |
| if video is None: | |
| raise gr.Error("Please upload a video of a bearded subject.") | |
| if randomize: | |
| seed = random.randint(0, MAX_SEED) | |
| seed = int(seed) | |
| num_frames = int(num_frames) | |
| probe = load_video(video) | |
| if not probe: | |
| raise gr.Error("Could not read any frames from that video.") | |
| width, height = _pick_resolution(probe[0], preset) | |
| ref = _load_frames(video, num_frames, width, height) | |
| full_prompt = _build_prompt(prompt) | |
| def _cb(p, i, t, kw): | |
| progress((i + 1) / NUM_STEPS, desc=f"Removing beard — step {i + 1}/{NUM_STEPS}") | |
| return {} | |
| video_out, audio_out = pipe( | |
| prompt=full_prompt, negative_prompt=NEGATIVE, | |
| reference_conditions=[LTX2ReferenceCondition(frames=ref, strength=1.0)], | |
| reference_downscale_factor=1, | |
| width=width, height=height, num_frames=num_frames, frame_rate=FPS, | |
| num_inference_steps=NUM_STEPS, guidance_scale=GUIDANCE, | |
| spatio_temporal_guidance_blocks=STG_BLOCKS, | |
| generator=torch.Generator(device="cuda").manual_seed(seed), | |
| output_type="np", return_dict=False, callback_on_step_end=_cb, | |
| ) | |
| out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name | |
| _export(video_out[0], audio_out, out_path) | |
| return out_path, seed | |
| with gr.Blocks(title="LTX-2.3 Beard Removal") as demo: | |
| gr.Markdown( | |
| "# 🪒 LTX-2.3 Beard Removal (Instant Shave)\n" | |
| "Remove beard, mustache and stubble from a person in a video while preserving identity, expression and " | |
| "motion. Using [LTX 2.3 Dev](https://huggingface.co/diffusers/LTX-2.3-Diffusers) with the " | |
| "[Beard-Removal IC-LoRA](https://huggingface.co/ltx-community/LTX-2.3-loras), via diffusers 🧨." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_in = gr.Video(label="Video of a bearded subject") | |
| prompt = gr.Textbox( | |
| label="Prompt — describe the clean-shaven subject/scene and any sounds (optional)", lines=3, | |
| placeholder="a man with a completely smooth clean-shaven face, warm indoor light, laughing; warm hearty laughter and quiet room tone", | |
| ) | |
| with gr.Accordion("Settings", open=False): | |
| preset = gr.Dropdown(list(RES_PRESETS), value="Fast (768×448)", label="Resolution") | |
| num_frames = gr.Dropdown(FRAME_CHOICES, value=49, label="Frames (25fps)") | |
| randomize = gr.Checkbox(True, label="Randomize seed") | |
| seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed") | |
| run = gr.Button("Remove beard", variant="primary") | |
| with gr.Column(): | |
| video_out = gr.Video(label="Clean-shaven result") | |
| used_seed = gr.Number(label="Seed used", interactive=False) | |
| run.click(shave, inputs=[video_in, prompt, preset, num_frames, seed, randomize], | |
| outputs=[video_out, used_seed]) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/beard_1.mp4", | |
| "a man with a completely smooth clean-shaven face, no beard, no mustache, no stubble, bare skin, neutral expression in soft indoor light; quiet room ambience", | |
| "Fast (768×448)", 73, 42, False], | |
| ["examples/beard_2.mp4", | |
| "a man in a hooded jacket with a completely smooth clean-shaven face, no beard or stubble, standing outdoors at night with city lights behind him; cool night ambience and distant traffic", | |
| "Fast (768×448)", 73, 42, False], | |
| ["examples/beard_3.mp4", | |
| "a man in profile with a completely smooth clean-shaven face, no beard or stubble, smoking a pipe outdoors in a red jacket; gentle outdoor ambience and a soft breeze", | |
| "Fast (768×448)", 73, 42, False], | |
| ], | |
| inputs=[video_in, prompt, preset, num_frames, seed, randomize], | |
| outputs=[video_out, used_seed], fn=shave, cache_examples=True, cache_mode="lazy", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) | |