SpaCeFormer / demo /text_encoder.py
chrischoy's picture
Merge SpaceFormer demo (viser + CLI + Gradio) under demo/
a8e8155 verified
Raw
History Blame Contribute Delete
8.26 kB
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import List, Tuple
import hashlib
import os
import abc
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
# Assume that models are already cached
os.environ["HF_HUB_OFFLINE"] = "1"
# Use a deterministic hash function for strings
def string_hash(s: str) -> int:
return int(hashlib.md5(s.encode()).hexdigest(), 16)
class CLIPTextEncoderInterace(abc.ABC):
model: torch.nn.Module
CHANNEL_DIM: int
def __post_init__(self):
self.freeze_encoder()
def freeze_encoder(self):
for params in self.model.parameters():
params.requires_grad = False
@abc.abstractmethod
def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor:
raise NotImplementedError
@torch.inference_mode()
def get_unique_text_embedding(
self,
list_of_texts: List[str] | List[List[str]],
normalize: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get unique embeddings for a list of texts.
Args:
list_of_texts: List[str] | List[List[str]]
List of texts or list of list of texts to get unique embeddings for.
Total number of texts is N.
Returns:
embeddings: torch.Tensor, shape (M, D)
Unique embeddings for the list of texts.
from_unique_indices: torch.Tensor, shape (N,)
Indices of the texts in the original list.
to_unique_indices: torch.Tensor, shape (M,)
Indices of the unique texts in the flattened list.
"""
# Flatten the list of texts
if isinstance(list_of_texts, list) and isinstance(list_of_texts[0], list):
# list of lists
list_of_texts = [item for sublist in list_of_texts for item in sublist]
# cchoy: Get unique texts using hash. Using string directly is not deterministic due to python string object not using the string values only for hashing.
flat_caption_hash = [string_hash(caption) for caption in list_of_texts]
_, to_unique_indices, from_unique_indices = np.unique(
flat_caption_hash, return_index=True, return_inverse=True
)
# Get unique texts
unique_texts = [list_of_texts[i] for i in to_unique_indices]
# Get embeddings
embeddings = self(unique_texts, normalize=normalize)
# Return embeddings and indices
return embeddings, torch.tensor(from_unique_indices), torch.tensor(to_unique_indices)
def get_text_encoder(
model_type: str,
device: str,
**kwargs,
) -> CLIPTextEncoderInterace:
if model_type == "siglip2":
return Siglip2TextEncoder(device=device, **kwargs)
elif model_type == "openclip": # Recap CLIP is also openclip
return OpenCLIPTextEncoder(device=device, **kwargs)
else:
raise ValueError(f"Model type {model_type} not supported")
class OpenCLIPTextEncoder(CLIPTextEncoderInterace):
CHANNEL_DIM = None
def __init__(
self,
model_id: str,
device: str = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
context_length: int = None,
**kwargs,
):
# This is a not a required dependency, so we need to import it here
try:
from open_clip import create_model_from_pretrained, get_tokenizer
except ImportError:
raise ImportError(
"open_clip is not installed. Please install it with `pip install open-clip`"
)
self.prepare_data(model_id)
self.tokenizer = get_tokenizer(model_id)
precision = {torch.float16: "fp16", torch.bfloat16: "bf16"}[torch_dtype]
self.model, _ = create_model_from_pretrained(
model_id,
device=device,
precision=precision,
)
self.device = device
# Set context_length: use provided value, or infer from model, or use default
if context_length is not None:
self.context_length = context_length
elif hasattr(self.model, "context_length"):
self.context_length = self.model.context_length
elif hasattr(self.model, "text") and hasattr(self.model.text, "context_length"):
self.context_length = self.model.text.context_length
else:
# Default to 77 for standard CLIP models
self.context_length = 77
print(
f"Warning: Could not infer context_length from model, using default: {self.context_length}"
)
def prepare_data(self, model_id: str):
from open_clip.factory import download_pretrained_from_hf
# Remove hf-hub: prefix if it exists
model_id = model_id[len("hf-hub:") :] if model_id.startswith("hf-hub:") else model_id
ckpt_path = download_pretrained_from_hf(
model_id, cache_dir=os.environ.get("HF_HUB_CACHE", os.path.expanduser("~/.cache/"))
)
return ckpt_path
@torch.inference_mode()
@torch.amp.autocast(enabled=True, device_type="cuda")
def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor:
text_tokens = self.tokenizer(list_of_texts, context_length=self.context_length).to(
self.device
)
embeddings = self.model.encode_text(text_tokens)
if normalize:
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
return embeddings
class Siglip2TextEncoder(CLIPTextEncoderInterace):
CHANNEL_DIM = 1152
def __init__(
self,
model_id: str = "google/siglip2-so400m-patch16-384",
device: str = "cuda",
attn_implementation: str = "flash_attention_2",
torch_dtype: torch.dtype = torch.bfloat16,
**kwargs,
):
# Disable tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Try loading from local cache first to avoid 429 errors
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
self.model = AutoModel.from_pretrained(
model_id,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
device_map=device,
local_files_only=True,
)
print(f"Successfully loaded {model_id} from local cache.")
except OSError:
print(
f"Model {model_id} not found locally. Downloading/Updating from Hugging Face Hub..."
)
# Fallback to downloading if not found locally
# This might still hit 429 if many ranks try it, but it's the standard fallback.
# Ideally verify downloading on rank 0 only in a multi-node setup if this persists.
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModel.from_pretrained(
model_id,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
device_map=device,
)
self.model.vision_model = None # Remove vision model
self.device = device
@torch.inference_mode()
@torch.amp.autocast(enabled=True, device_type="cuda")
def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor:
# Length is 64 https://huggingface.co/docs/transformers/main/model_doc/siglip2
text_inputs = self.tokenizer(
list_of_texts,
padding="max_length",
truncation=True,
max_length=64,
return_tensors="pt",
).to(self.device)
outputs = self.model.get_text_features(**text_inputs)
# In newer transformers, get_text_features may return a
# BaseModelOutputWithPooling instead of a plain tensor.
if not isinstance(outputs, torch.Tensor):
outputs = outputs.pooler_output
if normalize:
outputs = torch.nn.functional.normalize(outputs, dim=-1)
return outputs