orik-ss commited on
Commit
80cacd4
·
1 Parent(s): 2455309

Removed object detection tab

Browse files
Files changed (3) hide show
  1. app.py +88 -269
  2. dfine_jina_pipeline.py +26 -18
  3. siglip_zeroshot.py +14 -6
app.py CHANGED
@@ -1,131 +1,32 @@
1
- """ Gradio app: Tab 1 = Object Detection (YOLO models/v1), Tab 2 = D-FINE + SigLIP Classify. """
2
 
3
  import os
4
- os.environ["YOLO_CONFIG_DIR"] = os.environ.get("YOLO_CONFIG_DIR", "/tmp")
5
-
6
- import json
7
- import numpy as np
8
  import gradio as gr
9
- from ultralytics import YOLO
10
  from pathlib import Path
11
 
12
- # Tab 2: D-FINE runs first, then SigLIP for crop classification
13
  from dfine_jina_pipeline import run_single_image
14
 
15
-
16
- # --- Object Detection (Tab 1) ---
17
-
18
- PERSON_CLASS = 0
19
- CAR_CLASS = 2
20
- KNIFE_CLASS = 80
21
- WEAPON_CLASS = 81
22
-
23
- DRAW_CLASSES = [PERSON_CLASS, CAR_CLASS, KNIFE_CLASS, WEAPON_CLASS]
24
-
25
- CLASS_NAMES = {
26
- PERSON_CLASS: "person",
27
- CAR_CLASS: "car",
28
- KNIFE_CLASS: "knife",
29
- WEAPON_CLASS: "weapon",
30
- }
31
-
32
- CONF = 0.25
33
- IMGSZ = 640
34
-
35
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
36
- MODELS_DIR = os.path.join(BASE_DIR, "models")
37
- REFS_DIR = os.path.join(BASE_DIR, "refs")
38
-
39
-
40
- def _load_model(version: str):
41
- path = os.path.join(MODELS_DIR, version, "best.pt")
42
- if not os.path.isfile(path):
43
- raise FileNotFoundError(f"Model not found: {path}")
44
- return YOLO(path)
45
-
46
-
47
- MODELS = {"v1": _load_model("v1")}
48
-
49
- MODEL_CLASSES = {
50
- "v1": ["person", "car", "knife", "weapon"]
51
- }
52
-
53
-
54
- def run_detection(image, model):
55
-
56
- if image is None:
57
- return None, "{}"
58
-
59
- img = image if isinstance(image, np.ndarray) else np.array(image)
60
-
61
- if img.ndim == 2:
62
- img = np.stack([img] * 3, axis=-1)
63
-
64
- results = model.predict(
65
- source=img,
66
- imgsz=IMGSZ,
67
- conf=CONF,
68
- device="cpu",
69
- verbose=False,
70
- )
71
-
72
- r = results[0]
73
-
74
- if r.boxes is None or len(r.boxes) == 0:
75
- return image, json.dumps({"detections": []}, indent=2)
76
-
77
- clss = r.boxes.cls.cpu().numpy()
78
- confs = r.boxes.conf.cpu().numpy()
79
-
80
- keep = [
81
- i for i in range(len(r.boxes))
82
- if int(clss[i]) in DRAW_CLASSES
83
- ]
84
 
85
- if not keep:
86
- return image, json.dumps({"detections": []}, indent=2)
87
 
88
- detections = []
89
 
90
- for i in keep:
91
- cls_id = int(clss[i])
92
-
93
- detections.append({
94
- "class": CLASS_NAMES.get(cls_id, str(cls_id)),
95
- "confidence": round(float(confs[i]), 3),
96
- "bbox": r.boxes.xyxy[i].cpu().numpy().tolist(),
97
- })
98
-
99
- r.boxes = r.boxes[keep]
100
-
101
- out_img = r.plot()
102
-
103
- det_json = json.dumps(
104
- {"detections": detections},
105
- indent=2
106
- )
107
-
108
- return out_img, det_json
109
-
110
-
111
- def run_dfine_classify(image, refs_path, dfine_threshold, dfine_model_choice, siglip_threshold):
112
- """Tab 2: D-FINE first, then classify crops with SigLIP.
113
  Returns (group_crop_gallery, known_crop_gallery, status_message).
114
  """
115
  if image is None:
116
  return [], [], "Upload an image."
117
 
118
- refs = Path(refs_path.strip()) if refs_path and refs_path.strip() else Path(REFS_DIR)
119
-
120
- if not refs.is_dir():
121
- return [], [], f"Refs folder not found: {refs}"
122
 
123
- dfine_model = dfine_model_choice.strip().lower() if dfine_model_choice else "large-obj365"
124
  conf_thresh = float(siglip_threshold)
125
 
126
  group_crops, known_crops, status = run_single_image(
127
  image,
128
- refs_dir=refs,
129
  dfine_model=dfine_model,
130
  det_threshold=float(dfine_threshold),
131
  conf_threshold=conf_thresh,
@@ -134,6 +35,7 @@ def run_dfine_classify(image, refs_path, dfine_threshold, dfine_model_choice, si
134
  crop_dedup_iou=0.4,
135
  min_display_conf=conf_thresh,
136
  classifier="siglip",
 
137
  )
138
 
139
  return [(g, None) for g in (group_crops or [])], [(k, None) for k in (known_crops or [])], status or ""
@@ -142,187 +44,104 @@ def run_dfine_classify(image, refs_path, dfine_threshold, dfine_model_choice, si
142
  IMG_HEIGHT = 400
143
 
144
 
145
- TAB_STYLE = """
146
- <style>
147
-
148
- [data-testid="tabs"] > div:first-child,
149
- .gr-tabs > div:first-child,
150
- div[class*="tabs"] > div:first-child {
151
- display: flex !important;
152
- width: 100% !important;
153
- }
154
-
155
- [data-testid="tabs"] button,
156
- .gr-tabs button,
157
- div[class*="tabs"] > div:first-child button {
158
- flex: 1 !important;
159
- min-width: 0 !important;
160
- min-height: 40px !important;
161
- color: white !important;
162
- font-weight: 700 !important;
163
- font-size: 1rem !important;
164
- text-align: center !important;
165
- justify-content: center !important;
166
- }
167
-
168
- [data-testid="tabs"] button:not([aria-selected="true"]),
169
- .gr-tabs button:not([aria-selected="true"]),
170
- div[class*="tabs"] > div:first-child button:not([aria-selected="true"]) {
171
- background: #6b7280 !important;
172
- border-color: #6b7280 !important;
173
- }
174
-
175
- [data-testid="tabs"] button[aria-selected="true"],
176
- .gr-tabs button[aria-selected="true"],
177
- div[class*="tabs"] > div:first-child button[aria-selected="true"] {
178
- background: var(--primary-500, #f97316) !important;
179
- border-color: var(--primary-500, #f97316) !important;
180
- }
181
-
182
- </style>
183
- """
184
-
185
-
186
  with gr.Blocks(title="Small Object Detection") as app:
187
 
188
- gr.HTML(TAB_STYLE)
189
-
190
  gr.Markdown("# Small Object Detection")
191
 
192
- with gr.Tabs():
 
 
 
193
 
194
- with gr.TabItem("Object Detection"):
195
 
196
- gr.Markdown(
197
- "**Classes:** " + ", ".join(MODEL_CLASSES["v1"])
198
- )
199
 
200
- with gr.Row():
 
 
 
 
201
 
202
- with gr.Column(scale=1):
 
 
 
 
 
 
 
 
203
 
204
- inp_det = gr.Image(
205
- label="Input image",
206
- height=IMG_HEIGHT
207
- )
 
 
 
208
 
209
- btn_det = gr.Button(
210
- "Detect",
211
- variant="primary"
212
- )
 
 
 
 
 
 
 
 
213
 
214
- out_img_det = gr.Image(
215
- label="Output",
216
- height=IMG_HEIGHT
217
- )
 
 
 
218
 
219
- det_output = gr.JSON(
220
- label="Detections"
221
- )
 
 
222
 
223
- btn_det.click(
224
- fn=lambda img: run_detection(img, MODELS["v1"]),
225
- inputs=inp_det,
226
- outputs=[out_img_det, det_output],
227
  )
228
 
229
- with gr.TabItem("D-FINE + Classify"):
 
 
 
 
 
 
 
230
 
231
- gr.Markdown(
232
- "**D-FINE** runs first (person/car grouping), then small-object crops are classified with **SigLIP** (zero-shot). "
233
- "Choose a D-FINE model (obj365, coco, or obj2coco variants in small/medium/large). "
234
- "Uses the **refs** folder names as class labels (e.g. refs/phone/, refs/cigarette/)."
 
235
  )
236
 
237
- with gr.Row():
238
-
239
- with gr.Column(scale=1):
240
-
241
- inp_dfine = gr.Image(
242
- type="pil",
243
- label="Input image",
244
- height=IMG_HEIGHT
245
- )
246
-
247
- dfine_model_radio = gr.Dropdown(
248
- choices=[
249
- "small-obj365", "medium-obj365", "large-obj365",
250
- "small-coco", "medium-coco", "large-coco",
251
- "small-obj2coco", "medium-obj2coco", "large-obj2coco",
252
- ],
253
- value="large-obj365",
254
- label="D-FINE model",
255
- )
256
-
257
- dfine_threshold_slider = gr.Slider(
258
- minimum=0.05,
259
- maximum=0.5,
260
- value=0.2,
261
- step=0.05,
262
- label="D-FINE detection threshold (applied to chosen model)",
263
- )
264
-
265
- def update_dfine_threshold_default(choice):
266
- if not choice:
267
- return gr.update(value=0.15)
268
- size = choice.strip().lower().split("-")[0]
269
- defaults = {"large": 0.2, "medium": 0.15, "small": 0.1}
270
- return gr.update(value=defaults.get(size, 0.15))
271
-
272
- dfine_model_radio.change(
273
- fn=update_dfine_threshold_default,
274
- inputs=[dfine_model_radio],
275
- outputs=[dfine_threshold_slider],
276
- )
277
-
278
- siglip_threshold_slider = gr.Slider(
279
- minimum=0.001,
280
- maximum=0.1,
281
- value=0.01,
282
- step=0.001,
283
- label="SigLIP: min confidence threshold",
284
- )
285
-
286
- refs_path = gr.Textbox(
287
- label="Refs folder path",
288
- value=REFS_DIR,
289
- placeholder="e.g. refs or /path/to/refs",
290
- )
291
-
292
- btn_dfine = gr.Button(
293
- "Run D-FINE + Classify",
294
- variant="primary"
295
- )
296
-
297
- with gr.Column(scale=1):
298
-
299
- out_gallery_dfine = gr.Gallery(
300
- label="Person/car crops (all D-FINE objects inside drawn with label + score)",
301
- height=IMG_HEIGHT,
302
- columns=2,
303
- object_fit="contain",
304
- )
305
-
306
- out_gallery_known = gr.Gallery(
307
- label="Known objects (class + score above each crop)",
308
- height=IMG_HEIGHT,
309
- columns=4,
310
- object_fit="contain",
311
- )
312
-
313
- out_status_dfine = gr.Textbox(
314
- label="Classification details",
315
- lines=8,
316
- interactive=False,
317
- )
318
-
319
- btn_dfine.click(
320
- fn=run_dfine_classify,
321
- inputs=[inp_dfine, refs_path, dfine_threshold_slider, dfine_model_radio, siglip_threshold_slider],
322
- outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
323
- concurrency_limit=1,
324
  )
325
 
 
 
 
 
 
 
 
326
 
327
  app.launch(
328
  server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
@@ -332,4 +151,4 @@ app.launch(
332
  os.environ.get("GRADIO_SERVER_PORT", 7860)
333
  )
334
  ),
335
- )
 
1
+ """ Gradio app: D-FINE + SigLIP Classify. """
2
 
3
  import os
 
 
 
 
4
  import gradio as gr
 
5
  from pathlib import Path
6
 
 
7
  from dfine_jina_pipeline import run_single_image
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ DEFAULT_LABELS = "gun, knife, cigarette, phone"
 
12
 
 
13
 
14
+ def run_dfine_classify(image, dfine_threshold, dfine_model_choice, siglip_threshold, labels_text):
15
+ """D-FINE first, then classify crops with SigLIP.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  Returns (group_crop_gallery, known_crop_gallery, status_message).
17
  """
18
  if image is None:
19
  return [], [], "Upload an image."
20
 
21
+ labels = [l.strip() for l in labels_text.split(",") if l.strip()]
22
+ if not labels:
23
+ return [], [], "Enter at least one label."
 
24
 
25
+ dfine_model = dfine_model_choice.strip().lower() if dfine_model_choice else "medium-obj2coco"
26
  conf_thresh = float(siglip_threshold)
27
 
28
  group_crops, known_crops, status = run_single_image(
29
  image,
 
30
  dfine_model=dfine_model,
31
  det_threshold=float(dfine_threshold),
32
  conf_threshold=conf_thresh,
 
35
  crop_dedup_iou=0.4,
36
  min_display_conf=conf_thresh,
37
  classifier="siglip",
38
+ labels=labels,
39
  )
40
 
41
  return [(g, None) for g in (group_crops or [])], [(k, None) for k in (known_crops or [])], status or ""
 
44
  IMG_HEIGHT = 400
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  with gr.Blocks(title="Small Object Detection") as app:
48
 
 
 
49
  gr.Markdown("# Small Object Detection")
50
 
51
+ gr.Markdown(
52
+ "**D-FINE** detects persons/cars, then small-object crops are classified with **SigLIP** (zero-shot). "
53
+ "Choose a D-FINE model and enter comma-separated class labels for SigLIP."
54
+ )
55
 
56
+ with gr.Row():
57
 
58
+ with gr.Column(scale=1):
 
 
59
 
60
+ inp_dfine = gr.Image(
61
+ type="pil",
62
+ label="Input image",
63
+ height=IMG_HEIGHT
64
+ )
65
 
66
+ dfine_model_radio = gr.Dropdown(
67
+ choices=[
68
+ "small-obj365", "medium-obj365", "large-obj365",
69
+ "small-coco", "medium-coco", "large-coco",
70
+ "small-obj2coco", "medium-obj2coco", "large-obj2coco",
71
+ ],
72
+ value="medium-obj2coco",
73
+ label="D-FINE model",
74
+ )
75
 
76
+ dfine_threshold_slider = gr.Slider(
77
+ minimum=0.05,
78
+ maximum=0.5,
79
+ value=0.15,
80
+ step=0.05,
81
+ label="D-FINE detection threshold",
82
+ )
83
 
84
+ def update_dfine_threshold_default(choice):
85
+ if not choice:
86
+ return gr.update(value=0.15)
87
+ size = choice.strip().lower().split("-")[0]
88
+ defaults = {"large": 0.2, "medium": 0.15, "small": 0.1}
89
+ return gr.update(value=defaults.get(size, 0.15))
90
+
91
+ dfine_model_radio.change(
92
+ fn=update_dfine_threshold_default,
93
+ inputs=[dfine_model_radio],
94
+ outputs=[dfine_threshold_slider],
95
+ )
96
 
97
+ siglip_threshold_slider = gr.Slider(
98
+ minimum=0.001,
99
+ maximum=0.1,
100
+ value=0.005,
101
+ step=0.001,
102
+ label="SigLIP: min confidence threshold",
103
+ )
104
 
105
+ labels_input = gr.Textbox(
106
+ label="Labels (comma-separated)",
107
+ value=DEFAULT_LABELS,
108
+ placeholder="e.g. gun, knife, cigarette, phone",
109
+ )
110
 
111
+ btn_dfine = gr.Button(
112
+ "Run D-FINE + Classify",
113
+ variant="primary"
 
114
  )
115
 
116
+ with gr.Column(scale=1):
117
+
118
+ out_gallery_dfine = gr.Gallery(
119
+ label="Person/car crops (all D-FINE objects inside drawn with label + score)",
120
+ height=IMG_HEIGHT,
121
+ columns=2,
122
+ object_fit="contain",
123
+ )
124
 
125
+ out_gallery_known = gr.Gallery(
126
+ label="Known objects (class + score above each crop)",
127
+ height=IMG_HEIGHT,
128
+ columns=4,
129
+ object_fit="contain",
130
  )
131
 
132
+ out_status_dfine = gr.Textbox(
133
+ label="Classification details",
134
+ lines=8,
135
+ interactive=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
 
138
+ btn_dfine.click(
139
+ fn=run_dfine_classify,
140
+ inputs=[inp_dfine, dfine_threshold_slider, dfine_model_radio, siglip_threshold_slider, labels_input],
141
+ outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
142
+ concurrency_limit=1,
143
+ )
144
+
145
 
146
  app.launch(
147
  server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
 
151
  os.environ.get("GRADIO_SERVER_PORT", 7860)
152
  )
153
  ),
154
+ )
dfine_jina_pipeline.py CHANGED
@@ -519,9 +519,10 @@ DFINE_MODEL_IDS = {
519
  CLASSIFIER_CHOICES = ["jina", "siglip", "siglip2_onnx"]
520
 
521
 
522
- def _load_classifier(classifier_name, device, refs_dir):
523
  """Factory: load and initialize a classifier by name."""
524
- refs_dir = Path(refs_dir)
 
525
 
526
  if classifier_name == "jina":
527
  jina_encoder = JinaCLIPv2Encoder(device)
@@ -531,13 +532,13 @@ def _load_classifier(classifier_name, device, refs_dir):
531
  if classifier_name == "siglip":
532
  from siglip_zeroshot import SigLIPClassifier
533
  clf = SigLIPClassifier(device)
534
- clf.build_refs(refs_dir)
535
  return clf
536
 
537
  if classifier_name == "siglip2_onnx":
538
  from siglip2_onnx_zeroshot import SigLIP2ONNXClassifier
539
  clf = SigLIP2ONNXClassifier(device)
540
- clf.build_refs(refs_dir)
541
  return clf
542
 
543
  raise ValueError(f"Unknown classifier: {classifier_name}. Choose from {CLASSIFIER_CHOICES}")
@@ -555,7 +556,7 @@ def _classify_crop(classifier, crop, conf_threshold, gap_threshold):
555
 
556
  def run_single_image(
557
  pil_image,
558
- refs_dir,
559
  device=None,
560
  dfine_model="large",
561
  det_threshold=0.3,
@@ -565,13 +566,15 @@ def run_single_image(
565
  crop_dedup_iou=0.35,
566
  squarify=True,
567
  min_display_conf=None,
568
- classifier="jina",
 
569
  ):
570
  """
571
- Run D-FINE on one image, then classify small-object crops with Jina.
572
 
573
- refs_dir: path to refs folder (str or Path).
574
- dfine_model: "medium" or "large".
 
575
 
576
  Returns (group_crop_images, known_crop_composites, status_message).
577
  """
@@ -581,11 +584,15 @@ def run_single_image(
581
  min_display_conf = MIN_DISPLAY_CONF
582
  from PIL import Image
583
 
584
- global _APP_DFINE, _APP_JINA, _APP_REFS_JINA
585
 
586
- refs_dir = Path(refs_dir)
587
- if not refs_dir.is_dir():
588
- return [], [], f"Refs folder not found: {refs_dir}"
 
 
 
 
589
 
590
  dfine_model = (dfine_model or "large-obj365").strip().lower()
591
  if dfine_model not in DFINE_MODEL_IDS:
@@ -620,12 +627,13 @@ def run_single_image(
620
  grouped.sort(key=lambda x: x["conf"], reverse=True)
621
  top_groups = grouped[:10]
622
 
623
- # Load classifier (Jina, SigLIP, or SigLIP2 ONNX)
624
  global _APP_CLASSIFIERS
 
625
  clf_key = classifier
626
- if clf_key not in _APP_CLASSIFIERS or _APP_CLASSIFIERS[clf_key][1] != str(refs_dir):
627
- clf_instance = _load_classifier(classifier, device, refs_dir)
628
- _APP_CLASSIFIERS[clf_key] = (clf_instance, str(refs_dir))
629
 
630
  clf_instance = _APP_CLASSIFIERS[clf_key][0]
631
 
@@ -745,7 +753,7 @@ def run_single_image(
745
  # Build known-only gallery: only objects with conf >= min_display_conf
746
  known_crop_composites = []
747
  for (_gidx, _box, crop_pil, pred, conf) in results_per_crop:
748
- if pred not in KNOWN_DISPLAY_CLASSES or conf < min_display_conf:
749
  continue
750
  composite = draw_label_on_image(crop_pil, pred, conf)
751
  known_crop_composites.append(np.array(composite))
 
519
  CLASSIFIER_CHOICES = ["jina", "siglip", "siglip2_onnx"]
520
 
521
 
522
+ def _load_classifier(classifier_name, device, refs_dir=None, labels=None):
523
  """Factory: load and initialize a classifier by name."""
524
+ if refs_dir:
525
+ refs_dir = Path(refs_dir)
526
 
527
  if classifier_name == "jina":
528
  jina_encoder = JinaCLIPv2Encoder(device)
 
532
  if classifier_name == "siglip":
533
  from siglip_zeroshot import SigLIPClassifier
534
  clf = SigLIPClassifier(device)
535
+ clf.build_refs(refs_dir=refs_dir, labels=labels)
536
  return clf
537
 
538
  if classifier_name == "siglip2_onnx":
539
  from siglip2_onnx_zeroshot import SigLIP2ONNXClassifier
540
  clf = SigLIP2ONNXClassifier(device)
541
+ clf.build_refs(refs_dir=refs_dir, labels=labels)
542
  return clf
543
 
544
  raise ValueError(f"Unknown classifier: {classifier_name}. Choose from {CLASSIFIER_CHOICES}")
 
556
 
557
  def run_single_image(
558
  pil_image,
559
+ refs_dir=None,
560
  device=None,
561
  dfine_model="large",
562
  det_threshold=0.3,
 
566
  crop_dedup_iou=0.35,
567
  squarify=True,
568
  min_display_conf=None,
569
+ classifier="siglip",
570
+ labels=None,
571
  ):
572
  """
573
+ Run D-FINE on one image, then classify small-object crops.
574
 
575
+ refs_dir: path to refs folder (str or Path), optional if labels provided.
576
+ labels: list of class label strings for zero-shot classifiers.
577
+ dfine_model: key from DFINE_MODEL_IDS.
578
 
579
  Returns (group_crop_images, known_crop_composites, status_message).
580
  """
 
584
  min_display_conf = MIN_DISPLAY_CONF
585
  from PIL import Image
586
 
587
+ global _APP_DFINE
588
 
589
+ if refs_dir:
590
+ refs_dir = Path(refs_dir)
591
+ if not refs_dir.is_dir():
592
+ return [], [], f"Refs folder not found: {refs_dir}"
593
+
594
+ if not refs_dir and not labels:
595
+ return [], [], "Provide either refs_dir or labels."
596
 
597
  dfine_model = (dfine_model or "large-obj365").strip().lower()
598
  if dfine_model not in DFINE_MODEL_IDS:
 
627
  grouped.sort(key=lambda x: x["conf"], reverse=True)
628
  top_groups = grouped[:10]
629
 
630
+ # Load classifier
631
  global _APP_CLASSIFIERS
632
+ cache_key = str(labels) if labels else str(refs_dir)
633
  clf_key = classifier
634
+ if clf_key not in _APP_CLASSIFIERS or _APP_CLASSIFIERS[clf_key][1] != cache_key:
635
+ clf_instance = _load_classifier(classifier, device, refs_dir=refs_dir, labels=labels)
636
+ _APP_CLASSIFIERS[clf_key] = (clf_instance, cache_key)
637
 
638
  clf_instance = _APP_CLASSIFIERS[clf_key][0]
639
 
 
753
  # Build known-only gallery: only objects with conf >= min_display_conf
754
  known_crop_composites = []
755
  for (_gidx, _box, crop_pil, pred, conf) in results_per_crop:
756
+ if pred.startswith("unknown") or conf < min_display_conf:
757
  continue
758
  composite = draw_label_on_image(crop_pil, pred, conf)
759
  known_crop_composites.append(np.array(composite))
siglip_zeroshot.py CHANGED
@@ -27,15 +27,23 @@ class SigLIPClassifier:
27
 
28
  print(f"[*] SigLIP loaded in {time.perf_counter() - t0:.1f}s (device={device})")
29
 
30
- def build_refs(self, refs_dir, **kwargs):
31
- """Extract class names from refs_dir subfolders as plain labels."""
32
- refs_dir = Path(refs_dir)
33
- self.labels = sorted(d.name for d in refs_dir.iterdir() if d.is_dir())
34
  if not self.labels:
35
- raise ValueError(f"No subfolders in {refs_dir}")
36
-
37
  print(f" SigLIP labels: {self.labels}")
38
 
 
 
 
 
 
 
 
 
 
 
39
  def classify_crop(self, crop, conf_threshold, gap_threshold):
40
  """
41
  Classify a single crop image using zero-shot SigLIP.
 
27
 
28
  print(f"[*] SigLIP loaded in {time.perf_counter() - t0:.1f}s (device={device})")
29
 
30
+ def set_labels(self, labels):
31
+ """Set class labels directly from a list of strings."""
32
+ self.labels = list(labels)
 
33
  if not self.labels:
34
+ raise ValueError("No labels provided")
 
35
  print(f" SigLIP labels: {self.labels}")
36
 
37
+ def build_refs(self, refs_dir=None, labels=None, **kwargs):
38
+ """Set labels from a list or extract from refs_dir subfolders."""
39
+ if labels:
40
+ self.set_labels(labels)
41
+ elif refs_dir:
42
+ refs_dir = Path(refs_dir)
43
+ self.set_labels(sorted(d.name for d in refs_dir.iterdir() if d.is_dir()))
44
+ else:
45
+ raise ValueError("Provide either labels or refs_dir")
46
+
47
  def classify_crop(self, crop, conf_threshold, gap_threshold):
48
  """
49
  Classify a single crop image using zero-shot SigLIP.