""" Few-shot object classification using jina-clip-v2 via ONNX Runtime. Bypasses all PyTorch custom code / dtype issues on HF Spaces (T4). Combines IMAGE embeddings from reference photos + TEXT embeddings from class names. Dual threshold: confidence + gap between top-1 and top-2. Usage: python jina_fewshot.py \ --refs refs/ \ --input crops/ \ --output results/ \ --text-weight 0.3 \ --conf-threshold 0.75 \ --gap-threshold 0.05 refs/ folder structure (3-10 images per class recommended): refs/ ├── cigarette/ ├── gun/ ├── knife/ ├── phone/ └── nothing/ (empty hands, random objects) """ import argparse import csv import json import time from pathlib import Path import numpy as np import onnxruntime as ort from PIL import Image, ImageDraw, ImageFont from huggingface_hub import hf_hub_download from transformers import AutoImageProcessor, AutoTokenizer IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff"} TRUNCATE_DIM = 1024 # ONNX model outputs: [text_unnorm, image_unnorm, text_norm, image_norm] _TEXT_NORM_IDX = 2 _IMAGE_NORM_IDX = 3 def draw_label_on_image(img: Image.Image, label: str, confidence: float) -> Image.Image: """Draw the label in a bar outside and on top of the image (full width). Returns new image.""" img = img.convert("RGB") w, h = img.width, img.height text = f"{label} ({confidence:.2f})" margin = 8 max_text_w = max(1, w - 2 * margin) font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" try: font_size = max(10, min(h, w) // 12) font = ImageFont.truetype(font_path, size=font_size) except OSError: font = ImageFont.load_default() font_size = None dummy = Image.new("RGB", (1, 1)) ddraw = ImageDraw.Draw(dummy) bbox = ddraw.textbbox((0, 0), text, font=font) tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] if font_size is not None: while tw > max_text_w and font_size > 8: font_size = max(8, font_size - 2) font = ImageFont.truetype(font_path, size=font_size) bbox = ddraw.textbbox((0, 0), text, font=font) tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] bar_height = th + 2 * margin out = Image.new("RGB", (w, bar_height + h), color=(255, 255, 255)) draw = ImageDraw.Draw(out) draw.rectangle([0, 0, w, bar_height], fill=(0, 0, 0)) x = (w - tw) // 2 y = margin draw.text((x, y), text, fill=(255, 255, 255), font=font) out.paste(img, (0, bar_height)) return out def draw_bboxes_on_image( img: Image.Image, boxes: list[tuple[float, float, float, float, str, float]], ) -> Image.Image: """Draw bboxes and labels (label conf) on image. boxes: list of (x1, y1, x2, y2, label, conf).""" img = img.convert("RGB") draw = ImageDraw.Draw(img) w, h = img.width, img.height font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" try: font = ImageFont.truetype(font_path, size=max(10, min(h, w) // 20)) except OSError: font = ImageFont.load_default() for (x1, y1, x2, y2, label, conf) in boxes: x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) draw.rectangle([x1, y1, x2, y2], outline=(0, 255, 0), width=2) text = f"{label} {conf:.2f}" draw.text((x1, max(0, y1 - 16)), text, fill=(0, 255, 0), font=font) return img CLASS_PROMPTS = { "knife": [ "a knife", "a person holding a knife", "a sharp blade knife", ], "gun": [ "a gun", "a pistol", "a handgun", "a person holding a gun", "a person holding a pistol", "a firearm weapon", ], "cigarette": [ "a cigarette", "a person smoking a cigarette", "a lit cigarette in hand", ], "phone": [ "a phone", "a person holding a smartphone", "a mobile phone cell phone", ], "nothing": [ "a person with empty hands", "a person standing with no objects", "empty hands no weapon", ], } def parse_args(): p = argparse.ArgumentParser(description="Jina-CLIP-v2 few-shot classifier (ONNX)") p.add_argument("--refs", required=True, help="Reference images folder") p.add_argument("--input", required=True, help="Query crop images folder") p.add_argument("--output", default="jinaclip_results", help="Output folder") p.add_argument("--dim", type=int, default=TRUNCATE_DIM, help="Embedding dim (64-1024)") p.add_argument("--text-weight", type=float, default=0.3, help="Text embedding weight (0.0=image only, default 0.3)") p.add_argument("--conf-threshold", type=float, default=0.75, help="Min confidence to accept prediction (default 0.75)") p.add_argument("--gap-threshold", type=float, default=0.05, help="Min gap between top-1 and top-2 (default 0.05)") p.add_argument("--batch-size", type=int, default=16) p.add_argument("--save-refs", action="store_true", help="Save reference embeddings to .npy for fast reload") return p.parse_args() def _download_onnx_model(): """ Download the ONNX model from HF Hub. Try fp32 (model.onnx + model.onnx_data) first. Both files must be in the same directory for ONNX Runtime to find the external data file. """ print(" Downloading ONNX model files from jinaai/jina-clip-v2...") # Download both files — hf_hub_download puts them in the same snapshot dir onnx_path = hf_hub_download( repo_id="jinaai/jina-clip-v2", filename="onnx/model.onnx", ) # External weights file — MUST be downloaded to same directory hf_hub_download( repo_id="jinaai/jina-clip-v2", filename="onnx/model.onnx_data", ) print(f" Downloaded: {onnx_path}") print(f" External data: model.onnx_data (same directory)") return onnx_path class JinaCLIPv2Encoder: """ ONNX Runtime based encoder for jina-clip-v2. Completely bypasses PyTorch — no dtype/NaN issues. """ def __init__(self, device="cuda"): self.device = device print("[*] Loading jina-clip-v2 (ONNX)...") t0 = time.perf_counter() # Download ONNX model (fp32 with external data) onnx_path = _download_onnx_model() # Pick providers: prefer CUDA if available available = ort.get_available_providers() if "CUDAExecutionProvider" in available and device == "cuda": providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] print(f" ONNX providers: {providers}") self.session = ort.InferenceSession(onnx_path, providers=providers) # Load tokenizer and image processor self.tokenizer = AutoTokenizer.from_pretrained( "jinaai/jina-clip-v2", trust_remote_code=True ) self.image_processor = AutoImageProcessor.from_pretrained( "jinaai/jina-clip-v2", trust_remote_code=True ) # Inspect ONNX model I/O self.input_names = [inp.name for inp in self.session.get_inputs()] self.output_names = [out.name for out in self.session.get_outputs()] print(f" ONNX inputs: {self.input_names}") print(f" ONNX outputs: {self.output_names}") # Build input name mapping self._pixel_name = None self._ids_name = None self._mask_name = None for name in self.input_names: nl = name.lower() if "pixel" in nl: self._pixel_name = name elif "input_id" in nl: self._ids_name = name elif "attention" in nl or "mask" in nl: self._mask_name = name print(f" Mapped: pixel={self._pixel_name}, ids={self._ids_name}, mask={self._mask_name}") # Sanity checks _dummy = Image.new("RGB", (512, 512), color=(255, 0, 0)) _test = self.encode_images([_dummy], dim=64) _norm = float(np.linalg.norm(_test)) _is_nan = bool(np.isnan(_norm)) print(f" [SANITY] dummy image embed norm={_norm:.4f}, nan={_is_nan}") if _is_nan or _norm < 0.01: print(" [ERROR] ONNX vision encoder broken!") else: print(" [OK] ONNX vision encoder producing valid embeddings") _test_t = self.encode_texts(["a red square"], dim=64) _tn = float(np.linalg.norm(_test_t)) print(f" [SANITY] dummy text embed norm={_tn:.4f}") elapsed = time.perf_counter() - t0 print(f"[*] Loaded in {elapsed:.1f}s (ONNX, providers={providers})\n") def _run_image(self, pixel_values: np.ndarray) -> np.ndarray: """Run ONNX for images only. Returns normalized image embeddings.""" bs = pixel_values.shape[0] # Dummy text input (minimal tokens) dummy_ids = np.zeros((bs, 1), dtype=np.int64) dummy_mask = np.ones((bs, 1), dtype=np.int64) feeds = {} if self._pixel_name: feeds[self._pixel_name] = pixel_values.astype(np.float32) if self._ids_name: feeds[self._ids_name] = dummy_ids if self._mask_name: feeds[self._mask_name] = dummy_mask outputs = self.session.run(self.output_names, feeds) return outputs[_IMAGE_NORM_IDX] def _run_text(self, input_ids: np.ndarray, attention_mask: np.ndarray) -> np.ndarray: """Run ONNX for text only. Returns normalized text embeddings.""" bs = input_ids.shape[0] # Dummy pixel values (1 pixel image — minimal memory) dummy_pv = np.zeros((bs, 3, 512, 512), dtype=np.float32) feeds = {} if self._pixel_name: feeds[self._pixel_name] = dummy_pv if self._ids_name: feeds[self._ids_name] = input_ids.astype(np.int64) if self._mask_name: feeds[self._mask_name] = attention_mask.astype(np.int64) outputs = self.session.run(self.output_names, feeds) return outputs[_TEXT_NORM_IDX] def encode_images(self, images: list[Image.Image], dim: int = TRUNCATE_DIM) -> np.ndarray: rgb = [img.convert("RGB") for img in images] processed = self.image_processor(rgb, return_tensors="np") pv = processed["pixel_values"] pixel_values = pv.numpy().astype(np.float32) if hasattr(pv, "numpy") else np.asarray(pv, dtype=np.float32) embs = self._run_image(pixel_values) if dim and dim < embs.shape[-1]: embs = embs[:, :dim] embs = np.nan_to_num(embs, nan=0.0, posinf=0.0, neginf=0.0) norms = np.linalg.norm(embs, axis=-1, keepdims=True) norms = np.maximum(norms, 1e-12) return (embs / norms).astype(np.float32) def encode_texts(self, texts: list[str], dim: int = TRUNCATE_DIM) -> np.ndarray: tokens = self.tokenizer( texts, return_tensors="np", padding=True, truncation=True, max_length=512 ) input_ids = tokens["input_ids"].astype(np.int64) attention_mask = tokens["attention_mask"].astype(np.int64) embs = self._run_text(input_ids, attention_mask) if dim and dim < embs.shape[-1]: embs = embs[:, :dim] embs = np.nan_to_num(embs, nan=0.0, posinf=0.0, neginf=0.0) norms = np.linalg.norm(embs, axis=-1, keepdims=True) norms = np.maximum(norms, 1e-12) return (embs / norms).astype(np.float32) def encode_image_paths(self, paths: list[str], dim: int = TRUNCATE_DIM, batch_size: int = 16) -> np.ndarray: all_embs = [] for i in range(0, len(paths), batch_size): batch = [Image.open(p) for p in paths[i:i + batch_size]] all_embs.append(self.encode_images(batch, dim)) return np.concatenate(all_embs, axis=0) def build_refs(encoder: JinaCLIPv2Encoder, refs_dir: Path, dim: int, text_weight: float, batch_size: int): class_dirs = sorted(d for d in refs_dir.iterdir() if d.is_dir()) if not class_dirs: raise ValueError(f"No subfolders in {refs_dir}") labels, embeddings = [], [] print(f" Text weight: {text_weight:.1f} | Image weight: {1 - text_weight:.1f}\n") for d in class_dirs: name = d.name paths = sorted(str(p) for p in d.iterdir() if p.suffix.lower() in IMAGE_EXTS) if not paths: continue img_embs = encoder.encode_image_paths(paths, dim, batch_size) img_avg = np.nan_to_num(img_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0) prompts = CLASS_PROMPTS.get(name, [f"a {name}", f"a person holding a {name}"]) text_embs = encoder.encode_texts(prompts, dim) text_avg = np.nan_to_num(text_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0) combined = (1.0 - text_weight) * img_avg + text_weight * text_avg combined = np.nan_to_num(combined, nan=0.0, posinf=0.0, neginf=0.0) combined = combined / (np.linalg.norm(combined) + 1e-12) labels.append(name) embeddings.append(combined) img_norm = img_avg / (np.linalg.norm(img_avg) + 1e-12) text_norm = text_avg / (np.linalg.norm(text_avg) + 1e-12) sim = float(np.nan_to_num(np.dot(img_norm, text_norm), nan=0.0)) print(f" {name:<14}: {len(paths)} imgs + {len(prompts)} prompts | " f"img-text sim: {sim:.4f}") return labels, np.stack(embeddings) def classify(query_emb: np.ndarray, ref_labels: list[str], ref_embs: np.ndarray, conf_threshold: float, gap_threshold: float) -> dict: sims = (query_emb @ ref_embs.T).squeeze(0) sims = np.nan_to_num(sims.astype(np.float64), nan=0.0, posinf=0.0, neginf=0.0) sorted_idx = np.argsort(sims)[::-1] best_idx = sorted_idx[0] second_idx = sorted_idx[1] conf = float(sims[best_idx]) gap = float(sims[best_idx] - sims[second_idx]) conf_ok = conf >= conf_threshold gap_ok = gap >= gap_threshold if conf_ok and gap_ok: prediction = ref_labels[best_idx] status = "accepted" else: prediction = "unknown" reasons = [] if not conf_ok: reasons.append(f"conf {conf:.4f} < {conf_threshold}") if not gap_ok: reasons.append(f"gap {gap:.4f} < {gap_threshold}") status = "rejected: " + ", ".join(reasons) return { "prediction": prediction, "raw_prediction": ref_labels[best_idx], "confidence": conf, "gap": gap, "second_best": ref_labels[second_idx], "second_conf": float(sims[second_idx]), "status": status, "all_sims": {ref_labels[j]: float(sims[j]) for j in range(len(ref_labels))}, } def main(): args = parse_args() input_dir, output_dir = Path(args.input), Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) paths = sorted(p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTS) if not paths: return print(f"[!] No images in {input_dir}") print(f"[*] {len(paths)} query images") print(f"[*] Conf threshold: {args.conf_threshold} | Gap threshold: {args.gap_threshold}\n") encoder = JinaCLIPv2Encoder("cuda") print("[*] Building references...") ref_labels, ref_embs = build_refs( encoder, Path(args.refs), args.dim, args.text_weight, args.batch_size ) print(f"\n[*] {len(ref_labels)} classes: {ref_labels}\n") if args.save_refs: np.save(output_dir / "ref_embeddings.npy", ref_embs) with open(output_dir / "ref_labels.json", "w") as jf: json.dump(ref_labels, jf) print(f"[*] Saved refs to {output_dir}\n") csv_path = output_dir / "classifications.csv" f = open(csv_path, "w", newline="") w = csv.writer(f) w.writerow(["image", "prediction", "raw_prediction", "confidence", "gap", "second_best", "second_conf", "status"] + [f"sim_{l}" for l in ref_labels] + ["time_ms"]) times = [] counts = {"unknown": 0} for l in ref_labels: counts[l] = 0 accepted, rejected = 0, 0 hdr = " ".join(f"{l:>10}" for l in ref_labels) print(f"{'Image':<30} {'Result':<10} {'Conf':>6} {'Gap':>6} {hdr} {'Status'}") print("=" * (30 + 10 + 14 + len(hdr) + 40)) for p in paths: t0 = time.perf_counter() img = Image.open(p) q = encoder.encode_images([img], args.dim) ms = (time.perf_counter() - t0) * 1000 times.append(ms) result = classify(q, ref_labels, ref_embs, args.conf_threshold, args.gap_threshold) counts[result["prediction"]] += 1 if result["prediction"] != "unknown": accepted += 1 else: rejected += 1 annotated = draw_label_on_image(img, result["prediction"], result["confidence"]) out_path = output_dir / p.name annotated.save(out_path) sim_str = " ".join(f"{result['all_sims'][l]:>10.4f}" for l in ref_labels) print(f"{p.name:<30} {result['prediction']:<10} " f"{result['confidence']:>6.4f} {result['gap']:>6.4f} " f"{sim_str} {result['status']}") w.writerow([ p.name, result["prediction"], result["raw_prediction"], f"{result['confidence']:.4f}", f"{result['gap']:.4f}", result["second_best"], f"{result['second_conf']:.4f}", result["status"], ] + [f"{result['all_sims'][l]:.4f}" for l in ref_labels] + [f"{ms:.1f}"]) f.close() n = len(times) total = sum(times) print(f"\n{'='*70}") print("SUMMARY") print(f"{'='*70}") print(f" Model : jina-clip-v2 (ONNX Runtime, fp32)") print(f" Embed dim : {args.dim}") print(f" Text weight : {args.text_weight}") print(f" Conf threshold : {args.conf_threshold}") print(f" Gap threshold : {args.gap_threshold}") print(f" Images : {n}") if n: print(f" Accepted : {accepted} ({accepted/n*100:.1f}%)") print(f" Rejected : {rejected} ({rejected/n*100:.1f}%)") print(f" ──────────────────────────────────────────") for l in ref_labels + ["unknown"]: c = counts.get(l, 0) pct = (c / n * 100) if n else 0 print(f" {l:<14}: {c:>4} ({pct:.1f}%)") print(f" ──────────────────────────────────────────") if n: print(f" Total : {total:.0f}ms ({total/1000:.2f}s)") print(f" Avg/image : {total/n:.1f}ms") print(f" Throughput : {n/(total/1000):.1f} img/s") print(f" CSV : {csv_path}") print(f" Annotated imgs : {output_dir}") print(f"{'='*70}") if __name__ == "__main__": main()