metythorn's picture
Upload onnx.py with huggingface_hub
8871ca5 verified
Raw
History Blame
5.25 kB
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import List, Optional, Sequence, Union
import numpy as np
import onnxruntime as ort
from PIL import Image
import torch
from torchvision import transforms
PathLike = Union[str, os.PathLike]
class Vocabulary:
def __init__(self, serialized: dict):
self.specials = serialized.get("specials", ["<PAD>", "<SOS>", "<EOS>"])
self.char2idx: dict[str, int] = serialized["char2idx"]
idx2char_raw = serialized["idx2char"]
if isinstance(idx2char_raw, dict):
self.idx2char = {int(k): v for k, v in idx2char_raw.items()}
else:
self.idx2char = {int(idx): char for idx, char in enumerate(idx2char_raw)}
def encode(self, text: str) -> List[int]:
sos = self.char2idx["<SOS>"]
eos = self.char2idx["<EOS>"]
body = [self.char2idx[c] for c in text if c in self.char2idx]
return [sos, *body, eos]
def decode(self, tokens: Sequence[int]) -> str:
pad = self.char2idx["<PAD>"]
sos = self.char2idx["<SOS>"]
eos = self.char2idx["<EOS>"]
result: List[str] = []
for token in tokens:
if token in (pad, sos):
continue
if token == eos:
break
result.append(self.idx2char[token])
return "".join(result)
def __len__(self) -> int:
return len(self.char2idx)
def _load_config(config_path: Path) -> dict:
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
def _resolve_hparams(config_data: dict) -> dict:
candidates = config_data.get("hyperparameters", config_data)
defaults = {
"img_height": 128,
"img_width": 320,
"max_decode_len": 128,
}
resolved = {k: candidates.get(k, v) for k, v in defaults.items()}
resolved["img_height"] = int(resolved["img_height"])
resolved["img_width"] = int(resolved["img_width"])
resolved["max_decode_len"] = int(resolved["max_decode_len"])
return resolved
class ONNXPredictor:
def __init__(
self,
model_path: PathLike,
config_path: PathLike,
providers: Optional[list[str]] = None,
) -> None:
config_data = _load_config(Path(config_path))
if "vocab" not in config_data:
raise ValueError("config.json must include serialized vocabulary under key 'vocab'.")
self.vocab = Vocabulary(config_data["vocab"])
self.hparams = _resolve_hparams(config_data)
self.session = ort.InferenceSession(
str(model_path),
providers=providers or ort.get_available_providers(),
)
self.transform = transforms.Compose(
[
transforms.Resize((self.hparams["img_height"], self.hparams["img_width"])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def _prepare_image(self, image: Union[PathLike, Image.Image]) -> np.ndarray:
if isinstance(image, Image.Image):
pil_image = image.convert("RGB")
else:
image_path = Path(image).expanduser()
if not image_path.exists():
raise FileNotFoundError(f"Image not found: {image_path}")
pil_image = Image.open(image_path).convert("RGB")
tensor = self.transform(pil_image) # (C, H, W)
return tensor.unsqueeze(0).cpu().numpy().astype(np.float32) # (1, C, H, W)
def _greedy_decode(self, image_array: np.ndarray) -> List[int]:
sos_idx = self.vocab.char2idx["<SOS>"]
eos_idx = self.vocab.char2idx["<EOS>"]
pad_idx = self.vocab.char2idx["<PAD>"]
generated = [sos_idx]
max_len = self.hparams["max_decode_len"]
for _ in range(max_len - 1): # leave room for EOS
tgt = np.full((1, max_len), pad_idx, dtype=np.int64)
tgt[0, : len(generated)] = generated
outputs = self.session.run(
["logits"],
{
"images": image_array,
"tgt": tgt,
},
)
logits = outputs[0] # (1, seq, vocab)
next_pos = len(generated) - 1 # position we just filled
next_token = int(logits[0, next_pos, :].argmax(axis=-1))
if next_token == eos_idx:
break
generated.append(next_token)
return generated
def predict(self, image: Union[PathLike, Image.Image]) -> str:
image_array = self._prepare_image(image)
tokens = self._greedy_decode(image_array)
return self.vocab.decode(tokens)
def main() -> None:
# Edit these paths for quick experiments
model_path = "checkpoints_base/khmer_ocr.onnx"
config_path = "checkpoints_base/config.json"
image_path = "/home/metythorn/konai/services/ocr-service/data/raw/samples/image copy.png"
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
predictor = ONNXPredictor(model_path=model_path, config_path=config_path, providers=providers)
print(predictor.predict(image_path))
if __name__ == "__main__":
main()