| from __future__ import annotations |
|
|
| import asyncio |
| import dataclasses |
| import json |
| import logging |
| import os |
| import shutil |
| import tempfile |
| import threading |
| from pathlib import Path |
| from queue import Empty, Queue |
| from typing import List, Optional |
|
|
| import gradio as gr |
| from fastapi import Request |
| from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse |
|
|
| from catalog import get_catalog |
| from storage import StorageManager |
| from tracer import AgentTracer |
| from agents.invoice_extractor import InvoiceExtractorAgent |
| from agents.product_matcher import ProductMatcherAgent |
| from agents.pricing_agent import PricingAgent |
| from agents.visual_counter import VisualCounterAgent |
| from agents.reconciliation_agent import ReconciliationAgent |
| from agents.savings_agent import SavingsAgent |
| from pipeline import AuditOrchestrator |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s β %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| |
|
|
| MODELS_LOADED = False |
| ORCHESTRATOR: Optional[AuditOrchestrator] = None |
| STORAGE: Optional[StorageManager] = None |
| _HF_TOKEN = os.getenv("HF_TOKEN", "") |
| _CATALOG = get_catalog() |
|
|
| |
|
|
| def load_models() -> None: |
| global MODELS_LOADED, ORCHESTRATOR, STORAGE |
|
|
| STORAGE = StorageManager() |
| STORAGE.initialise_schema() |
| if not STORAGE.available: |
| logger.warning("Storage unavailable β running without price history / audit log") |
|
|
| tracer = AgentTracer(hf_token=_HF_TOKEN or None) |
|
|
| vision_llm = None |
| text_llm = None |
| ort_session = None |
| class_names: List[str] = [] |
|
|
| try: |
| from huggingface_hub import hf_hub_download |
| from llama_cpp import Llama |
| import onnxruntime as ort |
|
|
| logger.info("Downloading vision model (MiniCPM-V 4.6 merged)β¦") |
| import os |
| import torch |
| from transformers import AutoProcessor |
|
|
| _BASE_REPO = "openbmb/MiniCPM-V-4.6" |
| _MERGED_REPO = "build-small-hackathon/minicpm-v-4-6-indian-invoice-extraction-merged" |
| _dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| _device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| try: |
| from transformers import AutoModelForImageTextToText as _VisionModel |
| except ImportError: |
| from transformers import AutoModelForMultimodalLM as _VisionModel |
|
|
| |
| |
| logger.info("Loading merged vision model from %s β¦", _MERGED_REPO) |
| _model_kwargs = { |
| "trust_remote_code": True, |
| "torch_dtype": _dtype, |
| } |
| if _HF_TOKEN: |
| _model_kwargs["token"] = _HF_TOKEN |
| if torch.cuda.is_available(): |
| _model_kwargs["device_map"] = "auto" |
| _vision_model = _VisionModel.from_pretrained(_MERGED_REPO, **_model_kwargs) |
| if not torch.cuda.is_available(): |
| _vision_model.to(_device) |
|
|
| _vision_model.eval() |
| _processor_kwargs = {"trust_remote_code": True} |
| if _HF_TOKEN: |
| _processor_kwargs["token"] = _HF_TOKEN |
| |
| _vision_processor = AutoProcessor.from_pretrained(_BASE_REPO, **_processor_kwargs) |
| vision_llm = (_vision_model, _vision_processor) |
| logger.info("Vision LLM ready (device=%s dtype=%s)", _device, _dtype) |
|
|
| logger.info("Downloading text model (MiniCPM5-1B)β¦") |
| text_model_path = hf_hub_download( |
| repo_id="build-small-hackathon/minicpm5-1b-indian-fmcg-normalizer", |
| filename="MiniCPM5-1B.Q4_K_M.gguf", |
| ) |
| text_llm = Llama( |
| model_path=text_model_path, |
| n_ctx=8192, |
| n_threads=4, |
| verbose=False, |
| ) |
| logger.info("Text LLM ready") |
|
|
| logger.info("Downloading YOLO modelβ¦") |
| onnx_path = hf_hub_download( |
| repo_id="build-small-hackathon/yolo26n-indian-fmcg-detection", |
| filename="yolo26n_fmcg.onnx", |
| ) |
| class_names_path = hf_hub_download( |
| repo_id="build-small-hackathon/yolo26n-indian-fmcg-detection", |
| filename="class_names.json", |
| ) |
| with open(class_names_path, encoding="utf-8") as f: |
| class_names = json.load(f) |
| ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
| logger.info("YOLO ready (%d classes)", len(class_names)) |
|
|
| MODELS_LOADED = True |
| except Exception as exc: |
| logger.exception("Model loading failed β app will start in degraded mode") |
|
|
| ORCHESTRATOR = AuditOrchestrator( |
| invoice_extractor=InvoiceExtractorAgent(vision_llm), |
| product_matcher=ProductMatcherAgent(text_llm, _CATALOG), |
| pricing_agent=PricingAgent(STORAGE, _CATALOG), |
| visual_counter=VisualCounterAgent(ort_session, class_names) if ort_session else None, |
| reconciliation_agent=ReconciliationAgent(), |
| savings_agent=SavingsAgent(text_llm), |
| storage=STORAGE, |
| tracer=tracer, |
| ) |
| logger.info("AuditOrchestrator ready. Models loaded: %s", MODELS_LOADED) |
|
|
|
|
| |
|
|
| gr.set_static_paths(paths=["static/"]) |
| app = gr.Server() |
|
|
|
|
| |
|
|
| @app.get("/") |
| async def homepage(): |
| html = (Path(__file__).parent / "static" / "index.html").read_text(encoding="utf-8") |
| return HTMLResponse(content=html) |
|
|
|
|
| |
|
|
| @app.get("/api/health") |
| async def health(): |
| return JSONResponse({ |
| "status": "ok", |
| "models_loaded": MODELS_LOADED, |
| "storage_available": bool(STORAGE and STORAGE.available), |
| "catalog_size": len(_CATALOG.all_product_ids()), |
| }) |
|
|
|
|
| @app.get("/api/history") |
| async def get_history(limit: int = 20): |
| if not STORAGE or not STORAGE.available: |
| return JSONResponse({"audits": []}) |
| return JSONResponse({"audits": STORAGE.get_recent_audits(min(limit, 100))}) |
|
|
|
|
| |
| @app.api(name="health_check") |
| def health_check() -> dict: |
| return {"status": "ok", "models_loaded": MODELS_LOADED} |
|
|
|
|
| |
|
|
| def _run_audit_thread( |
| orchestrator: AuditOrchestrator, |
| invoice_path: str, |
| photo_paths: List[str], |
| supplier: str, |
| q: Queue, |
| ) -> None: |
| """Drive the pipeline generator in a background thread and push events.""" |
| try: |
| gen = orchestrator.run_audit(invoice_path, photo_paths, supplier) |
| report = None |
| while True: |
| try: |
| update = next(gen) |
| q.put(("progress", update)) |
| except StopIteration as exc: |
| report = exc.value |
| break |
| q.put(("result", report)) |
| except Exception as exc: |
| logger.exception("Audit pipeline raised an exception") |
| q.put(("error", str(exc))) |
| finally: |
| q.put(("done", None)) |
|
|
|
|
| def _dataclass_to_json(obj) -> str: |
| """Serialize a dataclass (including nested) to a JSON string safely.""" |
| return json.dumps(dataclasses.asdict(obj), default=str, ensure_ascii=False) |
|
|
|
|
| @app.post("/api/audit") |
| async def run_audit_endpoint(request: Request): |
| if ORCHESTRATOR is None: |
| async def _err(): |
| yield b'data: {"type":"error","message":"Models not initialised yet. Please wait."}\n\n' |
| return StreamingResponse(_err(), media_type="text/event-stream") |
|
|
| |
| try: |
| form = await request.form() |
| except Exception as exc: |
| return JSONResponse({"error": str(exc)}, status_code=400) |
|
|
| invoice_upload = form.get("invoice_file") |
| if invoice_upload is None or not hasattr(invoice_upload, "read"): |
| return JSONResponse({"error": "invoice_file is required"}, status_code=422) |
|
|
| supplier_name: str = form.get("supplier_name", "") or "" |
|
|
| |
| delivery_uploads = [] |
| for key, val in form.multi_items(): |
| if key == "delivery_photos" and hasattr(val, "read"): |
| delivery_uploads.append(val) |
|
|
| |
| tmpdir = tempfile.mkdtemp(prefix="kirana_") |
| try: |
| suffix = Path(invoice_upload.filename or "invoice.jpg").suffix.lower() or ".jpg" |
| invoice_path = os.path.join(tmpdir, f"invoice{suffix}") |
| with open(invoice_path, "wb") as fh: |
| fh.write(await invoice_upload.read()) |
|
|
| photo_paths: List[str] = [] |
| for i, photo in enumerate(delivery_uploads[:5]): |
| if photo.filename: |
| psuffix = Path(photo.filename).suffix.lower() or ".jpg" |
| pp = os.path.join(tmpdir, f"photo_{i}{psuffix}") |
| with open(pp, "wb") as fh: |
| fh.write(await photo.read()) |
| photo_paths.append(pp) |
| except Exception as exc: |
| shutil.rmtree(tmpdir, ignore_errors=True) |
| return JSONResponse({"error": str(exc)}, status_code=500) |
|
|
| |
| q: Queue = Queue() |
| thread = threading.Thread( |
| target=_run_audit_thread, |
| args=(ORCHESTRATOR, invoice_path, photo_paths, supplier_name, q), |
| daemon=True, |
| ) |
| thread.start() |
|
|
| async def event_stream(): |
| loop = asyncio.get_event_loop() |
| try: |
| while True: |
| try: |
| event_type, data = await loop.run_in_executor( |
| None, lambda: q.get(timeout=180.0) |
| ) |
| except Empty: |
| yield b": keep-alive\n\n" |
| continue |
|
|
| if event_type == "progress": |
| payload = json.dumps({ |
| "type": "progress", |
| "stage": data.stage, |
| "message": data.message, |
| "agent_name": data.agent_name, |
| "duration_ms": data.duration_ms, |
| }, ensure_ascii=False) |
| yield f"data: {payload}\n\n".encode() |
|
|
| elif event_type == "result": |
| payload = json.dumps({ |
| "type": "result", |
| "report": json.loads(_dataclass_to_json(data)), |
| }, ensure_ascii=False) |
| yield f"data: {payload}\n\n".encode() |
| break |
|
|
| elif event_type == "error": |
| payload = json.dumps({"type": "error", "message": str(data)}) |
| yield f"data: {payload}\n\n".encode() |
| break |
|
|
| elif event_type == "done": |
| break |
| finally: |
| shutil.rmtree(tmpdir, ignore_errors=True) |
|
|
| return StreamingResponse( |
| event_stream(), |
| media_type="text/event-stream", |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, |
| ) |
|
|
|
|
| |
|
|
| load_models() |
|
|
| if __name__ == "__main__": |
| app.launch(show_error=True) |
|
|