# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from typing import List, Tuple import hashlib import os import abc import numpy as np import torch from transformers import AutoTokenizer, AutoModel # Assume that models are already cached os.environ["HF_HUB_OFFLINE"] = "1" # Use a deterministic hash function for strings def string_hash(s: str) -> int: return int(hashlib.md5(s.encode()).hexdigest(), 16) class CLIPTextEncoderInterace(abc.ABC): model: torch.nn.Module CHANNEL_DIM: int def __post_init__(self): self.freeze_encoder() def freeze_encoder(self): for params in self.model.parameters(): params.requires_grad = False @abc.abstractmethod def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor: raise NotImplementedError @torch.inference_mode() def get_unique_text_embedding( self, list_of_texts: List[str] | List[List[str]], normalize: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Get unique embeddings for a list of texts. Args: list_of_texts: List[str] | List[List[str]] List of texts or list of list of texts to get unique embeddings for. Total number of texts is N. Returns: embeddings: torch.Tensor, shape (M, D) Unique embeddings for the list of texts. from_unique_indices: torch.Tensor, shape (N,) Indices of the texts in the original list. to_unique_indices: torch.Tensor, shape (M,) Indices of the unique texts in the flattened list. """ # Flatten the list of texts if isinstance(list_of_texts, list) and isinstance(list_of_texts[0], list): # list of lists list_of_texts = [item for sublist in list_of_texts for item in sublist] # cchoy: Get unique texts using hash. Using string directly is not deterministic due to python string object not using the string values only for hashing. flat_caption_hash = [string_hash(caption) for caption in list_of_texts] _, to_unique_indices, from_unique_indices = np.unique( flat_caption_hash, return_index=True, return_inverse=True ) # Get unique texts unique_texts = [list_of_texts[i] for i in to_unique_indices] # Get embeddings embeddings = self(unique_texts, normalize=normalize) # Return embeddings and indices return embeddings, torch.tensor(from_unique_indices), torch.tensor(to_unique_indices) def get_text_encoder( model_type: str, device: str, **kwargs, ) -> CLIPTextEncoderInterace: if model_type == "siglip2": return Siglip2TextEncoder(device=device, **kwargs) elif model_type == "openclip": # Recap CLIP is also openclip return OpenCLIPTextEncoder(device=device, **kwargs) else: raise ValueError(f"Model type {model_type} not supported") class OpenCLIPTextEncoder(CLIPTextEncoderInterace): CHANNEL_DIM = None def __init__( self, model_id: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16, context_length: int = None, **kwargs, ): # This is a not a required dependency, so we need to import it here try: from open_clip import create_model_from_pretrained, get_tokenizer except ImportError: raise ImportError( "open_clip is not installed. Please install it with `pip install open-clip`" ) self.prepare_data(model_id) self.tokenizer = get_tokenizer(model_id) precision = {torch.float16: "fp16", torch.bfloat16: "bf16"}[torch_dtype] self.model, _ = create_model_from_pretrained( model_id, device=device, precision=precision, ) self.device = device # Set context_length: use provided value, or infer from model, or use default if context_length is not None: self.context_length = context_length elif hasattr(self.model, "context_length"): self.context_length = self.model.context_length elif hasattr(self.model, "text") and hasattr(self.model.text, "context_length"): self.context_length = self.model.text.context_length else: # Default to 77 for standard CLIP models self.context_length = 77 print( f"Warning: Could not infer context_length from model, using default: {self.context_length}" ) def prepare_data(self, model_id: str): from open_clip.factory import download_pretrained_from_hf # Remove hf-hub: prefix if it exists model_id = model_id[len("hf-hub:") :] if model_id.startswith("hf-hub:") else model_id ckpt_path = download_pretrained_from_hf( model_id, cache_dir=os.environ.get("HF_HUB_CACHE", os.path.expanduser("~/.cache/")) ) return ckpt_path @torch.inference_mode() @torch.amp.autocast(enabled=True, device_type="cuda") def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor: text_tokens = self.tokenizer(list_of_texts, context_length=self.context_length).to( self.device ) embeddings = self.model.encode_text(text_tokens) if normalize: embeddings = torch.nn.functional.normalize(embeddings, dim=-1) return embeddings class Siglip2TextEncoder(CLIPTextEncoderInterace): CHANNEL_DIM = 1152 def __init__( self, model_id: str = "google/siglip2-so400m-patch16-384", device: str = "cuda", attn_implementation: str = "flash_attention_2", torch_dtype: torch.dtype = torch.bfloat16, **kwargs, ): # Disable tokenizer parallelism os.environ["TOKENIZERS_PARALLELISM"] = "false" # Try loading from local cache first to avoid 429 errors try: self.tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True) self.model = AutoModel.from_pretrained( model_id, attn_implementation=attn_implementation, torch_dtype=torch_dtype, device_map=device, local_files_only=True, ) print(f"Successfully loaded {model_id} from local cache.") except OSError: print( f"Model {model_id} not found locally. Downloading/Updating from Hugging Face Hub..." ) # Fallback to downloading if not found locally # This might still hit 429 if many ranks try it, but it's the standard fallback. # Ideally verify downloading on rank 0 only in a multi-node setup if this persists. self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModel.from_pretrained( model_id, attn_implementation=attn_implementation, torch_dtype=torch_dtype, device_map=device, ) self.model.vision_model = None # Remove vision model self.device = device @torch.inference_mode() @torch.amp.autocast(enabled=True, device_type="cuda") def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor: # Length is 64 https://huggingface.co/docs/transformers/main/model_doc/siglip2 text_inputs = self.tokenizer( list_of_texts, padding="max_length", truncation=True, max_length=64, return_tensors="pt", ).to(self.device) outputs = self.model.get_text_features(**text_inputs) # In newer transformers, get_text_features may return a # BaseModelOutputWithPooling instead of a plain tensor. if not isinstance(outputs, torch.Tensor): outputs = outputs.pooler_output if normalize: outputs = torch.nn.functional.normalize(outputs, dim=-1) return outputs