| from __future__ import annotations |
| import json |
| import os |
| from pathlib import Path |
| from typing import List, Optional, Sequence, Union |
|
|
| import numpy as np |
| import onnxruntime as ort |
| from PIL import Image |
| import torch |
| from torchvision import transforms |
|
|
|
|
| PathLike = Union[str, os.PathLike] |
|
|
|
|
| class Vocabulary: |
|
|
| def __init__(self, serialized: dict): |
| self.specials = serialized.get("specials", ["<PAD>", "<SOS>", "<EOS>"]) |
| self.char2idx: dict[str, int] = serialized["char2idx"] |
| idx2char_raw = serialized["idx2char"] |
| if isinstance(idx2char_raw, dict): |
| self.idx2char = {int(k): v for k, v in idx2char_raw.items()} |
| else: |
| self.idx2char = {int(idx): char for idx, char in enumerate(idx2char_raw)} |
|
|
| def encode(self, text: str) -> List[int]: |
| sos = self.char2idx["<SOS>"] |
| eos = self.char2idx["<EOS>"] |
| body = [self.char2idx[c] for c in text if c in self.char2idx] |
| return [sos, *body, eos] |
|
|
| def decode(self, tokens: Sequence[int]) -> str: |
| pad = self.char2idx["<PAD>"] |
| sos = self.char2idx["<SOS>"] |
| eos = self.char2idx["<EOS>"] |
| result: List[str] = [] |
| for token in tokens: |
| if token in (pad, sos): |
| continue |
| if token == eos: |
| break |
| result.append(self.idx2char[token]) |
| return "".join(result) |
|
|
| def __len__(self) -> int: |
| return len(self.char2idx) |
|
|
|
|
| def _load_config(config_path: Path) -> dict: |
| with open(config_path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def _resolve_hparams(config_data: dict) -> dict: |
| candidates = config_data.get("hyperparameters", config_data) |
| defaults = { |
| "img_height": 128, |
| "img_width": 320, |
| "max_decode_len": 128, |
| } |
| resolved = {k: candidates.get(k, v) for k, v in defaults.items()} |
| resolved["img_height"] = int(resolved["img_height"]) |
| resolved["img_width"] = int(resolved["img_width"]) |
| resolved["max_decode_len"] = int(resolved["max_decode_len"]) |
| return resolved |
|
|
|
|
| class ONNXPredictor: |
| def __init__( |
| self, |
| model_path: PathLike, |
| config_path: PathLike, |
| providers: Optional[list[str]] = None, |
| ) -> None: |
| config_data = _load_config(Path(config_path)) |
| if "vocab" not in config_data: |
| raise ValueError("config.json must include serialized vocabulary under key 'vocab'.") |
| self.vocab = Vocabulary(config_data["vocab"]) |
| self.hparams = _resolve_hparams(config_data) |
|
|
| self.session = ort.InferenceSession( |
| str(model_path), |
| providers=providers or ort.get_available_providers(), |
| ) |
|
|
| self.transform = transforms.Compose( |
| [ |
| transforms.Resize((self.hparams["img_height"], self.hparams["img_width"])), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ] |
| ) |
|
|
| def _prepare_image(self, image: Union[PathLike, Image.Image]) -> np.ndarray: |
| if isinstance(image, Image.Image): |
| pil_image = image.convert("RGB") |
| else: |
| image_path = Path(image).expanduser() |
| if not image_path.exists(): |
| raise FileNotFoundError(f"Image not found: {image_path}") |
| pil_image = Image.open(image_path).convert("RGB") |
| tensor = self.transform(pil_image) |
| return tensor.unsqueeze(0).cpu().numpy().astype(np.float32) |
|
|
| def _greedy_decode(self, image_array: np.ndarray) -> List[int]: |
| sos_idx = self.vocab.char2idx["<SOS>"] |
| eos_idx = self.vocab.char2idx["<EOS>"] |
| pad_idx = self.vocab.char2idx["<PAD>"] |
| generated = [sos_idx] |
| max_len = self.hparams["max_decode_len"] |
|
|
| for _ in range(max_len - 1): |
| tgt = np.full((1, max_len), pad_idx, dtype=np.int64) |
| tgt[0, : len(generated)] = generated |
| outputs = self.session.run( |
| ["logits"], |
| { |
| "images": image_array, |
| "tgt": tgt, |
| }, |
| ) |
| logits = outputs[0] |
| next_pos = len(generated) - 1 |
| next_token = int(logits[0, next_pos, :].argmax(axis=-1)) |
| if next_token == eos_idx: |
| break |
| generated.append(next_token) |
|
|
| return generated |
|
|
| def predict(self, image: Union[PathLike, Image.Image]) -> str: |
| image_array = self._prepare_image(image) |
| tokens = self._greedy_decode(image_array) |
| return self.vocab.decode(tokens) |
|
|
|
|
| def main() -> None: |
| |
| model_path = "checkpoints_base/khmer_ocr.onnx" |
| config_path = "checkpoints_base/config.json" |
| image_path = "/home/metythorn/konai/services/ocr-service/data/raw/samples/image copy.png" |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
|
|
| predictor = ONNXPredictor(model_path=model_path, config_path=config_path, providers=providers) |
| print(predictor.predict(image_path)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|