Napron commited on
Commit
2e3c33d
·
verified ·
1 Parent(s): 09b81c6

Update nomic_fewshot.py

Browse files
Files changed (1) hide show
  1. nomic_fewshot.py +258 -81
nomic_fewshot.py CHANGED
@@ -1,147 +1,324 @@
1
  """
2
- Few-shot object classification using Nomic embed-vision-v1.5 + embed-text-v1.5.
 
 
 
 
3
 
4
- Same treatment as Jina: image refs + text prompts, combined with text_weight (default 0.3).
5
- Used by dfine_jina_pipeline.py and tune_thresholds.py for Nomic crop classification.
 
 
 
 
 
6
  """
 
7
  import time
8
  from pathlib import Path
9
 
10
  import numpy as np
11
- import torch
12
- import torch.nn.functional as F
13
  from PIL import Image
14
- from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
15
- from transformers import modeling_utils
16
 
17
  from jina_fewshot import CLASS_PROMPTS, IMAGE_EXTS
18
 
19
 
20
- def _patch_tied_weights_for_nomic():
21
- """NomicVisionModel has _tied_weights_keys but newer transformers expect all_tied_weights_keys.
22
- Only patch when this method exists (newer transformers); older versions don't need it."""
23
- if not hasattr(modeling_utils.PreTrainedModel, "mark_tied_weights_as_initialized"):
24
- return
25
- _orig = modeling_utils.PreTrainedModel.mark_tied_weights_as_initialized
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- def _patched(self, loading_info):
28
- if not hasattr(self, "all_tied_weights_keys"):
29
- self.all_tied_weights_keys = getattr(self, "_tied_weights_keys", None) or {}
30
- return _orig(self, loading_info)
 
 
 
31
 
32
- modeling_utils.PreTrainedModel.mark_tied_weights_as_initialized = _patched
 
 
 
33
 
34
 
35
- def _nomic_mean_pool(last_hidden_state, attention_mask):
36
- mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
37
- return torch.sum(last_hidden_state * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
 
 
 
 
 
38
 
39
 
40
- class NomicTextEncoder:
41
- """Nomic embed-text-v1.5: text → normalized embedding (aligned to vision space)."""
 
 
 
42
 
43
- def __init__(self, device="cuda"):
44
  self.device = device
45
- print("[*] Loading nomic-embed-text-v1.5...")
 
 
46
  t0 = time.perf_counter()
47
- self.tokenizer = AutoTokenizer.from_pretrained("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
48
- if hasattr(torch, "set_default_device"):
49
- torch.set_default_device("cpu")
50
- try:
51
- self.model = AutoModel.from_pretrained(
52
- "nomic-ai/nomic-embed-text-v1.5",
53
- trust_remote_code=True,
54
- low_cpu_mem_usage=False,
55
- )
56
- finally:
57
- if hasattr(torch, "set_default_device"):
58
- torch.set_default_device("cpu")
59
- self.model = self.model.to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  print(f"[*] Loaded in {time.perf_counter() - t0:.1f}s\n")
61
 
62
  def encode_texts(self, texts: list[str]) -> np.ndarray:
63
  prefixed = [f"classification: {t}" for t in texts]
64
- inputs = self.tokenizer(prefixed, padding=True, truncation=True, return_tensors="pt", max_length=512)
65
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
66
- with torch.no_grad():
67
- out = self.model(**inputs)
68
- embs = _nomic_mean_pool(out.last_hidden_state, inputs["attention_mask"])
69
- embs = F.normalize(embs, p=2, dim=1)
70
- return embs.cpu().float().numpy()
 
 
 
71
 
 
 
 
 
 
 
 
72
 
73
- class NomicVisionEncoder:
74
- """Nomic embed-vision-v1.5: image → normalized CLS embedding."""
75
 
76
- def __init__(self, device="cuda"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  self.device = device
78
- print("[*] Loading nomic-embed-vision-v1.5...")
 
 
79
  t0 = time.perf_counter()
80
- self.processor = AutoImageProcessor.from_pretrained("nomic-ai/nomic-embed-vision-v1.5")
81
- _patch_tied_weights_for_nomic()
82
- if hasattr(torch, "set_default_device"):
83
- torch.set_default_device("cpu")
84
- try:
85
- self.model = AutoModel.from_pretrained(
86
- "nomic-ai/nomic-embed-vision-v1.5",
87
- trust_remote_code=True,
88
- low_cpu_mem_usage=False,
89
- )
90
- finally:
91
- if hasattr(torch, "set_default_device"):
92
- torch.set_default_device("cpu")
93
- self.model = self.model.to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  print(f"[*] Loaded in {time.perf_counter() - t0:.1f}s\n")
95
 
96
- def encode_images(self, images: list) -> np.ndarray:
97
- """Encode images to L2-normalized embeddings (CLS token)."""
98
- inputs = self.processor(images=images, return_tensors="pt")
99
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
100
- with torch.no_grad():
101
- out = self.model(**inputs).last_hidden_state
102
- # CLS token, then normalize
103
- embs = F.normalize(out[:, 0], p=2, dim=1)
104
- return embs.cpu().float().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
 
107
  def build_refs_nomic(
108
- encoder: NomicVisionEncoder,
109
  refs_dir: Path,
110
  batch_size: int = 16,
111
- text_encoder: NomicTextEncoder | None = None,
112
  text_weight: float = 0.3,
113
  ):
114
- """Build one ref embedding per class. Same treatment as Jina: image refs + text prompts, combined with text_weight (default 0.3)."""
 
 
 
 
 
 
115
  class_dirs = sorted(d for d in refs_dir.iterdir() if d.is_dir())
116
  if not class_dirs:
117
  raise ValueError(f"No subfolders in {refs_dir}")
 
118
  labels = []
119
  embeddings = []
 
120
  if text_encoder is not None:
121
  print(f" Text weight: {text_weight:.1f} | Image weight: {1 - text_weight:.1f}\n")
 
122
  for d in class_dirs:
123
  name = d.name
124
  paths = sorted(str(p) for p in d.iterdir() if p.suffix.lower() in IMAGE_EXTS)
125
  if not paths:
126
  continue
 
127
  all_embs = []
128
  for i in range(0, len(paths), batch_size):
129
- batch = [Image.open(p).convert("RGB") for p in paths[i : i + batch_size]]
130
  all_embs.append(encoder.encode_images(batch))
 
131
  img_embs = np.concatenate(all_embs, axis=0)
132
- img_avg = img_embs.mean(axis=0)
 
 
133
  if text_encoder is not None:
134
  prompts = CLASS_PROMPTS.get(name, [f"a {name}", f"a person holding a {name}"])
135
  text_embs = text_encoder.encode_texts(prompts)
136
- text_avg = text_embs.mean(axis=0)
 
 
137
  combined = (1.0 - text_weight) * img_avg + text_weight * text_avg
 
138
  combined = combined / (np.linalg.norm(combined) + 1e-12)
 
139
  labels.append(name)
140
  embeddings.append(combined)
141
- print(f" {name:<14}: {len(paths)} imgs + {len(prompts)} prompts")
 
 
 
 
 
142
  else:
143
- img_avg = img_avg / (np.linalg.norm(img_avg) + 1e-12)
144
  labels.append(name)
145
  embeddings.append(img_avg)
146
  print(f" {name:<14}: {len(paths)} imgs")
147
- return labels, np.stack(embeddings)
 
 
1
  """
2
+ Few-shot object classification using Nomic embed-vision-v1.5 + embed-text-v1.5 via ONNX Runtime.
3
+ Same treatment as current PyTorch version:
4
+ - vision refs -> average image embeddings
5
+ - text prompts -> average text embeddings
6
+ - combine with text_weight
7
 
8
+ This version uses:
9
+ - nomic-ai/nomic-embed-text-v1.5 -> ONNX
10
+ - nomic-ai/nomic-embed-vision-v1.5 -> ONNX
11
+
12
+ Transformers is used only for preprocessing:
13
+ - AutoTokenizer
14
+ - AutoImageProcessor
15
  """
16
+
17
  import time
18
  from pathlib import Path
19
 
20
  import numpy as np
21
+ import onnxruntime as ort
 
22
  from PIL import Image
23
+ from huggingface_hub import hf_hub_download
24
+ from transformers import AutoImageProcessor, AutoTokenizer
25
 
26
  from jina_fewshot import CLASS_PROMPTS, IMAGE_EXTS
27
 
28
 
29
+ def _l2_normalize(x: np.ndarray, axis: int = -1, eps: float = 1e-12) -> np.ndarray:
30
+ x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
31
+ norms = np.linalg.norm(x, axis=axis, keepdims=True)
32
+ norms = np.maximum(norms, eps)
33
+ return (x / norms).astype(np.float32)
34
+
35
+
36
+ def _mean_pool(last_hidden_state: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
37
+ """
38
+ last_hidden_state: [B, T, D]
39
+ attention_mask: [B, T]
40
+ """
41
+ mask = attention_mask.astype(np.float32)[..., None] # [B, T, 1]
42
+ summed = np.sum(last_hidden_state * mask, axis=1)
43
+ denom = np.clip(np.sum(mask, axis=1), 1e-9, None)
44
+ return summed / denom
45
+
46
+
47
+ def _pick_output(outputs: list[np.ndarray], output_names: list[str], kind: str) -> np.ndarray:
48
+ """
49
+ Try to find the main embedding tensor robustly.
50
+ For both text and vision Nomic ONNX exports, we expect a 3D tensor [B, T, D]
51
+ or sometimes a 2D tensor [B, D].
52
+ """
53
+ # Prefer names that look like hidden states / embeddings
54
+ preferred_keywords = [
55
+ "last_hidden_state",
56
+ "hidden_state",
57
+ "sentence_embedding",
58
+ "embedding",
59
+ "embeddings",
60
+ ]
61
+
62
+ for kw in preferred_keywords:
63
+ for i, name in enumerate(output_names):
64
+ if kw in name.lower():
65
+ arr = outputs[i]
66
+ if arr.ndim in (2, 3):
67
+ return arr
68
 
69
+ # Fallback: first 3D output, then first 2D output
70
+ for arr in outputs:
71
+ if arr.ndim == 3:
72
+ return arr
73
+ for arr in outputs:
74
+ if arr.ndim == 2:
75
+ return arr
76
 
77
+ raise RuntimeError(
78
+ f"Could not identify a usable {kind} ONNX output. "
79
+ f"Output names={output_names}, shapes={[getattr(o, 'shape', None) for o in outputs]}"
80
+ )
81
 
82
 
83
+ def _download_onnx_model(repo_id: str, filename: str = "onnx/model.onnx") -> str:
84
+ print(f" Downloading ONNX model from {repo_id} ...")
85
+ onnx_path = hf_hub_download(
86
+ repo_id=repo_id,
87
+ filename=filename,
88
+ )
89
+ print(f" Downloaded: {onnx_path}")
90
+ return onnx_path
91
 
92
 
93
+ class NomicTextEncoderONNX:
94
+ """
95
+ Nomic embed-text-v1.5 ONNX:
96
+ text -> token embeddings / hidden states -> mean pool -> L2 normalize
97
+ """
98
 
99
+ def __init__(self, device: str = "cuda"):
100
  self.device = device
101
+ self.repo_id = "nomic-ai/nomic-embed-text-v1.5"
102
+
103
+ print("[*] Loading nomic-embed-text-v1.5 (ONNX)...")
104
  t0 = time.perf_counter()
105
+
106
+ onnx_path = _download_onnx_model(self.repo_id)
107
+
108
+ available = ort.get_available_providers()
109
+ if "CUDAExecutionProvider" in available and device == "cuda":
110
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
111
+ else:
112
+ providers = ["CPUExecutionProvider"]
113
+ print(f" ONNX providers: {providers}")
114
+
115
+ self.session = ort.InferenceSession(onnx_path, providers=providers)
116
+ self.tokenizer = AutoTokenizer.from_pretrained(self.repo_id, trust_remote_code=True)
117
+
118
+ self.input_names = [inp.name for inp in self.session.get_inputs()]
119
+ self.output_names = [out.name for out in self.session.get_outputs()]
120
+
121
+ print(f" ONNX inputs: {self.input_names}")
122
+ print(f" ONNX outputs: {self.output_names}")
123
+
124
+ self._ids_name = None
125
+ self._mask_name = None
126
+ self._token_type_name = None
127
+
128
+ for name in self.input_names:
129
+ nl = name.lower()
130
+ if nl == "input_ids" or "input_ids" in nl:
131
+ self._ids_name = name
132
+ elif nl == "attention_mask" or "attention" in nl:
133
+ self._mask_name = name
134
+ elif nl == "token_type_ids" or "token_type" in nl:
135
+ self._token_type_name = name
136
+
137
+ print(
138
+ f" Mapped: input_ids={self._ids_name}, "
139
+ f"attention_mask={self._mask_name}, token_type_ids={self._token_type_name}"
140
+ )
141
+
142
+ # Sanity check
143
+ test = self.encode_texts(["a red square"])
144
+ nrm = float(np.linalg.norm(test[0]))
145
+ print(f" [SANITY] text embed norm={nrm:.4f}")
146
  print(f"[*] Loaded in {time.perf_counter() - t0:.1f}s\n")
147
 
148
  def encode_texts(self, texts: list[str]) -> np.ndarray:
149
  prefixed = [f"classification: {t}" for t in texts]
150
+ tokens = self.tokenizer(
151
+ prefixed,
152
+ padding=True,
153
+ truncation=True,
154
+ return_tensors="np",
155
+ max_length=512,
156
+ )
157
+
158
+ input_ids = np.asarray(tokens["input_ids"], dtype=np.int64)
159
+ attention_mask = np.asarray(tokens["attention_mask"], dtype=np.int64)
160
 
161
+ feeds = {}
162
+ if self._ids_name is not None:
163
+ feeds[self._ids_name] = input_ids
164
+ if self._mask_name is not None:
165
+ feeds[self._mask_name] = attention_mask
166
+ if self._token_type_name is not None:
167
+ feeds[self._token_type_name] = np.zeros_like(input_ids, dtype=np.int64)
168
 
169
+ outputs = self.session.run(self.output_names, feeds)
170
+ main_out = _pick_output(outputs, self.output_names, kind="text")
171
 
172
+ # Current PyTorch behavior: mean-pool last_hidden_state
173
+ if main_out.ndim == 3:
174
+ embs = _mean_pool(main_out, attention_mask)
175
+ elif main_out.ndim == 2:
176
+ embs = main_out
177
+ else:
178
+ raise RuntimeError(f"Unexpected text output rank: {main_out.ndim}")
179
+
180
+ return _l2_normalize(embs, axis=1)
181
+
182
+
183
+ class NomicVisionEncoderONNX:
184
+ """
185
+ Nomic embed-vision-v1.5 ONNX:
186
+ image -> hidden states -> CLS token -> L2 normalize
187
+ """
188
+
189
+ def __init__(self, device: str = "cuda"):
190
  self.device = device
191
+ self.repo_id = "nomic-ai/nomic-embed-vision-v1.5"
192
+
193
+ print("[*] Loading nomic-embed-vision-v1.5 (ONNX)...")
194
  t0 = time.perf_counter()
195
+
196
+ onnx_path = _download_onnx_model(self.repo_id)
197
+
198
+ available = ort.get_available_providers()
199
+ if "CUDAExecutionProvider" in available and device == "cuda":
200
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
201
+ else:
202
+ providers = ["CPUExecutionProvider"]
203
+ print(f" ONNX providers: {providers}")
204
+
205
+ self.session = ort.InferenceSession(onnx_path, providers=providers)
206
+ self.processor = AutoImageProcessor.from_pretrained(self.repo_id, trust_remote_code=True)
207
+
208
+ self.input_names = [inp.name for inp in self.session.get_inputs()]
209
+ self.output_names = [out.name for out in self.session.get_outputs()]
210
+
211
+ print(f" ONNX inputs: {self.input_names}")
212
+ print(f" ONNX outputs: {self.output_names}")
213
+
214
+ self._pixel_name = None
215
+ for name in self.input_names:
216
+ if "pixel" in name.lower():
217
+ self._pixel_name = name
218
+ break
219
+
220
+ print(f" Mapped: pixel_values={self._pixel_name}")
221
+
222
+ # Sanity check
223
+ dummy = Image.new("RGB", (224, 224), color=(255, 0, 0))
224
+ test = self.encode_images([dummy])
225
+ nrm = float(np.linalg.norm(test[0]))
226
+ print(f" [SANITY] image embed norm={nrm:.4f}")
227
  print(f"[*] Loaded in {time.perf_counter() - t0:.1f}s\n")
228
 
229
+ def encode_images(self, images: list[Image.Image]) -> np.ndarray:
230
+ rgb = [img.convert("RGB") for img in images]
231
+ processed = self.processor(images=rgb, return_tensors="np")
232
+
233
+ if "pixel_values" not in processed:
234
+ raise RuntimeError(f"Processor did not return pixel_values. Keys={list(processed.keys())}")
235
+
236
+ pixel_values = processed["pixel_values"]
237
+ pixel_values = (
238
+ pixel_values.numpy().astype(np.float32)
239
+ if hasattr(pixel_values, "numpy")
240
+ else np.asarray(pixel_values, dtype=np.float32)
241
+ )
242
+
243
+ feeds = {}
244
+ if self._pixel_name is None:
245
+ raise RuntimeError(f"Could not find pixel input in ONNX inputs: {self.input_names}")
246
+ feeds[self._pixel_name] = pixel_values
247
+
248
+ outputs = self.session.run(self.output_names, feeds)
249
+ main_out = _pick_output(outputs, self.output_names, kind="vision")
250
+
251
+ # Current PyTorch behavior: CLS token from last_hidden_state
252
+ if main_out.ndim == 3:
253
+ embs = main_out[:, 0, :]
254
+ elif main_out.ndim == 2:
255
+ embs = main_out
256
+ else:
257
+ raise RuntimeError(f"Unexpected vision output rank: {main_out.ndim}")
258
+
259
+ return _l2_normalize(embs, axis=1)
260
 
261
 
262
  def build_refs_nomic(
263
+ encoder: NomicVisionEncoderONNX,
264
  refs_dir: Path,
265
  batch_size: int = 16,
266
+ text_encoder: NomicTextEncoderONNX | None = None,
267
  text_weight: float = 0.3,
268
  ):
269
+ """
270
+ Build one ref embedding per class.
271
+ Same treatment as Jina:
272
+ - average reference image embeddings
273
+ - average class prompt text embeddings
274
+ - combine with text_weight
275
+ """
276
  class_dirs = sorted(d for d in refs_dir.iterdir() if d.is_dir())
277
  if not class_dirs:
278
  raise ValueError(f"No subfolders in {refs_dir}")
279
+
280
  labels = []
281
  embeddings = []
282
+
283
  if text_encoder is not None:
284
  print(f" Text weight: {text_weight:.1f} | Image weight: {1 - text_weight:.1f}\n")
285
+
286
  for d in class_dirs:
287
  name = d.name
288
  paths = sorted(str(p) for p in d.iterdir() if p.suffix.lower() in IMAGE_EXTS)
289
  if not paths:
290
  continue
291
+
292
  all_embs = []
293
  for i in range(0, len(paths), batch_size):
294
+ batch = [Image.open(p).convert("RGB") for p in paths[i:i + batch_size]]
295
  all_embs.append(encoder.encode_images(batch))
296
+
297
  img_embs = np.concatenate(all_embs, axis=0)
298
+ img_avg = np.nan_to_num(img_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0)
299
+ img_avg = img_avg / (np.linalg.norm(img_avg) + 1e-12)
300
+
301
  if text_encoder is not None:
302
  prompts = CLASS_PROMPTS.get(name, [f"a {name}", f"a person holding a {name}"])
303
  text_embs = text_encoder.encode_texts(prompts)
304
+ text_avg = np.nan_to_num(text_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0)
305
+ text_avg = text_avg / (np.linalg.norm(text_avg) + 1e-12)
306
+
307
  combined = (1.0 - text_weight) * img_avg + text_weight * text_avg
308
+ combined = np.nan_to_num(combined, nan=0.0, posinf=0.0, neginf=0.0)
309
  combined = combined / (np.linalg.norm(combined) + 1e-12)
310
+
311
  labels.append(name)
312
  embeddings.append(combined)
313
+
314
+ sim = float(np.dot(img_avg, text_avg))
315
+ print(
316
+ f" {name:<14}: {len(paths)} imgs + {len(prompts)} prompts | "
317
+ f"img-text sim: {sim:.4f}"
318
+ )
319
  else:
 
320
  labels.append(name)
321
  embeddings.append(img_avg)
322
  print(f" {name:<14}: {len(paths)} imgs")
323
+
324
+ return labels, np.stack(embeddings).astype(np.float32)