small_object_detection / dfine_jina_pipeline.py
orik-ss's picture
Removed second classification call on crops, DFINE runs once
5ad3720
Raw
History Blame Contribute Delete
28.9 kB
""" Pipeline: D-FINE (person/car only) → group detections → crop regions →
find all bboxes inside each crop → Jina-CLIP-v2 embeddings and classification.
Outputs jina_crops folder and results CSV.
"""
import argparse
import csv
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoImageProcessor, DFineForObjectDetection
# Jina-CLIP-v2 few-shot (same refs + classify as jina_fewshot.py)
from jina_fewshot import (
IMAGE_EXTS,
TRUNCATE_DIM,
JinaCLIPv2Encoder,
build_refs,
classify as jina_classify,
draw_bboxes_on_image,
draw_label_on_image,
)
# Only these ref classes get bboxes on group crops and appear in the known-object gallery
KNOWN_DISPLAY_CLASSES = {"gun", "knife", "cigarette", "phone"}
# Only show objects (and group crops) with confidence >= this
MIN_DISPLAY_CONF = 0.7
# Person/car detections must have confidence > this to be used for grouping
PERSON_CAR_MIN_CONF = 0.9
# -----------------------------------------------------------------------------
# Detection + grouping (from reference_detection.py)
# -----------------------------------------------------------------------------
def get_box_dist(box1, box2):
"""Euclidean distance between box centers. box = [x1, y1, x2, y2]."""
c1 = np.array([(box1[0] + box1[2]) / 2, (box1[1] + box1[3]) / 2])
c2 = np.array([(box2[0] + box2[2]) / 2, (box2[1] + box2[3]) / 2])
return np.linalg.norm(c1 - c2)
def group_detections(detections, threshold):
"""
Group detections by proximity (center distance < threshold).
detections: list of {"box": [x1,y1,x2,y2], "conf", "cls", ...}
Returns list of {"box": merged [x1,y1,x2,y2], "conf": best in group, "cls": best in group}.
"""
if not detections:
return []
boxes = [d["box"] for d in detections]
n = len(boxes)
adj = {i: [] for i in range(n)}
for i in range(n):
for j in range(i + 1, n):
if get_box_dist(boxes[i], boxes[j]) < threshold:
adj[i].append(j)
adj[j].append(i)
groups = []
visited = [False] * n
for i in range(n):
if not visited[i]:
group_indices = []
stack = [i]
visited[i] = True
while stack:
curr = stack.pop()
group_indices.append(curr)
for neighbor in adj[curr]:
if not visited[neighbor]:
visited[neighbor] = True
stack.append(neighbor)
group_dets = [detections[k] for k in group_indices]
x1 = min(d["box"][0] for d in group_dets)
y1 = min(d["box"][1] for d in group_dets)
x2 = max(d["box"][2] for d in group_dets)
y2 = max(d["box"][3] for d in group_dets)
best_det = max(group_dets, key=lambda x: x["conf"])
groups.append({
"box": [x1, y1, x2, y2],
"conf": best_det["conf"],
"cls": best_det["cls"],
"label": best_det.get("label", str(best_det["cls"])),
})
return groups
def box_center_inside(box, crop_box):
"""True if center of box is inside crop_box. All [x1,y1,x2,y2]."""
cx = (box[0] + box[2]) / 2
cy = (box[1] + box[3]) / 2
return (
crop_box[0] <= cx <= crop_box[2]
and crop_box[1] <= cy <= crop_box[3]
)
def expand_box_by_margin(box, margin_ratio, img_w, img_h):
"""Expand box [x1,y1,x2,y2] by margin_ratio (e.g. 0.1 = 10%) on all sides, clamped to image."""
x1, y1, x2, y2 = box
w, h = x2 - x1, y2 - y1
if w <= 0 or h <= 0:
return box
mx = w * margin_ratio
my = h * margin_ratio
x1 = max(0, x1 - mx)
y1 = max(0, y1 - my)
x2 = min(img_w, x2 + mx)
y2 = min(img_h, y2 + my)
return [x1, y1, x2, y2]
# 10% margin on person/car group crop (expand crop before running D-FINE on it)
PERSON_CAR_GROUP_MARGIN = 0.10
# Min side (px) for object crops extracted from person/car crop before sending to classifier (objects in crop are larger)
MIN_OBJECT_CROP_SIDE = 112
def squarify_crop_box(bx1, by1, bx2, by2, img_w, img_h):
"""
Expand the shorter side to match the longer (same ratio / square), centered, clamped to image.
If height > width: expand width. If width >= height: expand height.
Returns (bx1, by1, bx2, by2) as integers.
"""
orig = (int(bx1), int(by1), int(bx2), int(by2))
w = bx2 - bx1
h = by2 - by1
if w <= 0 or h <= 0:
return orig
if h > w:
add = (h - w) / 2.0
bx1 = max(0, bx1 - add)
bx2 = min(img_w, bx2 + add)
else:
add = (w - h) / 2.0
by1 = max(0, by1 - add)
by2 = min(img_h, by2 + add)
bx1, by1, bx2, by2 = int(bx1), int(by1), int(bx2), int(by2)
if bx2 <= bx1 or by2 <= by1:
return orig
return bx1, by1, bx2, by2
def box_iou(box1, box2):
"""IoU of two boxes [x1,y1,x2,y2]. Returns float in [0, 1]."""
ix1 = max(box1[0], box2[0])
iy1 = max(box1[1], box2[1])
ix2 = min(box1[2], box2[2])
iy2 = min(box1[3], box2[3])
inter_w = max(0, ix2 - ix1)
inter_h = max(0, iy2 - iy1)
inter = inter_w * inter_h
a1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
a2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
union = a1 + a2 - inter
return inter / union if union > 0 else 0.0
def deduplicate_by_iou(detections, iou_threshold=0.9):
"""Keep one detection per overlapping group (IoU >= iou_threshold). Prefer higher confidence."""
if not detections:
return []
# Sort by confidence descending; keep first, then add only if no kept box overlaps >= threshold
sorted_d = sorted(detections, key=lambda x: -x["conf"])
kept = []
for d in sorted_d:
if not any(box_iou(d["box"], k["box"]) >= iou_threshold for k in kept):
kept.append(d)
return kept
def parse_args():
p = argparse.ArgumentParser(
description="D-FINE (person/car) → group → Jina-CLIP-v2 on crops inside groups"
)
p.add_argument("--refs", required=True, help="Reference images folder for Jina (e.g. refs/)")
p.add_argument("--input", required=True, help="Full-frame images folder")
p.add_argument("--output", default="pipeline_results", help="Output folder (CSV, etc.)")
p.add_argument("--det-threshold", type=float, default=0.13, help="D-FINE score threshold")
p.add_argument("--group-dist", type=float, default=None, help="Group distance (default: 0.1 * max(H,W))")
p.add_argument("--min-side", type=int, default=40, help="Min side of expanded bbox in px (skip smaller)")
p.add_argument("--crop-dedup-iou", type=float, default=0.35, help="Min IoU to treat two crops as same object (keep larger)")
p.add_argument("--no-squarify", action="store_true", help="Skip squarify; use expanded bbox only (tighter crops, often better recognition)")
p.add_argument("--padding", type=float, default=0.2, help="Crop padding around group box (0.2 = 20%%)")
p.add_argument("--conf-threshold", type=float, default=0.75, help="Jina accept confidence")
p.add_argument("--gap-threshold", type=float, default=0.05, help="Jina accept gap")
p.add_argument("--text-weight", type=float, default=0.3)
p.add_argument("--max-images", type=int, default=None)
p.add_argument("--device", default=None)
p.add_argument("--dfine-model", choices=list(DFINE_MODEL_IDS.keys()), default="large-obj365", help="D-FINE model")
return p.parse_args()
def get_person_car_label_ids(model):
"""Return set of label IDs for person and car (Objects365: Person, Car, SUV, etc.)."""
id2label = getattr(model.config, "id2label", None) or {}
ids = set()
for idx, name in id2label.items():
try:
i = int(idx)
except (ValueError, TypeError):
continue
n = (name or "").lower()
if "person" in n or n in ("car", "suv"):
ids.add(i)
return ids
def run_dfine(image, processor, model, device, score_threshold):
"""Run D-FINE, return all detections as list of {box, score, label_id, label}."""
from PIL import Image
if isinstance(image, Image.Image):
pil = image.convert("RGB")
else:
pil = Image.fromarray(image).convert("RGB")
w, h = pil.size
target_size = torch.tensor([[h, w]], device=device)
inputs = processor(images=pil, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
target_sizes = target_size.to(outputs["logits"].device)
results = processor.post_process_object_detection(
outputs,
target_sizes=target_sizes,
threshold=score_threshold,
)
id2label = getattr(model.config, "id2label", {}) or {}
detections = []
for result in results:
for score, label_id, box in zip(
result["scores"],
result["labels"],
result["boxes"]
):
sid = int(label_id.item())
detections.append({
"box": [float(x) for x in box.cpu().tolist()],
"conf": float(score.item()),
"cls": sid,
"label": id2label.get(sid, str(sid)),
})
return detections
def main():
args = parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
input_dir = Path(args.input)
output_dir = Path(args.output)
refs_dir = Path(args.refs)
output_dir.mkdir(parents=True, exist_ok=True)
if not refs_dir.is_dir():
raise SystemExit(f"Refs folder not found: {refs_dir}")
if not input_dir.is_dir():
raise SystemExit(f"Input folder not found: {input_dir}")
paths = sorted(
p for p in input_dir.iterdir()
if p.suffix.lower() in IMAGE_EXTS
)
if args.max_images is not None:
paths = paths[: args.max_images]
if not paths:
raise SystemExit(f"No images in {input_dir}")
# Load D-FINE
dfine_model_id = DFINE_MODEL_IDS.get(args.dfine_model, DFINE_MODEL_IDS["large-obj365"])
print(f"[*] Loading D-FINE ({dfine_model_id})...")
t0 = time.perf_counter()
image_processor = AutoImageProcessor.from_pretrained(dfine_model_id)
dfine_model = DFineForObjectDetection.from_pretrained(dfine_model_id)
dfine_model = dfine_model.to(device).eval()
person_car_ids = get_person_car_label_ids(dfine_model)
print(f" Person/car label IDs: {person_car_ids} ({time.perf_counter()-t0:.1f}s)")
# Load Jina-CLIP-v2 + build refs
print("[*] Loading Jina-CLIP-v2 and building refs...")
t0 = time.perf_counter()
jina_encoder = JinaCLIPv2Encoder(device)
ref_labels, ref_embs = build_refs(
jina_encoder,
refs_dir,
TRUNCATE_DIM,
args.text_weight,
batch_size=16
)
print(f" Jina refs: {ref_labels} ({time.perf_counter()-t0:.1f}s)\n")
jina_crops_dir = output_dir / "jina_crops"
jina_crops_dir.mkdir(parents=True, exist_ok=True)
# CSV
csv_path = output_dir / "results.csv"
f = open(csv_path, "w", newline="")
w = csv.writer(f)
w.writerow([
"image",
"crop_filename",
"group_idx",
"crop_x1",
"crop_y1",
"crop_x2",
"crop_y2",
"bbox_x1",
"bbox_y1",
"bbox_x2",
"bbox_y2",
"dfine_label",
"dfine_conf",
"jina_prediction",
"jina_confidence",
"jina_status",
])
for img_path in paths:
pil = Image.open(img_path).convert("RGB")
img_w, img_h = pil.size
group_dist = args.group_dist if args.group_dist is not None else 0.1 * max(img_h, img_w)
# 1) D-FINE: detect everything, keep all bboxes for the image
detections = run_dfine(
pil,
image_processor,
dfine_model,
device,
args.det_threshold
)
person_car = [d for d in detections if d["cls"] in person_car_ids and d["conf"] > PERSON_CAR_MIN_CONF]
if not person_car:
continue
# 2) Group person/car detections (same as reference)
grouped = group_detections(person_car, group_dist)
grouped.sort(key=lambda x: x["conf"], reverse=True)
top_groups = grouped[:10] # limit groups per image
# 3) Collect all candidate crops (bboxes inside person/car groups)
# Each: (crop_box, crop_pil, d, gidx, crop_idx, x1, y1, x2, y2)
candidates = []
for gidx, grp in enumerate(top_groups):
x1, y1, x2, y2 = grp["box"]
group_box = [x1, y1, x2, y2]
group_box_with_margin = expand_box_by_margin(group_box, PERSON_CAR_GROUP_MARGIN, img_w, img_h)
inside = [
d for d in detections
if box_center_inside(d["box"], group_box_with_margin) and d["cls"] not in person_car_ids
]
inside = deduplicate_by_iou(inside, iou_threshold=0.9)
for crop_idx, d in enumerate(inside):
bx1, by1, bx2, by2 = [float(x) for x in d["box"]]
obj_w, obj_h = bx2 - bx1, by2 - by1
if obj_w <= 0 or obj_h <= 0:
continue
# Small objects (min side < 24 px): expand by 60%; larger: 30%
min_side_obj = min(obj_w, obj_h)
pad_ratio = 0.6 if min_side_obj < 24 else 0.3
pad_x = obj_w * pad_ratio
pad_y = obj_h * pad_ratio
bx1 = max(0, int(bx1 - pad_x))
by1 = max(0, int(by1 - pad_y))
bx2 = min(img_w, int(bx2 + pad_x))
by2 = min(img_h, int(by2 + pad_y))
if bx2 <= bx1 or by2 <= by1:
continue
if min(bx2 - bx1, by2 - by1) < args.min_side:
continue
expanded_box = [bx1, by1, bx2, by2]
candidates.append((expanded_box, d, gidx, crop_idx, x1, y1, x2, y2))
# 4) Dedup on EXPANDED boxes (before squarify), keep larger; then squarify only kept
def crop_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
candidates.sort(key=lambda c: -crop_area(c[0]))
kept = []
for c in candidates:
expanded_box = c[0]
def is_same_object(box_a, box_b):
if box_iou(box_a, box_b) >= args.crop_dedup_iou:
return True
if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a):
return True
return False
if not any(is_same_object(expanded_box, k[0]) for k in kept):
kept.append(c)
# 5) Optionally squarify, then run Jina on kept crops
for i, (expanded_box, d, gidx, crop_idx, x1, y1, x2, y2) in enumerate(kept):
if not args.no_squarify:
bx1, by1, bx2, by2 = squarify_crop_box(
expanded_box[0],
expanded_box[1],
expanded_box[2],
expanded_box[3],
img_w,
img_h
)
else:
bx1, by1, bx2, by2 = expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3]
crop_pil = pil.crop((bx1, by1, bx2, by2))
crop_name = f"{img_path.stem}_g{gidx}_{i}_{bx1}_{by1}_{bx2}_{by2}{img_path.suffix}"
q_jina = jina_encoder.encode_images([crop_pil], TRUNCATE_DIM)
result_jina = jina_classify(
q_jina,
ref_labels,
ref_embs,
args.conf_threshold,
args.gap_threshold
)
if result_jina["prediction"] in ref_labels:
label_jina = result_jina["prediction"]
conf_jina = result_jina["confidence"]
else:
label_jina = f"unnamed (dfine: {d['label']})"
conf_jina = 0.0
ann_jina = draw_label_on_image(crop_pil, label_jina, conf_jina)
ann_jina.save(jina_crops_dir / crop_name)
w.writerow([
img_path.name,
crop_name,
gidx,
x1,
y1,
x2,
y2,
bx1,
by1,
bx2,
by2,
d["label"],
f"{d['conf']:.4f}",
result_jina["prediction"],
f"{result_jina['confidence']:.4f}",
result_jina["status"],
])
f.close()
print(f"[*] Wrote {csv_path}")
print(f"[*] Jina crops: {jina_crops_dir}")
# -----------------------------------------------------------------------------
# Single-image runner for Gradio app: D-FINE first, then Jina
# -----------------------------------------------------------------------------
_APP_DFINE = None # (model_id, image_processor, dfine_model, person_car_ids)
_APP_CLASSIFIERS = {} # {classifier_name: (classifier_instance, refs_dir_str)}
DFINE_MODEL_IDS = {
# obj365
"small-obj365": "ustc-community/dfine-small-obj365",
"medium-obj365": "ustc-community/dfine-medium-obj365",
"large-obj365": "ustc-community/dfine-large-obj365",
# coco
"small-coco": "ustc-community/dfine-small-coco",
"medium-coco": "ustc-community/dfine-medium-coco",
"large-coco": "ustc-community/dfine-large-coco",
# obj2coco
"small-obj2coco": "ustc-community/dfine-small-obj2coco",
"medium-obj2coco": "ustc-community/dfine-medium-obj2coco",
"large-obj2coco": "ustc-community/dfine-large-obj2coco-e25",
}
CLASSIFIER_CHOICES = ["jina", "siglip-224", "siglip-256", "siglip-384", "siglip2_onnx"]
def _load_classifier(classifier_name, device, refs_dir=None, labels=None):
"""Factory: load and initialize a classifier by name."""
if refs_dir:
refs_dir = Path(refs_dir)
if classifier_name == "jina":
jina_encoder = JinaCLIPv2Encoder(device)
ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
return ("jina_wrapped", jina_encoder, ref_labels, ref_embs)
if classifier_name.startswith("siglip-"):
from siglip_zeroshot import SigLIPClassifier, SIGLIP_MODELS
if classifier_name not in SIGLIP_MODELS:
raise ValueError(f"Unknown SigLIP model: {classifier_name}. Choose from {list(SIGLIP_MODELS.keys())}")
clf = SigLIPClassifier(device, model_key=classifier_name)
clf.build_refs(refs_dir=refs_dir, labels=labels)
return clf
if classifier_name == "siglip2_onnx":
from siglip2_onnx_zeroshot import SigLIP2ONNXClassifier
clf = SigLIP2ONNXClassifier(device)
clf.build_refs(refs_dir=refs_dir, labels=labels)
return clf
raise ValueError(f"Unknown classifier: {classifier_name}. Choose from {CLASSIFIER_CHOICES}")
def _classify_crop(classifier, crop, conf_threshold, gap_threshold):
"""Unified classify call that works for both Jina (tuple) and SigLIP-style classifiers."""
if isinstance(classifier, tuple) and classifier[0] == "jina_wrapped":
_, jina_encoder, ref_labels, ref_embs = classifier
q = jina_encoder.encode_images([crop], TRUNCATE_DIM)
return jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
else:
return classifier.classify_crop(crop, conf_threshold, gap_threshold)
def run_single_image(
pil_image,
refs_dir=None,
device=None,
dfine_model="large",
det_threshold=0.3,
conf_threshold=0.75,
gap_threshold=0.05,
min_side=40,
crop_dedup_iou=0.35,
squarify=True,
min_display_conf=None,
classifier="siglip-256",
labels=None,
):
"""
Run D-FINE on one image, then classify small-object crops.
refs_dir: path to refs folder (str or Path), optional if labels provided.
labels: list of class label strings for zero-shot classifiers.
dfine_model: key from DFINE_MODEL_IDS.
Returns (group_crop_images, known_crop_composites, status_message).
"""
import numpy as np
if min_display_conf is None:
min_display_conf = MIN_DISPLAY_CONF
from PIL import Image
global _APP_DFINE
if refs_dir:
refs_dir = Path(refs_dir)
if not refs_dir.is_dir():
return [], [], f"Refs folder not found: {refs_dir}"
if not refs_dir and not labels:
return [], [], "Provide either refs_dir or labels."
dfine_model = (dfine_model or "large-obj365").strip().lower()
if dfine_model not in DFINE_MODEL_IDS:
dfine_model = "large-obj365"
model_id = DFINE_MODEL_IDS[dfine_model]
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"[*] Device: {device}")
pil = pil_image.convert("RGB") if isinstance(pil_image, Image.Image) else Image.fromarray(pil_image).convert("RGB")
img_w, img_h = pil.size
group_dist = 0.1 * max(img_h, img_w)
# Load D-FINE (reload if user switched model)
if _APP_DFINE is None or _APP_DFINE[0] != dfine_model:
print(f"[*] Loading D-FINE ({model_id})...")
image_processor = AutoImageProcessor.from_pretrained(model_id)
dfine_model_obj = DFineForObjectDetection.from_pretrained(model_id)
dfine_model_obj = dfine_model_obj.to(device).eval()
person_car_ids = get_person_car_label_ids(dfine_model_obj)
_APP_DFINE = (dfine_model, image_processor, dfine_model_obj, person_car_ids)
_model_id, image_processor, dfine_model_obj, person_car_ids = _APP_DFINE
# Apply user's D-FINE detection threshold to the chosen model (medium or large)
detections = run_dfine(pil, image_processor, dfine_model_obj, device, det_threshold)
person_car = [d for d in detections if d["cls"] in person_car_ids and d["conf"] > PERSON_CAR_MIN_CONF]
if not person_car:
return [], [], "No person/car detected (or none with confidence > 0.9). No small-object crops."
grouped = group_detections(person_car, group_dist)
grouped.sort(key=lambda x: x["conf"], reverse=True)
top_groups = grouped[:10]
# Load classifier
global _APP_CLASSIFIERS
cache_key = str(labels) if labels else str(refs_dir)
clf_key = classifier
if clf_key not in _APP_CLASSIFIERS or _APP_CLASSIFIERS[clf_key][1] != cache_key:
clf_instance = _load_classifier(classifier, device, refs_dir=refs_dir, labels=labels)
_APP_CLASSIFIERS[clf_key] = (clf_instance, cache_key)
clf_instance = _APP_CLASSIFIERS[clf_key][0]
results_per_crop = []
group_crop_images = []
classification_log = []
# Non-person/car detections from the full-frame pass, reused per group
other_detections = [d for d in detections if d["cls"] not in person_car_ids]
# For each person/car group: crop (with 10% margin), reuse full-frame detections that fall inside, then classify
for gidx, grp in enumerate(top_groups):
group_box = [grp["box"][0], grp["box"][1], grp["box"][2], grp["box"][3]]
crop_box = expand_box_by_margin(group_box, PERSON_CAR_GROUP_MARGIN, img_w, img_h)
gx1 = max(0, int(crop_box[0]))
gy1 = max(0, int(crop_box[1]))
gx2 = min(img_w, int(crop_box[2]))
gy2 = min(img_h, int(crop_box[3]))
if gx2 <= gx1 or gy2 <= gy1:
continue
crop_pil = pil.crop((gx1, gy1, gx2, gy2)).copy().convert("RGB")
crop_w, crop_h = crop_pil.size
# Filter full-frame detections whose center falls inside this crop box, then remap to crop-local coords
inside_full = [d for d in other_detections if box_center_inside(d["box"], [gx1, gy1, gx2, gy2])]
inside = []
for d in inside_full:
remapped = dict(d)
fx1, fy1, fx2, fy2 = d["box"]
remapped["box"] = [
max(0, fx1 - gx1),
max(0, fy1 - gy1),
min(crop_w, fx2 - gx1),
min(crop_h, fy2 - gy1),
]
inside.append(remapped)
inside = deduplicate_by_iou(inside, iou_threshold=0.9)
candidates = []
for d in inside:
bx1, by1, bx2, by2 = [float(x) for x in d["box"]]
obj_w, obj_h = bx2 - bx1, by2 - by1
if obj_w <= 0 or obj_h <= 0:
continue
min_side_obj = min(obj_w, obj_h)
pad_ratio = 0.6 if min_side_obj < 24 else 0.3
pad_x = obj_w * pad_ratio
pad_y = obj_h * pad_ratio
bx1 = max(0.0, bx1 - pad_x)
by1 = max(0.0, by1 - pad_y)
bx2 = min(crop_w, bx2 + pad_x)
by2 = min(crop_h, by2 + pad_y)
if bx2 <= bx1 or by2 <= by1:
continue
w, h = bx2 - bx1, by2 - by1
if min(w, h) < MIN_OBJECT_CROP_SIDE:
need = MIN_OBJECT_CROP_SIDE - min(w, h)
half = need / 2.0
if w < h:
bx1 = max(0, bx1 - half)
bx2 = min(crop_w, bx2 + half)
else:
by1 = max(0, by1 - half)
by2 = min(crop_h, by2 + half)
w, h = bx2 - bx1, by2 - by1
if w < MIN_OBJECT_CROP_SIDE:
add = (MIN_OBJECT_CROP_SIDE - w) / 2
bx1 = max(0, bx1 - add)
bx2 = min(crop_w, bx2 + add)
if h < MIN_OBJECT_CROP_SIDE:
add = (MIN_OBJECT_CROP_SIDE - h) / 2
by1 = max(0, by1 - add)
by2 = min(crop_h, by2 + add)
bx1, by1, bx2, by2 = int(bx1), int(by1), int(bx2), int(by2)
if bx2 <= bx1 or by2 <= by1:
continue
candidates.append(([bx1, by1, bx2, by2], d, gidx))
def crop_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
candidates.sort(key=lambda c: -crop_area(c[0]))
kept = []
for c in candidates:
expanded_box = c[0]
if not any(
box_iou(expanded_box, k[0]) >= crop_dedup_iou
or box_center_inside(expanded_box, k[0])
or box_center_inside(k[0], expanded_box)
for k in kept
):
kept.append(c)
for (bx1, by1, bx2, by2), d, _ in kept:
if squarify:
bx1, by1, bx2, by2 = squarify_crop_box(bx1, by1, bx2, by2, crop_w, crop_h)
small_crop = crop_pil.crop((bx1, by1, bx2, by2))
result = _classify_crop(clf_instance, small_crop, conf_threshold, gap_threshold)
raw_pred = result["prediction"]
pred = raw_pred if raw_pred != "unknown" else f"unknown ({d['label']})"
conf = result["confidence"]
results_per_crop.append((gidx, (bx1, by1, bx2, by2), small_crop, pred, conf))
# Build per-crop log line
sims_str = ", ".join(f"{k}: {v:.4f}" for k, v in result.get("all_sims", {}).items())
classification_log.append(
f"[group {gidx}] dfine: {d['label']} ({d['conf']:.3f}) → "
f"{pred} (conf={conf:.4f}, gap={result['gap']:.4f}, 2nd={result.get('second_best','?')}) "
f"| {result['status']} | {sims_str}"
)
# Draw bboxes on this group crop (bboxes already in crop coords)
boxes_to_draw = [
(bx1, by1, bx2, by2, pred, conf)
for (gidx2, (bx1, by1, bx2, by2), _sc, pred, conf) in results_per_crop
if gidx2 == gidx
]
if boxes_to_draw:
crop_pil_drawn = draw_bboxes_on_image(crop_pil.copy(), boxes_to_draw)
else:
crop_pil_drawn = crop_pil
group_crop_images.append(np.array(crop_pil_drawn))
log_text = f"Classifier: {classifier} | {len(results_per_crop)} crops classified\n"
log_text += "\n".join(classification_log) if classification_log else "(no crops)"
if not results_per_crop:
return group_crop_images if group_crop_images else [], [], log_text + "\nNo small-object crops: no object detections (gun/phone/etc.) found inside person/car groups, or all were below min size."
# Build known-only gallery: only objects with conf >= min_display_conf
known_crop_composites = []
for (_gidx, _box, crop_pil, pred, conf) in results_per_crop:
if pred.startswith("unknown") or conf < min_display_conf:
continue
composite = draw_label_on_image(crop_pil, pred, conf)
known_crop_composites.append(np.array(composite))
return group_crop_images, known_crop_composites, log_text
if __name__ == "__main__":
main()