File size: 28,857 Bytes
d6dee9e
b915b77
 
82551bb
0930a73
82551bb
 
 
 
 
 
 
d6dee9e
 
82551bb
 
 
 
 
 
 
 
 
be400de
d6dee9e
82551bb
 
be400de
 
bf32d2c
 
89db832
 
be400de
82551bb
d6dee9e
82551bb
 
 
 
 
 
 
 
 
 
 
 
d6dee9e
82551bb
 
 
 
 
 
 
 
 
d6dee9e
82551bb
 
 
 
 
 
 
 
d6dee9e
82551bb
 
 
 
 
d6dee9e
82551bb
 
 
d6dee9e
82551bb
 
 
 
 
 
 
 
 
 
 
d6dee9e
82551bb
 
 
 
 
 
d6dee9e
82551bb
 
 
 
 
 
 
d6dee9e
 
 
 
82551bb
 
3ecb1be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2035e
3ecb1be
cc2035e
 
3ecb1be
 
82551bb
 
 
 
 
 
 
 
 
d6dee9e
82551bb
 
d6dee9e
82551bb
 
 
 
 
 
 
 
d6dee9e
82551bb
d6dee9e
82551bb
 
d6dee9e
82551bb
 
 
 
 
 
 
 
 
d6dee9e
82551bb
 
 
d6dee9e
82551bb
 
 
d6dee9e
82551bb
 
 
 
 
 
 
d6dee9e
 
82551bb
 
d6dee9e
82551bb
 
 
0930a73
 
82551bb
 
 
 
d6dee9e
82551bb
b915b77
82551bb
 
 
0930a73
d6dee9e
 
 
 
 
 
82551bb
 
 
2455309
82551bb
 
 
 
 
 
 
d6dee9e
82551bb
 
 
 
 
d6dee9e
82551bb
 
 
d6dee9e
82551bb
 
 
 
 
d6dee9e
 
82551bb
 
 
 
d6dee9e
82551bb
 
d6dee9e
82551bb
 
 
 
 
 
 
 
d6dee9e
 
 
82551bb
 
d6dee9e
82551bb
d6dee9e
82551bb
d6dee9e
 
 
 
 
82551bb
 
 
 
 
 
 
0930a73
d6dee9e
0930a73
 
82551bb
 
 
d6dee9e
82551bb
 
 
 
 
 
 
 
 
 
d6dee9e
 
 
 
82551bb
 
d6dee9e
82551bb
 
 
 
2455309
b915b77
82551bb
b915b77
 
82551bb
 
d6dee9e
82551bb
d6dee9e
82551bb
 
 
d6dee9e
 
 
 
 
 
 
 
82551bb
 
 
 
d6dee9e
82551bb
 
 
 
d6dee9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82551bb
 
 
 
 
 
 
d6dee9e
 
 
 
 
 
 
 
 
89db832
82551bb
 
 
d6dee9e
82551bb
 
d6dee9e
 
 
 
 
82551bb
 
d6dee9e
 
3ecb1be
82551bb
d6dee9e
 
3ecb1be
d6dee9e
 
 
 
 
 
 
 
 
cd66f30
 
 
 
 
d6dee9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b915b77
d6dee9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82551bb
 
d6dee9e
 
 
 
 
 
0930a73
d6dee9e
0930a73
d6dee9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82551bb
 
d6dee9e
82551bb
 
 
b915b77
82551bb
d6dee9e
b915b77
81de9b1
b915b77
2455309
 
 
 
 
 
 
 
 
 
 
 
 
 
82551bb
33708c6
81de9b1
 
80cacd4
81de9b1
80cacd4
 
81de9b1
 
 
 
 
 
33708c6
 
 
 
 
80cacd4
81de9b1
 
 
 
 
80cacd4
81de9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
82551bb
 
 
80cacd4
82551bb
b915b77
82551bb
 
 
 
 
 
bb55c4d
d5c03fb
80cacd4
82551bb
 
80cacd4
d6dee9e
80cacd4
 
 
d6dee9e
be400de
82551bb
d6dee9e
bb55c4d
 
 
d6dee9e
 
80cacd4
0930a73
80cacd4
 
 
 
 
 
 
0930a73
2455309
b915b77
2455309
b915b77
 
82551bb
 
0930a73
82551bb
 
 
 
b915b77
 
 
 
 
 
 
 
0930a73
b915b77
82551bb
b915b77
 
89db832
82551bb
89db832
82551bb
 
 
 
 
80cacd4
81de9b1
80cacd4
81de9b1
80cacd4
 
 
cc2035e
81de9b1
d6dee9e
cc2035e
 
100dbc1
cc2035e
5ad3720
 
 
 
d6dee9e
cc2035e
 
 
 
 
 
 
 
 
 
d6dee9e
5ad3720
 
 
 
 
 
 
 
 
 
 
 
 
d6dee9e
 
cc2035e
 
d6dee9e
 
 
 
cd66f30
 
 
 
cc2035e
 
 
 
d6dee9e
 
cc2035e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6dee9e
cc2035e
d6dee9e
cc2035e
 
d6dee9e
cc2035e
 
 
 
 
 
 
 
 
 
 
0930a73
cc2035e
 
 
 
81de9b1
e06b945
 
cc2035e
 
 
100dbc1
 
 
 
 
 
 
 
cc2035e
 
 
 
 
 
 
 
d6dee9e
cc2035e
 
0930a73
100dbc1
 
 
cc2035e
5ad3720
be400de
bb55c4d
be400de
 
80cacd4
be400de
 
 
0930a73
100dbc1
82551bb
 
 
0930a73
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
""" Pipeline: D-FINE (person/car only) → group detections → crop regions →
find all bboxes inside each crop → Jina-CLIP-v2 embeddings and classification.
Outputs jina_crops folder and results CSV.
"""

import argparse
import csv
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoImageProcessor, DFineForObjectDetection

# Jina-CLIP-v2 few-shot (same refs + classify as jina_fewshot.py)
from jina_fewshot import (
    IMAGE_EXTS,
    TRUNCATE_DIM,
    JinaCLIPv2Encoder,
    build_refs,
    classify as jina_classify,
    draw_bboxes_on_image,
    draw_label_on_image,
)

# Only these ref classes get bboxes on group crops and appear in the known-object gallery
KNOWN_DISPLAY_CLASSES = {"gun", "knife", "cigarette", "phone"}
# Only show objects (and group crops) with confidence >= this
MIN_DISPLAY_CONF = 0.7
# Person/car detections must have confidence > this to be used for grouping
PERSON_CAR_MIN_CONF = 0.9

# -----------------------------------------------------------------------------
# Detection + grouping (from reference_detection.py)
# -----------------------------------------------------------------------------

def get_box_dist(box1, box2):
    """Euclidean distance between box centers. box = [x1, y1, x2, y2]."""
    c1 = np.array([(box1[0] + box1[2]) / 2, (box1[1] + box1[3]) / 2])
    c2 = np.array([(box2[0] + box2[2]) / 2, (box2[1] + box2[3]) / 2])
    return np.linalg.norm(c1 - c2)


def group_detections(detections, threshold):
    """
    Group detections by proximity (center distance < threshold).

    detections: list of {"box": [x1,y1,x2,y2], "conf", "cls", ...}
    Returns list of {"box": merged [x1,y1,x2,y2], "conf": best in group, "cls": best in group}.
    """
    if not detections:
        return []

    boxes = [d["box"] for d in detections]
    n = len(boxes)
    adj = {i: [] for i in range(n)}

    for i in range(n):
        for j in range(i + 1, n):
            if get_box_dist(boxes[i], boxes[j]) < threshold:
                adj[i].append(j)
                adj[j].append(i)

    groups = []
    visited = [False] * n

    for i in range(n):
        if not visited[i]:
            group_indices = []
            stack = [i]
            visited[i] = True

            while stack:
                curr = stack.pop()
                group_indices.append(curr)

                for neighbor in adj[curr]:
                    if not visited[neighbor]:
                        visited[neighbor] = True
                        stack.append(neighbor)

            group_dets = [detections[k] for k in group_indices]
            x1 = min(d["box"][0] for d in group_dets)
            y1 = min(d["box"][1] for d in group_dets)
            x2 = max(d["box"][2] for d in group_dets)
            y2 = max(d["box"][3] for d in group_dets)

            best_det = max(group_dets, key=lambda x: x["conf"])
            groups.append({
                "box": [x1, y1, x2, y2],
                "conf": best_det["conf"],
                "cls": best_det["cls"],
                "label": best_det.get("label", str(best_det["cls"])),
            })

    return groups


def box_center_inside(box, crop_box):
    """True if center of box is inside crop_box. All [x1,y1,x2,y2]."""
    cx = (box[0] + box[2]) / 2
    cy = (box[1] + box[3]) / 2
    return (
        crop_box[0] <= cx <= crop_box[2]
        and crop_box[1] <= cy <= crop_box[3]
    )


def expand_box_by_margin(box, margin_ratio, img_w, img_h):
    """Expand box [x1,y1,x2,y2] by margin_ratio (e.g. 0.1 = 10%) on all sides, clamped to image."""
    x1, y1, x2, y2 = box
    w, h = x2 - x1, y2 - y1
    if w <= 0 or h <= 0:
        return box
    mx = w * margin_ratio
    my = h * margin_ratio
    x1 = max(0, x1 - mx)
    y1 = max(0, y1 - my)
    x2 = min(img_w, x2 + mx)
    y2 = min(img_h, y2 + my)
    return [x1, y1, x2, y2]


# 10% margin on person/car group crop (expand crop before running D-FINE on it)
PERSON_CAR_GROUP_MARGIN = 0.10
# Min side (px) for object crops extracted from person/car crop before sending to classifier (objects in crop are larger)
MIN_OBJECT_CROP_SIDE = 112


def squarify_crop_box(bx1, by1, bx2, by2, img_w, img_h):
    """
    Expand the shorter side to match the longer (same ratio / square), centered, clamped to image.
    If height > width: expand width. If width >= height: expand height.
    Returns (bx1, by1, bx2, by2) as integers.
    """
    orig = (int(bx1), int(by1), int(bx2), int(by2))
    w = bx2 - bx1
    h = by2 - by1

    if w <= 0 or h <= 0:
        return orig

    if h > w:
        add = (h - w) / 2.0
        bx1 = max(0, bx1 - add)
        bx2 = min(img_w, bx2 + add)
    else:
        add = (w - h) / 2.0
        by1 = max(0, by1 - add)
        by2 = min(img_h, by2 + add)

    bx1, by1, bx2, by2 = int(bx1), int(by1), int(bx2), int(by2)

    if bx2 <= bx1 or by2 <= by1:
        return orig

    return bx1, by1, bx2, by2


def box_iou(box1, box2):
    """IoU of two boxes [x1,y1,x2,y2]. Returns float in [0, 1]."""
    ix1 = max(box1[0], box2[0])
    iy1 = max(box1[1], box2[1])
    ix2 = min(box1[2], box2[2])
    iy2 = min(box1[3], box2[3])

    inter_w = max(0, ix2 - ix1)
    inter_h = max(0, iy2 - iy1)
    inter = inter_w * inter_h

    a1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    a2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union = a1 + a2 - inter

    return inter / union if union > 0 else 0.0


def deduplicate_by_iou(detections, iou_threshold=0.9):
    """Keep one detection per overlapping group (IoU >= iou_threshold). Prefer higher confidence."""
    if not detections:
        return []

    # Sort by confidence descending; keep first, then add only if no kept box overlaps >= threshold
    sorted_d = sorted(detections, key=lambda x: -x["conf"])
    kept = []

    for d in sorted_d:
        if not any(box_iou(d["box"], k["box"]) >= iou_threshold for k in kept):
            kept.append(d)

    return kept


def parse_args():
    p = argparse.ArgumentParser(
        description="D-FINE (person/car) → group → Jina-CLIP-v2 on crops inside groups"
    )
    p.add_argument("--refs", required=True, help="Reference images folder for Jina (e.g. refs/)")
    p.add_argument("--input", required=True, help="Full-frame images folder")
    p.add_argument("--output", default="pipeline_results", help="Output folder (CSV, etc.)")
    p.add_argument("--det-threshold", type=float, default=0.13, help="D-FINE score threshold")
    p.add_argument("--group-dist", type=float, default=None, help="Group distance (default: 0.1 * max(H,W))")
    p.add_argument("--min-side", type=int, default=40, help="Min side of expanded bbox in px (skip smaller)")
    p.add_argument("--crop-dedup-iou", type=float, default=0.35, help="Min IoU to treat two crops as same object (keep larger)")
    p.add_argument("--no-squarify", action="store_true", help="Skip squarify; use expanded bbox only (tighter crops, often better recognition)")
    p.add_argument("--padding", type=float, default=0.2, help="Crop padding around group box (0.2 = 20%%)")
    p.add_argument("--conf-threshold", type=float, default=0.75, help="Jina accept confidence")
    p.add_argument("--gap-threshold", type=float, default=0.05, help="Jina accept gap")
    p.add_argument("--text-weight", type=float, default=0.3)
    p.add_argument("--max-images", type=int, default=None)
    p.add_argument("--device", default=None)
    p.add_argument("--dfine-model", choices=list(DFINE_MODEL_IDS.keys()), default="large-obj365", help="D-FINE model")
    return p.parse_args()


def get_person_car_label_ids(model):
    """Return set of label IDs for person and car (Objects365: Person, Car, SUV, etc.)."""
    id2label = getattr(model.config, "id2label", None) or {}
    ids = set()

    for idx, name in id2label.items():
        try:
            i = int(idx)
        except (ValueError, TypeError):
            continue

        n = (name or "").lower()
        if "person" in n or n in ("car", "suv"):
            ids.add(i)

    return ids


def run_dfine(image, processor, model, device, score_threshold):
    """Run D-FINE, return all detections as list of {box, score, label_id, label}."""
    from PIL import Image

    if isinstance(image, Image.Image):
        pil = image.convert("RGB")
    else:
        pil = Image.fromarray(image).convert("RGB")

    w, h = pil.size
    target_size = torch.tensor([[h, w]], device=device)

    inputs = processor(images=pil, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)

    target_sizes = target_size.to(outputs["logits"].device)
    results = processor.post_process_object_detection(
        outputs,
        target_sizes=target_sizes,
        threshold=score_threshold,
    )

    id2label = getattr(model.config, "id2label", {}) or {}
    detections = []

    for result in results:
        for score, label_id, box in zip(
            result["scores"],
            result["labels"],
            result["boxes"]
        ):
            sid = int(label_id.item())
            detections.append({
                "box": [float(x) for x in box.cpu().tolist()],
                "conf": float(score.item()),
                "cls": sid,
                "label": id2label.get(sid, str(sid)),
            })

    return detections


def main():
    args = parse_args()
    device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")

    input_dir = Path(args.input)
    output_dir = Path(args.output)
    refs_dir = Path(args.refs)
    output_dir.mkdir(parents=True, exist_ok=True)

    if not refs_dir.is_dir():
        raise SystemExit(f"Refs folder not found: {refs_dir}")
    if not input_dir.is_dir():
        raise SystemExit(f"Input folder not found: {input_dir}")

    paths = sorted(
        p for p in input_dir.iterdir()
        if p.suffix.lower() in IMAGE_EXTS
    )
    if args.max_images is not None:
        paths = paths[: args.max_images]

    if not paths:
        raise SystemExit(f"No images in {input_dir}")

    # Load D-FINE
    dfine_model_id = DFINE_MODEL_IDS.get(args.dfine_model, DFINE_MODEL_IDS["large-obj365"])
    print(f"[*] Loading D-FINE ({dfine_model_id})...")
    t0 = time.perf_counter()
    image_processor = AutoImageProcessor.from_pretrained(dfine_model_id)
    dfine_model = DFineForObjectDetection.from_pretrained(dfine_model_id)
    dfine_model = dfine_model.to(device).eval()
    person_car_ids = get_person_car_label_ids(dfine_model)
    print(f"  Person/car label IDs: {person_car_ids} ({time.perf_counter()-t0:.1f}s)")

    # Load Jina-CLIP-v2 + build refs
    print("[*] Loading Jina-CLIP-v2 and building refs...")
    t0 = time.perf_counter()
    jina_encoder = JinaCLIPv2Encoder(device)
    ref_labels, ref_embs = build_refs(
        jina_encoder,
        refs_dir,
        TRUNCATE_DIM,
        args.text_weight,
        batch_size=16
    )
    print(f"  Jina refs: {ref_labels} ({time.perf_counter()-t0:.1f}s)\n")

    jina_crops_dir = output_dir / "jina_crops"
    jina_crops_dir.mkdir(parents=True, exist_ok=True)

    # CSV
    csv_path = output_dir / "results.csv"
    f = open(csv_path, "w", newline="")
    w = csv.writer(f)
    w.writerow([
        "image",
        "crop_filename",
        "group_idx",
        "crop_x1",
        "crop_y1",
        "crop_x2",
        "crop_y2",
        "bbox_x1",
        "bbox_y1",
        "bbox_x2",
        "bbox_y2",
        "dfine_label",
        "dfine_conf",
        "jina_prediction",
        "jina_confidence",
        "jina_status",
    ])

    for img_path in paths:
        pil = Image.open(img_path).convert("RGB")
        img_w, img_h = pil.size
        group_dist = args.group_dist if args.group_dist is not None else 0.1 * max(img_h, img_w)

        # 1) D-FINE: detect everything, keep all bboxes for the image
        detections = run_dfine(
            pil,
            image_processor,
            dfine_model,
            device,
            args.det_threshold
        )

        person_car = [d for d in detections if d["cls"] in person_car_ids and d["conf"] > PERSON_CAR_MIN_CONF]
        if not person_car:
            continue

        # 2) Group person/car detections (same as reference)
        grouped = group_detections(person_car, group_dist)
        grouped.sort(key=lambda x: x["conf"], reverse=True)
        top_groups = grouped[:10]  # limit groups per image

        # 3) Collect all candidate crops (bboxes inside person/car groups)
        # Each: (crop_box, crop_pil, d, gidx, crop_idx, x1, y1, x2, y2)
        candidates = []

        for gidx, grp in enumerate(top_groups):
            x1, y1, x2, y2 = grp["box"]
            group_box = [x1, y1, x2, y2]
            group_box_with_margin = expand_box_by_margin(group_box, PERSON_CAR_GROUP_MARGIN, img_w, img_h)

            inside = [
                d for d in detections
                if box_center_inside(d["box"], group_box_with_margin) and d["cls"] not in person_car_ids
            ]
            inside = deduplicate_by_iou(inside, iou_threshold=0.9)

            for crop_idx, d in enumerate(inside):
                bx1, by1, bx2, by2 = [float(x) for x in d["box"]]
                obj_w, obj_h = bx2 - bx1, by2 - by1
                if obj_w <= 0 or obj_h <= 0:
                    continue

                # Small objects (min side < 24 px): expand by 60%; larger: 30%
                min_side_obj = min(obj_w, obj_h)
                pad_ratio = 0.6 if min_side_obj < 24 else 0.3
                pad_x = obj_w * pad_ratio
                pad_y = obj_h * pad_ratio
                bx1 = max(0, int(bx1 - pad_x))
                by1 = max(0, int(by1 - pad_y))
                bx2 = min(img_w, int(bx2 + pad_x))
                by2 = min(img_h, int(by2 + pad_y))

                if bx2 <= bx1 or by2 <= by1:
                    continue

                if min(bx2 - bx1, by2 - by1) < args.min_side:
                    continue

                expanded_box = [bx1, by1, bx2, by2]
                candidates.append((expanded_box, d, gidx, crop_idx, x1, y1, x2, y2))

        # 4) Dedup on EXPANDED boxes (before squarify), keep larger; then squarify only kept
        def crop_area(box):
            return (box[2] - box[0]) * (box[3] - box[1])

        candidates.sort(key=lambda c: -crop_area(c[0]))
        kept = []

        for c in candidates:
            expanded_box = c[0]

            def is_same_object(box_a, box_b):
                if box_iou(box_a, box_b) >= args.crop_dedup_iou:
                    return True
                if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a):
                    return True
                return False

            if not any(is_same_object(expanded_box, k[0]) for k in kept):
                kept.append(c)

        # 5) Optionally squarify, then run Jina on kept crops
        for i, (expanded_box, d, gidx, crop_idx, x1, y1, x2, y2) in enumerate(kept):
            if not args.no_squarify:
                bx1, by1, bx2, by2 = squarify_crop_box(
                    expanded_box[0],
                    expanded_box[1],
                    expanded_box[2],
                    expanded_box[3],
                    img_w,
                    img_h
                )
            else:
                bx1, by1, bx2, by2 = expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3]

            crop_pil = pil.crop((bx1, by1, bx2, by2))
            crop_name = f"{img_path.stem}_g{gidx}_{i}_{bx1}_{by1}_{bx2}_{by2}{img_path.suffix}"

            q_jina = jina_encoder.encode_images([crop_pil], TRUNCATE_DIM)
            result_jina = jina_classify(
                q_jina,
                ref_labels,
                ref_embs,
                args.conf_threshold,
                args.gap_threshold
            )

            if result_jina["prediction"] in ref_labels:
                label_jina = result_jina["prediction"]
                conf_jina = result_jina["confidence"]
            else:
                label_jina = f"unnamed (dfine: {d['label']})"
                conf_jina = 0.0

            ann_jina = draw_label_on_image(crop_pil, label_jina, conf_jina)
            ann_jina.save(jina_crops_dir / crop_name)

            w.writerow([
                img_path.name,
                crop_name,
                gidx,
                x1,
                y1,
                x2,
                y2,
                bx1,
                by1,
                bx2,
                by2,
                d["label"],
                f"{d['conf']:.4f}",
                result_jina["prediction"],
                f"{result_jina['confidence']:.4f}",
                result_jina["status"],
            ])

    f.close()
    print(f"[*] Wrote {csv_path}")
    print(f"[*] Jina crops: {jina_crops_dir}")


# -----------------------------------------------------------------------------
# Single-image runner for Gradio app: D-FINE first, then Jina
# -----------------------------------------------------------------------------

_APP_DFINE = None  # (model_id, image_processor, dfine_model, person_car_ids)
_APP_CLASSIFIERS = {}  # {classifier_name: (classifier_instance, refs_dir_str)}

DFINE_MODEL_IDS = {
    # obj365
    "small-obj365": "ustc-community/dfine-small-obj365",
    "medium-obj365": "ustc-community/dfine-medium-obj365",
    "large-obj365": "ustc-community/dfine-large-obj365",
    # coco
    "small-coco": "ustc-community/dfine-small-coco",
    "medium-coco": "ustc-community/dfine-medium-coco",
    "large-coco": "ustc-community/dfine-large-coco",
    # obj2coco
    "small-obj2coco": "ustc-community/dfine-small-obj2coco",
    "medium-obj2coco": "ustc-community/dfine-medium-obj2coco",
    "large-obj2coco": "ustc-community/dfine-large-obj2coco-e25",
}

CLASSIFIER_CHOICES = ["jina", "siglip-224", "siglip-256", "siglip-384", "siglip2_onnx"]


def _load_classifier(classifier_name, device, refs_dir=None, labels=None):
    """Factory: load and initialize a classifier by name."""
    if refs_dir:
        refs_dir = Path(refs_dir)

    if classifier_name == "jina":
        jina_encoder = JinaCLIPv2Encoder(device)
        ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
        return ("jina_wrapped", jina_encoder, ref_labels, ref_embs)

    if classifier_name.startswith("siglip-"):
        from siglip_zeroshot import SigLIPClassifier, SIGLIP_MODELS
        if classifier_name not in SIGLIP_MODELS:
            raise ValueError(f"Unknown SigLIP model: {classifier_name}. Choose from {list(SIGLIP_MODELS.keys())}")
        clf = SigLIPClassifier(device, model_key=classifier_name)
        clf.build_refs(refs_dir=refs_dir, labels=labels)
        return clf

    if classifier_name == "siglip2_onnx":
        from siglip2_onnx_zeroshot import SigLIP2ONNXClassifier
        clf = SigLIP2ONNXClassifier(device)
        clf.build_refs(refs_dir=refs_dir, labels=labels)
        return clf

    raise ValueError(f"Unknown classifier: {classifier_name}. Choose from {CLASSIFIER_CHOICES}")


def _classify_crop(classifier, crop, conf_threshold, gap_threshold):
    """Unified classify call that works for both Jina (tuple) and SigLIP-style classifiers."""
    if isinstance(classifier, tuple) and classifier[0] == "jina_wrapped":
        _, jina_encoder, ref_labels, ref_embs = classifier
        q = jina_encoder.encode_images([crop], TRUNCATE_DIM)
        return jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
    else:
        return classifier.classify_crop(crop, conf_threshold, gap_threshold)


def run_single_image(
    pil_image,
    refs_dir=None,
    device=None,
    dfine_model="large",
    det_threshold=0.3,
    conf_threshold=0.75,
    gap_threshold=0.05,
    min_side=40,
    crop_dedup_iou=0.35,
    squarify=True,
    min_display_conf=None,
    classifier="siglip-256",
    labels=None,
):
    """
    Run D-FINE on one image, then classify small-object crops.

    refs_dir: path to refs folder (str or Path), optional if labels provided.
    labels: list of class label strings for zero-shot classifiers.
    dfine_model: key from DFINE_MODEL_IDS.

    Returns (group_crop_images, known_crop_composites, status_message).
    """
    import numpy as np

    if min_display_conf is None:
        min_display_conf = MIN_DISPLAY_CONF
    from PIL import Image

    global _APP_DFINE

    if refs_dir:
        refs_dir = Path(refs_dir)
        if not refs_dir.is_dir():
            return [], [], f"Refs folder not found: {refs_dir}"

    if not refs_dir and not labels:
        return [], [], "Provide either refs_dir or labels."

    dfine_model = (dfine_model or "large-obj365").strip().lower()
    if dfine_model not in DFINE_MODEL_IDS:
        dfine_model = "large-obj365"
    model_id = DFINE_MODEL_IDS[dfine_model]

    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[*] Device: {device}")

    pil = pil_image.convert("RGB") if isinstance(pil_image, Image.Image) else Image.fromarray(pil_image).convert("RGB")
    img_w, img_h = pil.size
    group_dist = 0.1 * max(img_h, img_w)

    # Load D-FINE (reload if user switched model)
    if _APP_DFINE is None or _APP_DFINE[0] != dfine_model:
        print(f"[*] Loading D-FINE ({model_id})...")
        image_processor = AutoImageProcessor.from_pretrained(model_id)
        dfine_model_obj = DFineForObjectDetection.from_pretrained(model_id)
        dfine_model_obj = dfine_model_obj.to(device).eval()
        person_car_ids = get_person_car_label_ids(dfine_model_obj)
        _APP_DFINE = (dfine_model, image_processor, dfine_model_obj, person_car_ids)

    _model_id, image_processor, dfine_model_obj, person_car_ids = _APP_DFINE

    # Apply user's D-FINE detection threshold to the chosen model (medium or large)
    detections = run_dfine(pil, image_processor, dfine_model_obj, device, det_threshold)
    person_car = [d for d in detections if d["cls"] in person_car_ids and d["conf"] > PERSON_CAR_MIN_CONF]
    if not person_car:
        return [], [], "No person/car detected (or none with confidence > 0.9). No small-object crops."

    grouped = group_detections(person_car, group_dist)
    grouped.sort(key=lambda x: x["conf"], reverse=True)
    top_groups = grouped[:10]

    # Load classifier
    global _APP_CLASSIFIERS
    cache_key = str(labels) if labels else str(refs_dir)
    clf_key = classifier
    if clf_key not in _APP_CLASSIFIERS or _APP_CLASSIFIERS[clf_key][1] != cache_key:
        clf_instance = _load_classifier(classifier, device, refs_dir=refs_dir, labels=labels)
        _APP_CLASSIFIERS[clf_key] = (clf_instance, cache_key)

    clf_instance = _APP_CLASSIFIERS[clf_key][0]

    results_per_crop = []
    group_crop_images = []
    classification_log = []

    # Non-person/car detections from the full-frame pass, reused per group
    other_detections = [d for d in detections if d["cls"] not in person_car_ids]

    # For each person/car group: crop (with 10% margin), reuse full-frame detections that fall inside, then classify
    for gidx, grp in enumerate(top_groups):
        group_box = [grp["box"][0], grp["box"][1], grp["box"][2], grp["box"][3]]
        crop_box = expand_box_by_margin(group_box, PERSON_CAR_GROUP_MARGIN, img_w, img_h)
        gx1 = max(0, int(crop_box[0]))
        gy1 = max(0, int(crop_box[1]))
        gx2 = min(img_w, int(crop_box[2]))
        gy2 = min(img_h, int(crop_box[3]))
        if gx2 <= gx1 or gy2 <= gy1:
            continue
        crop_pil = pil.crop((gx1, gy1, gx2, gy2)).copy().convert("RGB")
        crop_w, crop_h = crop_pil.size

        # Filter full-frame detections whose center falls inside this crop box, then remap to crop-local coords
        inside_full = [d for d in other_detections if box_center_inside(d["box"], [gx1, gy1, gx2, gy2])]
        inside = []
        for d in inside_full:
            remapped = dict(d)
            fx1, fy1, fx2, fy2 = d["box"]
            remapped["box"] = [
                max(0, fx1 - gx1),
                max(0, fy1 - gy1),
                min(crop_w, fx2 - gx1),
                min(crop_h, fy2 - gy1),
            ]
            inside.append(remapped)
        inside = deduplicate_by_iou(inside, iou_threshold=0.9)

        candidates = []
        for d in inside:
            bx1, by1, bx2, by2 = [float(x) for x in d["box"]]
            obj_w, obj_h = bx2 - bx1, by2 - by1
            if obj_w <= 0 or obj_h <= 0:
                continue
            min_side_obj = min(obj_w, obj_h)
            pad_ratio = 0.6 if min_side_obj < 24 else 0.3
            pad_x = obj_w * pad_ratio
            pad_y = obj_h * pad_ratio
            bx1 = max(0.0, bx1 - pad_x)
            by1 = max(0.0, by1 - pad_y)
            bx2 = min(crop_w, bx2 + pad_x)
            by2 = min(crop_h, by2 + pad_y)
            if bx2 <= bx1 or by2 <= by1:
                continue
            w, h = bx2 - bx1, by2 - by1
            if min(w, h) < MIN_OBJECT_CROP_SIDE:
                need = MIN_OBJECT_CROP_SIDE - min(w, h)
                half = need / 2.0
                if w < h:
                    bx1 = max(0, bx1 - half)
                    bx2 = min(crop_w, bx2 + half)
                else:
                    by1 = max(0, by1 - half)
                    by2 = min(crop_h, by2 + half)
                w, h = bx2 - bx1, by2 - by1
                if w < MIN_OBJECT_CROP_SIDE:
                    add = (MIN_OBJECT_CROP_SIDE - w) / 2
                    bx1 = max(0, bx1 - add)
                    bx2 = min(crop_w, bx2 + add)
                if h < MIN_OBJECT_CROP_SIDE:
                    add = (MIN_OBJECT_CROP_SIDE - h) / 2
                    by1 = max(0, by1 - add)
                    by2 = min(crop_h, by2 + add)
            bx1, by1, bx2, by2 = int(bx1), int(by1), int(bx2), int(by2)
            if bx2 <= bx1 or by2 <= by1:
                continue
            candidates.append(([bx1, by1, bx2, by2], d, gidx))

        def crop_area(box):
            return (box[2] - box[0]) * (box[3] - box[1])

        candidates.sort(key=lambda c: -crop_area(c[0]))
        kept = []
        for c in candidates:
            expanded_box = c[0]
            if not any(
                box_iou(expanded_box, k[0]) >= crop_dedup_iou
                or box_center_inside(expanded_box, k[0])
                or box_center_inside(k[0], expanded_box)
                for k in kept
            ):
                kept.append(c)

        for (bx1, by1, bx2, by2), d, _ in kept:
            if squarify:
                bx1, by1, bx2, by2 = squarify_crop_box(bx1, by1, bx2, by2, crop_w, crop_h)
            small_crop = crop_pil.crop((bx1, by1, bx2, by2))
            result = _classify_crop(clf_instance, small_crop, conf_threshold, gap_threshold)
            raw_pred = result["prediction"]
            pred = raw_pred if raw_pred != "unknown" else f"unknown ({d['label']})"
            conf = result["confidence"]
            results_per_crop.append((gidx, (bx1, by1, bx2, by2), small_crop, pred, conf))

            # Build per-crop log line
            sims_str = ", ".join(f"{k}: {v:.4f}" for k, v in result.get("all_sims", {}).items())
            classification_log.append(
                f"[group {gidx}] dfine: {d['label']} ({d['conf']:.3f}) → "
                f"{pred} (conf={conf:.4f}, gap={result['gap']:.4f}, 2nd={result.get('second_best','?')}) "
                f"| {result['status']} | {sims_str}"
            )

        # Draw bboxes on this group crop (bboxes already in crop coords)
        boxes_to_draw = [
            (bx1, by1, bx2, by2, pred, conf)
            for (gidx2, (bx1, by1, bx2, by2), _sc, pred, conf) in results_per_crop
            if gidx2 == gidx
        ]
        if boxes_to_draw:
            crop_pil_drawn = draw_bboxes_on_image(crop_pil.copy(), boxes_to_draw)
        else:
            crop_pil_drawn = crop_pil
        group_crop_images.append(np.array(crop_pil_drawn))

    log_text = f"Classifier: {classifier} | {len(results_per_crop)} crops classified\n"
    log_text += "\n".join(classification_log) if classification_log else "(no crops)"

    if not results_per_crop:
        return group_crop_images if group_crop_images else [], [], log_text + "\nNo small-object crops: no object detections (gun/phone/etc.) found inside person/car groups, or all were below min size."

    # Build known-only gallery: only objects with conf >= min_display_conf
    known_crop_composites = []
    for (_gidx, _box, crop_pil, pred, conf) in results_per_crop:
        if pred.startswith("unknown") or conf < min_display_conf:
            continue
        composite = draw_label_on_image(crop_pil, pred, conf)
        known_crop_composites.append(np.array(composite))

    return group_crop_images, known_crop_composites, log_text


if __name__ == "__main__":
    main()