orik-ss commited on
Commit
81de9b1
·
1 Parent(s): cc88e05

Added siglip and siglip2 for classification

Browse files
Files changed (4) hide show
  1. app.py +22 -5
  2. dfine_jina_pipeline.py +46 -11
  3. siglip2_onnx_zeroshot.py +196 -0
  4. siglip_zeroshot.py +100 -0
app.py CHANGED
@@ -108,8 +108,15 @@ def run_detection(image, model):
108
  return out_img, det_json
109
 
110
 
111
- def run_dfine_classify(image, refs_path, dfine_threshold, dfine_model_choice, min_display_conf=0.703, gap_threshold=0.005):
112
- """Tab 2: D-FINE first, then classify crops with Jina.
 
 
 
 
 
 
 
113
  Returns (group_crop_gallery, known_crop_gallery, status_message).
114
  """
115
  if image is None:
@@ -121,6 +128,7 @@ def run_dfine_classify(image, refs_path, dfine_threshold, dfine_model_choice, mi
121
  return [], [], f"Refs folder not found: {refs}"
122
 
123
  dfine_model = "large" if dfine_model_choice.strip().lower() == "large" else "medium"
 
124
  group_crops, known_crops, status = run_single_image(
125
  image,
126
  refs_dir=refs,
@@ -131,6 +139,7 @@ def run_dfine_classify(image, refs_path, dfine_threshold, dfine_model_choice, mi
131
  min_side=24,
132
  crop_dedup_iou=0.4,
133
  min_display_conf=float(min_display_conf),
 
134
  )
135
 
136
  if status is not None:
@@ -229,10 +238,12 @@ with gr.Blocks(title="Small Object Detection") as app:
229
  with gr.TabItem("D-FINE + Classify"):
230
 
231
  gr.Markdown(
232
- "**D-FINE** runs first (person/car grouping), then small-object crops are classified with **Jina**. "
 
 
233
  "Choose D-FINE model size (Medium or Large). "
234
  "Uses the **refs** folder (one subfolder per class, e.g. refs/phone/, refs/cigarette/) "
235
- "with reference images.\n\n"
236
  "**Gap** = how much the top class (e.g. gun) must beat the next-best class (e.g. phone). "
237
  "Bigger gap means the model is more sure; we only accept the label if both confidence and gap are high enough."
238
  )
@@ -247,6 +258,12 @@ with gr.Blocks(title="Small Object Detection") as app:
247
  height=IMG_HEIGHT
248
  )
249
 
 
 
 
 
 
 
250
  dfine_model_radio = gr.Radio(
251
  choices=["Medium", "Large"],
252
  value="Large",
@@ -322,7 +339,7 @@ with gr.Blocks(title="Small Object Detection") as app:
322
 
323
  btn_dfine.click(
324
  fn=run_dfine_classify,
325
- inputs=[inp_dfine, refs_path, dfine_threshold_slider, dfine_model_radio, threshold_slider, gap_slider],
326
  outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
327
  concurrency_limit=1,
328
  )
 
108
  return out_img, det_json
109
 
110
 
111
+ CLASSIFIER_MAP = {
112
+ "Jina-CLIP-v2 (few-shot)": "jina",
113
+ "SigLIP (zero-shot)": "siglip",
114
+ "SigLIP2 ONNX (zero-shot)": "siglip2_onnx",
115
+ }
116
+
117
+
118
+ def run_dfine_classify(image, refs_path, dfine_threshold, dfine_model_choice, min_display_conf=0.703, gap_threshold=0.005, classifier_choice="Jina-CLIP-v2 (few-shot)"):
119
+ """Tab 2: D-FINE first, then classify crops.
120
  Returns (group_crop_gallery, known_crop_gallery, status_message).
121
  """
122
  if image is None:
 
128
  return [], [], f"Refs folder not found: {refs}"
129
 
130
  dfine_model = "large" if dfine_model_choice.strip().lower() == "large" else "medium"
131
+ classifier = CLASSIFIER_MAP.get(classifier_choice, "jina")
132
  group_crops, known_crops, status = run_single_image(
133
  image,
134
  refs_dir=refs,
 
139
  min_side=24,
140
  crop_dedup_iou=0.4,
141
  min_display_conf=float(min_display_conf),
142
+ classifier=classifier,
143
  )
144
 
145
  if status is not None:
 
238
  with gr.TabItem("D-FINE + Classify"):
239
 
240
  gr.Markdown(
241
+ "**D-FINE** runs first (person/car grouping), then small-object crops are classified. "
242
+ "Choose a **classifier**: Jina-CLIP-v2 (few-shot, uses reference images), "
243
+ "SigLIP (zero-shot, PyTorch), or SigLIP2 ONNX (zero-shot, larger model). "
244
  "Choose D-FINE model size (Medium or Large). "
245
  "Uses the **refs** folder (one subfolder per class, e.g. refs/phone/, refs/cigarette/) "
246
+ " Jina uses reference images; SigLIP models use only the folder names as class labels.\n\n"
247
  "**Gap** = how much the top class (e.g. gun) must beat the next-best class (e.g. phone). "
248
  "Bigger gap means the model is more sure; we only accept the label if both confidence and gap are high enough."
249
  )
 
258
  height=IMG_HEIGHT
259
  )
260
 
261
+ classifier_radio = gr.Radio(
262
+ choices=list(CLASSIFIER_MAP.keys()),
263
+ value="Jina-CLIP-v2 (few-shot)",
264
+ label="Classifier",
265
+ )
266
+
267
  dfine_model_radio = gr.Radio(
268
  choices=["Medium", "Large"],
269
  value="Large",
 
339
 
340
  btn_dfine.click(
341
  fn=run_dfine_classify,
342
+ inputs=[inp_dfine, refs_path, dfine_threshold_slider, dfine_model_radio, threshold_slider, gap_slider, classifier_radio],
343
  outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
344
  concurrency_limit=1,
345
  )
dfine_jina_pipeline.py CHANGED
@@ -499,11 +499,46 @@ def main():
499
  # -----------------------------------------------------------------------------
500
 
501
  _APP_DFINE = None # (model_id, image_processor, dfine_model, person_car_ids)
502
- _APP_JINA = None
503
- _APP_REFS_JINA = None
504
 
505
  DFINE_MODEL_IDS = {"medium": "ustc-community/dfine-medium-obj365", "large": "ustc-community/dfine-large-obj365"}
506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
  def run_single_image(
509
  pil_image,
@@ -517,6 +552,7 @@ def run_single_image(
517
  crop_dedup_iou=0.35,
518
  squarify=True,
519
  min_display_conf=None,
 
520
  ):
521
  """
522
  Run D-FINE on one image, then classify small-object crops with Jina.
@@ -571,14 +607,14 @@ def run_single_image(
571
  grouped.sort(key=lambda x: x["conf"], reverse=True)
572
  top_groups = grouped[:10]
573
 
574
- # Load Jina encoder + refs (needed for classification)
575
- if _APP_JINA is None or _APP_REFS_JINA != str(refs_dir):
576
- jina_encoder = JinaCLIPv2Encoder(device)
577
- ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
578
- _APP_JINA = (jina_encoder, ref_labels, ref_embs)
579
- _APP_REFS_JINA = str(refs_dir)
580
 
581
- jina_encoder, ref_labels, ref_embs = _APP_JINA
582
 
583
  results_per_crop = []
584
  group_crop_images = []
@@ -660,8 +696,7 @@ def run_single_image(
660
  if squarify:
661
  bx1, by1, bx2, by2 = squarify_crop_box(bx1, by1, bx2, by2, crop_w, crop_h)
662
  small_crop = crop_pil.crop((bx1, by1, bx2, by2))
663
- q = jina_encoder.encode_images([small_crop], TRUNCATE_DIM)
664
- result = jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
665
  pred = result["prediction"] if result["prediction"] in ref_labels else f"unknown ({d['label']})"
666
  conf = result["confidence"]
667
  results_per_crop.append((gidx, (bx1, by1, bx2, by2), small_crop, pred, conf))
 
499
  # -----------------------------------------------------------------------------
500
 
501
  _APP_DFINE = None # (model_id, image_processor, dfine_model, person_car_ids)
502
+ _APP_CLASSIFIERS = {} # {classifier_name: (classifier_instance, refs_dir_str)}
 
503
 
504
  DFINE_MODEL_IDS = {"medium": "ustc-community/dfine-medium-obj365", "large": "ustc-community/dfine-large-obj365"}
505
 
506
+ CLASSIFIER_CHOICES = ["jina", "siglip", "siglip2_onnx"]
507
+
508
+
509
+ def _load_classifier(classifier_name, device, refs_dir):
510
+ """Factory: load and initialize a classifier by name."""
511
+ refs_dir = Path(refs_dir)
512
+
513
+ if classifier_name == "jina":
514
+ jina_encoder = JinaCLIPv2Encoder(device)
515
+ ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
516
+ return ("jina_wrapped", jina_encoder, ref_labels, ref_embs)
517
+
518
+ if classifier_name == "siglip":
519
+ from siglip_zeroshot import SigLIPClassifier
520
+ clf = SigLIPClassifier(device)
521
+ clf.build_refs(refs_dir)
522
+ return clf
523
+
524
+ if classifier_name == "siglip2_onnx":
525
+ from siglip2_onnx_zeroshot import SigLIP2ONNXClassifier
526
+ clf = SigLIP2ONNXClassifier(device)
527
+ clf.build_refs(refs_dir)
528
+ return clf
529
+
530
+ raise ValueError(f"Unknown classifier: {classifier_name}. Choose from {CLASSIFIER_CHOICES}")
531
+
532
+
533
+ def _classify_crop(classifier, crop, conf_threshold, gap_threshold):
534
+ """Unified classify call that works for both Jina (tuple) and SigLIP-style classifiers."""
535
+ if isinstance(classifier, tuple) and classifier[0] == "jina_wrapped":
536
+ _, jina_encoder, ref_labels, ref_embs = classifier
537
+ q = jina_encoder.encode_images([crop], TRUNCATE_DIM)
538
+ return jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
539
+ else:
540
+ return classifier.classify_crop(crop, conf_threshold, gap_threshold)
541
+
542
 
543
  def run_single_image(
544
  pil_image,
 
552
  crop_dedup_iou=0.35,
553
  squarify=True,
554
  min_display_conf=None,
555
+ classifier="jina",
556
  ):
557
  """
558
  Run D-FINE on one image, then classify small-object crops with Jina.
 
607
  grouped.sort(key=lambda x: x["conf"], reverse=True)
608
  top_groups = grouped[:10]
609
 
610
+ # Load classifier (Jina, SigLIP, or SigLIP2 ONNX)
611
+ global _APP_CLASSIFIERS
612
+ clf_key = classifier
613
+ if clf_key not in _APP_CLASSIFIERS or _APP_CLASSIFIERS[clf_key][1] != str(refs_dir):
614
+ clf_instance = _load_classifier(classifier, device, refs_dir)
615
+ _APP_CLASSIFIERS[clf_key] = (clf_instance, str(refs_dir))
616
 
617
+ clf_instance = _APP_CLASSIFIERS[clf_key][0]
618
 
619
  results_per_crop = []
620
  group_crop_images = []
 
696
  if squarify:
697
  bx1, by1, bx2, by2 = squarify_crop_box(bx1, by1, bx2, by2, crop_w, crop_h)
698
  small_crop = crop_pil.crop((bx1, by1, bx2, by2))
699
+ result = _classify_crop(clf_instance, small_crop, conf_threshold, gap_threshold)
 
700
  pred = result["prediction"] if result["prediction"] in ref_labels else f"unknown ({d['label']})"
701
  conf = result["confidence"]
702
  results_per_crop.append((gidx, (bx1, by1, bx2, by2), small_crop, pred, conf))
siglip2_onnx_zeroshot.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SigLIP2 zero-shot classifier using ONNX Runtime.
3
+ Uses onnx-community/siglip2-large-patch16-256-ONNX (separate vision + text models).
4
+ Zero-shot: text prompts only, no reference images needed (folder names used for class labels).
5
+ """
6
+
7
+ import time
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import onnxruntime as ort
12
+ from PIL import Image
13
+ from huggingface_hub import hf_hub_download
14
+ from transformers import AutoProcessor
15
+
16
+ from jina_fewshot import CLASS_PROMPTS, IMAGE_EXTS
17
+
18
+
19
+ REPO_ID = "onnx-community/siglip2-large-patch16-256-ONNX"
20
+ # Use quantized models to save memory; full fp32 text_model is 2.3GB
21
+ VISION_ONNX = "onnx/vision_model_quantized.onnx"
22
+ TEXT_ONNX = "onnx/text_model_quantized.onnx"
23
+
24
+
25
+ def _download(repo_id, filename):
26
+ print(f" Downloading {filename} from {repo_id}...")
27
+ path = hf_hub_download(repo_id=repo_id, filename=filename)
28
+ print(f" Downloaded: {path}")
29
+ return path
30
+
31
+
32
+ def _make_session(onnx_path, device):
33
+ available = ort.get_available_providers()
34
+ if "CUDAExecutionProvider" in available and device == "cuda":
35
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
36
+ else:
37
+ providers = ["CPUExecutionProvider"]
38
+ print(f" ONNX providers: {providers}")
39
+ return ort.InferenceSession(onnx_path, providers=providers)
40
+
41
+
42
+ class SigLIP2ONNXClassifier:
43
+ """Zero-shot crop classifier using SigLIP2 ONNX (separate vision + text encoders)."""
44
+
45
+ def __init__(self, device="cuda"):
46
+ print("[*] Loading SigLIP2 ONNX (siglip2-large-patch16-256)...")
47
+ t0 = time.perf_counter()
48
+
49
+ self.device = device
50
+
51
+ # Download and load vision model
52
+ vision_path = _download(REPO_ID, VISION_ONNX)
53
+ self.vision_session = _make_session(vision_path, device)
54
+
55
+ # Download and load text model
56
+ text_path = _download(REPO_ID, TEXT_ONNX)
57
+ self.text_session = _make_session(text_path, device)
58
+
59
+ # Processor handles both image preprocessing and tokenization
60
+ self.processor = AutoProcessor.from_pretrained(REPO_ID)
61
+
62
+ # Map I/O names
63
+ self._vision_input_names = [i.name for i in self.vision_session.get_inputs()]
64
+ self._vision_output_names = [o.name for o in self.vision_session.get_outputs()]
65
+ self._text_input_names = [i.name for i in self.text_session.get_inputs()]
66
+ self._text_output_names = [o.name for o in self.text_session.get_outputs()]
67
+
68
+ print(f" Vision inputs: {self._vision_input_names}")
69
+ print(f" Vision outputs: {self._vision_output_names}")
70
+ print(f" Text inputs: {self._text_input_names}")
71
+ print(f" Text outputs: {self._text_output_names}")
72
+
73
+ self.labels = []
74
+ self._text_embeds = None
75
+
76
+ # Sanity check
77
+ dummy = Image.new("RGB", (256, 256), color=(255, 0, 0))
78
+ v_emb = self._encode_image(dummy)
79
+ print(f" [SANITY] vision embed shape={v_emb.shape}, norm={np.linalg.norm(v_emb):.4f}")
80
+
81
+ t_emb = self._encode_texts(["a red square"])
82
+ print(f" [SANITY] text embed shape={t_emb.shape}, norm={np.linalg.norm(t_emb):.4f}")
83
+
84
+ print(f"[*] SigLIP2 ONNX loaded in {time.perf_counter() - t0:.1f}s")
85
+
86
+ def _encode_image(self, image):
87
+ """Encode a single PIL image, return [1, D] embedding."""
88
+ processed = self.processor(images=image, return_tensors="np")
89
+ pixel_values = processed["pixel_values"].astype(np.float32)
90
+
91
+ feeds = {}
92
+ for name in self._vision_input_names:
93
+ if "pixel" in name.lower():
94
+ feeds[name] = pixel_values
95
+
96
+ outputs = self.vision_session.run(self._vision_output_names, feeds)
97
+
98
+ # Pick the pooler_output or last_hidden_state[:,0,:] — typically first 2D output
99
+ for out in outputs:
100
+ if out.ndim == 2:
101
+ return out
102
+ # Fallback: CLS token from 3D
103
+ for out in outputs:
104
+ if out.ndim == 3:
105
+ return out[:, 0, :]
106
+
107
+ raise RuntimeError(f"No usable vision output. Shapes: {[o.shape for o in outputs]}")
108
+
109
+ def _encode_texts(self, texts):
110
+ """Encode text strings, return [N, D] embeddings."""
111
+ processed = self.processor(text=texts, return_tensors="np", padding=True, truncation=True)
112
+
113
+ feeds = {}
114
+ for name in self._text_input_names:
115
+ nl = name.lower()
116
+ if "input_id" in nl and "input_ids" in processed:
117
+ feeds[name] = processed["input_ids"].astype(np.int64)
118
+ elif ("attention" in nl or "mask" in nl) and "attention_mask" in processed:
119
+ feeds[name] = processed["attention_mask"].astype(np.int64)
120
+
121
+ outputs = self.text_session.run(self._text_output_names, feeds)
122
+
123
+ # Pick pooler_output (2D) or CLS from 3D
124
+ for out in outputs:
125
+ if out.ndim == 2:
126
+ return out
127
+ for out in outputs:
128
+ if out.ndim == 3:
129
+ return out[:, 0, :]
130
+
131
+ raise RuntimeError(f"No usable text output. Shapes: {[o.shape for o in outputs]}")
132
+
133
+ def build_refs(self, refs_dir, **kwargs):
134
+ """Extract class names from refs_dir subfolders and precompute text embeddings."""
135
+ refs_dir = Path(refs_dir)
136
+ self.labels = sorted(d.name for d in refs_dir.iterdir() if d.is_dir())
137
+ if not self.labels:
138
+ raise ValueError(f"No subfolders in {refs_dir}")
139
+
140
+ text_prompts = []
141
+ for name in self.labels:
142
+ prompts = CLASS_PROMPTS.get(name, [f"a {name}"])
143
+ text_prompts.append(prompts[0])
144
+
145
+ self._text_embeds = self._encode_texts(text_prompts)
146
+
147
+ print(f" SigLIP2 ONNX classes: {self.labels}")
148
+ print(f" Text prompts: {text_prompts}")
149
+ print(f" Text embeds shape: {self._text_embeds.shape}")
150
+
151
+ def classify_crop(self, crop, conf_threshold, gap_threshold):
152
+ """
153
+ Classify a single crop image using zero-shot SigLIP2.
154
+ Computes image-text similarity via dot product + sigmoid (SigLIP style).
155
+ Returns dict matching jina_fewshot.classify() format.
156
+ """
157
+ image_emb = self._encode_image(crop) # [1, D]
158
+ text_emb = self._text_embeds # [N, D]
159
+
160
+ # SigLIP2 uses sigmoid on logits (dot product scaled by model)
161
+ logits = (image_emb @ text_emb.T).squeeze(0).astype(np.float64)
162
+ probs = 1.0 / (1.0 + np.exp(-logits)) # sigmoid
163
+ probs = np.nan_to_num(probs, nan=0.0)
164
+
165
+ sorted_idx = np.argsort(probs)[::-1]
166
+
167
+ best_idx = sorted_idx[0]
168
+ second_idx = sorted_idx[1]
169
+ conf = float(probs[best_idx])
170
+ gap = float(probs[best_idx] - probs[second_idx])
171
+
172
+ conf_ok = conf >= conf_threshold
173
+ gap_ok = gap >= gap_threshold
174
+
175
+ if conf_ok and gap_ok:
176
+ prediction = self.labels[best_idx]
177
+ status = "accepted"
178
+ else:
179
+ prediction = "unknown"
180
+ reasons = []
181
+ if not conf_ok:
182
+ reasons.append(f"conf {conf:.4f} < {conf_threshold}")
183
+ if not gap_ok:
184
+ reasons.append(f"gap {gap:.4f} < {gap_threshold}")
185
+ status = "rejected: " + ", ".join(reasons)
186
+
187
+ return {
188
+ "prediction": prediction,
189
+ "raw_prediction": self.labels[best_idx],
190
+ "confidence": conf,
191
+ "gap": gap,
192
+ "second_best": self.labels[second_idx],
193
+ "second_conf": float(probs[second_idx]),
194
+ "status": status,
195
+ "all_sims": {self.labels[j]: float(probs[j]) for j in range(len(self.labels))},
196
+ }
siglip_zeroshot.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SigLIP zero-shot classifier for crop classification.
3
+ Uses google/siglip-base-patch16-224 via PyTorch.
4
+ Zero-shot: text prompts only, no reference images needed (folder names used for class labels).
5
+ """
6
+
7
+ import time
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import torch
12
+ from transformers import SiglipModel, AutoProcessor
13
+
14
+ from jina_fewshot import CLASS_PROMPTS
15
+
16
+
17
+ class SigLIPClassifier:
18
+ """Zero-shot crop classifier using SigLIP (PyTorch)."""
19
+
20
+ def __init__(self, device="cuda"):
21
+ print("[*] Loading SigLIP (google/siglip-base-patch16-224)...")
22
+ t0 = time.perf_counter()
23
+
24
+ self.device = device
25
+ self.model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
26
+ self.model = self.model.to(device).eval()
27
+ self.processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
28
+
29
+ self.labels = []
30
+ self._text_prompts = []
31
+
32
+ print(f"[*] SigLIP loaded in {time.perf_counter() - t0:.1f}s (device={device})")
33
+
34
+ def build_refs(self, refs_dir, **kwargs):
35
+ """Extract class names from refs_dir subfolders. No images needed."""
36
+ refs_dir = Path(refs_dir)
37
+ self.labels = sorted(d.name for d in refs_dir.iterdir() if d.is_dir())
38
+ if not self.labels:
39
+ raise ValueError(f"No subfolders in {refs_dir}")
40
+
41
+ # Build one prompt per class (first from CLASS_PROMPTS, fallback to "a {name}")
42
+ self._text_prompts = []
43
+ for name in self.labels:
44
+ prompts = CLASS_PROMPTS.get(name, [f"a {name}"])
45
+ self._text_prompts.append(prompts[0])
46
+
47
+ print(f" SigLIP classes: {self.labels}")
48
+ print(f" Text prompts: {self._text_prompts}")
49
+
50
+ def classify_crop(self, crop, conf_threshold, gap_threshold):
51
+ """
52
+ Classify a single crop image using zero-shot SigLIP.
53
+ Returns dict matching jina_fewshot.classify() format.
54
+ """
55
+ inputs = self.processor(
56
+ text=self._text_prompts,
57
+ images=crop,
58
+ return_tensors="pt",
59
+ padding="max_length",
60
+ )
61
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
62
+
63
+ with torch.no_grad():
64
+ outputs = self.model(**inputs)
65
+ logits = outputs.logits_per_image
66
+ probs = torch.sigmoid(logits).cpu().numpy().squeeze(0)
67
+
68
+ probs = np.nan_to_num(probs.astype(np.float64), nan=0.0)
69
+ sorted_idx = np.argsort(probs)[::-1]
70
+
71
+ best_idx = sorted_idx[0]
72
+ second_idx = sorted_idx[1]
73
+ conf = float(probs[best_idx])
74
+ gap = float(probs[best_idx] - probs[second_idx])
75
+
76
+ conf_ok = conf >= conf_threshold
77
+ gap_ok = gap >= gap_threshold
78
+
79
+ if conf_ok and gap_ok:
80
+ prediction = self.labels[best_idx]
81
+ status = "accepted"
82
+ else:
83
+ prediction = "unknown"
84
+ reasons = []
85
+ if not conf_ok:
86
+ reasons.append(f"conf {conf:.4f} < {conf_threshold}")
87
+ if not gap_ok:
88
+ reasons.append(f"gap {gap:.4f} < {gap_threshold}")
89
+ status = "rejected: " + ", ".join(reasons)
90
+
91
+ return {
92
+ "prediction": prediction,
93
+ "raw_prediction": self.labels[best_idx],
94
+ "confidence": conf,
95
+ "gap": gap,
96
+ "second_best": self.labels[second_idx],
97
+ "second_conf": float(probs[second_idx]),
98
+ "status": status,
99
+ "all_sims": {self.labels[j]: float(probs[j]) for j in range(len(self.labels))},
100
+ }