from __future__ import annotations import math from typing import Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models class LocalizationNetwork(nn.Module): """Predicts fiducial control points for TPS-based spatial transformer.""" def __init__(self, F: int, I_channel_num: int): super().__init__() self.F = F self.I_channel_num = I_channel_num self.conv = nn.Sequential( nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), nn.AdaptiveAvgPool2d(1), ) self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) self.localization_fc2 = nn.Linear(256, self.F * 2) # Init fc2 bias to identity TPS (see RARE paper Fig.6a) self.localization_fc2.weight.data.fill_(0) ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) def forward(self, batch_I: torch.Tensor) -> torch.Tensor: batch_size = batch_I.size(0) features = self.conv(batch_I).view(batch_size, -1) return self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) class GridGenerator(nn.Module): """Generates a TPS sampling grid from predicted control points.""" def __init__(self, F: int, I_r_size: Tuple[int, int]): super().__init__() self.eps = 1e-6 self.I_r_height, self.I_r_width = I_r_size self.F = F self.C = self._build_C(self.F) self.P = self._build_P(self.I_r_width, self.I_r_height) self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) def _build_C(self, F: int) -> np.ndarray: ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) ctrl_pts_y_top = -1 * np.ones(int(F / 2)) ctrl_pts_y_bottom = np.ones(int(F / 2)) ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) return np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) def _build_inv_delta_C(self, F: int, C: np.ndarray) -> np.ndarray: hat_C = np.zeros((F, F), dtype=float) for i in range(0, F): for j in range(i, F): r = np.linalg.norm(C[i] - C[j]) hat_C[i, j] = r hat_C[j, i] = r np.fill_diagonal(hat_C, 1) hat_C = (hat_C ** 2) * np.log(hat_C) delta_C = np.concatenate( [ np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1), ], axis=0, ) return np.linalg.inv(delta_C) def _build_P(self, I_r_width: int, I_r_height: int) -> np.ndarray: I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height P = np.stack(np.meshgrid(I_r_grid_x, I_r_grid_y), axis=2) return P.reshape([-1, 2]) def _build_P_hat(self, F: int, C: np.ndarray, P: np.ndarray) -> np.ndarray: n = P.shape[0] P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) C_tile = np.expand_dims(C, axis=0) P_diff = P_tile - C_tile rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) return np.concatenate([np.ones((n, 1)), P, rbf], axis=1) def build_P_prime(self, batch_C_prime: torch.Tensor) -> torch.Tensor: batch_size = batch_C_prime.size(0) batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(batch_size, 3, 2, device=batch_C_prime.device)), dim=1) batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) return torch.bmm(batch_P_hat, batch_T) class TPSSpatialTransformerNetwork(nn.Module): """TPS STN rectifier used in RARE, adapted for RGB inputs.""" def __init__(self, F: int, I_size: Tuple[int, int], I_r_size: Tuple[int, int], I_channel_num: int = 3): super().__init__() self.F = F self.I_size = I_size self.I_r_size = I_r_size self.I_channel_num = I_channel_num self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) self.GridGenerator = GridGenerator(self.F, self.I_r_size) def forward(self, batch_I: torch.Tensor) -> torch.Tensor: batch_C_prime = self.LocalizationNetwork(batch_I) build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) return F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe.unsqueeze(1)) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.pe[: x.size(0)] return self.dropout(x) class KhmerOCRTransformer(nn.Module): def __init__( self, vocab_size: int, d_model: int = 256, nhead: int = 4, num_layers: int = 3, backbone_name: str = "resnet18", dim_feedforward: int = 2048, dropout: float = 0.1, use_stn: bool = False, stn_fiducial_points: int = 20, img_height: int = 128, img_width: int = 320, ): super().__init__() self.use_stn = use_stn if self.use_stn: # Rectify curved/rotated text before the CNN backbone self.stn = TPSSpatialTransformerNetwork( F=stn_fiducial_points, I_size=(img_height, img_width), I_r_size=(img_height, img_width), I_channel_num=3, ) else: self.stn = None resnet, feature_dim = self._load_backbone(backbone_name) self.backbone = nn.Sequential(*(list(resnet.children())[:-2])) self.conv_proj = nn.Conv2d(feature_dim, d_model, kernel_size=1) self.pos_encoder = PositionalEncoding(d_model, dropout=dropout) self.transformer = nn.Transformer( d_model=d_model, nhead=nhead, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dim_feedforward=dim_feedforward, dropout=dropout, ) self.embedding = nn.Embedding(vocab_size, d_model) self.fc_out = nn.Linear(d_model, vocab_size) def generate_square_subsequent_mask(self, sz: int) -> torch.Tensor: mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) return mask def forward(self, src_img: torch.Tensor, tgt_text_idx: torch.Tensor) -> torch.Tensor: if self.stn is not None: src_img = self.stn(src_img) features = self.backbone(src_img) features = self.conv_proj(features) src = features.flatten(2).permute(2, 0, 1) src = self.pos_encoder(src) tgt = self.embedding(tgt_text_idx).permute(1, 0, 2) tgt = self.pos_encoder(tgt) tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(src.device) tgt_padding_mask = tgt_text_idx == 0 output = self.transformer( src, tgt, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_padding_mask, ) return self.fc_out(output.permute(1, 0, 2)) def _load_backbone(self, backbone_name: str): name = backbone_name.lower() weight_map = { "resnet18": models.ResNet18_Weights.DEFAULT, "resnet34": models.ResNet34_Weights.DEFAULT, "resnet50": models.ResNet50_Weights.DEFAULT, } if not hasattr(models, name): raise ValueError(f"Unsupported backbone '{backbone_name}'") weights = weight_map.get(name) builder = getattr(models, name) try: resnet = builder(weights=weights) except Exception: # pragma: no cover - offline fallback resnet = builder(weights=None) return resnet, resnet.fc.in_features