kirana-detective / pipeline.py
naazimsnh02's picture
Finetuning completed for yolo26n-indian-fmcg-detection and minicpm5-1b-indian-fmcg-normalizer
7b5611f
Raw
History Blame
8.77 kB
from __future__ import annotations
import logging
import time
import uuid
from typing import Generator, List, Optional
from models import AuditProgressUpdate, LeakageReport, DeliveryCountMap
from tracer import make_trace_entry
logger = logging.getLogger(__name__)
class AuditOrchestrator:
def __init__(
self,
invoice_extractor,
product_matcher,
pricing_agent,
visual_counter,
reconciliation_agent,
savings_agent,
storage,
tracer,
) -> None:
self._extractor = invoice_extractor
self._matcher = product_matcher
self._pricer = pricing_agent
self._counter = visual_counter
self._reconciler = reconciliation_agent
self._saver = savings_agent
self._storage = storage
self._tracer = tracer
def run_audit(
self,
invoice_file_path: str,
delivery_photo_paths: Optional[List[str]] = None,
supplier_name: str = "",
) -> Generator[AuditProgressUpdate, None, LeakageReport]:
delivery_photo_paths = delivery_photo_paths or []
has_delivery_data = bool(delivery_photo_paths)
audit_run_id = str(uuid.uuid4())
try:
# ── Agent 1: Invoice Extractor ────────────────────────────────────
yield AuditProgressUpdate(
stage="extracting",
message="Extracting invoice data…",
agent_name="Invoice_Extractor",
)
invoice, trace1 = self._extractor.extract(invoice_file_path, audit_run_id)
self._tracer.collect(trace1)
yield AuditProgressUpdate(
stage="extracting",
message=f"Extracted {len(invoice.items)} line items",
agent_name="Invoice_Extractor",
duration_ms=trace1.duration_ms,
)
# Override supplier if provided by user
if supplier_name and not invoice.supplier:
invoice.supplier = supplier_name
# ── Agent 2: Product Matcher ──────────────────────────────────────
yield AuditProgressUpdate(
stage="normalising",
message="Normalising product names…",
agent_name="Product_Matcher",
)
invoice, trace2 = self._matcher.normalize(invoice, audit_run_id)
self._tracer.collect(trace2)
yield AuditProgressUpdate(
stage="normalising",
message=f"Normalised; {len(invoice.unmatched_products)} unmatched",
agent_name="Product_Matcher",
duration_ms=trace2.duration_ms,
)
# ── Agent 3: Pricing Agent ────────────────────────────────────────
yield AuditProgressUpdate(
stage="checking_prices",
message="Checking prices and GST…",
agent_name="Pricing_Agent",
)
pricing_flags, pricing_meta, trace3 = self._pricer.run(invoice, audit_run_id)
self._tracer.collect(trace3)
yield AuditProgressUpdate(
stage="checking_prices",
message=f"Found {len(pricing_flags)} pricing flags",
agent_name="Pricing_Agent",
duration_ms=trace3.duration_ms,
)
# ── Agent 4: Visual Counter ───────────────────────────────────────
delivery_count_map: DeliveryCountMap = {}
low_confidence_photos: List[str] = []
if has_delivery_data and self._counter is not None:
yield AuditProgressUpdate(
stage="counting_products",
message=f"Analysing {len(delivery_photo_paths)} delivery photo(s)…",
agent_name="Visual_Counter",
)
delivery_count_map, low_confidence_photos, trace4 = self._counter.count_photos(
delivery_photo_paths, audit_run_id
)
self._tracer.collect(trace4)
yield AuditProgressUpdate(
stage="counting_products",
message=f"Detected {sum(delivery_count_map.values())} products",
agent_name="Visual_Counter",
duration_ms=trace4.duration_ms,
)
else:
# Emit a stub trace so canonical agent order is preserved
t = time.monotonic()
trace4 = make_trace_entry(
agent_name="Visual_Counter",
agent_version="1.0.0",
audit_run_id=audit_run_id,
t_start=t,
t_end=t,
input_summary="0 photos (skipped)",
output_summary="skipped β€” no delivery photos provided",
)
self._tracer.collect(trace4)
# ── Agent 5: Reconciliation Agent ─────────────────────────────────
reconciliation_flags: list = []
unexpected_deliveries: List[str] = []
if has_delivery_data:
yield AuditProgressUpdate(
stage="reconciling",
message="Reconciling invoice vs delivery…",
agent_name="Reconciliation_Agent",
)
reconciliation_flags, unexpected_deliveries, trace5 = self._reconciler.run(
invoice, delivery_count_map, audit_run_id
)
self._tracer.collect(trace5)
yield AuditProgressUpdate(
stage="reconciling",
message=f"Found {len(reconciliation_flags)} shortage flags",
agent_name="Reconciliation_Agent",
duration_ms=trace5.duration_ms,
)
else:
t = time.monotonic()
trace5 = make_trace_entry(
agent_name="Reconciliation_Agent",
agent_version="1.0.0",
audit_run_id=audit_run_id,
t_start=t,
t_end=t,
input_summary="skipped β€” no delivery photos",
output_summary="skipped",
)
self._tracer.collect(trace5)
# ── Agent 6: Savings Agent ────────────────────────────────────────
all_flags = pricing_flags + reconciliation_flags
yield AuditProgressUpdate(
stage="generating_report",
message="Generating savings report…",
agent_name="Savings_Agent",
)
report, trace6 = self._saver.generate_report(
audit_run_id=audit_run_id,
invoice=invoice,
all_flags=all_flags,
skipped_price_checks=pricing_meta.skipped_price_checks,
skipped_gst_checks=pricing_meta.skipped_gst_checks,
unexpected_deliveries=unexpected_deliveries,
low_confidence_photos=low_confidence_photos,
has_delivery_data=has_delivery_data,
)
self._tracer.collect(trace6)
yield AuditProgressUpdate(
stage="generating_report",
message=f"Report ready β€” β‚Ή{report.total_leakage_inr:.2f} leakage found",
agent_name="Savings_Agent",
duration_ms=trace6.duration_ms,
)
# ── Persist and publish ───────────────────────────────────────────
if self._storage.available:
self._storage.save_invoice(audit_run_id, invoice)
entries = self._tracer.finalise(audit_run_id)
self._tracer.publish_async(audit_run_id, entries, self._storage)
yield AuditProgressUpdate(stage="complete", message="Audit complete")
return report
except Exception as e:
logger.exception("AuditOrchestrator: unhandled error in audit %s", audit_run_id)
yield AuditProgressUpdate(
stage="error",
message=str(e),
)
raise