""" Gradio app: Tab 1 = Object Detection (YOLO models/v1), Tab 2 = D-FINE + Classify (Jina). """ import os os.environ["YOLO_CONFIG_DIR"] = os.environ.get("YOLO_CONFIG_DIR", "/tmp") import json import numpy as np import gradio as gr from ultralytics import YOLO from pathlib import Path # Tab 2: D-FINE runs first, then Jina for crop classification from dfine_jina_pipeline import run_single_image # --- Object Detection (Tab 1) --- PERSON_CLASS = 0 CAR_CLASS = 2 KNIFE_CLASS = 80 WEAPON_CLASS = 81 DRAW_CLASSES = [PERSON_CLASS, CAR_CLASS, KNIFE_CLASS, WEAPON_CLASS] CLASS_NAMES = { PERSON_CLASS: "person", CAR_CLASS: "car", KNIFE_CLASS: "knife", WEAPON_CLASS: "weapon", } CONF = 0.25 IMGSZ = 640 BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODELS_DIR = os.path.join(BASE_DIR, "models") REFS_DIR = os.path.join(BASE_DIR, "refs") def _load_model(version: str): path = os.path.join(MODELS_DIR, version, "best.pt") if not os.path.isfile(path): raise FileNotFoundError(f"Model not found: {path}") return YOLO(path) MODELS = {"v1": _load_model("v1")} MODEL_CLASSES = { "v1": ["person", "car", "knife", "weapon"] } def run_detection(image, model): if image is None: return None, "{}" img = image if isinstance(image, np.ndarray) else np.array(image) if img.ndim == 2: img = np.stack([img] * 3, axis=-1) results = model.predict( source=img, imgsz=IMGSZ, conf=CONF, device="cpu", verbose=False, ) r = results[0] if r.boxes is None or len(r.boxes) == 0: return image, json.dumps({"detections": []}, indent=2) clss = r.boxes.cls.cpu().numpy() confs = r.boxes.conf.cpu().numpy() keep = [ i for i in range(len(r.boxes)) if int(clss[i]) in DRAW_CLASSES ] if not keep: return image, json.dumps({"detections": []}, indent=2) detections = [] for i in keep: cls_id = int(clss[i]) detections.append({ "class": CLASS_NAMES.get(cls_id, str(cls_id)), "confidence": round(float(confs[i]), 3), "bbox": r.boxes.xyxy[i].cpu().numpy().tolist(), }) r.boxes = r.boxes[keep] out_img = r.plot() det_json = json.dumps( {"detections": detections}, indent=2 ) return out_img, det_json def run_dfine_classify(image, refs_path, dfine_threshold, dfine_model_choice, min_display_conf=0.703, gap_threshold=0.005): """Tab 2: D-FINE first, then classify crops with Jina. Returns (group_crop_gallery, known_crop_gallery, status_message). """ if image is None: return [], [], "Upload an image." refs = Path(refs_path.strip()) if refs_path and refs_path.strip() else Path(REFS_DIR) if not refs.is_dir(): return [], [], f"Refs folder not found: {refs}" dfine_model = "large" if dfine_model_choice.strip().lower() == "large" else "medium" group_crops, known_crops, status = run_single_image( image, refs_dir=refs, dfine_model=dfine_model, det_threshold=float(dfine_threshold), conf_threshold=0.5, gap_threshold=float(gap_threshold), min_side=24, crop_dedup_iou=0.4, min_display_conf=float(min_display_conf), ) if status is not None: return [(g, None) for g in (group_crops or [])], [(k, None) for k in (known_crops or [])], status return [(g, None) for g in group_crops], [(k, None) for k in known_crops], "" IMG_HEIGHT = 400 TAB_STYLE = """ """ with gr.Blocks(title="Small Object Detection") as app: gr.HTML(TAB_STYLE) gr.Markdown("# Small Object Detection") with gr.Tabs(): with gr.TabItem("Object Detection"): gr.Markdown( "**Classes:** " + ", ".join(MODEL_CLASSES["v1"]) ) with gr.Row(): with gr.Column(scale=1): inp_det = gr.Image( label="Input image", height=IMG_HEIGHT ) btn_det = gr.Button( "Detect", variant="primary" ) out_img_det = gr.Image( label="Output", height=IMG_HEIGHT ) det_output = gr.JSON( label="Detections" ) btn_det.click( fn=lambda img: run_detection(img, MODELS["v1"]), inputs=inp_det, outputs=[out_img_det, det_output], ) with gr.TabItem("D-FINE + Classify"): gr.Markdown( "**D-FINE** runs first (person/car grouping), then small-object crops are classified with **Jina**. " "Choose D-FINE model size (Medium or Large). " "Uses the **refs** folder (one subfolder per class, e.g. refs/phone/, refs/cigarette/) " "with reference images.\n\n" "**Gap** = how much the top class (e.g. gun) must beat the next-best class (e.g. phone). " "Bigger gap means the model is more sure; we only accept the label if both confidence and gap are high enough." ) with gr.Row(): with gr.Column(scale=1): inp_dfine = gr.Image( type="pil", label="Input image", height=IMG_HEIGHT ) dfine_model_radio = gr.Radio( choices=["Medium", "Large"], value="Large", label="D-FINE model", ) # Default threshold: Large=0.2, Medium=0.15 (slider updates when model changes) dfine_threshold_slider = gr.Slider( minimum=0.05, maximum=0.5, value=0.2, step=0.05, label="D-FINE detection threshold (applied to chosen model)", ) def update_dfine_threshold_default(choice): return gr.update(value=0.2 if (choice and choice.strip().lower() == "large") else 0.15) dfine_model_radio.change( fn=update_dfine_threshold_default, inputs=[dfine_model_radio], outputs=[dfine_threshold_slider], ) refs_path = gr.Textbox( label="Refs folder path", value=REFS_DIR, placeholder="e.g. refs or /path/to/refs", ) btn_dfine = gr.Button( "Run D-FINE + Classify", variant="primary" ) with gr.Column(scale=1): threshold_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.703, step=0.005, label="Threshold (min display confidence)", ) gap_slider = gr.Slider( minimum=0.0, maximum=0.02, value=0.005, step=0.001, label="Gap: how much the top guess must beat the runner-up (higher = stricter, fewer accepted)", ) out_gallery_dfine = gr.Gallery( label="Person/car crops (all D-FINE objects inside drawn with label + score)", height=IMG_HEIGHT, columns=2, object_fit="contain", ) out_gallery_known = gr.Gallery( label="Known objects (class + score above each crop)", height=IMG_HEIGHT, columns=4, object_fit="contain", ) out_status_dfine = gr.Textbox( label="Status", lines=2, interactive=False, ) btn_dfine.click( fn=run_dfine_classify, inputs=[inp_dfine, refs_path, dfine_threshold_slider, dfine_model_radio, threshold_slider, gap_slider], outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine], concurrency_limit=1, ) app.launch( server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"), server_port=int( os.environ.get( "PORT", os.environ.get("GRADIO_SERVER_PORT", 7860) ) ), )