from __future__ import annotations import json import logging import time import threading from typing import List, Optional, Tuple from models import AnomalyFlag, LeakageReport, NormalisedInvoice, AgentTraceEntry from tracer import make_trace_entry logger = logging.getLogger(__name__) AGENT_NAME = "Savings_Agent" AGENT_VERSION = "1.0.0" MODEL_REPO = "build-small-hackathon/minicpm5-1b-indian-fmcg-normalizer" _TIMEOUT_SECONDS = 15 _SYSTEM_PROMPT = ( "You are an Indian kirana store auditor. " "Given a list of billing anomaly flags from an invoice audit, " "return ONLY a JSON array of concise, actionable items (strings). " "Each item should tell the shop owner exactly what to do to recover money. " "No extra text, no markdown, just the JSON array." ) def _compute_total_leakage(flags: List[AnomalyFlag]) -> float: return sum(f.amount_inr for f in flags) def _build_template_action_items( flags: List[AnomalyFlag], invoice_number: Optional[str], supplier: Optional[str], ) -> List[str]: """Generate action items purely from flag data, no LLM.""" items: List[str] = [] seen: set[str] = set() for f in flags: action = f.action_item if not action: if f.flag_type == "price_anomaly": action = ( f"Contact {supplier or 'supplier'} about " f"{f.product_name or f.product_id} price increase; " f"request credit note for ₹{f.amount_inr:.2f}" ) elif f.flag_type == "delivery_shortage": shortage = f.metadata.get("shortage_quantity", "?") action = ( f"Request delivery of {shortage} missing " f"{f.product_name or f.product_id} " f"or credit note for ₹{f.amount_inr:.2f}" ) elif f.flag_type == "duplicate_charge": action = ( f"Dispute duplicate charge for {f.product_name or f.product_id}; " f"request refund of ₹{f.amount_inr:.2f}" ) elif f.flag_type == "gst_mismatch": expected = f.metadata.get("expected_gst_rate", "?") action = ( f"Dispute GST overcharge on {f.product_name or f.product_id}; " f"correct rate is {expected}%" ) elif f.flag_type == "duplicate_invoice": action = ( f"Invoice {invoice_number or ''} already processed; do not pay again" ) if action and action not in seen: seen.add(action) items.append(action) return items class SavingsAgent: def __init__(self, llm) -> None: self._llm = llm def _build_template_report( self, audit_run_id: str, invoice: NormalisedInvoice, all_flags: List[AnomalyFlag], skipped_price_checks: int, skipped_gst_checks: int, unexpected_deliveries: List[str], low_confidence_photos: List[str], has_delivery_data: bool, ) -> LeakageReport: total = _compute_total_leakage(all_flags) action_items = _build_template_action_items( all_flags, invoice.invoice_number, invoice.supplier ) return LeakageReport( audit_run_id=audit_run_id, invoice_number=invoice.invoice_number, supplier=invoice.supplier, date=invoice.date, anomaly_flags=all_flags, total_leakage_inr=total, action_items=action_items, has_delivery_data=has_delivery_data, unexpected_deliveries=unexpected_deliveries, low_confidence_photos=low_confidence_photos, skipped_price_checks=skipped_price_checks, skipped_gst_checks=skipped_gst_checks, passed_all_checks=(len(all_flags) == 0), ) def generate_report( self, audit_run_id: str, invoice: NormalisedInvoice, all_flags: List[AnomalyFlag], skipped_price_checks: int = 0, skipped_gst_checks: int = 0, unexpected_deliveries: Optional[List[str]] = None, low_confidence_photos: Optional[List[str]] = None, has_delivery_data: bool = False, ) -> Tuple[LeakageReport, AgentTraceEntry]: unexpected_deliveries = unexpected_deliveries or [] low_confidence_photos = low_confidence_photos or [] t_start = time.monotonic() # No flags → pass immediately without LLM if not all_flags: report = self._build_template_report( audit_run_id, invoice, all_flags, skipped_price_checks, skipped_gst_checks, unexpected_deliveries, low_confidence_photos, has_delivery_data, ) t_end = time.monotonic() trace = make_trace_entry( agent_name=AGENT_NAME, agent_version=AGENT_VERSION, audit_run_id=audit_run_id, t_start=t_start, t_end=t_end, input_summary="0 flags", output_summary="passed all checks; no LLM call", ) return report, trace # Build LLM prompt flag_lines = [] for f in all_flags: flag_lines.append( f"- [{f.flag_type}] {f.product_name or f.product_id}: ₹{f.amount_inr:.2f} — {f.description}" ) prompt = ( f"Invoice: {invoice.invoice_number or 'N/A'} from {invoice.supplier or 'unknown supplier'}\n" f"Anomalies found:\n" + "\n".join(flag_lines) + "\n\nReturn a JSON array of action items." ) result_container: list[list[str]] = [] exception: list[Exception] = [] def _run(): try: response = self._llm.create_chat_completion( messages=[ {"role": "system", "content": _SYSTEM_PROMPT}, {"role": "user", "content": prompt}, ], max_tokens=512, temperature=0.2, ) text = response["choices"][0]["message"]["content"].strip() if text.startswith("```"): lines = text.split("\n") text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) parsed = json.loads(text) if isinstance(parsed, list): result_container.append([str(x) for x in parsed]) else: result_container.append([]) except Exception as e: exception.append(e) thread = threading.Thread(target=_run, daemon=True) thread.start() thread.join(timeout=_TIMEOUT_SECONDS) if thread.is_alive() or exception: if exception: logger.warning("SavingsAgent LLM failed: %s — using template", exception[0]) else: logger.warning("SavingsAgent LLM timed out — using template fallback") report = self._build_template_report( audit_run_id, invoice, all_flags, skipped_price_checks, skipped_gst_checks, unexpected_deliveries, low_confidence_photos, has_delivery_data, ) else: llm_items = result_container[0] if result_container else [] # Deduplicate while preserving order seen: set[str] = set() deduped = [] for item in llm_items: if item not in seen: seen.add(item) deduped.append(item) total = _compute_total_leakage(all_flags) report = LeakageReport( audit_run_id=audit_run_id, invoice_number=invoice.invoice_number, supplier=invoice.supplier, date=invoice.date, anomaly_flags=all_flags, total_leakage_inr=total, action_items=deduped, has_delivery_data=has_delivery_data, unexpected_deliveries=unexpected_deliveries, low_confidence_photos=low_confidence_photos, skipped_price_checks=skipped_price_checks, skipped_gst_checks=skipped_gst_checks, passed_all_checks=False, ) t_end = time.monotonic() n_flags = len(all_flags) n_items = len(report.action_items) trace = make_trace_entry( agent_name=AGENT_NAME, agent_version=AGENT_VERSION, audit_run_id=audit_run_id, t_start=t_start, t_end=t_end, input_summary=f"{n_flags} flags", output_summary=f"₹{report.total_leakage_inr:.2f} leakage; {n_items} action items", ) return report, trace