File size: 7,006 Bytes
1405c30 19f6b88 1405c30 19f6b88 1405c30 19f6b88 1405c30 19f6b88 1405c30 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | """
LTX-2.3 inference pipeline for BFS head-swap.
Model sources:
Transformer : SulphurAI/Sulphur-2-base — sulphur_dev_fp8mixed.safetensors
Video VAE : Kijai/LTX2.3_comfy — vae/LTX23_video_vae_bf16.safetensors
Pipeline cfg: Lightricks/LTX-Video (text encoder, scheduler, tokenizer via from_pretrained)
LoRA distill: SulphurAI/Sulphur-2-base — distill_loras/ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors
LoRA BFS : Alissonerdx/BFS-Best-Face-Swap-Video — ltx-2.3/head_swap_v3_rank_adaptive_fro_098.safetensors
"""
from __future__ import annotations
import gc
import os
import tempfile
from pathlib import Path
from typing import Callable
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
# ---------------------------------------------------------------------------
# Model file specs
# ---------------------------------------------------------------------------
_HF_CACHE = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
MODELS = {
"transformer": ("SulphurAI/Sulphur-2-base", "sulphur_dev_fp8mixed.safetensors"),
"video_vae": ("Kijai/LTX2.3_comfy", "vae/LTX23_video_vae_bf16.safetensors"),
"lora_motion": ("SulphurAI/Sulphur-2-base", "distill_loras/ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors"),
"lora_bfs": ("Alissonerdx/BFS-Best-Face-Swap-Video", "ltx-2.3/head_swap_v3_rank_adaptive_fro_098.safetensors"),
}
# Distilled sigmas from the workflow (BasicScheduler bong_tangent, 8 steps)
DISTILLED_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
NEGATIVE_PROMPT = (
"pc game, console game, video game, cartoon, childish, ugly, "
"artifacts, low resolution, blurry, jagged edges"
)
# ---------------------------------------------------------------------------
# Model download helpers
# ---------------------------------------------------------------------------
def _download(key: str, token: str | None = None) -> str:
repo, filename = MODELS[key]
return hf_hub_download(repo_id=repo, filename=filename, token=token)
def _maybe_download_all(
token: str | None = None,
progress_cb: Callable[[str], None] | None = None,
) -> dict[str, str]:
paths = {}
for key in MODELS:
if progress_cb:
progress_cb(f"Downloading {key}…")
paths[key] = _download(key, token=token)
return paths
# ---------------------------------------------------------------------------
# Pipeline construction
# ---------------------------------------------------------------------------
def load_pipeline(
device: str = "cuda",
token: str | None = None,
progress_cb: Callable[[str], None] | None = None,
) -> dict:
"""
Download (if needed) and load all model components.
Returns a dict of loaded objects cached for reuse.
LoRAs are loaded but NOT fused so that lora_strength can be adjusted
per-request in run_inference() via set_adapters().
"""
from diffusers import (
AutoencoderKLLTXVideo,
LTXImageToVideoPipeline,
LTXVideoTransformer3DModel,
)
effective_token = token or os.environ.get("HF_TOKEN")
paths = _maybe_download_all(token=effective_token, progress_cb=progress_cb)
if progress_cb:
progress_cb("Loading transformer…")
# fp8_e4m3fn weights are loaded into bfloat16 compute precision to maximise
# diffusers compatibility; the weight file on disk is fp8-quantised.
transformer = LTXVideoTransformer3DModel.from_single_file(
paths["transformer"],
torch_dtype=torch.bfloat16,
).to(device)
if progress_cb:
progress_cb("Loading video VAE…")
video_vae = AutoencoderKLLTXVideo.from_single_file(
paths["video_vae"],
torch_dtype=torch.bfloat16,
).to(device)
if progress_cb:
progress_cb("Building pipeline…")
pipe = LTXImageToVideoPipeline.from_pretrained(
"Lightricks/LTX-Video",
transformer=transformer,
vae=video_vae,
torch_dtype=torch.bfloat16,
token=effective_token,
).to(device)
if progress_cb:
progress_cb("Loading LoRAs…")
pipe.load_lora_weights(paths["lora_motion"], adapter_name="motion")
pipe.load_lora_weights(paths["lora_bfs"], adapter_name="bfs")
# Default weights — overridden per-request in run_inference()
pipe.set_adapters(["motion", "bfs"], adapter_weights=[1.0, 1.0])
return {"pipe": pipe, "device": device}
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def run_inference(
state: dict,
composed_frames: np.ndarray,
prompt: str,
fps: float = 24.0,
lora_strength: float = 1.0,
seed: int = 42,
num_inference_steps: int = 8,
guidance_scale: float = 1.0,
region_size_px: int = 256,
progress_cb: Callable[[str], None] | None = None,
) -> np.ndarray:
"""
Run LTX-2.3 image-to-video inference on the composed frames.
Args:
state: dict returned by load_pipeline()
composed_frames: uint8 [N, H, W, 3] with chroma strip added
prompt: text prompt (head_swap: format)
fps: target frame rate
lora_strength: multiplier applied on top of default LoRA weights
seed: RNG seed
region_size_px: strip width to set guide frame crop
Returns:
uint8 [N, H, W, 3] — generated frames (strip still present, call
composer.crop_reserved_region() to remove it)
"""
pipe = state["pipe"]
device = state["device"]
N, H, W, _ = composed_frames.shape
first_frame = Image.fromarray(composed_frames[0])
# Always set adapter weights — LoRAs are not fused, so strength is dynamic.
# motion LoRA is kept at 1.0; only the BFS identity LoRA is user-adjustable.
pipe.set_adapters(["motion", "bfs"], adapter_weights=[1.0, lora_strength])
generator = torch.Generator(device=device).manual_seed(seed)
if progress_cb:
progress_cb("Running diffusion…")
with torch.inference_mode():
result = pipe(
image=first_frame,
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
width=W,
height=H,
num_frames=N,
frame_rate=fps,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
decode_timestep=0.05,
decode_noise_scale=0.025,
output_type="pt",
)
# result.frames is [1, N, C, H, W] float in [0,1]
frames_pt = result.frames[0] # [N, C, H, W]
frames_np = (frames_pt.permute(0, 2, 3, 1).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return frames_np
|