naazimsnh02's picture
Model loading fix for vison model
a2c4e53
Raw
History Blame
12.9 kB
from __future__ import annotations
import asyncio
import dataclasses
import json
import logging
import os
import shutil
import tempfile
import threading
from pathlib import Path
from queue import Empty, Queue
from typing import List, Optional
import gradio as gr
from fastapi import Request
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from catalog import get_catalog
from storage import StorageManager
from tracer import AgentTracer
from agents.invoice_extractor import InvoiceExtractorAgent
from agents.product_matcher import ProductMatcherAgent
from agents.pricing_agent import PricingAgent
from agents.visual_counter import VisualCounterAgent
from agents.reconciliation_agent import ReconciliationAgent
from agents.savings_agent import SavingsAgent
from pipeline import AuditOrchestrator
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s β€” %(message)s")
logger = logging.getLogger(__name__)
# ── Global state ──────────────────────────────────────────────────────────────
MODELS_LOADED = False
ORCHESTRATOR: Optional[AuditOrchestrator] = None
STORAGE: Optional[StorageManager] = None
_HF_TOKEN = os.getenv("HF_TOKEN", "")
_CATALOG = get_catalog()
# ── Model loading ─────────────────────────────────────────────────────────────
def load_models() -> None:
global MODELS_LOADED, ORCHESTRATOR, STORAGE
STORAGE = StorageManager()
STORAGE.initialise_schema()
if not STORAGE.available:
logger.warning("Storage unavailable β€” running without price history / audit log")
tracer = AgentTracer(hf_token=_HF_TOKEN or None)
vision_llm = None
text_llm = None
ort_session = None
class_names: List[str] = []
try:
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import onnxruntime as ort
logger.info("Downloading vision model (MiniCPM-V 4.6 merged)…")
import os
import torch
from transformers import AutoProcessor
_BASE_REPO = "openbmb/MiniCPM-V-4.6"
_MERGED_REPO = "build-small-hackathon/minicpm-v-4-6-indian-invoice-extraction-merged"
_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
_device = "cuda" if torch.cuda.is_available() else "cpu"
try:
from transformers import AutoModelForImageTextToText as _VisionModel
except ImportError:
from transformers import AutoModelForMultimodalLM as _VisionModel
# The merged repo is a fully-merged model (not a LoRA delta) β€” load it directly.
# Loading base + overlaying weights fails because the repos use different param naming.
logger.info("Loading merged vision model from %s …", _MERGED_REPO)
_model_kwargs = {
"trust_remote_code": True,
"torch_dtype": _dtype,
}
if _HF_TOKEN:
_model_kwargs["token"] = _HF_TOKEN
if torch.cuda.is_available():
_model_kwargs["device_map"] = "auto"
_vision_model = _VisionModel.from_pretrained(_MERGED_REPO, **_model_kwargs)
if not torch.cuda.is_available():
_vision_model.to(_device)
_vision_model.eval()
_processor_kwargs = {"trust_remote_code": True}
if _HF_TOKEN:
_processor_kwargs["token"] = _HF_TOKEN
# Load processor from base repo β€” has complete preprocessor/chat-template configs.
_vision_processor = AutoProcessor.from_pretrained(_BASE_REPO, **_processor_kwargs)
vision_llm = (_vision_model, _vision_processor)
logger.info("Vision LLM ready (device=%s dtype=%s)", _device, _dtype)
logger.info("Downloading text model (MiniCPM5-1B)…")
text_model_path = hf_hub_download(
repo_id="build-small-hackathon/minicpm5-1b-indian-fmcg-normalizer",
filename="MiniCPM5-1B.Q4_K_M.gguf",
)
text_llm = Llama(
model_path=text_model_path,
n_ctx=8192,
n_threads=4,
verbose=False,
)
logger.info("Text LLM ready")
logger.info("Downloading YOLO model…")
onnx_path = hf_hub_download(
repo_id="build-small-hackathon/yolo26n-indian-fmcg-detection",
filename="yolo26n_fmcg.onnx",
)
class_names_path = hf_hub_download(
repo_id="build-small-hackathon/yolo26n-indian-fmcg-detection",
filename="class_names.json",
)
with open(class_names_path, encoding="utf-8") as f:
class_names = json.load(f)
ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
logger.info("YOLO ready (%d classes)", len(class_names))
MODELS_LOADED = True
except Exception as exc:
logger.exception("Model loading failed β€” app will start in degraded mode")
ORCHESTRATOR = AuditOrchestrator(
invoice_extractor=InvoiceExtractorAgent(vision_llm),
product_matcher=ProductMatcherAgent(text_llm, _CATALOG),
pricing_agent=PricingAgent(STORAGE, _CATALOG),
visual_counter=VisualCounterAgent(ort_session, class_names) if ort_session else None,
reconciliation_agent=ReconciliationAgent(),
savings_agent=SavingsAgent(text_llm),
storage=STORAGE,
tracer=tracer,
)
logger.info("AuditOrchestrator ready. Models loaded: %s", MODELS_LOADED)
# ── Gradio Server ─────────────────────────────────────────────────────────────
gr.set_static_paths(paths=["static/"])
app = gr.Server()
# ── Static routes ─────────────────────────────────────────────────────────────
@app.get("/")
async def homepage():
html = (Path(__file__).parent / "static" / "index.html").read_text(encoding="utf-8")
return HTMLResponse(content=html)
# ── Health / history ──────────────────────────────────────────────────────────
@app.get("/api/health")
async def health():
return JSONResponse({
"status": "ok",
"models_loaded": MODELS_LOADED,
"storage_available": bool(STORAGE and STORAGE.available),
"catalog_size": len(_CATALOG.all_product_ids()),
})
@app.get("/api/history")
async def get_history(limit: int = 20):
if not STORAGE or not STORAGE.available:
return JSONResponse({"audits": []})
return JSONResponse({"audits": STORAGE.get_recent_audits(min(limit, 100))})
# Gradio-client-compatible endpoint (satisfies Off-Brand + tool use by @gradio/client)
@app.api(name="health_check")
def health_check() -> dict:
return {"status": "ok", "models_loaded": MODELS_LOADED}
# ── Audit streaming endpoint ──────────────────────────────────────────────────
def _run_audit_thread(
orchestrator: AuditOrchestrator,
invoice_path: str,
photo_paths: List[str],
supplier: str,
q: Queue,
) -> None:
"""Drive the pipeline generator in a background thread and push events."""
try:
gen = orchestrator.run_audit(invoice_path, photo_paths, supplier)
report = None
while True:
try:
update = next(gen)
q.put(("progress", update))
except StopIteration as exc:
report = exc.value
break
q.put(("result", report))
except Exception as exc:
logger.exception("Audit pipeline raised an exception")
q.put(("error", str(exc)))
finally:
q.put(("done", None))
def _dataclass_to_json(obj) -> str:
"""Serialize a dataclass (including nested) to a JSON string safely."""
return json.dumps(dataclasses.asdict(obj), default=str, ensure_ascii=False)
@app.post("/api/audit")
async def run_audit_endpoint(request: Request):
if ORCHESTRATOR is None:
async def _err():
yield b'data: {"type":"error","message":"Models not initialised yet. Please wait."}\n\n'
return StreamingResponse(_err(), media_type="text/event-stream")
# ── Parse multipart form ──────────────────────────────────────────────────
try:
form = await request.form()
except Exception as exc:
return JSONResponse({"error": str(exc)}, status_code=400)
invoice_upload = form.get("invoice_file")
if invoice_upload is None or not hasattr(invoice_upload, "read"):
return JSONResponse({"error": "invoice_file is required"}, status_code=422)
supplier_name: str = form.get("supplier_name", "") or ""
# Collect delivery photos (multiple files with same key name)
delivery_uploads = []
for key, val in form.multi_items():
if key == "delivery_photos" and hasattr(val, "read"):
delivery_uploads.append(val)
# ── Save to temp dir ──────────────────────────────────────────────────────
tmpdir = tempfile.mkdtemp(prefix="kirana_")
try:
suffix = Path(invoice_upload.filename or "invoice.jpg").suffix.lower() or ".jpg"
invoice_path = os.path.join(tmpdir, f"invoice{suffix}")
with open(invoice_path, "wb") as fh:
fh.write(await invoice_upload.read())
photo_paths: List[str] = []
for i, photo in enumerate(delivery_uploads[:5]): # max 5
if photo.filename:
psuffix = Path(photo.filename).suffix.lower() or ".jpg"
pp = os.path.join(tmpdir, f"photo_{i}{psuffix}")
with open(pp, "wb") as fh:
fh.write(await photo.read())
photo_paths.append(pp)
except Exception as exc:
shutil.rmtree(tmpdir, ignore_errors=True)
return JSONResponse({"error": str(exc)}, status_code=500)
# ── Stream audit events ───────────────────────────────────────────────────
q: Queue = Queue()
thread = threading.Thread(
target=_run_audit_thread,
args=(ORCHESTRATOR, invoice_path, photo_paths, supplier_name, q),
daemon=True,
)
thread.start()
async def event_stream():
loop = asyncio.get_event_loop()
try:
while True:
try:
event_type, data = await loop.run_in_executor(
None, lambda: q.get(timeout=180.0)
)
except Empty:
yield b": keep-alive\n\n"
continue
if event_type == "progress":
payload = json.dumps({
"type": "progress",
"stage": data.stage,
"message": data.message,
"agent_name": data.agent_name,
"duration_ms": data.duration_ms,
}, ensure_ascii=False)
yield f"data: {payload}\n\n".encode()
elif event_type == "result":
payload = json.dumps({
"type": "result",
"report": json.loads(_dataclass_to_json(data)),
}, ensure_ascii=False)
yield f"data: {payload}\n\n".encode()
break
elif event_type == "error":
payload = json.dumps({"type": "error", "message": str(data)})
yield f"data: {payload}\n\n".encode()
break
elif event_type == "done":
break
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
# ── Startup ───────────────────────────────────────────────────────────────────
load_models()
if __name__ == "__main__":
app.launch(show_error=True)