from __future__ import annotations import logging import time import threading from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np from PIL import Image from models import DeliveryCountMap, AgentTraceEntry from catalog import FMCGCatalog, get_catalog from tracer import make_trace_entry logger = logging.getLogger(__name__) AGENT_NAME = "Visual_Counter" AGENT_VERSION = "1.0.0" INPUT_SIZE = 640 CONF_THRESHOLD = 0.25 IOU_THRESHOLD = 0.45 _TIMEOUT_SECONDS = 60 def _iou(box_a: np.ndarray, box_b: np.ndarray) -> float: """IoU between two [x1,y1,x2,y2] boxes.""" xi1 = max(box_a[0], box_b[0]) yi1 = max(box_a[1], box_b[1]) xi2 = min(box_a[2], box_b[2]) yi2 = min(box_a[3], box_b[3]) inter = max(0.0, xi2 - xi1) * max(0.0, yi2 - yi1) area_a = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]) area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) union = area_a + area_b - inter return inter / union if union > 0 else 0.0 def _nms(boxes: np.ndarray, scores: np.ndarray, iou_threshold: float) -> List[int]: """Non-maximum suppression; returns kept indices sorted by descending score.""" order = scores.argsort()[::-1].tolist() kept = [] while order: i = order.pop(0) kept.append(i) order = [j for j in order if _iou(boxes[i], boxes[j]) < iou_threshold] return kept class VisualCounterAgent: def __init__(self, session, class_names: List[str]) -> None: self._session = session self._class_names = class_names self._catalog: Optional[FMCGCatalog] = None def _get_catalog(self) -> FMCGCatalog: if self._catalog is None: self._catalog = get_catalog() return self._catalog def _preprocess(self, image_path: str) -> Tuple[np.ndarray, Tuple[int, int]]: img = Image.open(image_path).convert("RGB") orig_shape = (img.height, img.width) img = img.resize((INPUT_SIZE, INPUT_SIZE), Image.BILINEAR) arr = np.array(img, dtype=np.float32) / 255.0 # [H, W, C] arr = arr.transpose(2, 0, 1) # [C, H, W] arr = np.expand_dims(arr, axis=0) # [1, C, H, W] return arr, orig_shape def _postprocess( self, raw_output: np.ndarray, orig_shape: Tuple[int, int] ) -> Dict[str, int]: """ raw_output shape: [1, num_preds, 4+1+num_classes] (YOLOv8 ONNX format) or [1, 4+num_classes, num_preds] (older export layout). We handle both by checking shape. """ out = raw_output[0] # remove batch dim # Normalise to [num_preds, 4+1+num_classes] if out.shape[0] < out.shape[1]: # shape is [channels, num_preds] → transpose out = out.T num_classes = len(self._class_names) # boxes (cx, cy, w, h), objectness, class scores boxes_raw = out[:, :4] # YOLOv8 exports no explicit objectness; class scores are out[:,4:] class_scores = out[:, 4:4 + num_classes] scores = class_scores.max(axis=1) class_ids = class_scores.argmax(axis=1) # Filter by confidence mask = scores > CONF_THRESHOLD if not mask.any(): return {} boxes_raw = boxes_raw[mask] scores = scores[mask] class_ids = class_ids[mask] # Convert cx,cy,w,h → x1,y1,x2,y2 boxes = np.zeros_like(boxes_raw) boxes[:, 0] = boxes_raw[:, 0] - boxes_raw[:, 2] / 2 boxes[:, 1] = boxes_raw[:, 1] - boxes_raw[:, 3] / 2 boxes[:, 2] = boxes_raw[:, 0] + boxes_raw[:, 2] / 2 boxes[:, 3] = boxes_raw[:, 1] + boxes_raw[:, 3] / 2 kept = _nms(boxes, scores, IOU_THRESHOLD) counts: Dict[str, int] = {} for idx in kept: cid = int(class_ids[idx]) if 0 <= cid < len(self._class_names): name = self._class_names[cid] counts[name] = counts.get(name, 0) + 1 return counts def _class_name_to_product_id(self, class_name: str) -> Optional[str]: """Map a YOLO class name → catalog product_id via alias lookup.""" cat = self._get_catalog() pid = cat.lookup_alias(class_name) return pid def count_photos( self, photo_paths: List[str], audit_run_id: str, ) -> Tuple[DeliveryCountMap, List[str], AgentTraceEntry]: t_start = time.monotonic() if self._session is None: t_end = time.monotonic() trace = make_trace_entry( agent_name=AGENT_NAME, agent_version=AGENT_VERSION, audit_run_id=audit_run_id, t_start=t_start, t_end=t_end, input_summary=f"{len(photo_paths)} photos", output_summary="ONNX session not loaded — skipped", ) return {}, [str(p) for p in photo_paths], trace aggregated: Dict[str, int] = {} low_confidence: List[str] = [] done = threading.Event() exception: list[Exception] = [] def _run(): try: for path in photo_paths: if done.is_set(): low_confidence.append(str(path)) continue try: inp, orig_shape = self._preprocess(path) input_name = self._session.get_inputs()[0].name raw = self._session.run(None, {input_name: inp})[0] counts = self._postprocess(raw, orig_shape) except Exception as e: logger.warning("VisualCounter: error processing %s: %s", path, e) low_confidence.append(str(path)) continue if not counts: low_confidence.append(str(path)) continue # Map class names → product_ids and aggregate for class_name, cnt in counts.items(): pid = self._class_name_to_product_id(class_name) key = pid if pid else class_name aggregated[key] = aggregated.get(key, 0) + cnt except Exception as e: exception.append(e) thread = threading.Thread(target=_run, daemon=True) thread.start() thread.join(timeout=_TIMEOUT_SECONDS) if thread.is_alive(): done.set() logger.warning("VisualCounter: timeout after %ds — partial results", _TIMEOUT_SECONDS) elif exception: logger.error("VisualCounter: session error: %s", exception[0]) t_end = time.monotonic() trace = make_trace_entry( agent_name=AGENT_NAME, agent_version=AGENT_VERSION, audit_run_id=audit_run_id, t_start=t_start, t_end=t_end, input_summary=f"{len(photo_paths)} photos", output_summary=f"ONNX session error: {exception[0]}", ) return {}, low_confidence, trace t_end = time.monotonic() n_photos = len(photo_paths) n_detected = sum(aggregated.values()) n_low = len(low_confidence) trace = make_trace_entry( agent_name=AGENT_NAME, agent_version=AGENT_VERSION, audit_run_id=audit_run_id, t_start=t_start, t_end=t_end, input_summary=f"{n_photos} photos", output_summary=f"{n_detected} detections across {n_photos - n_low} photos; {n_low} low-confidence", ) return aggregated, low_confidence, trace