metythorn commited on
Commit
0f4b9e9
·
verified ·
1 Parent(s): 7ea5b09

Delete onnx.py

Browse files
Files changed (1) hide show
  1. onnx.py +0 -151
onnx.py DELETED
@@ -1,151 +0,0 @@
1
- from __future__ import annotations
2
- import json
3
- import os
4
- from pathlib import Path
5
- from typing import List, Optional, Sequence, Union
6
-
7
- import numpy as np
8
- import onnxruntime as ort
9
- from PIL import Image
10
- import torch
11
- from torchvision import transforms
12
-
13
-
14
- PathLike = Union[str, os.PathLike]
15
-
16
-
17
- class Vocabulary:
18
-
19
- def __init__(self, serialized: dict):
20
- self.specials = serialized.get("specials", ["<PAD>", "<SOS>", "<EOS>"])
21
- self.char2idx: dict[str, int] = serialized["char2idx"]
22
- idx2char_raw = serialized["idx2char"]
23
- if isinstance(idx2char_raw, dict):
24
- self.idx2char = {int(k): v for k, v in idx2char_raw.items()}
25
- else:
26
- self.idx2char = {int(idx): char for idx, char in enumerate(idx2char_raw)}
27
-
28
- def encode(self, text: str) -> List[int]:
29
- sos = self.char2idx["<SOS>"]
30
- eos = self.char2idx["<EOS>"]
31
- body = [self.char2idx[c] for c in text if c in self.char2idx]
32
- return [sos, *body, eos]
33
-
34
- def decode(self, tokens: Sequence[int]) -> str:
35
- pad = self.char2idx["<PAD>"]
36
- sos = self.char2idx["<SOS>"]
37
- eos = self.char2idx["<EOS>"]
38
- result: List[str] = []
39
- for token in tokens:
40
- if token in (pad, sos):
41
- continue
42
- if token == eos:
43
- break
44
- result.append(self.idx2char[token])
45
- return "".join(result)
46
-
47
- def __len__(self) -> int:
48
- return len(self.char2idx)
49
-
50
-
51
- def _load_config(config_path: Path) -> dict:
52
- with open(config_path, "r", encoding="utf-8") as f:
53
- return json.load(f)
54
-
55
-
56
- def _resolve_hparams(config_data: dict) -> dict:
57
- candidates = config_data.get("hyperparameters", config_data)
58
- defaults = {
59
- "img_height": 128,
60
- "img_width": 320,
61
- "max_decode_len": 128,
62
- }
63
- resolved = {k: candidates.get(k, v) for k, v in defaults.items()}
64
- resolved["img_height"] = int(resolved["img_height"])
65
- resolved["img_width"] = int(resolved["img_width"])
66
- resolved["max_decode_len"] = int(resolved["max_decode_len"])
67
- return resolved
68
-
69
-
70
- class ONNXPredictor:
71
- def __init__(
72
- self,
73
- model_path: PathLike,
74
- config_path: PathLike,
75
- providers: Optional[list[str]] = None,
76
- ) -> None:
77
- config_data = _load_config(Path(config_path))
78
- if "vocab" not in config_data:
79
- raise ValueError("config.json must include serialized vocabulary under key 'vocab'.")
80
- self.vocab = Vocabulary(config_data["vocab"])
81
- self.hparams = _resolve_hparams(config_data)
82
-
83
- self.session = ort.InferenceSession(
84
- str(model_path),
85
- providers=providers or ort.get_available_providers(),
86
- )
87
-
88
- self.transform = transforms.Compose(
89
- [
90
- transforms.Resize((self.hparams["img_height"], self.hparams["img_width"])),
91
- transforms.ToTensor(),
92
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
93
- ]
94
- )
95
-
96
- def _prepare_image(self, image: Union[PathLike, Image.Image]) -> np.ndarray:
97
- if isinstance(image, Image.Image):
98
- pil_image = image.convert("RGB")
99
- else:
100
- image_path = Path(image).expanduser()
101
- if not image_path.exists():
102
- raise FileNotFoundError(f"Image not found: {image_path}")
103
- pil_image = Image.open(image_path).convert("RGB")
104
- tensor = self.transform(pil_image) # (C, H, W)
105
- return tensor.unsqueeze(0).cpu().numpy().astype(np.float32) # (1, C, H, W)
106
-
107
- def _greedy_decode(self, image_array: np.ndarray) -> List[int]:
108
- sos_idx = self.vocab.char2idx["<SOS>"]
109
- eos_idx = self.vocab.char2idx["<EOS>"]
110
- pad_idx = self.vocab.char2idx["<PAD>"]
111
- generated = [sos_idx]
112
- max_len = self.hparams["max_decode_len"]
113
-
114
- for _ in range(max_len - 1): # leave room for EOS
115
- tgt = np.full((1, max_len), pad_idx, dtype=np.int64)
116
- tgt[0, : len(generated)] = generated
117
- outputs = self.session.run(
118
- ["logits"],
119
- {
120
- "images": image_array,
121
- "tgt": tgt,
122
- },
123
- )
124
- logits = outputs[0] # (1, seq, vocab)
125
- next_pos = len(generated) - 1 # position we just filled
126
- next_token = int(logits[0, next_pos, :].argmax(axis=-1))
127
- if next_token == eos_idx:
128
- break
129
- generated.append(next_token)
130
-
131
- return generated
132
-
133
- def predict(self, image: Union[PathLike, Image.Image]) -> str:
134
- image_array = self._prepare_image(image)
135
- tokens = self._greedy_decode(image_array)
136
- return self.vocab.decode(tokens)
137
-
138
-
139
- def main() -> None:
140
- # Edit these paths for quick experiments
141
- model_path = "checkpoints_base/khmer_ocr.onnx"
142
- config_path = "checkpoints_base/config.json"
143
- image_path = "/home/metythorn/konai/services/ocr-service/data/raw/samples/image copy.png"
144
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
145
-
146
- predictor = ONNXPredictor(model_path=model_path, config_path=config_path, providers=providers)
147
- print(predictor.predict(image_path))
148
-
149
-
150
- if __name__ == "__main__":
151
- main()