| from __future__ import annotations |
|
|
| import json |
| import logging |
| import time |
| import threading |
| from typing import List, Optional, Tuple |
|
|
| from models import AnomalyFlag, LeakageReport, NormalisedInvoice, AgentTraceEntry |
| from tracer import make_trace_entry |
|
|
| logger = logging.getLogger(__name__) |
|
|
| AGENT_NAME = "Savings_Agent" |
| AGENT_VERSION = "1.0.0" |
| MODEL_REPO = "build-small-hackathon/minicpm5-1b-indian-fmcg-normalizer" |
|
|
| _TIMEOUT_SECONDS = 15 |
|
|
| _SYSTEM_PROMPT = ( |
| "You are an Indian kirana store auditor. " |
| "Given a list of billing anomaly flags from an invoice audit, " |
| "return ONLY a JSON array of concise, actionable items (strings). " |
| "Each item should tell the shop owner exactly what to do to recover money. " |
| "No extra text, no markdown, just the JSON array." |
| ) |
|
|
|
|
| def _compute_total_leakage(flags: List[AnomalyFlag]) -> float: |
| return sum(f.amount_inr for f in flags) |
|
|
|
|
| def _build_template_action_items( |
| flags: List[AnomalyFlag], |
| invoice_number: Optional[str], |
| supplier: Optional[str], |
| ) -> List[str]: |
| """Generate action items purely from flag data, no LLM.""" |
| items: List[str] = [] |
| seen: set[str] = set() |
|
|
| for f in flags: |
| action = f.action_item |
| if not action: |
| if f.flag_type == "price_anomaly": |
| action = ( |
| f"Contact {supplier or 'supplier'} about " |
| f"{f.product_name or f.product_id} price increase; " |
| f"request credit note for ₹{f.amount_inr:.2f}" |
| ) |
| elif f.flag_type == "delivery_shortage": |
| shortage = f.metadata.get("shortage_quantity", "?") |
| action = ( |
| f"Request delivery of {shortage} missing " |
| f"{f.product_name or f.product_id} " |
| f"or credit note for ₹{f.amount_inr:.2f}" |
| ) |
| elif f.flag_type == "duplicate_charge": |
| action = ( |
| f"Dispute duplicate charge for {f.product_name or f.product_id}; " |
| f"request refund of ₹{f.amount_inr:.2f}" |
| ) |
| elif f.flag_type == "gst_mismatch": |
| expected = f.metadata.get("expected_gst_rate", "?") |
| action = ( |
| f"Dispute GST overcharge on {f.product_name or f.product_id}; " |
| f"correct rate is {expected}%" |
| ) |
| elif f.flag_type == "duplicate_invoice": |
| action = ( |
| f"Invoice {invoice_number or ''} already processed; do not pay again" |
| ) |
|
|
| if action and action not in seen: |
| seen.add(action) |
| items.append(action) |
|
|
| return items |
|
|
|
|
| class SavingsAgent: |
| def __init__(self, llm) -> None: |
| self._llm = llm |
|
|
| def _build_template_report( |
| self, |
| audit_run_id: str, |
| invoice: NormalisedInvoice, |
| all_flags: List[AnomalyFlag], |
| skipped_price_checks: int, |
| skipped_gst_checks: int, |
| unexpected_deliveries: List[str], |
| low_confidence_photos: List[str], |
| has_delivery_data: bool, |
| ) -> LeakageReport: |
| total = _compute_total_leakage(all_flags) |
| action_items = _build_template_action_items( |
| all_flags, invoice.invoice_number, invoice.supplier |
| ) |
| return LeakageReport( |
| audit_run_id=audit_run_id, |
| invoice_number=invoice.invoice_number, |
| supplier=invoice.supplier, |
| date=invoice.date, |
| anomaly_flags=all_flags, |
| total_leakage_inr=total, |
| action_items=action_items, |
| has_delivery_data=has_delivery_data, |
| unexpected_deliveries=unexpected_deliveries, |
| low_confidence_photos=low_confidence_photos, |
| skipped_price_checks=skipped_price_checks, |
| skipped_gst_checks=skipped_gst_checks, |
| passed_all_checks=(len(all_flags) == 0), |
| ) |
|
|
| def generate_report( |
| self, |
| audit_run_id: str, |
| invoice: NormalisedInvoice, |
| all_flags: List[AnomalyFlag], |
| skipped_price_checks: int = 0, |
| skipped_gst_checks: int = 0, |
| unexpected_deliveries: Optional[List[str]] = None, |
| low_confidence_photos: Optional[List[str]] = None, |
| has_delivery_data: bool = False, |
| ) -> Tuple[LeakageReport, AgentTraceEntry]: |
| unexpected_deliveries = unexpected_deliveries or [] |
| low_confidence_photos = low_confidence_photos or [] |
| t_start = time.monotonic() |
|
|
| |
| if not all_flags: |
| report = self._build_template_report( |
| audit_run_id, invoice, all_flags, |
| skipped_price_checks, skipped_gst_checks, |
| unexpected_deliveries, low_confidence_photos, has_delivery_data, |
| ) |
| t_end = time.monotonic() |
| trace = make_trace_entry( |
| agent_name=AGENT_NAME, |
| agent_version=AGENT_VERSION, |
| audit_run_id=audit_run_id, |
| t_start=t_start, |
| t_end=t_end, |
| input_summary="0 flags", |
| output_summary="passed all checks; no LLM call", |
| ) |
| return report, trace |
|
|
| |
| flag_lines = [] |
| for f in all_flags: |
| flag_lines.append( |
| f"- [{f.flag_type}] {f.product_name or f.product_id}: ₹{f.amount_inr:.2f} — {f.description}" |
| ) |
| prompt = ( |
| f"Invoice: {invoice.invoice_number or 'N/A'} from {invoice.supplier or 'unknown supplier'}\n" |
| f"Anomalies found:\n" + "\n".join(flag_lines) + |
| "\n\nReturn a JSON array of action items." |
| ) |
|
|
| result_container: list[list[str]] = [] |
| exception: list[Exception] = [] |
|
|
| def _run(): |
| try: |
| response = self._llm.create_chat_completion( |
| messages=[ |
| {"role": "system", "content": _SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt}, |
| ], |
| max_tokens=512, |
| temperature=0.2, |
| ) |
| text = response["choices"][0]["message"]["content"].strip() |
| if text.startswith("```"): |
| lines = text.split("\n") |
| text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) |
| parsed = json.loads(text) |
| if isinstance(parsed, list): |
| result_container.append([str(x) for x in parsed]) |
| else: |
| result_container.append([]) |
| except Exception as e: |
| exception.append(e) |
|
|
| thread = threading.Thread(target=_run, daemon=True) |
| thread.start() |
| thread.join(timeout=_TIMEOUT_SECONDS) |
|
|
| if thread.is_alive() or exception: |
| if exception: |
| logger.warning("SavingsAgent LLM failed: %s — using template", exception[0]) |
| else: |
| logger.warning("SavingsAgent LLM timed out — using template fallback") |
| report = self._build_template_report( |
| audit_run_id, invoice, all_flags, |
| skipped_price_checks, skipped_gst_checks, |
| unexpected_deliveries, low_confidence_photos, has_delivery_data, |
| ) |
| else: |
| llm_items = result_container[0] if result_container else [] |
| |
| seen: set[str] = set() |
| deduped = [] |
| for item in llm_items: |
| if item not in seen: |
| seen.add(item) |
| deduped.append(item) |
|
|
| total = _compute_total_leakage(all_flags) |
| report = LeakageReport( |
| audit_run_id=audit_run_id, |
| invoice_number=invoice.invoice_number, |
| supplier=invoice.supplier, |
| date=invoice.date, |
| anomaly_flags=all_flags, |
| total_leakage_inr=total, |
| action_items=deduped, |
| has_delivery_data=has_delivery_data, |
| unexpected_deliveries=unexpected_deliveries, |
| low_confidence_photos=low_confidence_photos, |
| skipped_price_checks=skipped_price_checks, |
| skipped_gst_checks=skipped_gst_checks, |
| passed_all_checks=False, |
| ) |
|
|
| t_end = time.monotonic() |
| n_flags = len(all_flags) |
| n_items = len(report.action_items) |
|
|
| trace = make_trace_entry( |
| agent_name=AGENT_NAME, |
| agent_version=AGENT_VERSION, |
| audit_run_id=audit_run_id, |
| t_start=t_start, |
| t_end=t_end, |
| input_summary=f"{n_flags} flags", |
| output_summary=f"₹{report.total_leakage_inr:.2f} leakage; {n_items} action items", |
| ) |
| return report, trace |
|
|