""" SigLIP zero-shot classifier for crop classification. Uses google/siglip-base-patch16-224 via PyTorch. Zero-shot: text prompts only, no reference images needed (folder names used for class labels). """ import time from pathlib import Path import numpy as np import torch from transformers import SiglipModel, AutoProcessor SIGLIP_MODELS = { "siglip-224": "google/siglip-base-patch16-224", "siglip-256": "google/siglip-base-patch16-256", "siglip-384": "google/siglip-base-patch16-384", } class SigLIPClassifier: """Zero-shot crop classifier using SigLIP (PyTorch).""" def __init__(self, device="cuda", model_key="siglip-224"): model_id = SIGLIP_MODELS.get(model_key, model_key) print(f"[*] Loading SigLIP ({model_id})...") t0 = time.perf_counter() self.device = device self.model_key = model_key self.model = SiglipModel.from_pretrained(model_id) self.model = self.model.to(device).eval() self.processor = AutoProcessor.from_pretrained(model_id) self.labels = [] print(f"[*] SigLIP loaded in {time.perf_counter() - t0:.1f}s (device={device})") def set_labels(self, labels): """Set class labels directly from a list of strings.""" self.labels = list(labels) if not self.labels: raise ValueError("No labels provided") print(f" SigLIP labels: {self.labels}") def build_refs(self, refs_dir=None, labels=None, **kwargs): """Set labels from a list or extract from refs_dir subfolders.""" if labels: self.set_labels(labels) elif refs_dir: refs_dir = Path(refs_dir) self.set_labels(sorted(d.name for d in refs_dir.iterdir() if d.is_dir())) else: raise ValueError("Provide either labels or refs_dir") def classify_crop(self, crop, conf_threshold, gap_threshold): """ Classify a single crop image using zero-shot SigLIP. Returns dict matching jina_fewshot.classify() format. """ inputs = self.processor( text=self.labels, images=crop, return_tensors="pt", padding="max_length", ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits_per_image probs = torch.sigmoid(logits).cpu().numpy().squeeze(0) probs = np.nan_to_num(probs.astype(np.float64), nan=0.0) sorted_idx = np.argsort(probs)[::-1] best_idx = sorted_idx[0] second_idx = sorted_idx[1] conf = float(probs[best_idx]) gap = float(probs[best_idx] - probs[second_idx]) if conf >= conf_threshold: prediction = self.labels[best_idx] status = "accepted" else: prediction = "unknown" status = f"rejected: conf {conf:.4f} < {conf_threshold}" return { "prediction": prediction, "raw_prediction": self.labels[best_idx], "confidence": conf, "gap": gap, "second_best": self.labels[second_idx], "second_conf": float(probs[second_idx]), "status": status, "all_sims": {self.labels[j]: float(probs[j]) for j in range(len(self.labels))}, }