import os, uuid, asyncio, logging from datetime import datetime, timedelta from typing import Dict, Optional, List, Any from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, HttpUrl, root_validator, ValidationError from fastapi.encoders import jsonable_encoder from app.lens_images_core import translate_lens from app.lens_text_core import translate_lens_text PORT = int(os.getenv("PORT", 8080)) MAX_WORKERS = int(os.getenv("MAX_WORKERS", 8)) MAX_WORKERS_IMAGES = int(os.getenv("MAX_WORKERS_IMAGES", MAX_WORKERS)) MAX_WORKERS_TEXT = int(os.getenv("MAX_WORKERS_TEXT", 3)) RESULTS_TTL = int(os.getenv("RESULTS_TTL_SECONDS", 300)) MAX_B64_IMG_LEN = int(os.getenv("MAX_BASE64_IMAGE_LENGTH", 5_000_000)) JOB_DELAY_SEC = int(os.getenv("JOB_DELAY_SECONDS", 0.1)) logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", ) log = logging.getLogger("ocr_ws") ENABLE_BACKGROUND_WORKERS = os.getenv("ENABLE_BACKGROUND_WORKERS", "0").strip().lower() in ("1","true","yes","on") workers_started: bool = False _workers_lock = asyncio.Lock() async def ensure_workers_started(): global workers_started if workers_started: return async with _workers_lock: if workers_started: return for _ in range(MAX_WORKERS_IMAGES): asyncio.create_task(worker("lens_images", jobq_img)) for _ in range(MAX_WORKERS_TEXT): asyncio.create_task(worker("lens_text", jobq_text)) workers_started = True log.info("workers started on-demand") app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) jobq_img: asyncio.Queue = asyncio.Queue() jobq_text: asyncio.Queue = asyncio.Queue() class Position(BaseModel): top: float; left: float; width: float; height: float viewport_width: int; viewport_height: int scroll_x: float; scroll_y: float class PipelineEvent(BaseModel): stage: str; at: datetime; target: Optional[str] = None class Context(BaseModel): page_url: Optional[HttpUrl] = None timestamp: Optional[datetime] = None class Metadata(BaseModel): image_id: str original_image_url: Optional[HttpUrl] = None position: Optional[Position] = None pipeline: List[PipelineEvent] = [] ocr_image: Optional[str] = None extra: Optional[Dict[str, Any]] = None @root_validator(pre=True) def _no_blob_urls(cls, v): url = v.get("original_image_url") if not url: v["original_image_url"] = None return v if isinstance(url, str) and url.startswith("blob:"): raise ValueError("original_image_url must be http(s)") return v class Job(BaseModel): mode: str = "lens_images" lang: str = "en" type: str = "image" src: Optional[HttpUrl] = None menu: Optional[str] = None context: Optional[Context] = None metadata: Metadata @root_validator(pre=True) def _src_no_blob(cls, v): s = v.get("src") if not s: v["src"] = None return v if isinstance(s, str) and s.startswith("blob:"): raise ValueError("src must be http(s)") return v class WsMessage(BaseModel): type: str id: Optional[str] = None payload: Optional[Job] = None jobq: asyncio.Queue = asyncio.Queue() pending_ws: Dict[str, WebSocket] = {} results: Dict[str, dict] = {} @app.api_route("/health", methods=["GET", "HEAD"]) async def health(): return {"ok": True, "timestamp": datetime.utcnow().isoformat()} @app.post("/translate") async def translate(job: Job): await ensure_workers_started() if job.mode not in ("lens_images", "lens_text"): raise HTTPException(400, "unsupported mode") jid = uuid.uuid4().hex job.metadata.pipeline.append(PipelineEvent(stage="received_rest", at=datetime.utcnow())) if job.mode == "lens_images": await jobq_img.put((jid, job)) else: await jobq_text.put((jid, job)) results[jid] = {"status": "queued", "_created_at": datetime.utcnow()} return {"id": jid, "status": "queued"} @app.get("/translate/{jid}") async def poll(jid: str): if jid not in results: raise HTTPException(404) payload = results[jid].copy(); payload.pop("_created_at", None) return {"id": jid, **payload} @app.websocket("/ws") async def ws_endpoint(ws: WebSocket): await ws.accept() await ensure_workers_started() try: while True: raw = await ws.receive_json() try: msg = WsMessage(**raw) except ValidationError as ve: await ws.send_json({"type": "error","detail": ve.errors()}) continue match msg.type: case "job": jid = msg.id or uuid.uuid4().hex pending_ws[jid] = ws await ws.send_json(jsonable_encoder({"type": "ack", "id": jid})) if msg.payload.mode == "lens_images": await jobq_img.put((jid, msg.payload)) elif msg.payload.mode == "lens_text": await jobq_text.put((jid, msg.payload)) else: await ws.send_json({"type": "error","detail": "unsupported_mode"}) pending_ws.pop(jid, None) continue results[jid] = {"status": "queued", "_created_at": datetime.utcnow()} case _: await ws.send_json({"type": "error","detail": "unknown_type"}) except WebSocketDisconnect: pass finally: for jid, sock in list(pending_ws.items()): if sock is ws: pending_ws.pop(jid, None) async def worker(mode: str, q: asyncio.Queue): while True: jid, job = await q.get() try: job.metadata.pipeline.append(PipelineEvent(stage="worker_start", at=datetime.utcnow())) if not job.src: raise RuntimeError("src missing") log.info("worker start %s mode=%s src=%s", jid, job.mode, job.src) if mode == "lens_images": res = await translate_lens(str(job.src), job.lang) elif mode == "lens_text": res = await translate_lens_text(str(job.src)) else: raise RuntimeError(f"unsupported mode {mode}") img_b64 = res.get("image") if img_b64 and len(img_b64) > MAX_B64_IMG_LEN: res.pop("image", None) job.metadata.extra = job.metadata.extra or {} job.metadata.extra.setdefault(job.mode, {})["dropped_ocr_image_due_to_size"] = True job.metadata.pipeline.append(PipelineEvent(stage="translated", at=datetime.utcnow())) payload = {**res, "metadata": job.metadata.dict()} serial = jsonable_encoder({"type": "result", "id": jid, "result": payload}) ws = pending_ws.pop(jid, None) if ws: try: await ws.send_json(serial) log.info("sent WS result %s", jid) except Exception: pending_ws.pop(jid, None) results[jid] = {"status": "done", "result": payload, "_created_at": datetime.utcnow()} log.info("worker done %s mode=%s", jid, job.mode) except Exception as e: log.exception("worker error %s", jid, exc_info=e) err_txt = (str(e) or e.__class__.__name__) err_type = e.__class__.__name__ err = {"type": "error", "id": jid, "error": err_txt, "error_type": err_type} ws = pending_ws.pop(jid, None) if ws: try: await ws.send_json(jsonable_encoder(err)) except Exception: pass results[jid] = {"status": "error", "result": err_txt, "error_type": err_type, "_created_at": datetime.utcnow()} finally: q.task_done() if JOB_DELAY_SEC > 0: await asyncio.sleep(JOB_DELAY_SEC) async def cleanup(): while True: await asyncio.sleep(60) cutoff = datetime.utcnow() - timedelta(seconds=RESULTS_TTL) for jid in [k for k,v in results.items() if v.get("_created_at") < cutoff]: results.pop(jid, None) @app.on_event("startup") async def startup(): if ENABLE_BACKGROUND_WORKERS: for _ in range(MAX_WORKERS_IMAGES): asyncio.create_task(worker("lens_images", jobq_img)) for _ in range(MAX_WORKERS_TEXT): asyncio.create_task(worker("lens_text", jobq_text)) asyncio.create_task(cleanup()) log.info( "startup OK – %d image workers, %d text workers, TTL=%ds (workers_enabled=%s)", MAX_WORKERS_IMAGES, MAX_WORKERS_TEXT, RESULTS_TTL, ENABLE_BACKGROUND_WORKERS )