from __future__ import annotations import json import logging import time import threading from pathlib import Path from typing import Tuple from models import InvoiceJSON, InvoiceLineItem, AgentTraceEntry from tracer import make_trace_entry logger = logging.getLogger(__name__) AGENT_NAME = "Invoice_Extractor" AGENT_VERSION = "1.0.0" MODEL_REPO = "build-small-hackathon/minicpm-v-4-6-indian-invoice-extraction-merged" _MAX_FILE_BYTES = 20 * 1024 * 1024 # 20 MB _ALLOWED_SUFFIXES = {".jpg", ".jpeg", ".png", ".webp", ".pdf"} _TIMEOUT_SECONDS = 120 _EXTRACT_PROMPT = ( "You are an OCR agent for Indian kirana store invoices. " "Extract all information from this invoice image and return ONLY valid JSON " "matching this schema exactly:\n" '{"invoice_number": string|null, "supplier": string|null, "date": string|null, ' '"items": [{"product_raw": string, "quantity": number, "unit_price": number, ' '"gst_rate": number, "line_total": number}], ' '"grand_total": number, "extraction_warnings": [string]}\n' "Rules:\n" "- quantity/unit_price/gst_rate/line_total must be numbers (0 if unknown)\n" "- gst_rate is a percentage (5, 12, 18, 28, etc.) not a decimal\n" "- date format: YYYY-MM-DD if parseable, else raw string\n" "- extraction_warnings: list issues you notice\n" "Return ONLY the JSON object, no markdown, no prose." ) _RETRY_PROMPT = ( "Your previous response was not valid JSON. " "Return ONLY the JSON object with this exact schema:\n" '{"invoice_number": null, "supplier": null, "date": null, ' '"items": [], "grand_total": 0, "extraction_warnings": ["parse error"]}\n' "Try again with the invoice image." ) class ValidationError(ValueError): pass def _render_pdf_pages(file_path: str) -> list[bytes]: """Render each PDF page to PNG bytes.""" import fitz # PyMuPDF doc = fitz.open(file_path) pages = [] for page in doc: mat = fitz.Matrix(2.0, 2.0) # 2x zoom → ~150 dpi equivalent pix = page.get_pixmap(matrix=mat) pages.append(pix.tobytes("png")) doc.close() return pages def _parse_llm_json(text: str) -> dict: text = text.strip() # Strip markdown fences if present if text.startswith("```"): lines = text.split("\n") text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) return json.loads(text) def _dict_to_invoice(data: dict) -> InvoiceJSON: items = [] for raw in data.get("items", []): items.append(InvoiceLineItem( product_raw=str(raw.get("product_raw", "")), quantity=float(raw.get("quantity", 0)), unit_price=float(raw.get("unit_price", 0)), gst_rate=float(raw.get("gst_rate", 0)), line_total=float(raw.get("line_total", 0)), )) return InvoiceJSON( invoice_number=data.get("invoice_number"), supplier=data.get("supplier"), date=data.get("date"), items=items, grand_total=float(data.get("grand_total", 0)), extraction_warnings=list(data.get("extraction_warnings", [])), ) def _call_llm_with_image(llm, image_bytes: bytes, prompt: str) -> str: """Call MiniCPM-V 4.6 through the current processor + generate API.""" import io import torch from PIL import Image as PILImage model, processor = llm image = PILImage.open(io.BytesIO(image_bytes)).convert("RGB") messages = [{ "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], }] downsample_mode = "16x" try: inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", downsample_mode=downsample_mode, max_slice_nums=36, ) except TypeError: inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ) downsample_mode = None inputs = inputs.to(model.device) generate_kwargs = { **inputs, "max_new_tokens": 2048, "do_sample": False, } if downsample_mode is not None: generate_kwargs["downsample_mode"] = downsample_mode with torch.inference_mode(): try: generated_ids = model.generate(**generate_kwargs) except TypeError: generate_kwargs.pop("downsample_mode", None) generated_ids = model.generate(**generate_kwargs) prompt_len = inputs["input_ids"].shape[-1] generated_ids = generated_ids[:, prompt_len:] decoded = processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) return decoded[0].strip() def _extract_from_image(llm, image_bytes: bytes) -> Tuple[InvoiceJSON, list[str]]: """Try extraction; retry once on JSON parse failure.""" warnings = [] try: raw = _call_llm_with_image(llm, image_bytes, _EXTRACT_PROMPT) data = _parse_llm_json(raw) return _dict_to_invoice(data), warnings except (json.JSONDecodeError, ValueError): warnings.append("First LLM response was not valid JSON; retrying") try: raw = _call_llm_with_image(llm, image_bytes, _RETRY_PROMPT) data = _parse_llm_json(raw) return _dict_to_invoice(data), warnings except (json.JSONDecodeError, ValueError) as e: warnings.append(f"Second LLM response also invalid: {e}") return InvoiceJSON(extraction_warnings=warnings), warnings def _merge_invoices(pages: list[InvoiceJSON]) -> InvoiceJSON: """Merge multi-page PDF results into one InvoiceJSON.""" if not pages: return InvoiceJSON() base = pages[0] for page in pages[1:]: base.items.extend(page.items) base.extraction_warnings.extend(page.extraction_warnings) if base.grand_total == 0 and page.grand_total: base.grand_total = page.grand_total return base class InvoiceExtractorAgent: def __init__(self, llm) -> None: self._llm = llm def extract(self, file_path: str, audit_run_id: str) -> Tuple[InvoiceJSON, AgentTraceEntry]: path = Path(file_path) suffix = path.suffix.lower() # Validate format if suffix not in _ALLOWED_SUFFIXES: raise ValidationError("UNSUPPORTED_FORMAT") # Validate size size_bytes = path.stat().st_size if size_bytes > _MAX_FILE_BYTES: raise ValidationError("FILE_TOO_LARGE") size_kb = size_bytes // 1024 fmt = suffix.lstrip(".") t_start = time.monotonic() result: list[InvoiceJSON] = [] exception: list[Exception] = [] def _run(): try: if suffix == ".pdf": pages_bytes = _render_pdf_pages(file_path) page_invoices = [] for pb in pages_bytes: inv, _ = _extract_from_image(self._llm, pb) page_invoices.append(inv) result.append(_merge_invoices(page_invoices)) else: image_bytes = path.read_bytes() inv, _ = _extract_from_image(self._llm, image_bytes) result.append(inv) except Exception as e: exception.append(e) thread = threading.Thread(target=_run, daemon=True) thread.start() thread.join(timeout=_TIMEOUT_SECONDS) t_end = time.monotonic() if thread.is_alive(): # Timeout: return partial/empty with warning invoice = InvoiceJSON( extraction_warnings=["Extraction timed out after 30 seconds"] ) elif exception: raise exception[0] else: invoice = result[0] n_items = len(invoice.items) n_warnings = len(invoice.extraction_warnings) trace = make_trace_entry( agent_name=AGENT_NAME, agent_version=AGENT_VERSION, audit_run_id=audit_run_id, t_start=t_start, t_end=t_end, input_summary=f"{size_kb}KB {fmt}", output_summary=f"{n_items} line items extracted; {n_warnings} warnings", ) return invoice, trace