""" FLOAT Lipsync - Direct Integration (no subprocess/daemon) Simplified from float_lipsync_daemon.py for single-process Docker deployment. Loads FLOAT model once at startup, then generates lipsync videos on demand. """ import sys import os import types import time import datetime import subprocess import tempfile import logging logger = logging.getLogger(__name__) # ---- Add FLOAT repo to Python path ---- FLOAT_REPO_PATH = "/app/float_repo" if FLOAT_REPO_PATH not in sys.path: sys.path.insert(0, FLOAT_REPO_PATH) import torch import cv2 import numpy as np import librosa import face_alignment import albumentations as A import albumentations.pytorch.transforms as A_pytorch from transformers import Wav2Vec2FeatureExtractor # Workaround: transformers now requires torch>=2.6 for torch.load safety (CVE-2025-32434) # but vLLM 0.7.3 pins torch 2.5.1. Disable the check since we only load trusted local checkpoints. try: import transformers.utils.import_utils as _tf_import import transformers.modeling_utils as _tf_modeling _noop = lambda: None # Patch the source function if hasattr(_tf_import, 'check_torch_load_is_safe'): _tf_import.check_torch_load_is_safe = _noop # Patch the local reference in modeling_utils (imported via 'from ... import') if hasattr(_tf_modeling, 'check_torch_load_is_safe'): _tf_modeling.check_torch_load_is_safe = _noop logger.info("[FLOAT] Patched transformers torch.load safety check for torch 2.5 compat") except Exception as e: logger.warning(f"[FLOAT] Failed to patch torch.load safety check: {e}") # Workaround: newer transformers defaults wav2vec2 to SDPA attention which doesn't support # output_attentions=True (needed by FLOAT's wav2vec2.py). Force eager attention. try: from transformers import Wav2Vec2Config as _W2VConfig _orig_init = _W2VConfig.__init__ def _patched_init(self, *args, **kwargs): _orig_init(self, *args, **kwargs) self._attn_implementation = "eager" _W2VConfig.__init__ = _patched_init logger.info("[FLOAT] Patched Wav2Vec2Config to use eager attention") except Exception as e: logger.warning(f"[FLOAT] Failed to patch Wav2Vec2Config: {e}") # Import FLOAT model from cloned repo from models.float.FLOAT import FLOAT # ---- Paths (configurable via env or defaults) ---- CHECKPOINTS_DIR = "/app/checkpoints" LIPSYNC_OUTPUT_DIR = "/tmp/lipsync_output" REF_IMAGE_DIR = "/app/assets" # Default config DEFAULT_CONFIG = { "ref_path": os.path.join(REF_IMAGE_DIR, "ref.png"), "ckpt_path": os.path.join(CHECKPOINTS_DIR, "float.pth"), "wav2vec_model_path": os.path.join(CHECKPOINTS_DIR, "wav2vec2-base-960h"), "audio2emotion_path": os.path.join(CHECKPOINTS_DIR, "wav2vec-english-speech-emotion-recognition"), "seed": 15, "a_cfg_scale": 2.0, "e_cfg_scale": 1.0, "r_cfg_scale": 1.0, "no_crop": False, "nfe": 7, "fps": 25.0, } class FloatLipsync: """ Direct FLOAT lipsync generator. Call initialize() once at startup, then generate() for each audio clip. """ def __init__(self): self.model = None self.opt = None self.device = None self.fa = None self.wav2vec_preprocessor = None self.transform = None self.preprocessed_ref_image = None self.ready = False def initialize(self, config: dict = None): """ Load all models and preprocess the reference image. Call this once at startup. """ if config is None: config = DEFAULT_CONFIG init_start = time.time() logger.info("[FLOAT] Initializing FLOAT lipsync system...") # 1. Build options namespace (matches daemon's opt structure) self.opt = self._build_options(config) os.makedirs(LIPSYNC_OUTPUT_DIR, exist_ok=True) # 2. Set device torch.cuda.empty_cache() self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger.info(f"[FLOAT] Using device: {self.device}") if not torch.cuda.is_available(): logger.warning("[FLOAT] WARNING: No CUDA GPU detected! FLOAT will be very slow on CPU.") # 3. Load face alignment model start = time.time() self.fa = face_alignment.FaceAlignment( face_alignment.LandmarksType.TWO_D, flip_input=False, device=str(self.device) ) logger.info(f"[FLOAT] Face alignment loaded: {time.time() - start:.2f}s") # 4. Load wav2vec preprocessor start = time.time() self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained( self.opt.wav2vec_model_path, local_files_only=True ) logger.info(f"[FLOAT] Wav2Vec2 preprocessor loaded: {time.time() - start:.2f}s") # 5. Image transform self.transform = A.Compose([ A.Resize(height=self.opt.input_size, width=self.opt.input_size, interpolation=cv2.INTER_AREA), A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), A_pytorch.ToTensorV2(), ]) # 6. Load FLOAT model logger.info("[FLOAT] Loading FLOAT model architecture...") start = time.time() self.model = FLOAT(self.opt) logger.info(f"[FLOAT] Architecture created: {time.time() - start:.2f}s") # 7. Load checkpoint weights logger.info("[FLOAT] Loading checkpoint weights...") start = time.time() self._load_weights(self.opt.ckpt_path) logger.info(f"[FLOAT] Weights loaded: {time.time() - start:.2f}s") # 8. Move to device and eval mode start = time.time() self.model.to(self.device) self.model.eval() logger.info(f"[FLOAT] Model to device + eval: {time.time() - start:.2f}s") # 9. Preprocess reference image logger.info(f"[FLOAT] Preprocessing reference image: {self.opt.ref_path}") self._preload_reference_image(self.opt.ref_path) self.ready = True logger.info(f"[FLOAT] ✓ TOTAL INIT TIME: {time.time() - init_start:.2f}s") logger.info("[FLOAT] ✓ System ready to generate.") def _build_options(self, config: dict) -> types.SimpleNamespace: """Build the options namespace matching FLOAT's expected format.""" opt = types.SimpleNamespace() # Paths opt.ref_path = config.get("ref_path", DEFAULT_CONFIG["ref_path"]) opt.ckpt_path = config.get("ckpt_path", DEFAULT_CONFIG["ckpt_path"]) opt.wav2vec_model_path = config.get("wav2vec_model_path", DEFAULT_CONFIG["wav2vec_model_path"]) opt.audio2emotion_path = config.get("audio2emotion_path", DEFAULT_CONFIG["audio2emotion_path"]) opt.res_dir = LIPSYNC_OUTPUT_DIR # Generation params opt.seed = config.get("seed", 15) opt.a_cfg_scale = config.get("a_cfg_scale", 2.0) opt.e_cfg_scale = config.get("e_cfg_scale", 1.0) opt.r_cfg_scale = config.get("r_cfg_scale", 1.0) opt.no_crop = config.get("no_crop", False) opt.nfe = config.get("nfe", 7) opt.fps = config.get("fps", 25.0) # Device / system opt.rank = 0 opt.ngpus = 1 opt.emo = 'S2E' # Model architecture params (from FLOAT defaults) opt.fix_noise_seed = False opt.input_size = 512 opt.input_nc = 3 opt.sampling_rate = 16000 opt.audio_marcing = 2 opt.wav2vec_sec = 2.0 opt.attention_window = 2 opt.only_last_features = False opt.average_emotion = False opt.audio_dropout_prob = 0.1 opt.ref_dropout_prob = 0.1 opt.emotion_dropout_prob = 0.1 opt.style_dim = 512 opt.dim_a = 512 opt.dim_w = 512 opt.dim_h = 1024 opt.dim_m = 20 opt.dim_e = 7 opt.fmt_depth = 8 opt.num_heads = 8 opt.mlp_ratio = 4.0 opt.no_learned_pe = False opt.num_prev_frames = 10 opt.max_grad_norm = 1.0 opt.ode_atol = 1e-5 opt.ode_rtol = 1e-5 opt.torchdiffeq_ode_method = 'euler' opt.n_diff_steps = 500 opt.diff_schedule = 'cosine' opt.diffusion_mode = 'sample' return opt def _load_weights(self, checkpoint_path: str): """Load checkpoint weights into the model.""" if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") state_dict = torch.load(checkpoint_path, map_location='cpu', weights_only=True) param_count = 0 with torch.no_grad(): for model_name, model_param in self.model.named_parameters(): if model_name in state_dict: model_param.copy_(state_dict[model_name].to(self.device)) param_count += 1 elif "wav2vec2" not in model_name: logger.warning(f"[FLOAT] Weight not in checkpoint: {model_name}") del state_dict logger.info(f"[FLOAT] Loaded {param_count} parameters from checkpoint") def _preload_reference_image(self, ref_path: str): """Load and preprocess the reference image.""" if not os.path.exists(ref_path): raise FileNotFoundError(f"Reference image not found: {ref_path}") img = cv2.imread(ref_path) if img is None: raise IOError(f"Could not read image: {ref_path}") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if not self.opt.no_crop: img = self._crop_face(img) self.preprocessed_ref_image = self.transform(image=img)['image'].unsqueeze(0).to(self.device) self._original_ref_image = self.preprocessed_ref_image.clone() logger.info(f"[FLOAT] Reference image preprocessed: {self.preprocessed_ref_image.shape}") def _crop_face(self, img: np.ndarray) -> np.ndarray: """Detect and crop face from image.""" mult = 360.0 / img.shape[0] interp = cv2.INTER_AREA if mult < 1.0 else cv2.INTER_CUBIC resized = cv2.resize(img, dsize=(0, 0), fx=mult, fy=mult, interpolation=interp) bboxes = self.fa.face_detector.detect_from_image(resized) bboxes = [ (int(x1 / mult), int(y1 / mult), int(x2 / mult), int(y2 / mult), score) for (x1, y1, x2, y2, score) in bboxes if score > 0.95 ] if not bboxes: raise RuntimeError("No face detected in reference image") bbox = bboxes[0] bsy = int((bbox[3] - bbox[1]) / 2) bsx = int((bbox[2] - bbox[0]) / 2) my = int((bbox[1] + bbox[3]) / 2) mx = int((bbox[0] + bbox[2]) / 2) bs = int(max(bsy, bsx) * 1.6) img = cv2.copyMakeBorder(img, bs, bs, bs, bs, cv2.BORDER_REPLICATE) my, mx = my + bs, mx + bs crop = img[my - bs:my + bs, mx - bs:mx + bs] interp = cv2.INTER_AREA if mult < 1.0 else cv2.INTER_CUBIC crop = cv2.resize(crop, (self.opt.input_size, self.opt.input_size), interpolation=interp) return crop def update_reference_image(self, ref_path: str) -> bool: """Swap to a different reference image at runtime.""" try: self._preload_reference_image(ref_path) self.opt.ref_path = ref_path logger.info(f"[FLOAT] Reference updated: {ref_path}") return True except Exception as e: logger.error(f"[FLOAT] Failed to update reference: {e}") return False @torch.no_grad() def generate(self, audio_path: str, output_path: str = None, emo: str = 'S2E', chain_frames: bool = False) -> str: """ Generate a lipsync video from an audio file. Args: audio_path: Path to the input WAV audio file output_path: Where to save the output video (auto-generated if None) emo: Emotion mode ('S2E' for speech-to-emotion auto-detect) chain_frames: If True, use the last frame as reference for next call Returns: Path to the generated video file """ if not self.ready: raise RuntimeError("FLOAT system not initialized. Call initialize() first.") inference_start = time.time() logger.info(f"[FLOAT] Generating lipsync for: {os.path.basename(audio_path)}") # Auto-generate output path if output_path is None: timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") output_path = os.path.join(LIPSYNC_OUTPUT_DIR, f"lipsync_{timestamp}.mp4") # 1. Load and preprocess audio speech_array, sr = librosa.load(audio_path, sr=self.opt.sampling_rate) # Pad with 0.5s silence so FLOAT eases the face back to neutral at the end pad_samples = int(0.5 * sr) speech_array = np.concatenate([speech_array, np.zeros(pad_samples, dtype=speech_array.dtype)]) processed_audio = self.wav2vec_preprocessor( speech_array, sampling_rate=sr, return_tensors='pt' ).input_values[0].unsqueeze(0).to(self.device) logger.info(f"[FLOAT] Audio preprocessed: {processed_audio.shape}") # 2. Run FLOAT inference data = { 's': self.preprocessed_ref_image, 'a': processed_audio, 'p': None, 'e': None, } gen_start = time.time() d_hat = self.model.inference( data=data, a_cfg_scale=self.opt.a_cfg_scale, r_cfg_scale=self.opt.r_cfg_scale, e_cfg_scale=self.opt.e_cfg_scale, emo=emo, nfe=self.opt.nfe, seed=self.opt.seed, )['d_hat'] logger.info(f"[FLOAT] Model inference: {time.time() - gen_start:.2f}s") # 3. Chain frames: use last frame as reference for next chunk if chain_frames and d_hat.shape[0] > 0: # d_hat is (T, C, H, W) in [-1, 1] range — same as preprocessed_ref_image self.preprocessed_ref_image = d_hat[-1:].clone().to(self.device) logger.info(f"[FLOAT] Chained last frame as next reference") # 4. Save video with audio self._save_video(d_hat, output_path, audio_path) logger.info(f"[FLOAT] ✓ Total generation: {time.time() - inference_start:.2f}s") logger.info(f"[FLOAT] ✓ Output: {output_path}") return output_path def reset_reference(self): """Reset the reference image back to the original ref.png.""" if hasattr(self, '_original_ref_image') and self._original_ref_image is not None: self.preprocessed_ref_image = self._original_ref_image.clone() logger.info("[FLOAT] Reference reset to original") def _save_video(self, vid_tensor: torch.Tensor, video_path: str, audio_path: str): """Save video frames tensor to mp4 with audio using ffmpeg.""" # Prepare frames: (T, C, H, W) -> (T, H, W, C) uint8 vid = vid_tensor.permute(0, 2, 3, 1).detach().clamp(-1, 1).contiguous() vid = ((vid + 1) / 2 * 255).to(torch.uint8).cpu().numpy() height, width = vid.shape[1], vid.shape[2] logger.info(f"[FLOAT] Saving video: {vid.shape[0]} frames, {width}x{height}") # Use CPU h264 encoding ffmpeg_cmd = [ 'ffmpeg', '-y', '-f', 'rawvideo', '-vcodec', 'rawvideo', '-s', f'{width}x{height}', '-pix_fmt', 'rgb24', '-r', str(self.opt.fps), '-i', 'pipe:0', ] if audio_path and os.path.exists(audio_path): ffmpeg_cmd += ['-i', audio_path] ffmpeg_cmd += [ '-c:v', 'libx264', '-preset', 'fast', '-crf', '23', '-pix_fmt', 'yuv420p', ] if audio_path and os.path.exists(audio_path): ffmpeg_cmd += ['-c:a', 'aac', '-b:a', '128k', '-shortest'] ffmpeg_cmd.append(video_path) try: process = subprocess.Popen( ffmpeg_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) # Use communicate() with input data — this handles pipe buffering correctly stdout, stderr = process.communicate(input=vid.tobytes()) if process.returncode != 0: logger.error(f"[FLOAT] ffmpeg error (rc={process.returncode}): {stderr.decode()[:500]}") raise RuntimeError("ffmpeg encoding failed") file_size = os.path.getsize(video_path) if os.path.exists(video_path) else 0 logger.info(f"[FLOAT] Video saved: {video_path} ({file_size / 1024:.1f} KB)") except Exception as e: logger.error(f"[FLOAT] Video save failed: {e}") raise # ---- Module-level singleton ---- _lipsync_instance = None def get_lipsync() -> FloatLipsync: """Get or create the singleton FloatLipsync instance.""" global _lipsync_instance if _lipsync_instance is None: _lipsync_instance = FloatLipsync() return _lipsync_instance