from __future__ import annotations import asyncio import dataclasses import json import logging import os import shutil import tempfile import threading from pathlib import Path from queue import Empty, Queue from typing import List, Optional import gradio as gr from fastapi import Request from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse from catalog import get_catalog from storage import StorageManager from tracer import AgentTracer from agents.invoice_extractor import InvoiceExtractorAgent from agents.product_matcher import ProductMatcherAgent from agents.pricing_agent import PricingAgent from agents.visual_counter import VisualCounterAgent from agents.reconciliation_agent import ReconciliationAgent from agents.savings_agent import SavingsAgent from pipeline import AuditOrchestrator logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s") logger = logging.getLogger(__name__) # ── Global state ────────────────────────────────────────────────────────────── MODELS_LOADED = False ORCHESTRATOR: Optional[AuditOrchestrator] = None STORAGE: Optional[StorageManager] = None _HF_TOKEN = os.getenv("HF_TOKEN", "") _CATALOG = get_catalog() # ── Model loading ───────────────────────────────────────────────────────────── def load_models() -> None: global MODELS_LOADED, ORCHESTRATOR, STORAGE STORAGE = StorageManager() STORAGE.initialise_schema() if not STORAGE.available: logger.warning("Storage unavailable — running without price history / audit log") tracer = AgentTracer(hf_token=_HF_TOKEN or None) vision_llm = None text_llm = None ort_session = None class_names: List[str] = [] try: from huggingface_hub import hf_hub_download from llama_cpp import Llama import onnxruntime as ort logger.info("Downloading vision model (MiniCPM-V 4.6 merged)…") import os import torch from transformers import AutoProcessor _BASE_REPO = "openbmb/MiniCPM-V-4.6" _MERGED_REPO = "build-small-hackathon/minicpm-v-4-6-indian-invoice-extraction-merged" _dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 _device = "cuda" if torch.cuda.is_available() else "cpu" try: from transformers import AutoModelForImageTextToText as _VisionModel except ImportError: from transformers import AutoModelForMultimodalLM as _VisionModel # The merged repo is a fully-merged model (not a LoRA delta) — load it directly. # Loading base + overlaying weights fails because the repos use different param naming. logger.info("Loading merged vision model from %s …", _MERGED_REPO) _model_kwargs = { "trust_remote_code": True, "torch_dtype": _dtype, } if _HF_TOKEN: _model_kwargs["token"] = _HF_TOKEN if torch.cuda.is_available(): _model_kwargs["device_map"] = "auto" _vision_model = _VisionModel.from_pretrained(_MERGED_REPO, **_model_kwargs) if not torch.cuda.is_available(): _vision_model.to(_device) _vision_model.eval() _processor_kwargs = {"trust_remote_code": True} if _HF_TOKEN: _processor_kwargs["token"] = _HF_TOKEN # Load processor from base repo — has complete preprocessor/chat-template configs. _vision_processor = AutoProcessor.from_pretrained(_BASE_REPO, **_processor_kwargs) vision_llm = (_vision_model, _vision_processor) logger.info("Vision LLM ready (device=%s dtype=%s)", _device, _dtype) logger.info("Downloading text model (MiniCPM5-1B)…") text_model_path = hf_hub_download( repo_id="build-small-hackathon/minicpm5-1b-indian-fmcg-normalizer", filename="MiniCPM5-1B.Q4_K_M.gguf", ) text_llm = Llama( model_path=text_model_path, n_ctx=8192, n_threads=4, verbose=False, ) logger.info("Text LLM ready") logger.info("Downloading YOLO model…") onnx_path = hf_hub_download( repo_id="build-small-hackathon/yolo26n-indian-fmcg-detection", filename="yolo26n_fmcg.onnx", ) class_names_path = hf_hub_download( repo_id="build-small-hackathon/yolo26n-indian-fmcg-detection", filename="class_names.json", ) with open(class_names_path, encoding="utf-8") as f: class_names = json.load(f) ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) logger.info("YOLO ready (%d classes)", len(class_names)) MODELS_LOADED = True except Exception as exc: logger.exception("Model loading failed — app will start in degraded mode") ORCHESTRATOR = AuditOrchestrator( invoice_extractor=InvoiceExtractorAgent(vision_llm), product_matcher=ProductMatcherAgent(text_llm, _CATALOG), pricing_agent=PricingAgent(STORAGE, _CATALOG), visual_counter=VisualCounterAgent(ort_session, class_names) if ort_session else None, reconciliation_agent=ReconciliationAgent(), savings_agent=SavingsAgent(text_llm), storage=STORAGE, tracer=tracer, ) logger.info("AuditOrchestrator ready. Models loaded: %s", MODELS_LOADED) # ── Gradio Server ───────────────────────────────────────────────────────────── gr.set_static_paths(paths=["static/"]) app = gr.Server() # ── Static routes ───────────────────────────────────────────────────────────── @app.get("/") async def homepage(): html = (Path(__file__).parent / "static" / "index.html").read_text(encoding="utf-8") return HTMLResponse(content=html) # ── Health / history ────────────────────────────────────────────────────────── @app.get("/api/health") async def health(): return JSONResponse({ "status": "ok", "models_loaded": MODELS_LOADED, "storage_available": bool(STORAGE and STORAGE.available), "catalog_size": len(_CATALOG.all_product_ids()), }) @app.get("/api/history") async def get_history(limit: int = 20): if not STORAGE or not STORAGE.available: return JSONResponse({"audits": []}) return JSONResponse({"audits": STORAGE.get_recent_audits(min(limit, 100))}) # Gradio-client-compatible endpoint (satisfies Off-Brand + tool use by @gradio/client) @app.api(name="health_check") def health_check() -> dict: return {"status": "ok", "models_loaded": MODELS_LOADED} # ── Audit streaming endpoint ────────────────────────────────────────────────── def _run_audit_thread( orchestrator: AuditOrchestrator, invoice_path: str, photo_paths: List[str], supplier: str, q: Queue, ) -> None: """Drive the pipeline generator in a background thread and push events.""" try: gen = orchestrator.run_audit(invoice_path, photo_paths, supplier) report = None while True: try: update = next(gen) q.put(("progress", update)) except StopIteration as exc: report = exc.value break q.put(("result", report)) except Exception as exc: logger.exception("Audit pipeline raised an exception") q.put(("error", str(exc))) finally: q.put(("done", None)) def _dataclass_to_json(obj) -> str: """Serialize a dataclass (including nested) to a JSON string safely.""" return json.dumps(dataclasses.asdict(obj), default=str, ensure_ascii=False) @app.post("/api/audit") async def run_audit_endpoint(request: Request): if ORCHESTRATOR is None: async def _err(): yield b'data: {"type":"error","message":"Models not initialised yet. Please wait."}\n\n' return StreamingResponse(_err(), media_type="text/event-stream") # ── Parse multipart form ────────────────────────────────────────────────── try: form = await request.form() except Exception as exc: return JSONResponse({"error": str(exc)}, status_code=400) invoice_upload = form.get("invoice_file") if invoice_upload is None or not hasattr(invoice_upload, "read"): return JSONResponse({"error": "invoice_file is required"}, status_code=422) supplier_name: str = form.get("supplier_name", "") or "" # Collect delivery photos (multiple files with same key name) delivery_uploads = [] for key, val in form.multi_items(): if key == "delivery_photos" and hasattr(val, "read"): delivery_uploads.append(val) # ── Save to temp dir ────────────────────────────────────────────────────── tmpdir = tempfile.mkdtemp(prefix="kirana_") try: suffix = Path(invoice_upload.filename or "invoice.jpg").suffix.lower() or ".jpg" invoice_path = os.path.join(tmpdir, f"invoice{suffix}") with open(invoice_path, "wb") as fh: fh.write(await invoice_upload.read()) photo_paths: List[str] = [] for i, photo in enumerate(delivery_uploads[:5]): # max 5 if photo.filename: psuffix = Path(photo.filename).suffix.lower() or ".jpg" pp = os.path.join(tmpdir, f"photo_{i}{psuffix}") with open(pp, "wb") as fh: fh.write(await photo.read()) photo_paths.append(pp) except Exception as exc: shutil.rmtree(tmpdir, ignore_errors=True) return JSONResponse({"error": str(exc)}, status_code=500) # ── Stream audit events ─────────────────────────────────────────────────── q: Queue = Queue() thread = threading.Thread( target=_run_audit_thread, args=(ORCHESTRATOR, invoice_path, photo_paths, supplier_name, q), daemon=True, ) thread.start() async def event_stream(): loop = asyncio.get_event_loop() try: while True: try: event_type, data = await loop.run_in_executor( None, lambda: q.get(timeout=180.0) ) except Empty: yield b": keep-alive\n\n" continue if event_type == "progress": payload = json.dumps({ "type": "progress", "stage": data.stage, "message": data.message, "agent_name": data.agent_name, "duration_ms": data.duration_ms, }, ensure_ascii=False) yield f"data: {payload}\n\n".encode() elif event_type == "result": payload = json.dumps({ "type": "result", "report": json.loads(_dataclass_to_json(data)), }, ensure_ascii=False) yield f"data: {payload}\n\n".encode() break elif event_type == "error": payload = json.dumps({"type": "error", "message": str(data)}) yield f"data: {payload}\n\n".encode() break elif event_type == "done": break finally: shutil.rmtree(tmpdir, ignore_errors=True) return StreamingResponse( event_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) # ── Startup ─────────────────────────────────────────────────────────────────── load_models() if __name__ == "__main__": app.launch(show_error=True)