Spaces:
Running
Running
Added siglip and siglip2 for classification
Browse files- app.py +22 -5
- dfine_jina_pipeline.py +46 -11
- siglip2_onnx_zeroshot.py +196 -0
- siglip_zeroshot.py +100 -0
app.py
CHANGED
|
@@ -108,8 +108,15 @@ def run_detection(image, model):
|
|
| 108 |
return out_img, det_json
|
| 109 |
|
| 110 |
|
| 111 |
-
|
| 112 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
Returns (group_crop_gallery, known_crop_gallery, status_message).
|
| 114 |
"""
|
| 115 |
if image is None:
|
|
@@ -121,6 +128,7 @@ def run_dfine_classify(image, refs_path, dfine_threshold, dfine_model_choice, mi
|
|
| 121 |
return [], [], f"Refs folder not found: {refs}"
|
| 122 |
|
| 123 |
dfine_model = "large" if dfine_model_choice.strip().lower() == "large" else "medium"
|
|
|
|
| 124 |
group_crops, known_crops, status = run_single_image(
|
| 125 |
image,
|
| 126 |
refs_dir=refs,
|
|
@@ -131,6 +139,7 @@ def run_dfine_classify(image, refs_path, dfine_threshold, dfine_model_choice, mi
|
|
| 131 |
min_side=24,
|
| 132 |
crop_dedup_iou=0.4,
|
| 133 |
min_display_conf=float(min_display_conf),
|
|
|
|
| 134 |
)
|
| 135 |
|
| 136 |
if status is not None:
|
|
@@ -229,10 +238,12 @@ with gr.Blocks(title="Small Object Detection") as app:
|
|
| 229 |
with gr.TabItem("D-FINE + Classify"):
|
| 230 |
|
| 231 |
gr.Markdown(
|
| 232 |
-
"**D-FINE** runs first (person/car grouping), then small-object crops are classified
|
|
|
|
|
|
|
| 233 |
"Choose D-FINE model size (Medium or Large). "
|
| 234 |
"Uses the **refs** folder (one subfolder per class, e.g. refs/phone/, refs/cigarette/) "
|
| 235 |
-
"
|
| 236 |
"**Gap** = how much the top class (e.g. gun) must beat the next-best class (e.g. phone). "
|
| 237 |
"Bigger gap means the model is more sure; we only accept the label if both confidence and gap are high enough."
|
| 238 |
)
|
|
@@ -247,6 +258,12 @@ with gr.Blocks(title="Small Object Detection") as app:
|
|
| 247 |
height=IMG_HEIGHT
|
| 248 |
)
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
dfine_model_radio = gr.Radio(
|
| 251 |
choices=["Medium", "Large"],
|
| 252 |
value="Large",
|
|
@@ -322,7 +339,7 @@ with gr.Blocks(title="Small Object Detection") as app:
|
|
| 322 |
|
| 323 |
btn_dfine.click(
|
| 324 |
fn=run_dfine_classify,
|
| 325 |
-
inputs=[inp_dfine, refs_path, dfine_threshold_slider, dfine_model_radio, threshold_slider, gap_slider],
|
| 326 |
outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
|
| 327 |
concurrency_limit=1,
|
| 328 |
)
|
|
|
|
| 108 |
return out_img, det_json
|
| 109 |
|
| 110 |
|
| 111 |
+
CLASSIFIER_MAP = {
|
| 112 |
+
"Jina-CLIP-v2 (few-shot)": "jina",
|
| 113 |
+
"SigLIP (zero-shot)": "siglip",
|
| 114 |
+
"SigLIP2 ONNX (zero-shot)": "siglip2_onnx",
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def run_dfine_classify(image, refs_path, dfine_threshold, dfine_model_choice, min_display_conf=0.703, gap_threshold=0.005, classifier_choice="Jina-CLIP-v2 (few-shot)"):
|
| 119 |
+
"""Tab 2: D-FINE first, then classify crops.
|
| 120 |
Returns (group_crop_gallery, known_crop_gallery, status_message).
|
| 121 |
"""
|
| 122 |
if image is None:
|
|
|
|
| 128 |
return [], [], f"Refs folder not found: {refs}"
|
| 129 |
|
| 130 |
dfine_model = "large" if dfine_model_choice.strip().lower() == "large" else "medium"
|
| 131 |
+
classifier = CLASSIFIER_MAP.get(classifier_choice, "jina")
|
| 132 |
group_crops, known_crops, status = run_single_image(
|
| 133 |
image,
|
| 134 |
refs_dir=refs,
|
|
|
|
| 139 |
min_side=24,
|
| 140 |
crop_dedup_iou=0.4,
|
| 141 |
min_display_conf=float(min_display_conf),
|
| 142 |
+
classifier=classifier,
|
| 143 |
)
|
| 144 |
|
| 145 |
if status is not None:
|
|
|
|
| 238 |
with gr.TabItem("D-FINE + Classify"):
|
| 239 |
|
| 240 |
gr.Markdown(
|
| 241 |
+
"**D-FINE** runs first (person/car grouping), then small-object crops are classified. "
|
| 242 |
+
"Choose a **classifier**: Jina-CLIP-v2 (few-shot, uses reference images), "
|
| 243 |
+
"SigLIP (zero-shot, PyTorch), or SigLIP2 ONNX (zero-shot, larger model). "
|
| 244 |
"Choose D-FINE model size (Medium or Large). "
|
| 245 |
"Uses the **refs** folder (one subfolder per class, e.g. refs/phone/, refs/cigarette/) "
|
| 246 |
+
"— Jina uses reference images; SigLIP models use only the folder names as class labels.\n\n"
|
| 247 |
"**Gap** = how much the top class (e.g. gun) must beat the next-best class (e.g. phone). "
|
| 248 |
"Bigger gap means the model is more sure; we only accept the label if both confidence and gap are high enough."
|
| 249 |
)
|
|
|
|
| 258 |
height=IMG_HEIGHT
|
| 259 |
)
|
| 260 |
|
| 261 |
+
classifier_radio = gr.Radio(
|
| 262 |
+
choices=list(CLASSIFIER_MAP.keys()),
|
| 263 |
+
value="Jina-CLIP-v2 (few-shot)",
|
| 264 |
+
label="Classifier",
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
dfine_model_radio = gr.Radio(
|
| 268 |
choices=["Medium", "Large"],
|
| 269 |
value="Large",
|
|
|
|
| 339 |
|
| 340 |
btn_dfine.click(
|
| 341 |
fn=run_dfine_classify,
|
| 342 |
+
inputs=[inp_dfine, refs_path, dfine_threshold_slider, dfine_model_radio, threshold_slider, gap_slider, classifier_radio],
|
| 343 |
outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
|
| 344 |
concurrency_limit=1,
|
| 345 |
)
|
dfine_jina_pipeline.py
CHANGED
|
@@ -499,11 +499,46 @@ def main():
|
|
| 499 |
# -----------------------------------------------------------------------------
|
| 500 |
|
| 501 |
_APP_DFINE = None # (model_id, image_processor, dfine_model, person_car_ids)
|
| 502 |
-
|
| 503 |
-
_APP_REFS_JINA = None
|
| 504 |
|
| 505 |
DFINE_MODEL_IDS = {"medium": "ustc-community/dfine-medium-obj365", "large": "ustc-community/dfine-large-obj365"}
|
| 506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
def run_single_image(
|
| 509 |
pil_image,
|
|
@@ -517,6 +552,7 @@ def run_single_image(
|
|
| 517 |
crop_dedup_iou=0.35,
|
| 518 |
squarify=True,
|
| 519 |
min_display_conf=None,
|
|
|
|
| 520 |
):
|
| 521 |
"""
|
| 522 |
Run D-FINE on one image, then classify small-object crops with Jina.
|
|
@@ -571,14 +607,14 @@ def run_single_image(
|
|
| 571 |
grouped.sort(key=lambda x: x["conf"], reverse=True)
|
| 572 |
top_groups = grouped[:10]
|
| 573 |
|
| 574 |
-
# Load Jina
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
|
| 581 |
-
|
| 582 |
|
| 583 |
results_per_crop = []
|
| 584 |
group_crop_images = []
|
|
@@ -660,8 +696,7 @@ def run_single_image(
|
|
| 660 |
if squarify:
|
| 661 |
bx1, by1, bx2, by2 = squarify_crop_box(bx1, by1, bx2, by2, crop_w, crop_h)
|
| 662 |
small_crop = crop_pil.crop((bx1, by1, bx2, by2))
|
| 663 |
-
|
| 664 |
-
result = jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
|
| 665 |
pred = result["prediction"] if result["prediction"] in ref_labels else f"unknown ({d['label']})"
|
| 666 |
conf = result["confidence"]
|
| 667 |
results_per_crop.append((gidx, (bx1, by1, bx2, by2), small_crop, pred, conf))
|
|
|
|
| 499 |
# -----------------------------------------------------------------------------
|
| 500 |
|
| 501 |
_APP_DFINE = None # (model_id, image_processor, dfine_model, person_car_ids)
|
| 502 |
+
_APP_CLASSIFIERS = {} # {classifier_name: (classifier_instance, refs_dir_str)}
|
|
|
|
| 503 |
|
| 504 |
DFINE_MODEL_IDS = {"medium": "ustc-community/dfine-medium-obj365", "large": "ustc-community/dfine-large-obj365"}
|
| 505 |
|
| 506 |
+
CLASSIFIER_CHOICES = ["jina", "siglip", "siglip2_onnx"]
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def _load_classifier(classifier_name, device, refs_dir):
|
| 510 |
+
"""Factory: load and initialize a classifier by name."""
|
| 511 |
+
refs_dir = Path(refs_dir)
|
| 512 |
+
|
| 513 |
+
if classifier_name == "jina":
|
| 514 |
+
jina_encoder = JinaCLIPv2Encoder(device)
|
| 515 |
+
ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
|
| 516 |
+
return ("jina_wrapped", jina_encoder, ref_labels, ref_embs)
|
| 517 |
+
|
| 518 |
+
if classifier_name == "siglip":
|
| 519 |
+
from siglip_zeroshot import SigLIPClassifier
|
| 520 |
+
clf = SigLIPClassifier(device)
|
| 521 |
+
clf.build_refs(refs_dir)
|
| 522 |
+
return clf
|
| 523 |
+
|
| 524 |
+
if classifier_name == "siglip2_onnx":
|
| 525 |
+
from siglip2_onnx_zeroshot import SigLIP2ONNXClassifier
|
| 526 |
+
clf = SigLIP2ONNXClassifier(device)
|
| 527 |
+
clf.build_refs(refs_dir)
|
| 528 |
+
return clf
|
| 529 |
+
|
| 530 |
+
raise ValueError(f"Unknown classifier: {classifier_name}. Choose from {CLASSIFIER_CHOICES}")
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def _classify_crop(classifier, crop, conf_threshold, gap_threshold):
|
| 534 |
+
"""Unified classify call that works for both Jina (tuple) and SigLIP-style classifiers."""
|
| 535 |
+
if isinstance(classifier, tuple) and classifier[0] == "jina_wrapped":
|
| 536 |
+
_, jina_encoder, ref_labels, ref_embs = classifier
|
| 537 |
+
q = jina_encoder.encode_images([crop], TRUNCATE_DIM)
|
| 538 |
+
return jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
|
| 539 |
+
else:
|
| 540 |
+
return classifier.classify_crop(crop, conf_threshold, gap_threshold)
|
| 541 |
+
|
| 542 |
|
| 543 |
def run_single_image(
|
| 544 |
pil_image,
|
|
|
|
| 552 |
crop_dedup_iou=0.35,
|
| 553 |
squarify=True,
|
| 554 |
min_display_conf=None,
|
| 555 |
+
classifier="jina",
|
| 556 |
):
|
| 557 |
"""
|
| 558 |
Run D-FINE on one image, then classify small-object crops with Jina.
|
|
|
|
| 607 |
grouped.sort(key=lambda x: x["conf"], reverse=True)
|
| 608 |
top_groups = grouped[:10]
|
| 609 |
|
| 610 |
+
# Load classifier (Jina, SigLIP, or SigLIP2 ONNX)
|
| 611 |
+
global _APP_CLASSIFIERS
|
| 612 |
+
clf_key = classifier
|
| 613 |
+
if clf_key not in _APP_CLASSIFIERS or _APP_CLASSIFIERS[clf_key][1] != str(refs_dir):
|
| 614 |
+
clf_instance = _load_classifier(classifier, device, refs_dir)
|
| 615 |
+
_APP_CLASSIFIERS[clf_key] = (clf_instance, str(refs_dir))
|
| 616 |
|
| 617 |
+
clf_instance = _APP_CLASSIFIERS[clf_key][0]
|
| 618 |
|
| 619 |
results_per_crop = []
|
| 620 |
group_crop_images = []
|
|
|
|
| 696 |
if squarify:
|
| 697 |
bx1, by1, bx2, by2 = squarify_crop_box(bx1, by1, bx2, by2, crop_w, crop_h)
|
| 698 |
small_crop = crop_pil.crop((bx1, by1, bx2, by2))
|
| 699 |
+
result = _classify_crop(clf_instance, small_crop, conf_threshold, gap_threshold)
|
|
|
|
| 700 |
pred = result["prediction"] if result["prediction"] in ref_labels else f"unknown ({d['label']})"
|
| 701 |
conf = result["confidence"]
|
| 702 |
results_per_crop.append((gidx, (bx1, by1, bx2, by2), small_crop, pred, conf))
|
siglip2_onnx_zeroshot.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SigLIP2 zero-shot classifier using ONNX Runtime.
|
| 3 |
+
Uses onnx-community/siglip2-large-patch16-256-ONNX (separate vision + text models).
|
| 4 |
+
Zero-shot: text prompts only, no reference images needed (folder names used for class labels).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import onnxruntime as ort
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
+
from transformers import AutoProcessor
|
| 15 |
+
|
| 16 |
+
from jina_fewshot import CLASS_PROMPTS, IMAGE_EXTS
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
REPO_ID = "onnx-community/siglip2-large-patch16-256-ONNX"
|
| 20 |
+
# Use quantized models to save memory; full fp32 text_model is 2.3GB
|
| 21 |
+
VISION_ONNX = "onnx/vision_model_quantized.onnx"
|
| 22 |
+
TEXT_ONNX = "onnx/text_model_quantized.onnx"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _download(repo_id, filename):
|
| 26 |
+
print(f" Downloading {filename} from {repo_id}...")
|
| 27 |
+
path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 28 |
+
print(f" Downloaded: {path}")
|
| 29 |
+
return path
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _make_session(onnx_path, device):
|
| 33 |
+
available = ort.get_available_providers()
|
| 34 |
+
if "CUDAExecutionProvider" in available and device == "cuda":
|
| 35 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 36 |
+
else:
|
| 37 |
+
providers = ["CPUExecutionProvider"]
|
| 38 |
+
print(f" ONNX providers: {providers}")
|
| 39 |
+
return ort.InferenceSession(onnx_path, providers=providers)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SigLIP2ONNXClassifier:
|
| 43 |
+
"""Zero-shot crop classifier using SigLIP2 ONNX (separate vision + text encoders)."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, device="cuda"):
|
| 46 |
+
print("[*] Loading SigLIP2 ONNX (siglip2-large-patch16-256)...")
|
| 47 |
+
t0 = time.perf_counter()
|
| 48 |
+
|
| 49 |
+
self.device = device
|
| 50 |
+
|
| 51 |
+
# Download and load vision model
|
| 52 |
+
vision_path = _download(REPO_ID, VISION_ONNX)
|
| 53 |
+
self.vision_session = _make_session(vision_path, device)
|
| 54 |
+
|
| 55 |
+
# Download and load text model
|
| 56 |
+
text_path = _download(REPO_ID, TEXT_ONNX)
|
| 57 |
+
self.text_session = _make_session(text_path, device)
|
| 58 |
+
|
| 59 |
+
# Processor handles both image preprocessing and tokenization
|
| 60 |
+
self.processor = AutoProcessor.from_pretrained(REPO_ID)
|
| 61 |
+
|
| 62 |
+
# Map I/O names
|
| 63 |
+
self._vision_input_names = [i.name for i in self.vision_session.get_inputs()]
|
| 64 |
+
self._vision_output_names = [o.name for o in self.vision_session.get_outputs()]
|
| 65 |
+
self._text_input_names = [i.name for i in self.text_session.get_inputs()]
|
| 66 |
+
self._text_output_names = [o.name for o in self.text_session.get_outputs()]
|
| 67 |
+
|
| 68 |
+
print(f" Vision inputs: {self._vision_input_names}")
|
| 69 |
+
print(f" Vision outputs: {self._vision_output_names}")
|
| 70 |
+
print(f" Text inputs: {self._text_input_names}")
|
| 71 |
+
print(f" Text outputs: {self._text_output_names}")
|
| 72 |
+
|
| 73 |
+
self.labels = []
|
| 74 |
+
self._text_embeds = None
|
| 75 |
+
|
| 76 |
+
# Sanity check
|
| 77 |
+
dummy = Image.new("RGB", (256, 256), color=(255, 0, 0))
|
| 78 |
+
v_emb = self._encode_image(dummy)
|
| 79 |
+
print(f" [SANITY] vision embed shape={v_emb.shape}, norm={np.linalg.norm(v_emb):.4f}")
|
| 80 |
+
|
| 81 |
+
t_emb = self._encode_texts(["a red square"])
|
| 82 |
+
print(f" [SANITY] text embed shape={t_emb.shape}, norm={np.linalg.norm(t_emb):.4f}")
|
| 83 |
+
|
| 84 |
+
print(f"[*] SigLIP2 ONNX loaded in {time.perf_counter() - t0:.1f}s")
|
| 85 |
+
|
| 86 |
+
def _encode_image(self, image):
|
| 87 |
+
"""Encode a single PIL image, return [1, D] embedding."""
|
| 88 |
+
processed = self.processor(images=image, return_tensors="np")
|
| 89 |
+
pixel_values = processed["pixel_values"].astype(np.float32)
|
| 90 |
+
|
| 91 |
+
feeds = {}
|
| 92 |
+
for name in self._vision_input_names:
|
| 93 |
+
if "pixel" in name.lower():
|
| 94 |
+
feeds[name] = pixel_values
|
| 95 |
+
|
| 96 |
+
outputs = self.vision_session.run(self._vision_output_names, feeds)
|
| 97 |
+
|
| 98 |
+
# Pick the pooler_output or last_hidden_state[:,0,:] — typically first 2D output
|
| 99 |
+
for out in outputs:
|
| 100 |
+
if out.ndim == 2:
|
| 101 |
+
return out
|
| 102 |
+
# Fallback: CLS token from 3D
|
| 103 |
+
for out in outputs:
|
| 104 |
+
if out.ndim == 3:
|
| 105 |
+
return out[:, 0, :]
|
| 106 |
+
|
| 107 |
+
raise RuntimeError(f"No usable vision output. Shapes: {[o.shape for o in outputs]}")
|
| 108 |
+
|
| 109 |
+
def _encode_texts(self, texts):
|
| 110 |
+
"""Encode text strings, return [N, D] embeddings."""
|
| 111 |
+
processed = self.processor(text=texts, return_tensors="np", padding=True, truncation=True)
|
| 112 |
+
|
| 113 |
+
feeds = {}
|
| 114 |
+
for name in self._text_input_names:
|
| 115 |
+
nl = name.lower()
|
| 116 |
+
if "input_id" in nl and "input_ids" in processed:
|
| 117 |
+
feeds[name] = processed["input_ids"].astype(np.int64)
|
| 118 |
+
elif ("attention" in nl or "mask" in nl) and "attention_mask" in processed:
|
| 119 |
+
feeds[name] = processed["attention_mask"].astype(np.int64)
|
| 120 |
+
|
| 121 |
+
outputs = self.text_session.run(self._text_output_names, feeds)
|
| 122 |
+
|
| 123 |
+
# Pick pooler_output (2D) or CLS from 3D
|
| 124 |
+
for out in outputs:
|
| 125 |
+
if out.ndim == 2:
|
| 126 |
+
return out
|
| 127 |
+
for out in outputs:
|
| 128 |
+
if out.ndim == 3:
|
| 129 |
+
return out[:, 0, :]
|
| 130 |
+
|
| 131 |
+
raise RuntimeError(f"No usable text output. Shapes: {[o.shape for o in outputs]}")
|
| 132 |
+
|
| 133 |
+
def build_refs(self, refs_dir, **kwargs):
|
| 134 |
+
"""Extract class names from refs_dir subfolders and precompute text embeddings."""
|
| 135 |
+
refs_dir = Path(refs_dir)
|
| 136 |
+
self.labels = sorted(d.name for d in refs_dir.iterdir() if d.is_dir())
|
| 137 |
+
if not self.labels:
|
| 138 |
+
raise ValueError(f"No subfolders in {refs_dir}")
|
| 139 |
+
|
| 140 |
+
text_prompts = []
|
| 141 |
+
for name in self.labels:
|
| 142 |
+
prompts = CLASS_PROMPTS.get(name, [f"a {name}"])
|
| 143 |
+
text_prompts.append(prompts[0])
|
| 144 |
+
|
| 145 |
+
self._text_embeds = self._encode_texts(text_prompts)
|
| 146 |
+
|
| 147 |
+
print(f" SigLIP2 ONNX classes: {self.labels}")
|
| 148 |
+
print(f" Text prompts: {text_prompts}")
|
| 149 |
+
print(f" Text embeds shape: {self._text_embeds.shape}")
|
| 150 |
+
|
| 151 |
+
def classify_crop(self, crop, conf_threshold, gap_threshold):
|
| 152 |
+
"""
|
| 153 |
+
Classify a single crop image using zero-shot SigLIP2.
|
| 154 |
+
Computes image-text similarity via dot product + sigmoid (SigLIP style).
|
| 155 |
+
Returns dict matching jina_fewshot.classify() format.
|
| 156 |
+
"""
|
| 157 |
+
image_emb = self._encode_image(crop) # [1, D]
|
| 158 |
+
text_emb = self._text_embeds # [N, D]
|
| 159 |
+
|
| 160 |
+
# SigLIP2 uses sigmoid on logits (dot product scaled by model)
|
| 161 |
+
logits = (image_emb @ text_emb.T).squeeze(0).astype(np.float64)
|
| 162 |
+
probs = 1.0 / (1.0 + np.exp(-logits)) # sigmoid
|
| 163 |
+
probs = np.nan_to_num(probs, nan=0.0)
|
| 164 |
+
|
| 165 |
+
sorted_idx = np.argsort(probs)[::-1]
|
| 166 |
+
|
| 167 |
+
best_idx = sorted_idx[0]
|
| 168 |
+
second_idx = sorted_idx[1]
|
| 169 |
+
conf = float(probs[best_idx])
|
| 170 |
+
gap = float(probs[best_idx] - probs[second_idx])
|
| 171 |
+
|
| 172 |
+
conf_ok = conf >= conf_threshold
|
| 173 |
+
gap_ok = gap >= gap_threshold
|
| 174 |
+
|
| 175 |
+
if conf_ok and gap_ok:
|
| 176 |
+
prediction = self.labels[best_idx]
|
| 177 |
+
status = "accepted"
|
| 178 |
+
else:
|
| 179 |
+
prediction = "unknown"
|
| 180 |
+
reasons = []
|
| 181 |
+
if not conf_ok:
|
| 182 |
+
reasons.append(f"conf {conf:.4f} < {conf_threshold}")
|
| 183 |
+
if not gap_ok:
|
| 184 |
+
reasons.append(f"gap {gap:.4f} < {gap_threshold}")
|
| 185 |
+
status = "rejected: " + ", ".join(reasons)
|
| 186 |
+
|
| 187 |
+
return {
|
| 188 |
+
"prediction": prediction,
|
| 189 |
+
"raw_prediction": self.labels[best_idx],
|
| 190 |
+
"confidence": conf,
|
| 191 |
+
"gap": gap,
|
| 192 |
+
"second_best": self.labels[second_idx],
|
| 193 |
+
"second_conf": float(probs[second_idx]),
|
| 194 |
+
"status": status,
|
| 195 |
+
"all_sims": {self.labels[j]: float(probs[j]) for j in range(len(self.labels))},
|
| 196 |
+
}
|
siglip_zeroshot.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SigLIP zero-shot classifier for crop classification.
|
| 3 |
+
Uses google/siglip-base-patch16-224 via PyTorch.
|
| 4 |
+
Zero-shot: text prompts only, no reference images needed (folder names used for class labels).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import SiglipModel, AutoProcessor
|
| 13 |
+
|
| 14 |
+
from jina_fewshot import CLASS_PROMPTS
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SigLIPClassifier:
|
| 18 |
+
"""Zero-shot crop classifier using SigLIP (PyTorch)."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, device="cuda"):
|
| 21 |
+
print("[*] Loading SigLIP (google/siglip-base-patch16-224)...")
|
| 22 |
+
t0 = time.perf_counter()
|
| 23 |
+
|
| 24 |
+
self.device = device
|
| 25 |
+
self.model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
|
| 26 |
+
self.model = self.model.to(device).eval()
|
| 27 |
+
self.processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
| 28 |
+
|
| 29 |
+
self.labels = []
|
| 30 |
+
self._text_prompts = []
|
| 31 |
+
|
| 32 |
+
print(f"[*] SigLIP loaded in {time.perf_counter() - t0:.1f}s (device={device})")
|
| 33 |
+
|
| 34 |
+
def build_refs(self, refs_dir, **kwargs):
|
| 35 |
+
"""Extract class names from refs_dir subfolders. No images needed."""
|
| 36 |
+
refs_dir = Path(refs_dir)
|
| 37 |
+
self.labels = sorted(d.name for d in refs_dir.iterdir() if d.is_dir())
|
| 38 |
+
if not self.labels:
|
| 39 |
+
raise ValueError(f"No subfolders in {refs_dir}")
|
| 40 |
+
|
| 41 |
+
# Build one prompt per class (first from CLASS_PROMPTS, fallback to "a {name}")
|
| 42 |
+
self._text_prompts = []
|
| 43 |
+
for name in self.labels:
|
| 44 |
+
prompts = CLASS_PROMPTS.get(name, [f"a {name}"])
|
| 45 |
+
self._text_prompts.append(prompts[0])
|
| 46 |
+
|
| 47 |
+
print(f" SigLIP classes: {self.labels}")
|
| 48 |
+
print(f" Text prompts: {self._text_prompts}")
|
| 49 |
+
|
| 50 |
+
def classify_crop(self, crop, conf_threshold, gap_threshold):
|
| 51 |
+
"""
|
| 52 |
+
Classify a single crop image using zero-shot SigLIP.
|
| 53 |
+
Returns dict matching jina_fewshot.classify() format.
|
| 54 |
+
"""
|
| 55 |
+
inputs = self.processor(
|
| 56 |
+
text=self._text_prompts,
|
| 57 |
+
images=crop,
|
| 58 |
+
return_tensors="pt",
|
| 59 |
+
padding="max_length",
|
| 60 |
+
)
|
| 61 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
outputs = self.model(**inputs)
|
| 65 |
+
logits = outputs.logits_per_image
|
| 66 |
+
probs = torch.sigmoid(logits).cpu().numpy().squeeze(0)
|
| 67 |
+
|
| 68 |
+
probs = np.nan_to_num(probs.astype(np.float64), nan=0.0)
|
| 69 |
+
sorted_idx = np.argsort(probs)[::-1]
|
| 70 |
+
|
| 71 |
+
best_idx = sorted_idx[0]
|
| 72 |
+
second_idx = sorted_idx[1]
|
| 73 |
+
conf = float(probs[best_idx])
|
| 74 |
+
gap = float(probs[best_idx] - probs[second_idx])
|
| 75 |
+
|
| 76 |
+
conf_ok = conf >= conf_threshold
|
| 77 |
+
gap_ok = gap >= gap_threshold
|
| 78 |
+
|
| 79 |
+
if conf_ok and gap_ok:
|
| 80 |
+
prediction = self.labels[best_idx]
|
| 81 |
+
status = "accepted"
|
| 82 |
+
else:
|
| 83 |
+
prediction = "unknown"
|
| 84 |
+
reasons = []
|
| 85 |
+
if not conf_ok:
|
| 86 |
+
reasons.append(f"conf {conf:.4f} < {conf_threshold}")
|
| 87 |
+
if not gap_ok:
|
| 88 |
+
reasons.append(f"gap {gap:.4f} < {gap_threshold}")
|
| 89 |
+
status = "rejected: " + ", ".join(reasons)
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
"prediction": prediction,
|
| 93 |
+
"raw_prediction": self.labels[best_idx],
|
| 94 |
+
"confidence": conf,
|
| 95 |
+
"gap": gap,
|
| 96 |
+
"second_best": self.labels[second_idx],
|
| 97 |
+
"second_conf": float(probs[second_idx]),
|
| 98 |
+
"status": status,
|
| 99 |
+
"all_sims": {self.labels[j]: float(probs[j]) for j in range(len(self.labels))},
|
| 100 |
+
}
|