from __future__ import annotations import logging import time from collections import Counter from typing import List, Optional, Tuple from models import ( AnomalyFlag, InvoiceLineItem, NormalisedInvoice, AgentTraceEntry, PricingAgentMeta, ) from catalog import FMCGCatalog from tracer import make_trace_entry logger = logging.getLogger(__name__) AGENT_NAME = "Pricing_Agent" AGENT_VERSION = "1.0.0" class PricingAgent: def __init__( self, storage, catalog: FMCGCatalog, price_anomaly_threshold: float = 1.05, ) -> None: self._storage = storage self._catalog = catalog self._threshold = price_anomaly_threshold # ── Individual checks ───────────────────────────────────────────────────── def _check_price_anomaly( self, item: InvoiceLineItem, historical_prices: List[float], meta: PricingAgentMeta, ) -> Optional[AnomalyFlag]: if not historical_prices or item.product_id is None: meta.skipped_price_checks += 1 return None max_hist = max(historical_prices) if max_hist == 0: meta.skipped_price_checks += 1 return None if item.unit_price > max_hist * self._threshold: excess_per_unit = item.unit_price - max_hist pct = (item.unit_price / max_hist - 1.0) * 100 return AnomalyFlag( flag_type="price_anomaly", product_id=item.product_id, product_name=item.product_normalized or item.product_raw, amount_inr=item.quantity * excess_per_unit, description=( f"{item.product_normalized or item.product_raw}: charged ₹{item.unit_price:.2f}, " f"max historical ₹{max_hist:.2f} (+{pct:.1f}%)" ), action_item=( f"Contact supplier about {item.product_normalized or item.product_raw} " f"price increase; request credit note for ₹{item.quantity * excess_per_unit:.2f}" ), metadata={ "current_unit_price": item.unit_price, "max_historical_price": max_hist, "pct_increase": round(pct, 2), "quantity": item.quantity, }, ) return None def _check_gst_mismatch( self, item: InvoiceLineItem, meta: PricingAgentMeta, ) -> Optional[AnomalyFlag]: if item.product_id is None: meta.skipped_gst_checks += 1 return None entry = self._catalog.get_by_id(item.product_id) if entry is None: meta.skipped_gst_checks += 1 return None expected = self._catalog.get_gst_rate(entry.hsn_code) if expected is None: meta.skipped_gst_checks += 1 return None if abs(item.gst_rate - expected) > 0.001: overcharge_per_unit = item.unit_price * (item.gst_rate - expected) / 100.0 amount = item.quantity * overcharge_per_unit return AnomalyFlag( flag_type="gst_mismatch", product_id=item.product_id, product_name=item.product_normalized or item.product_raw, amount_inr=amount, description=( f"GST on {item.product_normalized or item.product_raw}: " f"charged {item.gst_rate}%, expected {expected}%" ), action_item=( f"Dispute GST overcharge on {item.product_normalized or item.product_raw}; " f"correct rate is {expected}%" ), metadata={ "expected_gst_rate": expected, "charged_gst_rate": item.gst_rate, "quantity": item.quantity, "unit_price": item.unit_price, }, ) return None def _check_duplicate_charges( self, items: List[InvoiceLineItem] ) -> List[AnomalyFlag]: counts: Counter = Counter() key_meta: dict = {} for item in items: if item.product_id is None: continue key = (item.product_id.strip().lower(), float(item.quantity), float(item.unit_price)) counts[key] += 1 key_meta[key] = item flags = [] for key, count in counts.items(): if count >= 2: item = key_meta[key] extra = count - 1 flags.append(AnomalyFlag( flag_type="duplicate_charge", product_id=item.product_id, product_name=item.product_normalized or item.product_raw, amount_inr=extra * item.quantity * item.unit_price, description=( f"{item.product_normalized or item.product_raw} charged " f"{count}× (qty {item.quantity} @ ₹{item.unit_price:.2f})" ), action_item=( f"Dispute duplicate charge for {item.product_normalized or item.product_raw}; " f"request refund of ₹{extra * item.quantity * item.unit_price:.2f}" ), metadata={ "occurrences": count, "quantity": item.quantity, "unit_price": item.unit_price, "duplicate_key": str(key), }, )) return flags def _check_duplicate_invoice( self, invoice_number: Optional[str] ) -> Optional[AnomalyFlag]: if not self._storage.available or not invoice_number: return None prior_id = self._storage.invoice_number_exists(invoice_number) if prior_id: return AnomalyFlag( flag_type="duplicate_invoice", amount_inr=0.0, description=f"Invoice {invoice_number} was already processed in audit {prior_id}", action_item=f"Invoice {invoice_number} already processed; do not pay again", metadata={ "original_audit_run_id": prior_id, }, ) return None # ── Main entry ──────────────────────────────────────────────────────────── def run( self, invoice: NormalisedInvoice, audit_run_id: str ) -> Tuple[List[AnomalyFlag], PricingAgentMeta, AgentTraceEntry]: t_start = time.monotonic() meta = PricingAgentMeta() flags: List[AnomalyFlag] = [] # Duplicate invoice check (needs storage) dup_inv_flag = self._check_duplicate_invoice(invoice.invoice_number) if dup_inv_flag: flags.append(dup_inv_flag) for item in invoice.items: # Price anomaly (needs storage history) if self._storage.available: hist = self._storage.get_price_history(item.product_id) if item.product_id else [] flag = self._check_price_anomaly(item, hist, meta) if flag: flags.append(flag) else: meta.skipped_price_checks += 1 # GST mismatch (rule-based, no storage needed) flag = self._check_gst_mismatch(item, meta) if flag: flags.append(flag) # Duplicate charges (in-memory, no storage needed) flags.extend(self._check_duplicate_charges(invoice.items)) # Persist prices after checks so this run doesn't self-flag if self._storage.available: self._storage.save_price_history(audit_run_id, invoice) t_end = time.monotonic() n_flags = len(flags) trace = make_trace_entry( agent_name=AGENT_NAME, agent_version=AGENT_VERSION, audit_run_id=audit_run_id, t_start=t_start, t_end=t_end, input_summary=f"{len(invoice.items)} items; storage={'available' if self._storage.available else 'degraded'}", output_summary=( f"{n_flags} flags; {meta.skipped_price_checks} price checks skipped; " f"{meta.skipped_gst_checks} GST checks skipped" ), ) return flags, meta, trace