Spaces:
Running
Running
Update nomic_fewshot.py
Browse files- nomic_fewshot.py +258 -81
nomic_fewshot.py
CHANGED
|
@@ -1,147 +1,324 @@
|
|
| 1 |
"""
|
| 2 |
-
Few-shot object classification using Nomic embed-vision-v1.5 + embed-text-v1.5.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
|
|
|
| 7 |
import time
|
| 8 |
from pathlib import Path
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
-
import
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
from PIL import Image
|
| 14 |
-
from
|
| 15 |
-
from transformers import
|
| 16 |
|
| 17 |
from jina_fewshot import CLASS_PROMPTS, IMAGE_EXTS
|
| 18 |
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
-
def
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
-
class
|
| 41 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
def __init__(self, device="cuda"):
|
| 44 |
self.device = device
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
t0 = time.perf_counter()
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
print(f"[*] Loaded in {time.perf_counter() - t0:.1f}s\n")
|
| 61 |
|
| 62 |
def encode_texts(self, texts: list[str]) -> np.ndarray:
|
| 63 |
prefixed = [f"classification: {t}" for t in texts]
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
self.device = device
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
t0 = time.perf_counter()
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
print(f"[*] Loaded in {time.perf_counter() - t0:.1f}s\n")
|
| 95 |
|
| 96 |
-
def encode_images(self, images: list) -> np.ndarray:
|
| 97 |
-
""
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
def build_refs_nomic(
|
| 108 |
-
encoder:
|
| 109 |
refs_dir: Path,
|
| 110 |
batch_size: int = 16,
|
| 111 |
-
text_encoder:
|
| 112 |
text_weight: float = 0.3,
|
| 113 |
):
|
| 114 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
class_dirs = sorted(d for d in refs_dir.iterdir() if d.is_dir())
|
| 116 |
if not class_dirs:
|
| 117 |
raise ValueError(f"No subfolders in {refs_dir}")
|
|
|
|
| 118 |
labels = []
|
| 119 |
embeddings = []
|
|
|
|
| 120 |
if text_encoder is not None:
|
| 121 |
print(f" Text weight: {text_weight:.1f} | Image weight: {1 - text_weight:.1f}\n")
|
|
|
|
| 122 |
for d in class_dirs:
|
| 123 |
name = d.name
|
| 124 |
paths = sorted(str(p) for p in d.iterdir() if p.suffix.lower() in IMAGE_EXTS)
|
| 125 |
if not paths:
|
| 126 |
continue
|
|
|
|
| 127 |
all_embs = []
|
| 128 |
for i in range(0, len(paths), batch_size):
|
| 129 |
-
batch = [Image.open(p).convert("RGB") for p in paths[i
|
| 130 |
all_embs.append(encoder.encode_images(batch))
|
|
|
|
| 131 |
img_embs = np.concatenate(all_embs, axis=0)
|
| 132 |
-
img_avg = img_embs.mean(axis=0)
|
|
|
|
|
|
|
| 133 |
if text_encoder is not None:
|
| 134 |
prompts = CLASS_PROMPTS.get(name, [f"a {name}", f"a person holding a {name}"])
|
| 135 |
text_embs = text_encoder.encode_texts(prompts)
|
| 136 |
-
text_avg = text_embs.mean(axis=0)
|
|
|
|
|
|
|
| 137 |
combined = (1.0 - text_weight) * img_avg + text_weight * text_avg
|
|
|
|
| 138 |
combined = combined / (np.linalg.norm(combined) + 1e-12)
|
|
|
|
| 139 |
labels.append(name)
|
| 140 |
embeddings.append(combined)
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
else:
|
| 143 |
-
img_avg = img_avg / (np.linalg.norm(img_avg) + 1e-12)
|
| 144 |
labels.append(name)
|
| 145 |
embeddings.append(img_avg)
|
| 146 |
print(f" {name:<14}: {len(paths)} imgs")
|
| 147 |
-
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Few-shot object classification using Nomic embed-vision-v1.5 + embed-text-v1.5 via ONNX Runtime.
|
| 3 |
+
Same treatment as current PyTorch version:
|
| 4 |
+
- vision refs -> average image embeddings
|
| 5 |
+
- text prompts -> average text embeddings
|
| 6 |
+
- combine with text_weight
|
| 7 |
|
| 8 |
+
This version uses:
|
| 9 |
+
- nomic-ai/nomic-embed-text-v1.5 -> ONNX
|
| 10 |
+
- nomic-ai/nomic-embed-vision-v1.5 -> ONNX
|
| 11 |
+
|
| 12 |
+
Transformers is used only for preprocessing:
|
| 13 |
+
- AutoTokenizer
|
| 14 |
+
- AutoImageProcessor
|
| 15 |
"""
|
| 16 |
+
|
| 17 |
import time
|
| 18 |
from pathlib import Path
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
+
import onnxruntime as ort
|
|
|
|
| 22 |
from PIL import Image
|
| 23 |
+
from huggingface_hub import hf_hub_download
|
| 24 |
+
from transformers import AutoImageProcessor, AutoTokenizer
|
| 25 |
|
| 26 |
from jina_fewshot import CLASS_PROMPTS, IMAGE_EXTS
|
| 27 |
|
| 28 |
|
| 29 |
+
def _l2_normalize(x: np.ndarray, axis: int = -1, eps: float = 1e-12) -> np.ndarray:
|
| 30 |
+
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
|
| 31 |
+
norms = np.linalg.norm(x, axis=axis, keepdims=True)
|
| 32 |
+
norms = np.maximum(norms, eps)
|
| 33 |
+
return (x / norms).astype(np.float32)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _mean_pool(last_hidden_state: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
|
| 37 |
+
"""
|
| 38 |
+
last_hidden_state: [B, T, D]
|
| 39 |
+
attention_mask: [B, T]
|
| 40 |
+
"""
|
| 41 |
+
mask = attention_mask.astype(np.float32)[..., None] # [B, T, 1]
|
| 42 |
+
summed = np.sum(last_hidden_state * mask, axis=1)
|
| 43 |
+
denom = np.clip(np.sum(mask, axis=1), 1e-9, None)
|
| 44 |
+
return summed / denom
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _pick_output(outputs: list[np.ndarray], output_names: list[str], kind: str) -> np.ndarray:
|
| 48 |
+
"""
|
| 49 |
+
Try to find the main embedding tensor robustly.
|
| 50 |
+
For both text and vision Nomic ONNX exports, we expect a 3D tensor [B, T, D]
|
| 51 |
+
or sometimes a 2D tensor [B, D].
|
| 52 |
+
"""
|
| 53 |
+
# Prefer names that look like hidden states / embeddings
|
| 54 |
+
preferred_keywords = [
|
| 55 |
+
"last_hidden_state",
|
| 56 |
+
"hidden_state",
|
| 57 |
+
"sentence_embedding",
|
| 58 |
+
"embedding",
|
| 59 |
+
"embeddings",
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
for kw in preferred_keywords:
|
| 63 |
+
for i, name in enumerate(output_names):
|
| 64 |
+
if kw in name.lower():
|
| 65 |
+
arr = outputs[i]
|
| 66 |
+
if arr.ndim in (2, 3):
|
| 67 |
+
return arr
|
| 68 |
|
| 69 |
+
# Fallback: first 3D output, then first 2D output
|
| 70 |
+
for arr in outputs:
|
| 71 |
+
if arr.ndim == 3:
|
| 72 |
+
return arr
|
| 73 |
+
for arr in outputs:
|
| 74 |
+
if arr.ndim == 2:
|
| 75 |
+
return arr
|
| 76 |
|
| 77 |
+
raise RuntimeError(
|
| 78 |
+
f"Could not identify a usable {kind} ONNX output. "
|
| 79 |
+
f"Output names={output_names}, shapes={[getattr(o, 'shape', None) for o in outputs]}"
|
| 80 |
+
)
|
| 81 |
|
| 82 |
|
| 83 |
+
def _download_onnx_model(repo_id: str, filename: str = "onnx/model.onnx") -> str:
|
| 84 |
+
print(f" Downloading ONNX model from {repo_id} ...")
|
| 85 |
+
onnx_path = hf_hub_download(
|
| 86 |
+
repo_id=repo_id,
|
| 87 |
+
filename=filename,
|
| 88 |
+
)
|
| 89 |
+
print(f" Downloaded: {onnx_path}")
|
| 90 |
+
return onnx_path
|
| 91 |
|
| 92 |
|
| 93 |
+
class NomicTextEncoderONNX:
|
| 94 |
+
"""
|
| 95 |
+
Nomic embed-text-v1.5 ONNX:
|
| 96 |
+
text -> token embeddings / hidden states -> mean pool -> L2 normalize
|
| 97 |
+
"""
|
| 98 |
|
| 99 |
+
def __init__(self, device: str = "cuda"):
|
| 100 |
self.device = device
|
| 101 |
+
self.repo_id = "nomic-ai/nomic-embed-text-v1.5"
|
| 102 |
+
|
| 103 |
+
print("[*] Loading nomic-embed-text-v1.5 (ONNX)...")
|
| 104 |
t0 = time.perf_counter()
|
| 105 |
+
|
| 106 |
+
onnx_path = _download_onnx_model(self.repo_id)
|
| 107 |
+
|
| 108 |
+
available = ort.get_available_providers()
|
| 109 |
+
if "CUDAExecutionProvider" in available and device == "cuda":
|
| 110 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 111 |
+
else:
|
| 112 |
+
providers = ["CPUExecutionProvider"]
|
| 113 |
+
print(f" ONNX providers: {providers}")
|
| 114 |
+
|
| 115 |
+
self.session = ort.InferenceSession(onnx_path, providers=providers)
|
| 116 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.repo_id, trust_remote_code=True)
|
| 117 |
+
|
| 118 |
+
self.input_names = [inp.name for inp in self.session.get_inputs()]
|
| 119 |
+
self.output_names = [out.name for out in self.session.get_outputs()]
|
| 120 |
+
|
| 121 |
+
print(f" ONNX inputs: {self.input_names}")
|
| 122 |
+
print(f" ONNX outputs: {self.output_names}")
|
| 123 |
+
|
| 124 |
+
self._ids_name = None
|
| 125 |
+
self._mask_name = None
|
| 126 |
+
self._token_type_name = None
|
| 127 |
+
|
| 128 |
+
for name in self.input_names:
|
| 129 |
+
nl = name.lower()
|
| 130 |
+
if nl == "input_ids" or "input_ids" in nl:
|
| 131 |
+
self._ids_name = name
|
| 132 |
+
elif nl == "attention_mask" or "attention" in nl:
|
| 133 |
+
self._mask_name = name
|
| 134 |
+
elif nl == "token_type_ids" or "token_type" in nl:
|
| 135 |
+
self._token_type_name = name
|
| 136 |
+
|
| 137 |
+
print(
|
| 138 |
+
f" Mapped: input_ids={self._ids_name}, "
|
| 139 |
+
f"attention_mask={self._mask_name}, token_type_ids={self._token_type_name}"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Sanity check
|
| 143 |
+
test = self.encode_texts(["a red square"])
|
| 144 |
+
nrm = float(np.linalg.norm(test[0]))
|
| 145 |
+
print(f" [SANITY] text embed norm={nrm:.4f}")
|
| 146 |
print(f"[*] Loaded in {time.perf_counter() - t0:.1f}s\n")
|
| 147 |
|
| 148 |
def encode_texts(self, texts: list[str]) -> np.ndarray:
|
| 149 |
prefixed = [f"classification: {t}" for t in texts]
|
| 150 |
+
tokens = self.tokenizer(
|
| 151 |
+
prefixed,
|
| 152 |
+
padding=True,
|
| 153 |
+
truncation=True,
|
| 154 |
+
return_tensors="np",
|
| 155 |
+
max_length=512,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
input_ids = np.asarray(tokens["input_ids"], dtype=np.int64)
|
| 159 |
+
attention_mask = np.asarray(tokens["attention_mask"], dtype=np.int64)
|
| 160 |
|
| 161 |
+
feeds = {}
|
| 162 |
+
if self._ids_name is not None:
|
| 163 |
+
feeds[self._ids_name] = input_ids
|
| 164 |
+
if self._mask_name is not None:
|
| 165 |
+
feeds[self._mask_name] = attention_mask
|
| 166 |
+
if self._token_type_name is not None:
|
| 167 |
+
feeds[self._token_type_name] = np.zeros_like(input_ids, dtype=np.int64)
|
| 168 |
|
| 169 |
+
outputs = self.session.run(self.output_names, feeds)
|
| 170 |
+
main_out = _pick_output(outputs, self.output_names, kind="text")
|
| 171 |
|
| 172 |
+
# Current PyTorch behavior: mean-pool last_hidden_state
|
| 173 |
+
if main_out.ndim == 3:
|
| 174 |
+
embs = _mean_pool(main_out, attention_mask)
|
| 175 |
+
elif main_out.ndim == 2:
|
| 176 |
+
embs = main_out
|
| 177 |
+
else:
|
| 178 |
+
raise RuntimeError(f"Unexpected text output rank: {main_out.ndim}")
|
| 179 |
+
|
| 180 |
+
return _l2_normalize(embs, axis=1)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class NomicVisionEncoderONNX:
|
| 184 |
+
"""
|
| 185 |
+
Nomic embed-vision-v1.5 ONNX:
|
| 186 |
+
image -> hidden states -> CLS token -> L2 normalize
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(self, device: str = "cuda"):
|
| 190 |
self.device = device
|
| 191 |
+
self.repo_id = "nomic-ai/nomic-embed-vision-v1.5"
|
| 192 |
+
|
| 193 |
+
print("[*] Loading nomic-embed-vision-v1.5 (ONNX)...")
|
| 194 |
t0 = time.perf_counter()
|
| 195 |
+
|
| 196 |
+
onnx_path = _download_onnx_model(self.repo_id)
|
| 197 |
+
|
| 198 |
+
available = ort.get_available_providers()
|
| 199 |
+
if "CUDAExecutionProvider" in available and device == "cuda":
|
| 200 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 201 |
+
else:
|
| 202 |
+
providers = ["CPUExecutionProvider"]
|
| 203 |
+
print(f" ONNX providers: {providers}")
|
| 204 |
+
|
| 205 |
+
self.session = ort.InferenceSession(onnx_path, providers=providers)
|
| 206 |
+
self.processor = AutoImageProcessor.from_pretrained(self.repo_id, trust_remote_code=True)
|
| 207 |
+
|
| 208 |
+
self.input_names = [inp.name for inp in self.session.get_inputs()]
|
| 209 |
+
self.output_names = [out.name for out in self.session.get_outputs()]
|
| 210 |
+
|
| 211 |
+
print(f" ONNX inputs: {self.input_names}")
|
| 212 |
+
print(f" ONNX outputs: {self.output_names}")
|
| 213 |
+
|
| 214 |
+
self._pixel_name = None
|
| 215 |
+
for name in self.input_names:
|
| 216 |
+
if "pixel" in name.lower():
|
| 217 |
+
self._pixel_name = name
|
| 218 |
+
break
|
| 219 |
+
|
| 220 |
+
print(f" Mapped: pixel_values={self._pixel_name}")
|
| 221 |
+
|
| 222 |
+
# Sanity check
|
| 223 |
+
dummy = Image.new("RGB", (224, 224), color=(255, 0, 0))
|
| 224 |
+
test = self.encode_images([dummy])
|
| 225 |
+
nrm = float(np.linalg.norm(test[0]))
|
| 226 |
+
print(f" [SANITY] image embed norm={nrm:.4f}")
|
| 227 |
print(f"[*] Loaded in {time.perf_counter() - t0:.1f}s\n")
|
| 228 |
|
| 229 |
+
def encode_images(self, images: list[Image.Image]) -> np.ndarray:
|
| 230 |
+
rgb = [img.convert("RGB") for img in images]
|
| 231 |
+
processed = self.processor(images=rgb, return_tensors="np")
|
| 232 |
+
|
| 233 |
+
if "pixel_values" not in processed:
|
| 234 |
+
raise RuntimeError(f"Processor did not return pixel_values. Keys={list(processed.keys())}")
|
| 235 |
+
|
| 236 |
+
pixel_values = processed["pixel_values"]
|
| 237 |
+
pixel_values = (
|
| 238 |
+
pixel_values.numpy().astype(np.float32)
|
| 239 |
+
if hasattr(pixel_values, "numpy")
|
| 240 |
+
else np.asarray(pixel_values, dtype=np.float32)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
feeds = {}
|
| 244 |
+
if self._pixel_name is None:
|
| 245 |
+
raise RuntimeError(f"Could not find pixel input in ONNX inputs: {self.input_names}")
|
| 246 |
+
feeds[self._pixel_name] = pixel_values
|
| 247 |
+
|
| 248 |
+
outputs = self.session.run(self.output_names, feeds)
|
| 249 |
+
main_out = _pick_output(outputs, self.output_names, kind="vision")
|
| 250 |
+
|
| 251 |
+
# Current PyTorch behavior: CLS token from last_hidden_state
|
| 252 |
+
if main_out.ndim == 3:
|
| 253 |
+
embs = main_out[:, 0, :]
|
| 254 |
+
elif main_out.ndim == 2:
|
| 255 |
+
embs = main_out
|
| 256 |
+
else:
|
| 257 |
+
raise RuntimeError(f"Unexpected vision output rank: {main_out.ndim}")
|
| 258 |
+
|
| 259 |
+
return _l2_normalize(embs, axis=1)
|
| 260 |
|
| 261 |
|
| 262 |
def build_refs_nomic(
|
| 263 |
+
encoder: NomicVisionEncoderONNX,
|
| 264 |
refs_dir: Path,
|
| 265 |
batch_size: int = 16,
|
| 266 |
+
text_encoder: NomicTextEncoderONNX | None = None,
|
| 267 |
text_weight: float = 0.3,
|
| 268 |
):
|
| 269 |
+
"""
|
| 270 |
+
Build one ref embedding per class.
|
| 271 |
+
Same treatment as Jina:
|
| 272 |
+
- average reference image embeddings
|
| 273 |
+
- average class prompt text embeddings
|
| 274 |
+
- combine with text_weight
|
| 275 |
+
"""
|
| 276 |
class_dirs = sorted(d for d in refs_dir.iterdir() if d.is_dir())
|
| 277 |
if not class_dirs:
|
| 278 |
raise ValueError(f"No subfolders in {refs_dir}")
|
| 279 |
+
|
| 280 |
labels = []
|
| 281 |
embeddings = []
|
| 282 |
+
|
| 283 |
if text_encoder is not None:
|
| 284 |
print(f" Text weight: {text_weight:.1f} | Image weight: {1 - text_weight:.1f}\n")
|
| 285 |
+
|
| 286 |
for d in class_dirs:
|
| 287 |
name = d.name
|
| 288 |
paths = sorted(str(p) for p in d.iterdir() if p.suffix.lower() in IMAGE_EXTS)
|
| 289 |
if not paths:
|
| 290 |
continue
|
| 291 |
+
|
| 292 |
all_embs = []
|
| 293 |
for i in range(0, len(paths), batch_size):
|
| 294 |
+
batch = [Image.open(p).convert("RGB") for p in paths[i:i + batch_size]]
|
| 295 |
all_embs.append(encoder.encode_images(batch))
|
| 296 |
+
|
| 297 |
img_embs = np.concatenate(all_embs, axis=0)
|
| 298 |
+
img_avg = np.nan_to_num(img_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0)
|
| 299 |
+
img_avg = img_avg / (np.linalg.norm(img_avg) + 1e-12)
|
| 300 |
+
|
| 301 |
if text_encoder is not None:
|
| 302 |
prompts = CLASS_PROMPTS.get(name, [f"a {name}", f"a person holding a {name}"])
|
| 303 |
text_embs = text_encoder.encode_texts(prompts)
|
| 304 |
+
text_avg = np.nan_to_num(text_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0)
|
| 305 |
+
text_avg = text_avg / (np.linalg.norm(text_avg) + 1e-12)
|
| 306 |
+
|
| 307 |
combined = (1.0 - text_weight) * img_avg + text_weight * text_avg
|
| 308 |
+
combined = np.nan_to_num(combined, nan=0.0, posinf=0.0, neginf=0.0)
|
| 309 |
combined = combined / (np.linalg.norm(combined) + 1e-12)
|
| 310 |
+
|
| 311 |
labels.append(name)
|
| 312 |
embeddings.append(combined)
|
| 313 |
+
|
| 314 |
+
sim = float(np.dot(img_avg, text_avg))
|
| 315 |
+
print(
|
| 316 |
+
f" {name:<14}: {len(paths)} imgs + {len(prompts)} prompts | "
|
| 317 |
+
f"img-text sim: {sim:.4f}"
|
| 318 |
+
)
|
| 319 |
else:
|
|
|
|
| 320 |
labels.append(name)
|
| 321 |
embeddings.append(img_avg)
|
| 322 |
print(f" {name:<14}: {len(paths)} imgs")
|
| 323 |
+
|
| 324 |
+
return labels, np.stack(embeddings).astype(np.float32)
|