Switch base model to SulphurAI/Sulphur-2-base (fp8mixed transformer + Sulphur distill LoRA)
19f6b88 verified | """ | |
| LTX-2.3 inference pipeline for BFS head-swap. | |
| Model sources: | |
| Transformer : SulphurAI/Sulphur-2-base — sulphur_dev_fp8mixed.safetensors | |
| Video VAE : Kijai/LTX2.3_comfy — vae/LTX23_video_vae_bf16.safetensors | |
| Pipeline cfg: Lightricks/LTX-Video (text encoder, scheduler, tokenizer via from_pretrained) | |
| LoRA distill: SulphurAI/Sulphur-2-base — distill_loras/ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors | |
| LoRA BFS : Alissonerdx/BFS-Best-Face-Swap-Video — ltx-2.3/head_swap_v3_rank_adaptive_fro_098.safetensors | |
| """ | |
| from __future__ import annotations | |
| import gc | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Callable | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| # --------------------------------------------------------------------------- | |
| # Model file specs | |
| # --------------------------------------------------------------------------- | |
| _HF_CACHE = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")) | |
| MODELS = { | |
| "transformer": ("SulphurAI/Sulphur-2-base", "sulphur_dev_fp8mixed.safetensors"), | |
| "video_vae": ("Kijai/LTX2.3_comfy", "vae/LTX23_video_vae_bf16.safetensors"), | |
| "lora_motion": ("SulphurAI/Sulphur-2-base", "distill_loras/ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors"), | |
| "lora_bfs": ("Alissonerdx/BFS-Best-Face-Swap-Video", "ltx-2.3/head_swap_v3_rank_adaptive_fro_098.safetensors"), | |
| } | |
| # Distilled sigmas from the workflow (BasicScheduler bong_tangent, 8 steps) | |
| DISTILLED_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] | |
| NEGATIVE_PROMPT = ( | |
| "pc game, console game, video game, cartoon, childish, ugly, " | |
| "artifacts, low resolution, blurry, jagged edges" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Model download helpers | |
| # --------------------------------------------------------------------------- | |
| def _download(key: str, token: str | None = None) -> str: | |
| repo, filename = MODELS[key] | |
| return hf_hub_download(repo_id=repo, filename=filename, token=token) | |
| def _maybe_download_all( | |
| token: str | None = None, | |
| progress_cb: Callable[[str], None] | None = None, | |
| ) -> dict[str, str]: | |
| paths = {} | |
| for key in MODELS: | |
| if progress_cb: | |
| progress_cb(f"Downloading {key}…") | |
| paths[key] = _download(key, token=token) | |
| return paths | |
| # --------------------------------------------------------------------------- | |
| # Pipeline construction | |
| # --------------------------------------------------------------------------- | |
| def load_pipeline( | |
| device: str = "cuda", | |
| token: str | None = None, | |
| progress_cb: Callable[[str], None] | None = None, | |
| ) -> dict: | |
| """ | |
| Download (if needed) and load all model components. | |
| Returns a dict of loaded objects cached for reuse. | |
| LoRAs are loaded but NOT fused so that lora_strength can be adjusted | |
| per-request in run_inference() via set_adapters(). | |
| """ | |
| from diffusers import ( | |
| AutoencoderKLLTXVideo, | |
| LTXImageToVideoPipeline, | |
| LTXVideoTransformer3DModel, | |
| ) | |
| effective_token = token or os.environ.get("HF_TOKEN") | |
| paths = _maybe_download_all(token=effective_token, progress_cb=progress_cb) | |
| if progress_cb: | |
| progress_cb("Loading transformer…") | |
| # fp8_e4m3fn weights are loaded into bfloat16 compute precision to maximise | |
| # diffusers compatibility; the weight file on disk is fp8-quantised. | |
| transformer = LTXVideoTransformer3DModel.from_single_file( | |
| paths["transformer"], | |
| torch_dtype=torch.bfloat16, | |
| ).to(device) | |
| if progress_cb: | |
| progress_cb("Loading video VAE…") | |
| video_vae = AutoencoderKLLTXVideo.from_single_file( | |
| paths["video_vae"], | |
| torch_dtype=torch.bfloat16, | |
| ).to(device) | |
| if progress_cb: | |
| progress_cb("Building pipeline…") | |
| pipe = LTXImageToVideoPipeline.from_pretrained( | |
| "Lightricks/LTX-Video", | |
| transformer=transformer, | |
| vae=video_vae, | |
| torch_dtype=torch.bfloat16, | |
| token=effective_token, | |
| ).to(device) | |
| if progress_cb: | |
| progress_cb("Loading LoRAs…") | |
| pipe.load_lora_weights(paths["lora_motion"], adapter_name="motion") | |
| pipe.load_lora_weights(paths["lora_bfs"], adapter_name="bfs") | |
| # Default weights — overridden per-request in run_inference() | |
| pipe.set_adapters(["motion", "bfs"], adapter_weights=[1.0, 1.0]) | |
| return {"pipe": pipe, "device": device} | |
| # --------------------------------------------------------------------------- | |
| # Inference | |
| # --------------------------------------------------------------------------- | |
| def run_inference( | |
| state: dict, | |
| composed_frames: np.ndarray, | |
| prompt: str, | |
| fps: float = 24.0, | |
| lora_strength: float = 1.0, | |
| seed: int = 42, | |
| num_inference_steps: int = 8, | |
| guidance_scale: float = 1.0, | |
| region_size_px: int = 256, | |
| progress_cb: Callable[[str], None] | None = None, | |
| ) -> np.ndarray: | |
| """ | |
| Run LTX-2.3 image-to-video inference on the composed frames. | |
| Args: | |
| state: dict returned by load_pipeline() | |
| composed_frames: uint8 [N, H, W, 3] with chroma strip added | |
| prompt: text prompt (head_swap: format) | |
| fps: target frame rate | |
| lora_strength: multiplier applied on top of default LoRA weights | |
| seed: RNG seed | |
| region_size_px: strip width to set guide frame crop | |
| Returns: | |
| uint8 [N, H, W, 3] — generated frames (strip still present, call | |
| composer.crop_reserved_region() to remove it) | |
| """ | |
| pipe = state["pipe"] | |
| device = state["device"] | |
| N, H, W, _ = composed_frames.shape | |
| first_frame = Image.fromarray(composed_frames[0]) | |
| # Always set adapter weights — LoRAs are not fused, so strength is dynamic. | |
| # motion LoRA is kept at 1.0; only the BFS identity LoRA is user-adjustable. | |
| pipe.set_adapters(["motion", "bfs"], adapter_weights=[1.0, lora_strength]) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| if progress_cb: | |
| progress_cb("Running diffusion…") | |
| with torch.inference_mode(): | |
| result = pipe( | |
| image=first_frame, | |
| prompt=prompt, | |
| negative_prompt=NEGATIVE_PROMPT, | |
| width=W, | |
| height=H, | |
| num_frames=N, | |
| frame_rate=fps, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| decode_timestep=0.05, | |
| decode_noise_scale=0.025, | |
| output_type="pt", | |
| ) | |
| # result.frames is [1, N, C, H, W] float in [0,1] | |
| frames_pt = result.frames[0] # [N, C, H, W] | |
| frames_np = (frames_pt.permute(0, 2, 3, 1).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8) | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return frames_np | |