""" LTX-2.3 inference pipeline for BFS head-swap. Model sources: Transformer : SulphurAI/Sulphur-2-base — sulphur_dev_fp8mixed.safetensors Video VAE : Kijai/LTX2.3_comfy — vae/LTX23_video_vae_bf16.safetensors Pipeline cfg: Lightricks/LTX-Video (text encoder, scheduler, tokenizer via from_pretrained) LoRA distill: SulphurAI/Sulphur-2-base — distill_loras/ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors LoRA BFS : Alissonerdx/BFS-Best-Face-Swap-Video — ltx-2.3/head_swap_v3_rank_adaptive_fro_098.safetensors """ from __future__ import annotations import gc import os import tempfile from pathlib import Path from typing import Callable import numpy as np import torch from huggingface_hub import hf_hub_download from PIL import Image # --------------------------------------------------------------------------- # Model file specs # --------------------------------------------------------------------------- _HF_CACHE = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")) MODELS = { "transformer": ("SulphurAI/Sulphur-2-base", "sulphur_dev_fp8mixed.safetensors"), "video_vae": ("Kijai/LTX2.3_comfy", "vae/LTX23_video_vae_bf16.safetensors"), "lora_motion": ("SulphurAI/Sulphur-2-base", "distill_loras/ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors"), "lora_bfs": ("Alissonerdx/BFS-Best-Face-Swap-Video", "ltx-2.3/head_swap_v3_rank_adaptive_fro_098.safetensors"), } # Distilled sigmas from the workflow (BasicScheduler bong_tangent, 8 steps) DISTILLED_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] NEGATIVE_PROMPT = ( "pc game, console game, video game, cartoon, childish, ugly, " "artifacts, low resolution, blurry, jagged edges" ) # --------------------------------------------------------------------------- # Model download helpers # --------------------------------------------------------------------------- def _download(key: str, token: str | None = None) -> str: repo, filename = MODELS[key] return hf_hub_download(repo_id=repo, filename=filename, token=token) def _maybe_download_all( token: str | None = None, progress_cb: Callable[[str], None] | None = None, ) -> dict[str, str]: paths = {} for key in MODELS: if progress_cb: progress_cb(f"Downloading {key}…") paths[key] = _download(key, token=token) return paths # --------------------------------------------------------------------------- # Pipeline construction # --------------------------------------------------------------------------- def load_pipeline( device: str = "cuda", token: str | None = None, progress_cb: Callable[[str], None] | None = None, ) -> dict: """ Download (if needed) and load all model components. Returns a dict of loaded objects cached for reuse. LoRAs are loaded but NOT fused so that lora_strength can be adjusted per-request in run_inference() via set_adapters(). """ from diffusers import ( AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel, ) effective_token = token or os.environ.get("HF_TOKEN") paths = _maybe_download_all(token=effective_token, progress_cb=progress_cb) if progress_cb: progress_cb("Loading transformer…") # fp8_e4m3fn weights are loaded into bfloat16 compute precision to maximise # diffusers compatibility; the weight file on disk is fp8-quantised. transformer = LTXVideoTransformer3DModel.from_single_file( paths["transformer"], torch_dtype=torch.bfloat16, ).to(device) if progress_cb: progress_cb("Loading video VAE…") video_vae = AutoencoderKLLTXVideo.from_single_file( paths["video_vae"], torch_dtype=torch.bfloat16, ).to(device) if progress_cb: progress_cb("Building pipeline…") pipe = LTXImageToVideoPipeline.from_pretrained( "Lightricks/LTX-Video", transformer=transformer, vae=video_vae, torch_dtype=torch.bfloat16, token=effective_token, ).to(device) if progress_cb: progress_cb("Loading LoRAs…") pipe.load_lora_weights(paths["lora_motion"], adapter_name="motion") pipe.load_lora_weights(paths["lora_bfs"], adapter_name="bfs") # Default weights — overridden per-request in run_inference() pipe.set_adapters(["motion", "bfs"], adapter_weights=[1.0, 1.0]) return {"pipe": pipe, "device": device} # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def run_inference( state: dict, composed_frames: np.ndarray, prompt: str, fps: float = 24.0, lora_strength: float = 1.0, seed: int = 42, num_inference_steps: int = 8, guidance_scale: float = 1.0, region_size_px: int = 256, progress_cb: Callable[[str], None] | None = None, ) -> np.ndarray: """ Run LTX-2.3 image-to-video inference on the composed frames. Args: state: dict returned by load_pipeline() composed_frames: uint8 [N, H, W, 3] with chroma strip added prompt: text prompt (head_swap: format) fps: target frame rate lora_strength: multiplier applied on top of default LoRA weights seed: RNG seed region_size_px: strip width to set guide frame crop Returns: uint8 [N, H, W, 3] — generated frames (strip still present, call composer.crop_reserved_region() to remove it) """ pipe = state["pipe"] device = state["device"] N, H, W, _ = composed_frames.shape first_frame = Image.fromarray(composed_frames[0]) # Always set adapter weights — LoRAs are not fused, so strength is dynamic. # motion LoRA is kept at 1.0; only the BFS identity LoRA is user-adjustable. pipe.set_adapters(["motion", "bfs"], adapter_weights=[1.0, lora_strength]) generator = torch.Generator(device=device).manual_seed(seed) if progress_cb: progress_cb("Running diffusion…") with torch.inference_mode(): result = pipe( image=first_frame, prompt=prompt, negative_prompt=NEGATIVE_PROMPT, width=W, height=H, num_frames=N, frame_rate=fps, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, decode_timestep=0.05, decode_noise_scale=0.025, output_type="pt", ) # result.frames is [1, N, C, H, W] float in [0,1] frames_pt = result.frames[0] # [N, C, H, W] frames_np = (frames_pt.permute(0, 2, 3, 1).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return frames_np