import logging import unicodedata from pathlib import Path from typing import Any, Dict, List, Optional, Union import numpy as np import torch import torch.nn.functional as F from PIL import Image from safetensors.torch import load_file from transformers import AutoConfig, AutoModel, AutoProcessor from vl_utils.vision_process import process_vision_info logger = logging.getLogger(__name__) MAX_LENGTH = 8192 IMAGE_BASE_FACTOR = 16 IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2 MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR MAX_PIXELS = 1800 * IMAGE_FACTOR * IMAGE_FACTOR FPS = 1 MAX_FRAMES = 64 FRAME_MAX_PIXELS = 768 * IMAGE_FACTOR * IMAGE_FACTOR MAX_TOTAL_PIXELS = 10 * FRAME_MAX_PIXELS def _resolve_weights_path(checkpoint_dir: str, weights_path: Optional[str] = None) -> str: weights_path = weights_path or str(Path(checkpoint_dir) / "model.safetensors") if Path(weights_path).exists(): return weights_path from huggingface_hub import hf_hub_download return hf_hub_download(repo_id=checkpoint_dir, filename=Path(weights_path).name) def _load_backbone( checkpoint_dir: str, weights_path: Optional[str] = None, **kwargs, ) -> torch.nn.Module: """Load Eddy weights from ``model.safetensors`` in the checkpoint folder.""" checkpoint_dir = str(checkpoint_dir) weights_path = _resolve_weights_path(checkpoint_dir, weights_path) dtype = kwargs.pop("torch_dtype", kwargs.pop("dtype", None)) config = AutoConfig.from_pretrained(checkpoint_dir, trust_remote_code=False) model = AutoModel.from_config(config) state_dict = { key.removeprefix("model."): value for key, value in load_file(weights_path).items() } model.load_state_dict(state_dict, strict=True) if dtype is not None: model = model.to(dtype=dtype) return model def sample_frames( frames: List[Union[str, Image.Image]], num_segments: int, max_segments: int ) -> List[str]: duration = len(frames) frame_id_array = np.linspace(0, duration - 1, num_segments, dtype=int) frame_id_list = frame_id_array.tolist() last_frame_id = frame_id_list[-1] sampled_frames = [] for frame_idx in frame_id_list: try: sampled_frames.append(frames[frame_idx]) except Exception: break while len(sampled_frames) < num_segments: sampled_frames.append(frames[last_frame_id]) return sampled_frames[:max_segments] class VLEmbedder: def __init__( self, model_name_or_path: str, weights_path: Optional[str] = None, max_length: int = MAX_LENGTH, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS, total_pixels: int = MAX_TOTAL_PIXELS, fps: float = FPS, num_frames: int = MAX_FRAMES, max_frames: int = MAX_FRAMES, default_instruction: str = "Represent the user's input.", **kwargs, ): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.max_length = max_length self.min_pixels = min_pixels self.max_pixels = max_pixels self.total_pixels = total_pixels self.fps = fps self.num_frames = num_frames self.max_frames = max_frames self.default_instruction = default_instruction self.model = _load_backbone( model_name_or_path, weights_path=weights_path, **kwargs, ).to(device) self.processor = AutoProcessor.from_pretrained( model_name_or_path, trust_remote_code=True, padding_side="right", ) self.model.eval() @torch.no_grad() def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: outputs = self.model(**inputs) return { "last_hidden_state": outputs.last_hidden_state, "attention_mask": inputs.get("attention_mask"), } def _truncate_tokens(self, token_ids: List[int], max_length: int) -> List[int]: if len(token_ids) <= max_length: return token_ids special_token_ids = set(self.processor.tokenizer.all_special_ids) num_special = sum(1 for token_idx in token_ids if token_idx in special_token_ids) num_non_special_to_keep = max_length - num_special final_token_ids = [] non_special_kept_count = 0 for token_idx in token_ids: if token_idx in special_token_ids: final_token_ids.append(token_idx) elif non_special_kept_count < num_non_special_to_keep: final_token_ids.append(token_idx) non_special_kept_count += 1 return final_token_ids def format_model_input( self, text: Optional[str] = None, image: Optional[Union[str, Image.Image]] = None, video: Optional[Union[str, List[Union[str, Image.Image]]]] = None, instruction: Optional[str] = None, fps: Optional[float] = None, max_frames: Optional[int] = None, ) -> List[Dict]: if instruction: instruction = instruction.strip() if instruction and not unicodedata.category(instruction[-1]).startswith("P"): instruction = instruction + "." content = [] conversation = [ { "role": "system", "content": [{"type": "text", "text": instruction or self.default_instruction}], }, {"role": "user", "content": content}, ] if not text and not image and not video: content.append({"type": "text", "text": "NULL"}) return conversation if video: video_content = None video_kwargs = {"total_pixels": self.total_pixels} if isinstance(video, list): video_content = video if self.num_frames is not None or self.max_frames is not None: video_content = sample_frames( video_content, self.num_frames, self.max_frames ) video_content = [ ("file://" + ele if isinstance(ele, str) else ele) for ele in video_content ] elif isinstance(video, str): video_content = ( video if video.startswith(("http://", "https://")) else "file://" + video ) video_kwargs = { "fps": fps or self.fps, "max_frames": max_frames or self.max_frames, } else: raise TypeError(f"Unrecognized video type: {type(video)}") if video_content: content.append({"type": "video", "video": video_content, **video_kwargs}) if image: image_content = None if isinstance(image, Image.Image): image_content = image elif isinstance(image, str): image_content = ( image if image.startswith(("http", "oss")) else "file://" + image ) else: raise TypeError(f"Unrecognized image type: {type(image)}") if image_content: content.append( { "type": "image", "image": image_content, "min_pixels": self.min_pixels, "max_pixels": self.max_pixels, } ) if text: content.append({"type": "text", "text": text}) return conversation def _preprocess_inputs(self, conversations: List[List[Dict]]) -> Dict[str, torch.Tensor]: text = self.processor.apply_chat_template( conversations, add_generation_prompt=True, tokenize=False ) try: images, video_inputs, video_kwargs = process_vision_info( conversations, image_patch_size=16, return_video_metadata=True, return_video_kwargs=True, ) except Exception as e: logger.error(f"Error in processing vision info: {e}") images = None video_inputs = None video_kwargs = {"do_sample_frames": False} text = self.processor.apply_chat_template( [{"role": "user", "content": [{"type": "text", "text": "NULL"}]}], add_generation_prompt=True, tokenize=False, ) if video_inputs is not None: videos, video_metadata = zip(*video_inputs) videos = list(videos) video_metadata = list(video_metadata) else: videos, video_metadata = None, None return self.processor( text=text, images=images, videos=videos, video_metadata=video_metadata, truncation=True, max_length=self.max_length, padding=True, do_resize=False, return_tensors="pt", **video_kwargs, ) @staticmethod def _pooling_last( hidden_state: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: flipped_tensor = attention_mask.flip(dims=[1]) last_one_positions = flipped_tensor.argmax(dim=1) col = attention_mask.shape[1] - last_one_positions - 1 row = torch.arange(hidden_state.shape[0], device=hidden_state.device) return hidden_state[row, col] def process(self, inputs: List[Dict[str, Any]], normalize: bool = True) -> torch.Tensor: conversations = [ self.format_model_input( text=ele.get("text"), image=ele.get("image"), video=ele.get("video"), instruction=ele.get("instruction"), fps=ele.get("fps"), max_frames=ele.get("max_frames"), ) for ele in inputs ] processed_inputs = self._preprocess_inputs(conversations) processed_inputs = {k: v.to(self.model.device) for k, v in processed_inputs.items()} outputs = self.forward(processed_inputs) embeddings = self._pooling_last( outputs["last_hidden_state"], outputs["attention_mask"] ) if normalize: embeddings = F.normalize(embeddings, p=2, dim=-1) return embeddings