kirana-detective / agents /savings_agent.py
naazimsnh02's picture
All models training uploaded
9d75c8c
Raw
History Blame
9.05 kB
from __future__ import annotations
import json
import logging
import time
import threading
from typing import List, Optional, Tuple
from models import AnomalyFlag, LeakageReport, NormalisedInvoice, AgentTraceEntry
from tracer import make_trace_entry
logger = logging.getLogger(__name__)
AGENT_NAME = "Savings_Agent"
AGENT_VERSION = "1.0.0"
MODEL_REPO = "build-small-hackathon/minicpm5-1b-indian-fmcg-normalizer"
_TIMEOUT_SECONDS = 15
_SYSTEM_PROMPT = (
"You are an Indian kirana store auditor. "
"Given a list of billing anomaly flags from an invoice audit, "
"return ONLY a JSON array of concise, actionable items (strings). "
"Each item should tell the shop owner exactly what to do to recover money. "
"No extra text, no markdown, just the JSON array."
)
def _compute_total_leakage(flags: List[AnomalyFlag]) -> float:
return sum(f.amount_inr for f in flags)
def _build_template_action_items(
flags: List[AnomalyFlag],
invoice_number: Optional[str],
supplier: Optional[str],
) -> List[str]:
"""Generate action items purely from flag data, no LLM."""
items: List[str] = []
seen: set[str] = set()
for f in flags:
action = f.action_item
if not action:
if f.flag_type == "price_anomaly":
action = (
f"Contact {supplier or 'supplier'} about "
f"{f.product_name or f.product_id} price increase; "
f"request credit note for ₹{f.amount_inr:.2f}"
)
elif f.flag_type == "delivery_shortage":
shortage = f.metadata.get("shortage_quantity", "?")
action = (
f"Request delivery of {shortage} missing "
f"{f.product_name or f.product_id} "
f"or credit note for ₹{f.amount_inr:.2f}"
)
elif f.flag_type == "duplicate_charge":
action = (
f"Dispute duplicate charge for {f.product_name or f.product_id}; "
f"request refund of ₹{f.amount_inr:.2f}"
)
elif f.flag_type == "gst_mismatch":
expected = f.metadata.get("expected_gst_rate", "?")
action = (
f"Dispute GST overcharge on {f.product_name or f.product_id}; "
f"correct rate is {expected}%"
)
elif f.flag_type == "duplicate_invoice":
action = (
f"Invoice {invoice_number or ''} already processed; do not pay again"
)
if action and action not in seen:
seen.add(action)
items.append(action)
return items
class SavingsAgent:
def __init__(self, llm) -> None:
self._llm = llm
def _build_template_report(
self,
audit_run_id: str,
invoice: NormalisedInvoice,
all_flags: List[AnomalyFlag],
skipped_price_checks: int,
skipped_gst_checks: int,
unexpected_deliveries: List[str],
low_confidence_photos: List[str],
has_delivery_data: bool,
) -> LeakageReport:
total = _compute_total_leakage(all_flags)
action_items = _build_template_action_items(
all_flags, invoice.invoice_number, invoice.supplier
)
return LeakageReport(
audit_run_id=audit_run_id,
invoice_number=invoice.invoice_number,
supplier=invoice.supplier,
date=invoice.date,
anomaly_flags=all_flags,
total_leakage_inr=total,
action_items=action_items,
has_delivery_data=has_delivery_data,
unexpected_deliveries=unexpected_deliveries,
low_confidence_photos=low_confidence_photos,
skipped_price_checks=skipped_price_checks,
skipped_gst_checks=skipped_gst_checks,
passed_all_checks=(len(all_flags) == 0),
)
def generate_report(
self,
audit_run_id: str,
invoice: NormalisedInvoice,
all_flags: List[AnomalyFlag],
skipped_price_checks: int = 0,
skipped_gst_checks: int = 0,
unexpected_deliveries: Optional[List[str]] = None,
low_confidence_photos: Optional[List[str]] = None,
has_delivery_data: bool = False,
) -> Tuple[LeakageReport, AgentTraceEntry]:
unexpected_deliveries = unexpected_deliveries or []
low_confidence_photos = low_confidence_photos or []
t_start = time.monotonic()
# No flags → pass immediately without LLM
if not all_flags:
report = self._build_template_report(
audit_run_id, invoice, all_flags,
skipped_price_checks, skipped_gst_checks,
unexpected_deliveries, low_confidence_photos, has_delivery_data,
)
t_end = time.monotonic()
trace = make_trace_entry(
agent_name=AGENT_NAME,
agent_version=AGENT_VERSION,
audit_run_id=audit_run_id,
t_start=t_start,
t_end=t_end,
input_summary="0 flags",
output_summary="passed all checks; no LLM call",
)
return report, trace
# Build LLM prompt
flag_lines = []
for f in all_flags:
flag_lines.append(
f"- [{f.flag_type}] {f.product_name or f.product_id}: ₹{f.amount_inr:.2f}{f.description}"
)
prompt = (
f"Invoice: {invoice.invoice_number or 'N/A'} from {invoice.supplier or 'unknown supplier'}\n"
f"Anomalies found:\n" + "\n".join(flag_lines) +
"\n\nReturn a JSON array of action items."
)
result_container: list[list[str]] = []
exception: list[Exception] = []
def _run():
try:
response = self._llm.create_chat_completion(
messages=[
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
max_tokens=512,
temperature=0.2,
)
text = response["choices"][0]["message"]["content"].strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
parsed = json.loads(text)
if isinstance(parsed, list):
result_container.append([str(x) for x in parsed])
else:
result_container.append([])
except Exception as e:
exception.append(e)
thread = threading.Thread(target=_run, daemon=True)
thread.start()
thread.join(timeout=_TIMEOUT_SECONDS)
if thread.is_alive() or exception:
if exception:
logger.warning("SavingsAgent LLM failed: %s — using template", exception[0])
else:
logger.warning("SavingsAgent LLM timed out — using template fallback")
report = self._build_template_report(
audit_run_id, invoice, all_flags,
skipped_price_checks, skipped_gst_checks,
unexpected_deliveries, low_confidence_photos, has_delivery_data,
)
else:
llm_items = result_container[0] if result_container else []
# Deduplicate while preserving order
seen: set[str] = set()
deduped = []
for item in llm_items:
if item not in seen:
seen.add(item)
deduped.append(item)
total = _compute_total_leakage(all_flags)
report = LeakageReport(
audit_run_id=audit_run_id,
invoice_number=invoice.invoice_number,
supplier=invoice.supplier,
date=invoice.date,
anomaly_flags=all_flags,
total_leakage_inr=total,
action_items=deduped,
has_delivery_data=has_delivery_data,
unexpected_deliveries=unexpected_deliveries,
low_confidence_photos=low_confidence_photos,
skipped_price_checks=skipped_price_checks,
skipped_gst_checks=skipped_gst_checks,
passed_all_checks=False,
)
t_end = time.monotonic()
n_flags = len(all_flags)
n_items = len(report.action_items)
trace = make_trace_entry(
agent_name=AGENT_NAME,
agent_version=AGENT_VERSION,
audit_run_id=audit_run_id,
t_start=t_start,
t_end=t_end,
input_summary=f"{n_flags} flags",
output_summary=f"₹{report.total_leakage_inr:.2f} leakage; {n_items} action items",
)
return report, trace