Spaces:
Running
Running
| """ | |
| SigLIP2 zero-shot classifier using ONNX Runtime. | |
| Uses onnx-community/siglip2-large-patch16-256-ONNX (separate vision + text models). | |
| Zero-shot: text prompts only, no reference images needed (folder names used for class labels). | |
| """ | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import onnxruntime as ort | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoProcessor | |
| from jina_fewshot import IMAGE_EXTS | |
| REPO_ID = "onnx-community/siglip2-large-patch16-256-ONNX" | |
| # Use quantized models to save memory; full fp32 text_model is 2.3GB | |
| VISION_ONNX = "onnx/vision_model_quantized.onnx" | |
| TEXT_ONNX = "onnx/text_model_quantized.onnx" | |
| def _download(repo_id, filename): | |
| print(f" Downloading {filename} from {repo_id}...") | |
| path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| print(f" Downloaded: {path}") | |
| return path | |
| def _make_session(onnx_path, device): | |
| available = ort.get_available_providers() | |
| if "CUDAExecutionProvider" in available and device == "cuda": | |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| else: | |
| providers = ["CPUExecutionProvider"] | |
| print(f" ONNX providers: {providers}") | |
| return ort.InferenceSession(onnx_path, providers=providers) | |
| class SigLIP2ONNXClassifier: | |
| """Zero-shot crop classifier using SigLIP2 ONNX (separate vision + text encoders).""" | |
| def __init__(self, device="cuda"): | |
| print("[*] Loading SigLIP2 ONNX (siglip2-large-patch16-256)...") | |
| t0 = time.perf_counter() | |
| self.device = device | |
| # Download and load vision model | |
| vision_path = _download(REPO_ID, VISION_ONNX) | |
| self.vision_session = _make_session(vision_path, device) | |
| # Download and load text model | |
| text_path = _download(REPO_ID, TEXT_ONNX) | |
| self.text_session = _make_session(text_path, device) | |
| # Processor handles both image preprocessing and tokenization | |
| self.processor = AutoProcessor.from_pretrained(REPO_ID, use_fast=False) | |
| # Map I/O names | |
| self._vision_input_names = [i.name for i in self.vision_session.get_inputs()] | |
| self._vision_output_names = [o.name for o in self.vision_session.get_outputs()] | |
| self._text_input_names = [i.name for i in self.text_session.get_inputs()] | |
| self._text_output_names = [o.name for o in self.text_session.get_outputs()] | |
| print(f" Vision inputs: {self._vision_input_names}") | |
| print(f" Vision outputs: {self._vision_output_names}") | |
| print(f" Text inputs: {self._text_input_names}") | |
| print(f" Text outputs: {self._text_output_names}") | |
| self.labels = [] | |
| self._text_embeds = None | |
| # Sanity check | |
| dummy = Image.new("RGB", (256, 256), color=(255, 0, 0)) | |
| v_emb = self._encode_image(dummy) | |
| print(f" [SANITY] vision embed shape={v_emb.shape}, norm={np.linalg.norm(v_emb):.4f}") | |
| t_emb = self._encode_texts(["a red square"]) | |
| print(f" [SANITY] text embed shape={t_emb.shape}, norm={np.linalg.norm(t_emb):.4f}") | |
| print(f"[*] SigLIP2 ONNX loaded in {time.perf_counter() - t0:.1f}s") | |
| def _encode_image(self, image): | |
| """Encode a single PIL image, return [1, D] embedding.""" | |
| processed = self.processor(images=image, return_tensors="np") | |
| pixel_values = processed["pixel_values"].astype(np.float32) | |
| feeds = {} | |
| for name in self._vision_input_names: | |
| if "pixel" in name.lower(): | |
| feeds[name] = pixel_values | |
| outputs = self.vision_session.run(self._vision_output_names, feeds) | |
| # Pick the pooler_output or last_hidden_state[:,0,:] — typically first 2D output | |
| for out in outputs: | |
| if out.ndim == 2: | |
| return out | |
| # Fallback: CLS token from 3D | |
| for out in outputs: | |
| if out.ndim == 3: | |
| return out[:, 0, :] | |
| raise RuntimeError(f"No usable vision output. Shapes: {[o.shape for o in outputs]}") | |
| def _encode_texts(self, texts): | |
| """Encode text strings, return [N, D] embeddings.""" | |
| processed = self.processor(text=texts, return_tensors="np", padding=True, truncation=True) | |
| feeds = {} | |
| for name in self._text_input_names: | |
| nl = name.lower() | |
| if "input_id" in nl and "input_ids" in processed: | |
| feeds[name] = processed["input_ids"].astype(np.int64) | |
| elif ("attention" in nl or "mask" in nl) and "attention_mask" in processed: | |
| feeds[name] = processed["attention_mask"].astype(np.int64) | |
| outputs = self.text_session.run(self._text_output_names, feeds) | |
| # Pick pooler_output (2D) or CLS from 3D | |
| for out in outputs: | |
| if out.ndim == 2: | |
| return out | |
| for out in outputs: | |
| if out.ndim == 3: | |
| return out[:, 0, :] | |
| raise RuntimeError(f"No usable text output. Shapes: {[o.shape for o in outputs]}") | |
| def build_refs(self, refs_dir, **kwargs): | |
| """Extract class names from refs_dir subfolders and precompute text embeddings.""" | |
| refs_dir = Path(refs_dir) | |
| self.labels = sorted(d.name for d in refs_dir.iterdir() if d.is_dir()) | |
| if not self.labels: | |
| raise ValueError(f"No subfolders in {refs_dir}") | |
| self._text_embeds = self._encode_texts(self.labels) | |
| print(f" SigLIP2 ONNX labels: {self.labels}") | |
| print(f" Text embeds shape: {self._text_embeds.shape}") | |
| def classify_crop(self, crop, conf_threshold, gap_threshold): | |
| """ | |
| Classify a single crop image using zero-shot SigLIP2. | |
| Computes image-text similarity via dot product + sigmoid (SigLIP style). | |
| Returns dict matching jina_fewshot.classify() format. | |
| """ | |
| image_emb = self._encode_image(crop) # [1, D] | |
| text_emb = self._text_embeds # [N, D] | |
| # SigLIP2 uses sigmoid on logits (dot product scaled by model) | |
| logits = (image_emb @ text_emb.T).squeeze(0).astype(np.float64) | |
| probs = 1.0 / (1.0 + np.exp(-logits)) # sigmoid | |
| probs = np.nan_to_num(probs, nan=0.0) | |
| sorted_idx = np.argsort(probs)[::-1] | |
| best_idx = sorted_idx[0] | |
| second_idx = sorted_idx[1] | |
| conf = float(probs[best_idx]) | |
| gap = float(probs[best_idx] - probs[second_idx]) | |
| if conf >= conf_threshold: | |
| prediction = self.labels[best_idx] | |
| status = "accepted" | |
| else: | |
| prediction = "unknown" | |
| status = f"rejected: conf {conf:.4f} < {conf_threshold}" | |
| return { | |
| "prediction": prediction, | |
| "raw_prediction": self.labels[best_idx], | |
| "confidence": conf, | |
| "gap": gap, | |
| "second_best": self.labels[second_idx], | |
| "second_conf": float(probs[second_idx]), | |
| "status": status, | |
| "all_sims": {self.labels[j]: float(probs[j]) for j in range(len(self.labels))}, | |
| } | |