Spaces:
Running
Running
| """ Pipeline: D-FINE (person/car only) → group detections → crop regions → | |
| find all bboxes inside each crop → Jina-CLIP-v2 embeddings and classification. | |
| Outputs jina_crops folder and results CSV. | |
| """ | |
| import argparse | |
| import csv | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, DFineForObjectDetection | |
| # Jina-CLIP-v2 few-shot (same refs + classify as jina_fewshot.py) | |
| from jina_fewshot import ( | |
| IMAGE_EXTS, | |
| TRUNCATE_DIM, | |
| JinaCLIPv2Encoder, | |
| build_refs, | |
| classify as jina_classify, | |
| draw_bboxes_on_image, | |
| draw_label_on_image, | |
| ) | |
| # Only these ref classes get bboxes on group crops and appear in the known-object gallery | |
| KNOWN_DISPLAY_CLASSES = {"gun", "knife", "cigarette", "phone"} | |
| # Only show objects (and group crops) with confidence >= this | |
| MIN_DISPLAY_CONF = 0.7 | |
| # Person/car detections must have confidence > this to be used for grouping | |
| PERSON_CAR_MIN_CONF = 0.9 | |
| # ----------------------------------------------------------------------------- | |
| # Detection + grouping (from reference_detection.py) | |
| # ----------------------------------------------------------------------------- | |
| def get_box_dist(box1, box2): | |
| """Euclidean distance between box centers. box = [x1, y1, x2, y2].""" | |
| c1 = np.array([(box1[0] + box1[2]) / 2, (box1[1] + box1[3]) / 2]) | |
| c2 = np.array([(box2[0] + box2[2]) / 2, (box2[1] + box2[3]) / 2]) | |
| return np.linalg.norm(c1 - c2) | |
| def group_detections(detections, threshold): | |
| """ | |
| Group detections by proximity (center distance < threshold). | |
| detections: list of {"box": [x1,y1,x2,y2], "conf", "cls", ...} | |
| Returns list of {"box": merged [x1,y1,x2,y2], "conf": best in group, "cls": best in group}. | |
| """ | |
| if not detections: | |
| return [] | |
| boxes = [d["box"] for d in detections] | |
| n = len(boxes) | |
| adj = {i: [] for i in range(n)} | |
| for i in range(n): | |
| for j in range(i + 1, n): | |
| if get_box_dist(boxes[i], boxes[j]) < threshold: | |
| adj[i].append(j) | |
| adj[j].append(i) | |
| groups = [] | |
| visited = [False] * n | |
| for i in range(n): | |
| if not visited[i]: | |
| group_indices = [] | |
| stack = [i] | |
| visited[i] = True | |
| while stack: | |
| curr = stack.pop() | |
| group_indices.append(curr) | |
| for neighbor in adj[curr]: | |
| if not visited[neighbor]: | |
| visited[neighbor] = True | |
| stack.append(neighbor) | |
| group_dets = [detections[k] for k in group_indices] | |
| x1 = min(d["box"][0] for d in group_dets) | |
| y1 = min(d["box"][1] for d in group_dets) | |
| x2 = max(d["box"][2] for d in group_dets) | |
| y2 = max(d["box"][3] for d in group_dets) | |
| best_det = max(group_dets, key=lambda x: x["conf"]) | |
| groups.append({ | |
| "box": [x1, y1, x2, y2], | |
| "conf": best_det["conf"], | |
| "cls": best_det["cls"], | |
| "label": best_det.get("label", str(best_det["cls"])), | |
| }) | |
| return groups | |
| def box_center_inside(box, crop_box): | |
| """True if center of box is inside crop_box. All [x1,y1,x2,y2].""" | |
| cx = (box[0] + box[2]) / 2 | |
| cy = (box[1] + box[3]) / 2 | |
| return ( | |
| crop_box[0] <= cx <= crop_box[2] | |
| and crop_box[1] <= cy <= crop_box[3] | |
| ) | |
| def expand_box_by_margin(box, margin_ratio, img_w, img_h): | |
| """Expand box [x1,y1,x2,y2] by margin_ratio (e.g. 0.1 = 10%) on all sides, clamped to image.""" | |
| x1, y1, x2, y2 = box | |
| w, h = x2 - x1, y2 - y1 | |
| if w <= 0 or h <= 0: | |
| return box | |
| mx = w * margin_ratio | |
| my = h * margin_ratio | |
| x1 = max(0, x1 - mx) | |
| y1 = max(0, y1 - my) | |
| x2 = min(img_w, x2 + mx) | |
| y2 = min(img_h, y2 + my) | |
| return [x1, y1, x2, y2] | |
| # 10% margin on person/car group crop (expand crop before running D-FINE on it) | |
| PERSON_CAR_GROUP_MARGIN = 0.10 | |
| # Min side (px) for object crops extracted from person/car crop before sending to classifier (objects in crop are larger) | |
| MIN_OBJECT_CROP_SIDE = 112 | |
| def squarify_crop_box(bx1, by1, bx2, by2, img_w, img_h): | |
| """ | |
| Expand the shorter side to match the longer (same ratio / square), centered, clamped to image. | |
| If height > width: expand width. If width >= height: expand height. | |
| Returns (bx1, by1, bx2, by2) as integers. | |
| """ | |
| orig = (int(bx1), int(by1), int(bx2), int(by2)) | |
| w = bx2 - bx1 | |
| h = by2 - by1 | |
| if w <= 0 or h <= 0: | |
| return orig | |
| if h > w: | |
| add = (h - w) / 2.0 | |
| bx1 = max(0, bx1 - add) | |
| bx2 = min(img_w, bx2 + add) | |
| else: | |
| add = (w - h) / 2.0 | |
| by1 = max(0, by1 - add) | |
| by2 = min(img_h, by2 + add) | |
| bx1, by1, bx2, by2 = int(bx1), int(by1), int(bx2), int(by2) | |
| if bx2 <= bx1 or by2 <= by1: | |
| return orig | |
| return bx1, by1, bx2, by2 | |
| def box_iou(box1, box2): | |
| """IoU of two boxes [x1,y1,x2,y2]. Returns float in [0, 1].""" | |
| ix1 = max(box1[0], box2[0]) | |
| iy1 = max(box1[1], box2[1]) | |
| ix2 = min(box1[2], box2[2]) | |
| iy2 = min(box1[3], box2[3]) | |
| inter_w = max(0, ix2 - ix1) | |
| inter_h = max(0, iy2 - iy1) | |
| inter = inter_w * inter_h | |
| a1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| a2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| union = a1 + a2 - inter | |
| return inter / union if union > 0 else 0.0 | |
| def deduplicate_by_iou(detections, iou_threshold=0.9): | |
| """Keep one detection per overlapping group (IoU >= iou_threshold). Prefer higher confidence.""" | |
| if not detections: | |
| return [] | |
| # Sort by confidence descending; keep first, then add only if no kept box overlaps >= threshold | |
| sorted_d = sorted(detections, key=lambda x: -x["conf"]) | |
| kept = [] | |
| for d in sorted_d: | |
| if not any(box_iou(d["box"], k["box"]) >= iou_threshold for k in kept): | |
| kept.append(d) | |
| return kept | |
| def parse_args(): | |
| p = argparse.ArgumentParser( | |
| description="D-FINE (person/car) → group → Jina-CLIP-v2 on crops inside groups" | |
| ) | |
| p.add_argument("--refs", required=True, help="Reference images folder for Jina (e.g. refs/)") | |
| p.add_argument("--input", required=True, help="Full-frame images folder") | |
| p.add_argument("--output", default="pipeline_results", help="Output folder (CSV, etc.)") | |
| p.add_argument("--det-threshold", type=float, default=0.13, help="D-FINE score threshold") | |
| p.add_argument("--group-dist", type=float, default=None, help="Group distance (default: 0.1 * max(H,W))") | |
| p.add_argument("--min-side", type=int, default=40, help="Min side of expanded bbox in px (skip smaller)") | |
| p.add_argument("--crop-dedup-iou", type=float, default=0.35, help="Min IoU to treat two crops as same object (keep larger)") | |
| p.add_argument("--no-squarify", action="store_true", help="Skip squarify; use expanded bbox only (tighter crops, often better recognition)") | |
| p.add_argument("--padding", type=float, default=0.2, help="Crop padding around group box (0.2 = 20%%)") | |
| p.add_argument("--conf-threshold", type=float, default=0.75, help="Jina accept confidence") | |
| p.add_argument("--gap-threshold", type=float, default=0.05, help="Jina accept gap") | |
| p.add_argument("--text-weight", type=float, default=0.3) | |
| p.add_argument("--max-images", type=int, default=None) | |
| p.add_argument("--device", default=None) | |
| p.add_argument("--dfine-model", choices=list(DFINE_MODEL_IDS.keys()), default="large-obj365", help="D-FINE model") | |
| return p.parse_args() | |
| def get_person_car_label_ids(model): | |
| """Return set of label IDs for person and car (Objects365: Person, Car, SUV, etc.).""" | |
| id2label = getattr(model.config, "id2label", None) or {} | |
| ids = set() | |
| for idx, name in id2label.items(): | |
| try: | |
| i = int(idx) | |
| except (ValueError, TypeError): | |
| continue | |
| n = (name or "").lower() | |
| if "person" in n or n in ("car", "suv"): | |
| ids.add(i) | |
| return ids | |
| def run_dfine(image, processor, model, device, score_threshold): | |
| """Run D-FINE, return all detections as list of {box, score, label_id, label}.""" | |
| from PIL import Image | |
| if isinstance(image, Image.Image): | |
| pil = image.convert("RGB") | |
| else: | |
| pil = Image.fromarray(image).convert("RGB") | |
| w, h = pil.size | |
| target_size = torch.tensor([[h, w]], device=device) | |
| inputs = processor(images=pil, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| target_sizes = target_size.to(outputs["logits"].device) | |
| results = processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=target_sizes, | |
| threshold=score_threshold, | |
| ) | |
| id2label = getattr(model.config, "id2label", {}) or {} | |
| detections = [] | |
| for result in results: | |
| for score, label_id, box in zip( | |
| result["scores"], | |
| result["labels"], | |
| result["boxes"] | |
| ): | |
| sid = int(label_id.item()) | |
| detections.append({ | |
| "box": [float(x) for x in box.cpu().tolist()], | |
| "conf": float(score.item()), | |
| "cls": sid, | |
| "label": id2label.get(sid, str(sid)), | |
| }) | |
| return detections | |
| def main(): | |
| args = parse_args() | |
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| input_dir = Path(args.input) | |
| output_dir = Path(args.output) | |
| refs_dir = Path(args.refs) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| if not refs_dir.is_dir(): | |
| raise SystemExit(f"Refs folder not found: {refs_dir}") | |
| if not input_dir.is_dir(): | |
| raise SystemExit(f"Input folder not found: {input_dir}") | |
| paths = sorted( | |
| p for p in input_dir.iterdir() | |
| if p.suffix.lower() in IMAGE_EXTS | |
| ) | |
| if args.max_images is not None: | |
| paths = paths[: args.max_images] | |
| if not paths: | |
| raise SystemExit(f"No images in {input_dir}") | |
| # Load D-FINE | |
| dfine_model_id = DFINE_MODEL_IDS.get(args.dfine_model, DFINE_MODEL_IDS["large-obj365"]) | |
| print(f"[*] Loading D-FINE ({dfine_model_id})...") | |
| t0 = time.perf_counter() | |
| image_processor = AutoImageProcessor.from_pretrained(dfine_model_id) | |
| dfine_model = DFineForObjectDetection.from_pretrained(dfine_model_id) | |
| dfine_model = dfine_model.to(device).eval() | |
| person_car_ids = get_person_car_label_ids(dfine_model) | |
| print(f" Person/car label IDs: {person_car_ids} ({time.perf_counter()-t0:.1f}s)") | |
| # Load Jina-CLIP-v2 + build refs | |
| print("[*] Loading Jina-CLIP-v2 and building refs...") | |
| t0 = time.perf_counter() | |
| jina_encoder = JinaCLIPv2Encoder(device) | |
| ref_labels, ref_embs = build_refs( | |
| jina_encoder, | |
| refs_dir, | |
| TRUNCATE_DIM, | |
| args.text_weight, | |
| batch_size=16 | |
| ) | |
| print(f" Jina refs: {ref_labels} ({time.perf_counter()-t0:.1f}s)\n") | |
| jina_crops_dir = output_dir / "jina_crops" | |
| jina_crops_dir.mkdir(parents=True, exist_ok=True) | |
| # CSV | |
| csv_path = output_dir / "results.csv" | |
| f = open(csv_path, "w", newline="") | |
| w = csv.writer(f) | |
| w.writerow([ | |
| "image", | |
| "crop_filename", | |
| "group_idx", | |
| "crop_x1", | |
| "crop_y1", | |
| "crop_x2", | |
| "crop_y2", | |
| "bbox_x1", | |
| "bbox_y1", | |
| "bbox_x2", | |
| "bbox_y2", | |
| "dfine_label", | |
| "dfine_conf", | |
| "jina_prediction", | |
| "jina_confidence", | |
| "jina_status", | |
| ]) | |
| for img_path in paths: | |
| pil = Image.open(img_path).convert("RGB") | |
| img_w, img_h = pil.size | |
| group_dist = args.group_dist if args.group_dist is not None else 0.1 * max(img_h, img_w) | |
| # 1) D-FINE: detect everything, keep all bboxes for the image | |
| detections = run_dfine( | |
| pil, | |
| image_processor, | |
| dfine_model, | |
| device, | |
| args.det_threshold | |
| ) | |
| person_car = [d for d in detections if d["cls"] in person_car_ids and d["conf"] > PERSON_CAR_MIN_CONF] | |
| if not person_car: | |
| continue | |
| # 2) Group person/car detections (same as reference) | |
| grouped = group_detections(person_car, group_dist) | |
| grouped.sort(key=lambda x: x["conf"], reverse=True) | |
| top_groups = grouped[:10] # limit groups per image | |
| # 3) Collect all candidate crops (bboxes inside person/car groups) | |
| # Each: (crop_box, crop_pil, d, gidx, crop_idx, x1, y1, x2, y2) | |
| candidates = [] | |
| for gidx, grp in enumerate(top_groups): | |
| x1, y1, x2, y2 = grp["box"] | |
| group_box = [x1, y1, x2, y2] | |
| group_box_with_margin = expand_box_by_margin(group_box, PERSON_CAR_GROUP_MARGIN, img_w, img_h) | |
| inside = [ | |
| d for d in detections | |
| if box_center_inside(d["box"], group_box_with_margin) and d["cls"] not in person_car_ids | |
| ] | |
| inside = deduplicate_by_iou(inside, iou_threshold=0.9) | |
| for crop_idx, d in enumerate(inside): | |
| bx1, by1, bx2, by2 = [float(x) for x in d["box"]] | |
| obj_w, obj_h = bx2 - bx1, by2 - by1 | |
| if obj_w <= 0 or obj_h <= 0: | |
| continue | |
| # Small objects (min side < 24 px): expand by 60%; larger: 30% | |
| min_side_obj = min(obj_w, obj_h) | |
| pad_ratio = 0.6 if min_side_obj < 24 else 0.3 | |
| pad_x = obj_w * pad_ratio | |
| pad_y = obj_h * pad_ratio | |
| bx1 = max(0, int(bx1 - pad_x)) | |
| by1 = max(0, int(by1 - pad_y)) | |
| bx2 = min(img_w, int(bx2 + pad_x)) | |
| by2 = min(img_h, int(by2 + pad_y)) | |
| if bx2 <= bx1 or by2 <= by1: | |
| continue | |
| if min(bx2 - bx1, by2 - by1) < args.min_side: | |
| continue | |
| expanded_box = [bx1, by1, bx2, by2] | |
| candidates.append((expanded_box, d, gidx, crop_idx, x1, y1, x2, y2)) | |
| # 4) Dedup on EXPANDED boxes (before squarify), keep larger; then squarify only kept | |
| def crop_area(box): | |
| return (box[2] - box[0]) * (box[3] - box[1]) | |
| candidates.sort(key=lambda c: -crop_area(c[0])) | |
| kept = [] | |
| for c in candidates: | |
| expanded_box = c[0] | |
| def is_same_object(box_a, box_b): | |
| if box_iou(box_a, box_b) >= args.crop_dedup_iou: | |
| return True | |
| if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a): | |
| return True | |
| return False | |
| if not any(is_same_object(expanded_box, k[0]) for k in kept): | |
| kept.append(c) | |
| # 5) Optionally squarify, then run Jina on kept crops | |
| for i, (expanded_box, d, gidx, crop_idx, x1, y1, x2, y2) in enumerate(kept): | |
| if not args.no_squarify: | |
| bx1, by1, bx2, by2 = squarify_crop_box( | |
| expanded_box[0], | |
| expanded_box[1], | |
| expanded_box[2], | |
| expanded_box[3], | |
| img_w, | |
| img_h | |
| ) | |
| else: | |
| bx1, by1, bx2, by2 = expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3] | |
| crop_pil = pil.crop((bx1, by1, bx2, by2)) | |
| crop_name = f"{img_path.stem}_g{gidx}_{i}_{bx1}_{by1}_{bx2}_{by2}{img_path.suffix}" | |
| q_jina = jina_encoder.encode_images([crop_pil], TRUNCATE_DIM) | |
| result_jina = jina_classify( | |
| q_jina, | |
| ref_labels, | |
| ref_embs, | |
| args.conf_threshold, | |
| args.gap_threshold | |
| ) | |
| if result_jina["prediction"] in ref_labels: | |
| label_jina = result_jina["prediction"] | |
| conf_jina = result_jina["confidence"] | |
| else: | |
| label_jina = f"unnamed (dfine: {d['label']})" | |
| conf_jina = 0.0 | |
| ann_jina = draw_label_on_image(crop_pil, label_jina, conf_jina) | |
| ann_jina.save(jina_crops_dir / crop_name) | |
| w.writerow([ | |
| img_path.name, | |
| crop_name, | |
| gidx, | |
| x1, | |
| y1, | |
| x2, | |
| y2, | |
| bx1, | |
| by1, | |
| bx2, | |
| by2, | |
| d["label"], | |
| f"{d['conf']:.4f}", | |
| result_jina["prediction"], | |
| f"{result_jina['confidence']:.4f}", | |
| result_jina["status"], | |
| ]) | |
| f.close() | |
| print(f"[*] Wrote {csv_path}") | |
| print(f"[*] Jina crops: {jina_crops_dir}") | |
| # ----------------------------------------------------------------------------- | |
| # Single-image runner for Gradio app: D-FINE first, then Jina | |
| # ----------------------------------------------------------------------------- | |
| _APP_DFINE = None # (model_id, image_processor, dfine_model, person_car_ids) | |
| _APP_CLASSIFIERS = {} # {classifier_name: (classifier_instance, refs_dir_str)} | |
| DFINE_MODEL_IDS = { | |
| # obj365 | |
| "small-obj365": "ustc-community/dfine-small-obj365", | |
| "medium-obj365": "ustc-community/dfine-medium-obj365", | |
| "large-obj365": "ustc-community/dfine-large-obj365", | |
| # coco | |
| "small-coco": "ustc-community/dfine-small-coco", | |
| "medium-coco": "ustc-community/dfine-medium-coco", | |
| "large-coco": "ustc-community/dfine-large-coco", | |
| # obj2coco | |
| "small-obj2coco": "ustc-community/dfine-small-obj2coco", | |
| "medium-obj2coco": "ustc-community/dfine-medium-obj2coco", | |
| "large-obj2coco": "ustc-community/dfine-large-obj2coco-e25", | |
| } | |
| CLASSIFIER_CHOICES = ["jina", "siglip-224", "siglip-256", "siglip-384", "siglip2_onnx"] | |
| def _load_classifier(classifier_name, device, refs_dir=None, labels=None): | |
| """Factory: load and initialize a classifier by name.""" | |
| if refs_dir: | |
| refs_dir = Path(refs_dir) | |
| if classifier_name == "jina": | |
| jina_encoder = JinaCLIPv2Encoder(device) | |
| ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16) | |
| return ("jina_wrapped", jina_encoder, ref_labels, ref_embs) | |
| if classifier_name.startswith("siglip-"): | |
| from siglip_zeroshot import SigLIPClassifier, SIGLIP_MODELS | |
| if classifier_name not in SIGLIP_MODELS: | |
| raise ValueError(f"Unknown SigLIP model: {classifier_name}. Choose from {list(SIGLIP_MODELS.keys())}") | |
| clf = SigLIPClassifier(device, model_key=classifier_name) | |
| clf.build_refs(refs_dir=refs_dir, labels=labels) | |
| return clf | |
| if classifier_name == "siglip2_onnx": | |
| from siglip2_onnx_zeroshot import SigLIP2ONNXClassifier | |
| clf = SigLIP2ONNXClassifier(device) | |
| clf.build_refs(refs_dir=refs_dir, labels=labels) | |
| return clf | |
| raise ValueError(f"Unknown classifier: {classifier_name}. Choose from {CLASSIFIER_CHOICES}") | |
| def _classify_crop(classifier, crop, conf_threshold, gap_threshold): | |
| """Unified classify call that works for both Jina (tuple) and SigLIP-style classifiers.""" | |
| if isinstance(classifier, tuple) and classifier[0] == "jina_wrapped": | |
| _, jina_encoder, ref_labels, ref_embs = classifier | |
| q = jina_encoder.encode_images([crop], TRUNCATE_DIM) | |
| return jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold) | |
| else: | |
| return classifier.classify_crop(crop, conf_threshold, gap_threshold) | |
| def run_single_image( | |
| pil_image, | |
| refs_dir=None, | |
| device=None, | |
| dfine_model="large", | |
| det_threshold=0.3, | |
| conf_threshold=0.75, | |
| gap_threshold=0.05, | |
| min_side=40, | |
| crop_dedup_iou=0.35, | |
| squarify=True, | |
| min_display_conf=None, | |
| classifier="siglip-256", | |
| labels=None, | |
| ): | |
| """ | |
| Run D-FINE on one image, then classify small-object crops. | |
| refs_dir: path to refs folder (str or Path), optional if labels provided. | |
| labels: list of class label strings for zero-shot classifiers. | |
| dfine_model: key from DFINE_MODEL_IDS. | |
| Returns (group_crop_images, known_crop_composites, status_message). | |
| """ | |
| import numpy as np | |
| if min_display_conf is None: | |
| min_display_conf = MIN_DISPLAY_CONF | |
| from PIL import Image | |
| global _APP_DFINE | |
| if refs_dir: | |
| refs_dir = Path(refs_dir) | |
| if not refs_dir.is_dir(): | |
| return [], [], f"Refs folder not found: {refs_dir}" | |
| if not refs_dir and not labels: | |
| return [], [], "Provide either refs_dir or labels." | |
| dfine_model = (dfine_model or "large-obj365").strip().lower() | |
| if dfine_model not in DFINE_MODEL_IDS: | |
| dfine_model = "large-obj365" | |
| model_id = DFINE_MODEL_IDS[dfine_model] | |
| device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[*] Device: {device}") | |
| pil = pil_image.convert("RGB") if isinstance(pil_image, Image.Image) else Image.fromarray(pil_image).convert("RGB") | |
| img_w, img_h = pil.size | |
| group_dist = 0.1 * max(img_h, img_w) | |
| # Load D-FINE (reload if user switched model) | |
| if _APP_DFINE is None or _APP_DFINE[0] != dfine_model: | |
| print(f"[*] Loading D-FINE ({model_id})...") | |
| image_processor = AutoImageProcessor.from_pretrained(model_id) | |
| dfine_model_obj = DFineForObjectDetection.from_pretrained(model_id) | |
| dfine_model_obj = dfine_model_obj.to(device).eval() | |
| person_car_ids = get_person_car_label_ids(dfine_model_obj) | |
| _APP_DFINE = (dfine_model, image_processor, dfine_model_obj, person_car_ids) | |
| _model_id, image_processor, dfine_model_obj, person_car_ids = _APP_DFINE | |
| # Apply user's D-FINE detection threshold to the chosen model (medium or large) | |
| detections = run_dfine(pil, image_processor, dfine_model_obj, device, det_threshold) | |
| person_car = [d for d in detections if d["cls"] in person_car_ids and d["conf"] > PERSON_CAR_MIN_CONF] | |
| if not person_car: | |
| return [], [], "No person/car detected (or none with confidence > 0.9). No small-object crops." | |
| grouped = group_detections(person_car, group_dist) | |
| grouped.sort(key=lambda x: x["conf"], reverse=True) | |
| top_groups = grouped[:10] | |
| # Load classifier | |
| global _APP_CLASSIFIERS | |
| cache_key = str(labels) if labels else str(refs_dir) | |
| clf_key = classifier | |
| if clf_key not in _APP_CLASSIFIERS or _APP_CLASSIFIERS[clf_key][1] != cache_key: | |
| clf_instance = _load_classifier(classifier, device, refs_dir=refs_dir, labels=labels) | |
| _APP_CLASSIFIERS[clf_key] = (clf_instance, cache_key) | |
| clf_instance = _APP_CLASSIFIERS[clf_key][0] | |
| results_per_crop = [] | |
| group_crop_images = [] | |
| classification_log = [] | |
| # Non-person/car detections from the full-frame pass, reused per group | |
| other_detections = [d for d in detections if d["cls"] not in person_car_ids] | |
| # For each person/car group: crop (with 10% margin), reuse full-frame detections that fall inside, then classify | |
| for gidx, grp in enumerate(top_groups): | |
| group_box = [grp["box"][0], grp["box"][1], grp["box"][2], grp["box"][3]] | |
| crop_box = expand_box_by_margin(group_box, PERSON_CAR_GROUP_MARGIN, img_w, img_h) | |
| gx1 = max(0, int(crop_box[0])) | |
| gy1 = max(0, int(crop_box[1])) | |
| gx2 = min(img_w, int(crop_box[2])) | |
| gy2 = min(img_h, int(crop_box[3])) | |
| if gx2 <= gx1 or gy2 <= gy1: | |
| continue | |
| crop_pil = pil.crop((gx1, gy1, gx2, gy2)).copy().convert("RGB") | |
| crop_w, crop_h = crop_pil.size | |
| # Filter full-frame detections whose center falls inside this crop box, then remap to crop-local coords | |
| inside_full = [d for d in other_detections if box_center_inside(d["box"], [gx1, gy1, gx2, gy2])] | |
| inside = [] | |
| for d in inside_full: | |
| remapped = dict(d) | |
| fx1, fy1, fx2, fy2 = d["box"] | |
| remapped["box"] = [ | |
| max(0, fx1 - gx1), | |
| max(0, fy1 - gy1), | |
| min(crop_w, fx2 - gx1), | |
| min(crop_h, fy2 - gy1), | |
| ] | |
| inside.append(remapped) | |
| inside = deduplicate_by_iou(inside, iou_threshold=0.9) | |
| candidates = [] | |
| for d in inside: | |
| bx1, by1, bx2, by2 = [float(x) for x in d["box"]] | |
| obj_w, obj_h = bx2 - bx1, by2 - by1 | |
| if obj_w <= 0 or obj_h <= 0: | |
| continue | |
| min_side_obj = min(obj_w, obj_h) | |
| pad_ratio = 0.6 if min_side_obj < 24 else 0.3 | |
| pad_x = obj_w * pad_ratio | |
| pad_y = obj_h * pad_ratio | |
| bx1 = max(0.0, bx1 - pad_x) | |
| by1 = max(0.0, by1 - pad_y) | |
| bx2 = min(crop_w, bx2 + pad_x) | |
| by2 = min(crop_h, by2 + pad_y) | |
| if bx2 <= bx1 or by2 <= by1: | |
| continue | |
| w, h = bx2 - bx1, by2 - by1 | |
| if min(w, h) < MIN_OBJECT_CROP_SIDE: | |
| need = MIN_OBJECT_CROP_SIDE - min(w, h) | |
| half = need / 2.0 | |
| if w < h: | |
| bx1 = max(0, bx1 - half) | |
| bx2 = min(crop_w, bx2 + half) | |
| else: | |
| by1 = max(0, by1 - half) | |
| by2 = min(crop_h, by2 + half) | |
| w, h = bx2 - bx1, by2 - by1 | |
| if w < MIN_OBJECT_CROP_SIDE: | |
| add = (MIN_OBJECT_CROP_SIDE - w) / 2 | |
| bx1 = max(0, bx1 - add) | |
| bx2 = min(crop_w, bx2 + add) | |
| if h < MIN_OBJECT_CROP_SIDE: | |
| add = (MIN_OBJECT_CROP_SIDE - h) / 2 | |
| by1 = max(0, by1 - add) | |
| by2 = min(crop_h, by2 + add) | |
| bx1, by1, bx2, by2 = int(bx1), int(by1), int(bx2), int(by2) | |
| if bx2 <= bx1 or by2 <= by1: | |
| continue | |
| candidates.append(([bx1, by1, bx2, by2], d, gidx)) | |
| def crop_area(box): | |
| return (box[2] - box[0]) * (box[3] - box[1]) | |
| candidates.sort(key=lambda c: -crop_area(c[0])) | |
| kept = [] | |
| for c in candidates: | |
| expanded_box = c[0] | |
| if not any( | |
| box_iou(expanded_box, k[0]) >= crop_dedup_iou | |
| or box_center_inside(expanded_box, k[0]) | |
| or box_center_inside(k[0], expanded_box) | |
| for k in kept | |
| ): | |
| kept.append(c) | |
| for (bx1, by1, bx2, by2), d, _ in kept: | |
| if squarify: | |
| bx1, by1, bx2, by2 = squarify_crop_box(bx1, by1, bx2, by2, crop_w, crop_h) | |
| small_crop = crop_pil.crop((bx1, by1, bx2, by2)) | |
| result = _classify_crop(clf_instance, small_crop, conf_threshold, gap_threshold) | |
| raw_pred = result["prediction"] | |
| pred = raw_pred if raw_pred != "unknown" else f"unknown ({d['label']})" | |
| conf = result["confidence"] | |
| results_per_crop.append((gidx, (bx1, by1, bx2, by2), small_crop, pred, conf)) | |
| # Build per-crop log line | |
| sims_str = ", ".join(f"{k}: {v:.4f}" for k, v in result.get("all_sims", {}).items()) | |
| classification_log.append( | |
| f"[group {gidx}] dfine: {d['label']} ({d['conf']:.3f}) → " | |
| f"{pred} (conf={conf:.4f}, gap={result['gap']:.4f}, 2nd={result.get('second_best','?')}) " | |
| f"| {result['status']} | {sims_str}" | |
| ) | |
| # Draw bboxes on this group crop (bboxes already in crop coords) | |
| boxes_to_draw = [ | |
| (bx1, by1, bx2, by2, pred, conf) | |
| for (gidx2, (bx1, by1, bx2, by2), _sc, pred, conf) in results_per_crop | |
| if gidx2 == gidx | |
| ] | |
| if boxes_to_draw: | |
| crop_pil_drawn = draw_bboxes_on_image(crop_pil.copy(), boxes_to_draw) | |
| else: | |
| crop_pil_drawn = crop_pil | |
| group_crop_images.append(np.array(crop_pil_drawn)) | |
| log_text = f"Classifier: {classifier} | {len(results_per_crop)} crops classified\n" | |
| log_text += "\n".join(classification_log) if classification_log else "(no crops)" | |
| if not results_per_crop: | |
| return group_crop_images if group_crop_images else [], [], log_text + "\nNo small-object crops: no object detections (gun/phone/etc.) found inside person/car groups, or all were below min size." | |
| # Build known-only gallery: only objects with conf >= min_display_conf | |
| known_crop_composites = [] | |
| for (_gidx, _box, crop_pil, pred, conf) in results_per_crop: | |
| if pred.startswith("unknown") or conf < min_display_conf: | |
| continue | |
| composite = draw_label_on_image(crop_pil, pred, conf) | |
| known_crop_composites.append(np.array(composite)) | |
| return group_crop_images, known_crop_composites, log_text | |
| if __name__ == "__main__": | |
| main() |