from __future__ import annotations import json import os from pathlib import Path from typing import List, Optional, Sequence, Union import numpy as np import onnxruntime as ort from PIL import Image import torch from torchvision import transforms PathLike = Union[str, os.PathLike] class Vocabulary: def __init__(self, serialized: dict): self.specials = serialized.get("specials", ["", "", ""]) self.char2idx: dict[str, int] = serialized["char2idx"] idx2char_raw = serialized["idx2char"] if isinstance(idx2char_raw, dict): self.idx2char = {int(k): v for k, v in idx2char_raw.items()} else: self.idx2char = {int(idx): char for idx, char in enumerate(idx2char_raw)} def encode(self, text: str) -> List[int]: sos = self.char2idx[""] eos = self.char2idx[""] body = [self.char2idx[c] for c in text if c in self.char2idx] return [sos, *body, eos] def decode(self, tokens: Sequence[int]) -> str: pad = self.char2idx[""] sos = self.char2idx[""] eos = self.char2idx[""] result: List[str] = [] for token in tokens: if token in (pad, sos): continue if token == eos: break result.append(self.idx2char[token]) return "".join(result) def __len__(self) -> int: return len(self.char2idx) def _load_config(config_path: Path) -> dict: with open(config_path, "r", encoding="utf-8") as f: return json.load(f) def _resolve_hparams(config_data: dict) -> dict: candidates = config_data.get("hyperparameters", config_data) defaults = { "img_height": 128, "img_width": 320, "max_decode_len": 128, } resolved = {k: candidates.get(k, v) for k, v in defaults.items()} resolved["img_height"] = int(resolved["img_height"]) resolved["img_width"] = int(resolved["img_width"]) resolved["max_decode_len"] = int(resolved["max_decode_len"]) return resolved class ONNXPredictor: def __init__( self, model_path: PathLike, config_path: PathLike, providers: Optional[list[str]] = None, ) -> None: config_data = _load_config(Path(config_path)) if "vocab" not in config_data: raise ValueError("config.json must include serialized vocabulary under key 'vocab'.") self.vocab = Vocabulary(config_data["vocab"]) self.hparams = _resolve_hparams(config_data) self.session = ort.InferenceSession( str(model_path), providers=providers or ort.get_available_providers(), ) self.transform = transforms.Compose( [ transforms.Resize((self.hparams["img_height"], self.hparams["img_width"])), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) def _prepare_image(self, image: Union[PathLike, Image.Image]) -> np.ndarray: if isinstance(image, Image.Image): pil_image = image.convert("RGB") else: image_path = Path(image).expanduser() if not image_path.exists(): raise FileNotFoundError(f"Image not found: {image_path}") pil_image = Image.open(image_path).convert("RGB") tensor = self.transform(pil_image) # (C, H, W) return tensor.unsqueeze(0).cpu().numpy().astype(np.float32) # (1, C, H, W) def _greedy_decode(self, image_array: np.ndarray) -> List[int]: sos_idx = self.vocab.char2idx[""] eos_idx = self.vocab.char2idx[""] pad_idx = self.vocab.char2idx[""] generated = [sos_idx] max_len = self.hparams["max_decode_len"] for _ in range(max_len - 1): # leave room for EOS tgt = np.full((1, max_len), pad_idx, dtype=np.int64) tgt[0, : len(generated)] = generated outputs = self.session.run( ["logits"], { "images": image_array, "tgt": tgt, }, ) logits = outputs[0] # (1, seq, vocab) next_pos = len(generated) - 1 # position we just filled next_token = int(logits[0, next_pos, :].argmax(axis=-1)) if next_token == eos_idx: break generated.append(next_token) return generated def predict(self, image: Union[PathLike, Image.Image]) -> str: image_array = self._prepare_image(image) tokens = self._greedy_decode(image_array) return self.vocab.decode(tokens) def main() -> None: # Edit these paths for quick experiments model_path = "checkpoints_base/khmer_ocr.onnx" config_path = "checkpoints_base/config.json" image_path = "/home/metythorn/konai/services/ocr-service/data/raw/samples/image copy.png" providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] predictor = ONNXPredictor(model_path=model_path, config_path=config_path, providers=providers) print(predictor.predict(image_path)) if __name__ == "__main__": main()