ismailkattakath's picture
Switch base model to SulphurAI/Sulphur-2-base (fp8mixed transformer + Sulphur distill LoRA)
19f6b88 verified
Raw
History Blame Contribute Delete
7.01 kB
"""
LTX-2.3 inference pipeline for BFS head-swap.
Model sources:
Transformer : SulphurAI/Sulphur-2-base — sulphur_dev_fp8mixed.safetensors
Video VAE : Kijai/LTX2.3_comfy — vae/LTX23_video_vae_bf16.safetensors
Pipeline cfg: Lightricks/LTX-Video (text encoder, scheduler, tokenizer via from_pretrained)
LoRA distill: SulphurAI/Sulphur-2-base — distill_loras/ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors
LoRA BFS : Alissonerdx/BFS-Best-Face-Swap-Video — ltx-2.3/head_swap_v3_rank_adaptive_fro_098.safetensors
"""
from __future__ import annotations
import gc
import os
import tempfile
from pathlib import Path
from typing import Callable
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
# ---------------------------------------------------------------------------
# Model file specs
# ---------------------------------------------------------------------------
_HF_CACHE = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
MODELS = {
"transformer": ("SulphurAI/Sulphur-2-base", "sulphur_dev_fp8mixed.safetensors"),
"video_vae": ("Kijai/LTX2.3_comfy", "vae/LTX23_video_vae_bf16.safetensors"),
"lora_motion": ("SulphurAI/Sulphur-2-base", "distill_loras/ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors"),
"lora_bfs": ("Alissonerdx/BFS-Best-Face-Swap-Video", "ltx-2.3/head_swap_v3_rank_adaptive_fro_098.safetensors"),
}
# Distilled sigmas from the workflow (BasicScheduler bong_tangent, 8 steps)
DISTILLED_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
NEGATIVE_PROMPT = (
"pc game, console game, video game, cartoon, childish, ugly, "
"artifacts, low resolution, blurry, jagged edges"
)
# ---------------------------------------------------------------------------
# Model download helpers
# ---------------------------------------------------------------------------
def _download(key: str, token: str | None = None) -> str:
repo, filename = MODELS[key]
return hf_hub_download(repo_id=repo, filename=filename, token=token)
def _maybe_download_all(
token: str | None = None,
progress_cb: Callable[[str], None] | None = None,
) -> dict[str, str]:
paths = {}
for key in MODELS:
if progress_cb:
progress_cb(f"Downloading {key}…")
paths[key] = _download(key, token=token)
return paths
# ---------------------------------------------------------------------------
# Pipeline construction
# ---------------------------------------------------------------------------
def load_pipeline(
device: str = "cuda",
token: str | None = None,
progress_cb: Callable[[str], None] | None = None,
) -> dict:
"""
Download (if needed) and load all model components.
Returns a dict of loaded objects cached for reuse.
LoRAs are loaded but NOT fused so that lora_strength can be adjusted
per-request in run_inference() via set_adapters().
"""
from diffusers import (
AutoencoderKLLTXVideo,
LTXImageToVideoPipeline,
LTXVideoTransformer3DModel,
)
effective_token = token or os.environ.get("HF_TOKEN")
paths = _maybe_download_all(token=effective_token, progress_cb=progress_cb)
if progress_cb:
progress_cb("Loading transformer…")
# fp8_e4m3fn weights are loaded into bfloat16 compute precision to maximise
# diffusers compatibility; the weight file on disk is fp8-quantised.
transformer = LTXVideoTransformer3DModel.from_single_file(
paths["transformer"],
torch_dtype=torch.bfloat16,
).to(device)
if progress_cb:
progress_cb("Loading video VAE…")
video_vae = AutoencoderKLLTXVideo.from_single_file(
paths["video_vae"],
torch_dtype=torch.bfloat16,
).to(device)
if progress_cb:
progress_cb("Building pipeline…")
pipe = LTXImageToVideoPipeline.from_pretrained(
"Lightricks/LTX-Video",
transformer=transformer,
vae=video_vae,
torch_dtype=torch.bfloat16,
token=effective_token,
).to(device)
if progress_cb:
progress_cb("Loading LoRAs…")
pipe.load_lora_weights(paths["lora_motion"], adapter_name="motion")
pipe.load_lora_weights(paths["lora_bfs"], adapter_name="bfs")
# Default weights — overridden per-request in run_inference()
pipe.set_adapters(["motion", "bfs"], adapter_weights=[1.0, 1.0])
return {"pipe": pipe, "device": device}
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def run_inference(
state: dict,
composed_frames: np.ndarray,
prompt: str,
fps: float = 24.0,
lora_strength: float = 1.0,
seed: int = 42,
num_inference_steps: int = 8,
guidance_scale: float = 1.0,
region_size_px: int = 256,
progress_cb: Callable[[str], None] | None = None,
) -> np.ndarray:
"""
Run LTX-2.3 image-to-video inference on the composed frames.
Args:
state: dict returned by load_pipeline()
composed_frames: uint8 [N, H, W, 3] with chroma strip added
prompt: text prompt (head_swap: format)
fps: target frame rate
lora_strength: multiplier applied on top of default LoRA weights
seed: RNG seed
region_size_px: strip width to set guide frame crop
Returns:
uint8 [N, H, W, 3] — generated frames (strip still present, call
composer.crop_reserved_region() to remove it)
"""
pipe = state["pipe"]
device = state["device"]
N, H, W, _ = composed_frames.shape
first_frame = Image.fromarray(composed_frames[0])
# Always set adapter weights — LoRAs are not fused, so strength is dynamic.
# motion LoRA is kept at 1.0; only the BFS identity LoRA is user-adjustable.
pipe.set_adapters(["motion", "bfs"], adapter_weights=[1.0, lora_strength])
generator = torch.Generator(device=device).manual_seed(seed)
if progress_cb:
progress_cb("Running diffusion…")
with torch.inference_mode():
result = pipe(
image=first_frame,
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
width=W,
height=H,
num_frames=N,
frame_rate=fps,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
decode_timestep=0.05,
decode_noise_scale=0.025,
output_type="pt",
)
# result.frames is [1, N, C, H, W] float in [0,1]
frames_pt = result.frames[0] # [N, C, H, W]
frames_np = (frames_pt.permute(0, 2, 3, 1).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return frames_np