from __future__ import annotations import logging import time import uuid from typing import Generator, List, Optional from models import AuditProgressUpdate, LeakageReport, DeliveryCountMap from tracer import make_trace_entry logger = logging.getLogger(__name__) class AuditOrchestrator: def __init__( self, invoice_extractor, product_matcher, pricing_agent, visual_counter, reconciliation_agent, savings_agent, storage, tracer, ) -> None: self._extractor = invoice_extractor self._matcher = product_matcher self._pricer = pricing_agent self._counter = visual_counter self._reconciler = reconciliation_agent self._saver = savings_agent self._storage = storage self._tracer = tracer def run_audit( self, invoice_file_path: str, delivery_photo_paths: Optional[List[str]] = None, supplier_name: str = "", ) -> Generator[AuditProgressUpdate, None, LeakageReport]: delivery_photo_paths = delivery_photo_paths or [] has_delivery_data = bool(delivery_photo_paths) audit_run_id = str(uuid.uuid4()) try: # ── Agent 1: Invoice Extractor ──────────────────────────────────── yield AuditProgressUpdate( stage="extracting", message="Extracting invoice data…", agent_name="Invoice_Extractor", ) invoice, trace1 = self._extractor.extract(invoice_file_path, audit_run_id) self._tracer.collect(trace1) yield AuditProgressUpdate( stage="extracting", message=f"Extracted {len(invoice.items)} line items", agent_name="Invoice_Extractor", duration_ms=trace1.duration_ms, ) # Override supplier if provided by user if supplier_name and not invoice.supplier: invoice.supplier = supplier_name # ── Agent 2: Product Matcher ────────────────────────────────────── yield AuditProgressUpdate( stage="normalising", message="Normalising product names…", agent_name="Product_Matcher", ) invoice, trace2 = self._matcher.normalize(invoice, audit_run_id) self._tracer.collect(trace2) yield AuditProgressUpdate( stage="normalising", message=f"Normalised; {len(invoice.unmatched_products)} unmatched", agent_name="Product_Matcher", duration_ms=trace2.duration_ms, ) # ── Agent 3: Pricing Agent ──────────────────────────────────────── yield AuditProgressUpdate( stage="checking_prices", message="Checking prices and GST…", agent_name="Pricing_Agent", ) pricing_flags, pricing_meta, trace3 = self._pricer.run(invoice, audit_run_id) self._tracer.collect(trace3) yield AuditProgressUpdate( stage="checking_prices", message=f"Found {len(pricing_flags)} pricing flags", agent_name="Pricing_Agent", duration_ms=trace3.duration_ms, ) # ── Agent 4: Visual Counter ─────────────────────────────────────── delivery_count_map: DeliveryCountMap = {} low_confidence_photos: List[str] = [] if has_delivery_data and self._counter is not None: yield AuditProgressUpdate( stage="counting_products", message=f"Analysing {len(delivery_photo_paths)} delivery photo(s)…", agent_name="Visual_Counter", ) delivery_count_map, low_confidence_photos, trace4 = self._counter.count_photos( delivery_photo_paths, audit_run_id ) self._tracer.collect(trace4) yield AuditProgressUpdate( stage="counting_products", message=f"Detected {sum(delivery_count_map.values())} products", agent_name="Visual_Counter", duration_ms=trace4.duration_ms, ) else: # Emit a stub trace so canonical agent order is preserved t = time.monotonic() trace4 = make_trace_entry( agent_name="Visual_Counter", agent_version="1.0.0", audit_run_id=audit_run_id, t_start=t, t_end=t, input_summary="0 photos (skipped)", output_summary="skipped — no delivery photos provided", ) self._tracer.collect(trace4) # ── Agent 5: Reconciliation Agent ───────────────────────────────── reconciliation_flags: list = [] unexpected_deliveries: List[str] = [] if has_delivery_data: yield AuditProgressUpdate( stage="reconciling", message="Reconciling invoice vs delivery…", agent_name="Reconciliation_Agent", ) reconciliation_flags, unexpected_deliveries, trace5 = self._reconciler.run( invoice, delivery_count_map, audit_run_id ) self._tracer.collect(trace5) yield AuditProgressUpdate( stage="reconciling", message=f"Found {len(reconciliation_flags)} shortage flags", agent_name="Reconciliation_Agent", duration_ms=trace5.duration_ms, ) else: t = time.monotonic() trace5 = make_trace_entry( agent_name="Reconciliation_Agent", agent_version="1.0.0", audit_run_id=audit_run_id, t_start=t, t_end=t, input_summary="skipped — no delivery photos", output_summary="skipped", ) self._tracer.collect(trace5) # ── Agent 6: Savings Agent ──────────────────────────────────────── all_flags = pricing_flags + reconciliation_flags yield AuditProgressUpdate( stage="generating_report", message="Generating savings report…", agent_name="Savings_Agent", ) report, trace6 = self._saver.generate_report( audit_run_id=audit_run_id, invoice=invoice, all_flags=all_flags, skipped_price_checks=pricing_meta.skipped_price_checks, skipped_gst_checks=pricing_meta.skipped_gst_checks, unexpected_deliveries=unexpected_deliveries, low_confidence_photos=low_confidence_photos, has_delivery_data=has_delivery_data, ) self._tracer.collect(trace6) yield AuditProgressUpdate( stage="generating_report", message=f"Report ready — ₹{report.total_leakage_inr:.2f} leakage found", agent_name="Savings_Agent", duration_ms=trace6.duration_ms, ) # ── Persist and publish ─────────────────────────────────────────── if self._storage.available: self._storage.save_invoice(audit_run_id, invoice) entries = self._tracer.finalise(audit_run_id) self._tracer.publish_async(audit_run_id, entries, self._storage) yield AuditProgressUpdate(stage="complete", message="Audit complete") return report except Exception as e: logger.exception("AuditOrchestrator: unhandled error in audit %s", audit_run_id) yield AuditProgressUpdate( stage="error", message=str(e), ) raise