import ast import base64 import copy import io import json import math import os import re import shutil import sys import threading import time from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from dataclasses import dataclass from datetime import datetime from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import botocore import cv2 import gradio as gr import numpy as np import pandas as pd import pytesseract import requests import spaces from pdfminer.layout import LTChar from PIL import Image, ImageDraw, ImageFont from presidio_analyzer import AnalyzerEngine, RecognizerResult from tools.config import ( AWS_LLM_PII_OPTION, AWS_PII_OPTION, CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE, CLOUD_LLM_PII_MODEL_CHOICE, CONVERT_LINE_TO_WORD_LEVEL, DEFAULT_INFERENCE_SERVER_VLM_MODEL, DEFAULT_LANGUAGE, DEFAULT_NEW_BATCH_CHAR_COUNT, DEFAULT_NEW_BATCH_WORD_COUNT, FULL_COMPREHEND_ENTITY_LIST, HYBRID_OCR_CONFIDENCE_THRESHOLD, HYBRID_OCR_MAX_NEW_TOKENS, HYBRID_OCR_MAX_WORDS, HYBRID_OCR_PADDING, IMAGES_DPI, INFERENCE_SERVER_API_URL, INFERENCE_SERVER_DISABLE_THINKING, INFERENCE_SERVER_LLM_PII_MODEL_CHOICE, INFERENCE_SERVER_MODEL_NAME, INFERENCE_SERVER_PII_OPTION, INFERENCE_SERVER_TIMEOUT, LINE_TO_WORD_SEGMENT_MAX_WORKERS, LLM_MAX_NEW_TOKENS, LLM_TEMPERATURE, LOAD_PADDLE_AT_STARTUP, LOCAL_OCR_MODEL_OPTIONS, LOCAL_OCR_READING_ORDER, LOCAL_PII_OPTION, LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE, LOCAL_TRANSFORMERS_LLM_PII_OPTION, MAX_NEW_TOKENS, MAX_SPACES_GPU_RUN_TIME, MAX_WORKERS, MERGE_BOUNDING_BOXES, OUTPUT_FOLDER, PADDLE_DET_DB_UNCLIP_RATIO, PADDLE_DEVICE, PADDLE_FONT_PATH, PADDLE_MODEL_PATH, PADDLE_PRESERVE_LINE_BOXES, PADDLE_USE_TEXTLINE_ORIENTATION, PREPARE_PAGE_FOR_HYBRID_VLM_BEFORE_PADDLE, PREPROCESS_LOCAL_OCR_IMAGES, REPORT_VLM_OUTPUTS_TO_GUI, SAVE_EXAMPLE_HYBRID_IMAGES, SAVE_PAGE_OCR_VISUALISATIONS, SAVE_PREPROCESS_IMAGES, SAVE_TEXTRACT_BEDROCK_HYBRID_EXAMPLES, SAVE_VLM_INPUT_IMAGES, SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL, SPACES_ZERO_GPU, TESSERACT_SEGMENTATION_LEVEL, TESSERACT_WORD_LEVEL_OCR, USE_LLAMA_SWAP, USE_TRANSFORMERS_VLM_MODEL_AS_LLM, VLM_DEFAULT_STREAM, VLM_HYBRID_MIN_IMAGE_SIZE, VLM_MAX_ASPECT_RATIO, VLM_MAX_DPI, VLM_MAX_IMAGE_SIZE, VLM_MIN_DPI, VLM_MIN_IMAGE_SIZE, ) from tools.helper_functions import ( clean_unicode_text, extract_balanced_json_array, get_system_font_path, model_from_ocr_boxes, strip_vlm_thinking_tags, ) from tools.inference_attention import resolve_paddle_attn_implementation from tools.llm_funcs import _extract_choice_message_text from tools.load_spacy_model_custom_recognisers import custom_entities from tools.ocr_reading_order import build_line_groups from tools.presidio_analyzer_custom import recognizer_result_from_dict from tools.run_vlm import ( extract_text_from_image_vlm, full_page_ocr_people_vlm_prompt, full_page_ocr_signature_vlm_prompt, full_page_ocr_vlm_prompt, model_default_do_sample, model_default_max_new_tokens, model_default_min_p, model_default_presence_penalty, model_default_prompt, model_default_repetition_penalty, model_default_seed, model_default_temperature, model_default_top_k, model_default_top_p, ) from tools.secure_path_utils import validate_folder_containment from tools.secure_regex_utils import safe_sanitize_text from tools.word_segmenter import AdaptiveSegmenter # ---- Tesseract discovery helpers ------------------------------------------------- def _is_probable_tessdata_dir(path: str) -> bool: try: if not path: return False p = os.path.abspath(os.path.expandvars(os.path.expanduser(path))) if not os.path.isdir(p): return False # If eng.traineddata exists, it's definitely a tessdata dir. Otherwise allow # the folder if it contains *any* traineddata (language packs vary). if os.path.isfile(os.path.join(p, "eng.traineddata")): return True for name in os.listdir(p): if name.lower().endswith(".traineddata"): return True except Exception: return False return False def _guess_tessdata_dir_from_tesseract_exe(tesseract_exe: str | None) -> str | None: if not tesseract_exe: return None exe_dir = os.path.dirname(os.path.abspath(tesseract_exe)) # Common Windows installer layout: # C:\Program Files\Tesseract-OCR\tesseract.exe # C:\Program Files\Tesseract-OCR\tessdata\eng.traineddata candidate = os.path.join(exe_dir, "tessdata") if _is_probable_tessdata_dir(candidate): return candidate # Common conda-forge layout on Windows: # \Library\bin\tesseract.exe # \Library\share\tessdata\eng.traineddata up1 = os.path.dirname(exe_dir) up2 = os.path.dirname(up1) conda_candidate = os.path.join(up2, "share", "tessdata") if _is_probable_tessdata_dir(conda_candidate): return conda_candidate return None def _strip_wrapping_quotes(path: str | None) -> str: """ Some .env setups accidentally include quotes in values, e.g.: TESSDATA_PREFIX="tesseract/tessdata" If those quotes end up in the environment variable value (including a stray trailing quote), Tesseract will literally try to open paths like '"..."/eng.traineddata' and fail. """ if not path: return "" s = str(path).strip() if len(s) >= 2 and ((s[0] == s[-1] == '"') or (s[0] == s[-1] == "'")): return s[1:-1].strip() # Also handle common "one-sided" cases, e.g. a stray trailing quote. return s.strip('"').strip("'").strip() def _resolve_tessdata_dir() -> str | None: """ Return an absolute tessdata directory if we can find one. Priority: 1) Existing TESSDATA_PREFIX if it points to a valid tessdata dir 2) tools.config.TESSERACT_DATA_FOLDER if it points to a valid tessdata dir 3) Guess based on tesseract executable location (PATH / pytesseract config) """ env_prefix = _strip_wrapping_quotes(os.environ.get("TESSDATA_PREFIX", "")) if _is_probable_tessdata_dir(env_prefix): return os.path.abspath(env_prefix) try: from tools.config import TESSERACT_DATA_FOLDER cfg_dir = _strip_wrapping_quotes(TESSERACT_DATA_FOLDER) if _is_probable_tessdata_dir(cfg_dir): return os.path.abspath(cfg_dir) except Exception: # config import is optional for library use pass tesseract_exe = getattr( pytesseract.pytesseract, "tesseract_cmd", None ) or shutil.which("tesseract") return _guess_tessdata_dir_from_tesseract_exe(tesseract_exe) def _ensure_tessdata_available_in_env(existing_config: str | None) -> str | None: """ Ensure Tesseract can find language traineddata files by: - setting TESSDATA_PREFIX if we can resolve tessdata dir - adding --tessdata-dir to the tesseract config string when not already present """ tessdata_dir = _resolve_tessdata_dir() if not tessdata_dir: return existing_config # Overwrite (not setdefault) so we can repair misquoted values already present in env. os.environ["TESSDATA_PREFIX"] = tessdata_dir cfg = (existing_config or "").strip() if "--tessdata-dir" in cfg: return cfg # On Windows, pytesseract parses config with shlex(posix=False), which can # preserve quote characters in values and make Tesseract treat them as part # of the path (e.g. '"C:\\...\\tessdata"/eng.traineddata'). Rely on # TESSDATA_PREFIX there instead of injecting --tessdata-dir. if sys.platform == "win32": return cfg return (cfg + f' --tessdata-dir "{tessdata_dir}"').strip() # AWS Comprehend billing: 1 unit = 100 characters (entity recognition, PII, etc.) COMPREHEND_CHARACTERS_PER_UNIT = 100 # Phrase-ending punctuation marks (batch boundaries) PHRASE_ENDING_PUNCTUATION = {".", "!", "?", ";", ":"} # When Bedrock VLM word count differs from Textract by this many or less, we still # accept Bedrock text and derive word-level boxes from the Textract line bbox via # line-to-word segmentation. MAX_WORD_COUNT_DIFF_FOR_LINE_DERIVED_WORDS = 6 def ends_with_phrase_punctuation(word: str) -> bool: """Check if a word ends with phrase-ending punctuation.""" if not word: return False # Check if the word ends with any phrase-ending punctuation return any(word.rstrip().endswith(punct) for punct in PHRASE_ENDING_PUNCTUATION) # --- Language utilities --- def _normalize_lang(language: str) -> str: return language.strip().lower().replace("-", "_") if language else "en" def _tesseract_lang_code(language: str) -> str: """Map a user language input to a Tesseract traineddata code.""" lang = _normalize_lang(language) mapping = { # Common "en": "eng", "eng": "eng", "fr": "fra", "fre": "fra", "fra": "fra", "de": "deu", "ger": "deu", "deu": "deu", "es": "spa", "spa": "spa", "it": "ita", "ita": "ita", "nl": "nld", "dut": "nld", "nld": "nld", "pt": "por", "por": "por", "ru": "rus", "rus": "rus", "ar": "ara", "ara": "ara", # Nordics "sv": "swe", "swe": "swe", "no": "nor", "nb": "nor", "nn": "nor", "nor": "nor", "fi": "fin", "fin": "fin", "da": "dan", "dan": "dan", # Eastern/Central "pl": "pol", "pol": "pol", "cs": "ces", "cz": "ces", "ces": "ces", "hu": "hun", "hun": "hun", "ro": "ron", "rum": "ron", "ron": "ron", "bg": "bul", "bul": "bul", "el": "ell", "gre": "ell", "ell": "ell", # Asian "ja": "jpn", "jp": "jpn", "jpn": "jpn", "zh": "chi_sim", "zh_cn": "chi_sim", "zh_hans": "chi_sim", "chi_sim": "chi_sim", "zh_tw": "chi_tra", "zh_hk": "chi_tra", "zh_tr": "chi_tra", "chi_tra": "chi_tra", "hi": "hin", "hin": "hin", "bn": "ben", "ben": "ben", "ur": "urd", "urd": "urd", "fa": "fas", "per": "fas", "fas": "fas", } return mapping.get(lang, "eng") def _paddle_lang_code(language: str) -> str: """Map a user language input to a PaddleOCR language code. PaddleOCR supports codes like: 'en', 'ch', 'chinese_cht', 'korean', 'japan', 'german', 'fr', 'it', 'es', as well as script packs like 'arabic', 'cyrillic', 'latin'. """ lang = _normalize_lang(language) mapping = { "en": "en", "fr": "fr", "de": "german", "es": "es", "it": "it", "pt": "pt", "nl": "nl", "ru": "cyrillic", # Russian is covered by cyrillic models "uk": "cyrillic", "bg": "cyrillic", "sr": "cyrillic", "ar": "arabic", "tr": "tr", "fa": "arabic", # fallback to arabic script pack "zh": "ch", "zh_cn": "ch", "zh_tw": "chinese_cht", "zh_hk": "chinese_cht", "ja": "japan", "jp": "japan", "ko": "korean", "hi": "latin", # fallback; dedicated Hindi not always available } return mapping.get(lang, "en") _module_paddle_ocr = None _module_paddle_ocr_lock = threading.Lock() _paddleocr_class = None _module_paddle_kwargs_template: Optional[Dict[str, Any]] = None def _configure_paddle_ocr_environment() -> None: """Set PaddleOCR env vars before import (fonts, model dir).""" if PADDLE_MODEL_PATH and PADDLE_MODEL_PATH.strip(): os.environ["PADDLEOCR_MODEL_DIR"] = PADDLE_MODEL_PATH print(f"Setting PaddleOCR model path to: {PADDLE_MODEL_PATH}") else: print("Using default PaddleOCR model storage location") if ( PADDLE_FONT_PATH and PADDLE_FONT_PATH.strip() and os.path.exists(PADDLE_FONT_PATH) ): os.environ["PADDLE_PDX_LOCAL_FONT_FILE_PATH"] = PADDLE_FONT_PATH print(f"Setting PaddleOCR font path to configured font: {PADDLE_FONT_PATH}") else: system_font_path = get_system_font_path() if system_font_path: os.environ["PADDLE_PDX_LOCAL_FONT_FILE_PATH"] = system_font_path print(f"Setting PaddleOCR font path to system font: {system_font_path}") else: print( "Warning: No suitable system font found. PaddleOCR may download default fonts." ) def _default_paddle_device() -> str: """Resolve PaddleOCR device for the transformers inference backend.""" if PADDLE_DEVICE: return PADDLE_DEVICE try: import torch if torch.cuda.is_available(): return "gpu:0" except ImportError: pass return "cpu" def _default_paddle_engine_config(device: Optional[str] = None) -> Dict[str, Any]: """engine_config for PaddleOCR transformers backend (see PaddleOCR inference-engine docs).""" device = device or _default_paddle_device() use_gpu = device.startswith("gpu") config: Dict[str, Any] = {"dtype": "float32"} if use_gpu: device_id = 0 if ":" in device: try: device_id = int(device.split(":", 1)[1]) except ValueError: device_id = 0 config.update( { "device_type": "gpu", "device_id": device_id, "attn_implementation": resolve_paddle_attn_implementation(), } ) else: config["device_type"] = "cpu" return config def _default_paddle_kwargs(lang: Optional[str] = None) -> Dict[str, Any]: if lang is None: lang = _paddle_lang_code(DEFAULT_LANGUAGE) device = _default_paddle_device() return { "text_detection_model_name": "PP-OCRv6_medium_det", "text_recognition_model_name": "PP-OCRv6_medium_rec", "engine": "transformers", "device": device, "engine_config": _default_paddle_engine_config(device), "det_db_unclip_ratio": PADDLE_DET_DB_UNCLIP_RATIO, "use_textline_orientation": PADDLE_USE_TEXTLINE_ORIENTATION, "use_doc_orientation_classify": False, "use_doc_unwarping": False, "lang": lang, } def _finalize_paddle_kwargs(paddle_kwargs: Dict[str, Any]) -> Dict[str, Any]: """Merge caller overrides with GPU-aware defaults.""" kwargs = dict(paddle_kwargs) defaults = _default_paddle_kwargs(kwargs.get("lang")) kwargs.setdefault("engine", defaults["engine"]) kwargs.setdefault("device", defaults["device"]) default_engine_config = defaults["engine_config"] if "engine_config" not in kwargs: kwargs["engine_config"] = dict(default_engine_config) else: kwargs["engine_config"] = { **default_engine_config, **kwargs["engine_config"], } # PP-OCRv6 transformers models only support eager (not sdpa / flash_attention_2). kwargs["engine_config"][ "attn_implementation" ] = resolve_paddle_attn_implementation() return kwargs def _log_paddle_runtime_diagnostics(device: str) -> None: """Log torch/CUDA state when PaddleOCR is initialized (helps debug ZeroGPU).""" try: import torch except ImportError: print( f"PaddleOCR init: device={device!r}, torch not installed " "(required for engine=transformers)" ) return cuda_available = torch.cuda.is_available() parts = [ f"PaddleOCR init: device={device!r}", "engine=transformers", f"attn_implementation={resolve_paddle_attn_implementation()!r}", f"torch={torch.__version__}", f"cuda_available={cuda_available}", f"spaces_zero_gpu={SPACES_ZERO_GPU}", ] if cuda_available: try: parts.append( f"cuda_device={torch.cuda.get_device_name(torch.cuda.current_device())}" ) except Exception: pass print(", ".join(parts)) def _import_paddleocr_class(): global _paddleocr_class if _paddleocr_class is not None: return _paddleocr_class _configure_paddle_ocr_environment() try: from paddleocr import PaddleOCR as paddleocr_cls print("PaddleOCR imported successfully") except Exception as e: raise ImportError( f"Error importing PaddleOCR: {e}. Please install it using " "'pip install paddleocr paddlepaddle' in your python environment and retry." ) from e _paddleocr_class = paddleocr_cls return _paddleocr_class def register_module_paddle_kwargs(paddle_kwargs: Dict[str, Any]) -> None: """Store Paddle kwargs for lazy init (ZeroGPU: init runs in @spaces.GPU worker).""" global _module_paddle_kwargs_template _module_paddle_kwargs_template = _finalize_paddle_kwargs(paddle_kwargs) def get_or_create_module_paddle_ocr( paddle_kwargs: Optional[Dict[str, Any]] = None, ) -> Any: """Return a module-scoped PaddleOCR instance (required for ZeroGPU pickling).""" global _module_paddle_ocr if _module_paddle_ocr is not None: return _module_paddle_ocr with _module_paddle_ocr_lock: if _module_paddle_ocr is not None: return _module_paddle_ocr PaddleOCR = _import_paddleocr_class() if paddle_kwargs is not None: kwargs = dict(paddle_kwargs) kwargs.setdefault("lang", _paddle_lang_code(DEFAULT_LANGUAGE)) kwargs = _finalize_paddle_kwargs(kwargs) elif _module_paddle_kwargs_template is not None: kwargs = dict(_module_paddle_kwargs_template) else: kwargs = _default_paddle_kwargs() _log_paddle_runtime_diagnostics(str(kwargs.get("device", ""))) try: _module_paddle_ocr = PaddleOCR(**kwargs) except Exception as e: if ( "WinError 127" in str(e) or "could not be found" in str(e).lower() or "dll" in str(e).lower() ): print( f"Warning: GPU initialization failed (likely missing CUDA/cuDNN dependencies): {e}" ) print("PaddleOCR will not be available. To fix GPU issues:") print("1. Install Visual C++ Redistributables (latest version)") print("2. Ensure CUDA runtime libraries are in your PATH") print( "3. Or reinstall paddlepaddle CPU version: pip install paddlepaddle" ) raise ImportError( f"Error initializing PaddleOCR: {e}. Please install it using " "'pip install paddleocr paddlepaddle' in your python environment and retry." ) from e print("Module PaddleOCR instance initialized") return _module_paddle_ocr def _paddle_result_to_plain_dict(result: Any) -> Dict[str, Any]: """Convert a PaddleOCR result object to a plain, pickle-safe dictionary.""" if isinstance(result, dict): plain_dict = dict(result) else: plain_dict = {} for key in ["rec_texts", "rec_scores", "rec_polys", "rec_models"]: if key in result or hasattr(result, key): value = ( result.get(key, []) if hasattr(result, "get") else getattr(result, key, []) ) if value is not None: plain_dict[key] = ( list(value) if hasattr(value, "__iter__") and not isinstance(value, str) else value ) for key in ["image_width", "image_height"]: if key in result or hasattr(result, key): value = ( result.get(key) if hasattr(result, "get") else getattr(result, key, None) ) if value is not None: plain_dict[key] = value for key in ["rec_texts", "rec_scores", "rec_polys", "rec_models"]: if key in plain_dict and plain_dict[key] is not None: value = plain_dict[key] plain_dict[key] = ( list(value) if hasattr(value, "__iter__") and not isinstance(value, str) else value ) return plain_dict def _paddle_results_to_plain_dicts(results: Any) -> List[Dict[str, Any]]: if not results: return results return [_paddle_result_to_plain_dict(result) for result in results] @spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME) def paddle_predict(image: Union[np.ndarray, str]) -> List[Dict[str, Any]]: """ Run Paddle OCR on a numpy image or file path. Uses a module-scoped PaddleOCR instance so ZeroGPU only pickles the image input (not the live transformers-backed model with forward hooks). """ ocr = get_or_create_module_paddle_ocr() paddle_results = ocr.predict(image) return _paddle_results_to_plain_dicts(paddle_results) # ZeroGPU: the main web process has no real GPU. PaddleOCR must initialize inside # paddle_predict (@spaces.GPU worker) so transformers weights land on CUDA. if LOAD_PADDLE_AT_STARTUP and not SPACES_ZERO_GPU: try: get_or_create_module_paddle_ocr() except Exception as e: print(f"Warning: LOAD_PADDLE_AT_STARTUP failed: {e}") @dataclass class OCRResult: text: str left: int top: int width: int height: int conf: float = None line: int = None model: str = ( None # Track which OCR model was used (e.g., "Tesseract", "Paddle", "VLM") ) @dataclass class CustomImageRecognizerResult: entity_type: str start: int end: int score: float left: int top: int width: int height: int text: str color: tuple = (0, 0, 0) class ImagePreprocessor: """ImagePreprocessor class. Parent class for image preprocessing objects.""" def __init__(self, use_greyscale: bool = True) -> None: self.use_greyscale = use_greyscale def preprocess_image(self, image: Image.Image) -> Tuple[Image.Image, dict]: return image, {} def convert_image_to_array(self, image: Image.Image) -> np.ndarray: if isinstance(image, np.ndarray): img = image else: if self.use_greyscale: image = image.convert("L") img = np.asarray(image) return img @staticmethod def _get_bg_color( image: np.ndarray, is_greyscale: bool, invert: bool = False ) -> Union[int, Tuple[int, int, int]]: # Note: Modified to expect numpy array for bincount if invert: image = 255 - image # Simple inversion for greyscale numpy array if is_greyscale: bg_color = int(np.bincount(image.flatten()).argmax()) else: # This part would need more complex logic for color numpy arrays # For this pipeline, we only use greyscale, so it's fine. # A simple alternative: from scipy import stats bg_color = tuple(stats.mode(image.reshape(-1, 3), axis=0)[0][0]) return bg_color @staticmethod def _get_image_contrast(image: np.ndarray) -> Tuple[float, float]: contrast = np.std(image) mean_intensity = np.mean(image) return contrast, mean_intensity class BilateralFilter(ImagePreprocessor): """Applies bilateral filtering.""" def __init__( self, diameter: int = 9, sigma_color: int = 75, sigma_space: int = 75 ) -> None: super().__init__(use_greyscale=True) self.diameter = diameter self.sigma_color = sigma_color self.sigma_space = sigma_space def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, dict]: # Modified to accept and return numpy array for consistency in the pipeline filtered_image = cv2.bilateralFilter( image, self.diameter, self.sigma_color, self.sigma_space ) metadata = { "diameter": self.diameter, "sigma_color": self.sigma_color, "sigma_space": self.sigma_space, } return filtered_image, metadata class SegmentedAdaptiveThreshold(ImagePreprocessor): """Applies adaptive thresholding.""" def __init__( self, block_size: int = 21, contrast_threshold: int = 40, c_low_contrast: int = 5, c_high_contrast: int = 10, bg_threshold: int = 127, ) -> None: super().__init__(use_greyscale=True) self.block_size = ( block_size if block_size % 2 == 1 else block_size + 1 ) # Ensure odd self.c_low_contrast = c_low_contrast self.c_high_contrast = c_high_contrast self.bg_threshold = bg_threshold self.contrast_threshold = contrast_threshold def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, dict]: # Modified to accept and return numpy array background_color = self._get_bg_color(image, True) contrast, _ = self._get_image_contrast(image) c = ( self.c_low_contrast if contrast <= self.contrast_threshold else self.c_high_contrast ) if background_color < self.bg_threshold: # Dark background, light text adaptive_threshold_image = cv2.adaptiveThreshold( image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, self.block_size, -c, ) else: # Light background, dark text adaptive_threshold_image = cv2.adaptiveThreshold( image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, self.block_size, c, ) metadata = {"C": c, "background_color": background_color, "contrast": contrast} return adaptive_threshold_image, metadata class ImageRescaling(ImagePreprocessor): """Rescales images based on their size.""" def __init__(self, target_dpi: int = 300, assumed_input_dpi: int = 96) -> None: super().__init__(use_greyscale=True) self.target_dpi = target_dpi self.assumed_input_dpi = assumed_input_dpi def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, dict]: # Modified to accept and return numpy array scale_factor = self.target_dpi / self.assumed_input_dpi metadata = {"scale_factor": 1.0} if scale_factor != 1.0: width = int(image.shape[1] * scale_factor) height = int(image.shape[0] * scale_factor) dimensions = (width, height) # Use better interpolation for upscaling vs downscaling interpolation = cv2.INTER_CUBIC if scale_factor > 1.0 else cv2.INTER_AREA rescaled_image = cv2.resize(image, dimensions, interpolation=interpolation) metadata["scale_factor"] = scale_factor return rescaled_image, metadata return image, metadata class ContrastSegmentedImageEnhancer(ImagePreprocessor): """Class containing all logic to perform contrastive segmentation.""" def __init__( self, bilateral_filter: Optional[BilateralFilter] = None, adaptive_threshold: Optional[SegmentedAdaptiveThreshold] = None, image_rescaling: Optional[ImageRescaling] = None, low_contrast_threshold: int = 40, ) -> None: super().__init__(use_greyscale=True) self.bilateral_filter = bilateral_filter or BilateralFilter() self.adaptive_threshold = adaptive_threshold or SegmentedAdaptiveThreshold() self.image_rescaling = image_rescaling or ImageRescaling() self.low_contrast_threshold = low_contrast_threshold def _improve_contrast(self, image: np.ndarray) -> Tuple[np.ndarray, str, str]: contrast, mean_intensity = self._get_image_contrast(image) if contrast <= self.low_contrast_threshold: # Using CLAHE as a generally more robust alternative clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) adjusted_image = clahe.apply(image) adjusted_contrast, _ = self._get_image_contrast(adjusted_image) else: adjusted_image = image adjusted_contrast = contrast return adjusted_image, contrast, adjusted_contrast def _deskew(self, image_np: np.ndarray) -> np.ndarray: """ Corrects the skew of an image. This method works best on a grayscaled image. """ # We'll work with a copy for angle detection gray = ( cv2.cvtColor(image_np, cv2.COLOR_BGR2GRAY) if len(image_np.shape) == 3 else image_np.copy() ) # Invert the image for contour finding thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] coords = np.column_stack(np.where(thresh > 0)) angle = cv2.minAreaRect(coords)[-1] # Adjust the angle for rotation if angle < -45: angle = -(90 + angle) else: angle = -angle # Don't rotate if the angle is negligible if abs(angle) < 0.1: return image_np h, w = image_np.shape[:2] center = (w // 2, h // 2) M = cv2.getRotationMatrix2D(center, angle, 1.0) # Use the original numpy image for the rotation to preserve quality rotated = cv2.warpAffine( image_np, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE ) return rotated def preprocess_image( self, image: Image.Image, perform_deskew: bool = False, perform_binarization: bool = False, ) -> Tuple[Image.Image, dict]: """ A pipeline for OCR preprocessing. Order: Deskew -> Greyscale -> Rescale -> Denoise -> Enhance Contrast -> Binarize """ # 1. Convert PIL image to NumPy array for OpenCV processing # Assuming the original image is RGB image_np = np.array(image.convert("RGB")) # OpenCV uses BGR, so we convert RGB to BGR image_np_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) # --- REVISED PIPELINE --- # 2. Deskew the image (critical new step) # This is best done early on the full-quality image. if perform_deskew: deskewed_image_np = self._deskew(image_np_bgr) else: deskewed_image_np = image_np_bgr # 3. Convert to greyscale # Your convert_image_to_array probably does this, but for clarity: gray_image_np = cv2.cvtColor(deskewed_image_np, cv2.COLOR_BGR2GRAY) # 4. Rescale image to optimal DPI # Assuming your image_rescaling object can handle a greyscale numpy array rescaled_image_np, scale_metadata = self.image_rescaling.preprocess_image( gray_image_np ) # 5. Apply filtering for noise reduction # Suggestion: A Median filter is often very effective for scanned docs # filtered_image_np = cv2.medianBlur(rescaled_image_np, 3) # Or using your existing bilateral filter: filtered_image_np, _ = self.bilateral_filter.preprocess_image(rescaled_image_np) # 6. Improve contrast adjusted_image_np, _, _ = self._improve_contrast(filtered_image_np) # 7. Adaptive Thresholding (Binarization) - Final optional step if perform_binarization: final_image_np, threshold_metadata = ( self.adaptive_threshold.preprocess_image(adjusted_image_np) ) else: final_image_np = adjusted_image_np threshold_metadata = {} # Combine metadata final_metadata = {**scale_metadata, **threshold_metadata} # Convert final numpy array back to PIL Image for return # The final image is greyscale, so it's safe to use 'L' mode return Image.fromarray(final_image_np).convert("L"), final_metadata def rescale_ocr_data(ocr_data, scale_factor: float): # We loop from 0 to the number of detected words. num_boxes = len(ocr_data["text"]) for i in range(num_boxes): # We only want to process actual words, not empty boxes Tesseract might find if int(ocr_data["conf"][i]) > -1: # -1 confidence is for structural elements # Get coordinates from the processed image using the index 'i' x_proc = ocr_data["left"][i] y_proc = ocr_data["top"][i] w_proc = ocr_data["width"][i] h_proc = ocr_data["height"][i] # Apply the inverse transformation (division) x_orig = int(x_proc / scale_factor) y_orig = int(y_proc / scale_factor) w_orig = int(w_proc / scale_factor) h_orig = int(h_proc / scale_factor) # --- THE MAPPING STEP --- # Update the dictionary values in-place using the same index 'i' ocr_data["left"][i] = x_orig ocr_data["top"][i] = y_orig ocr_data["width"][i] = w_orig ocr_data["height"][i] = h_orig return ocr_data def filter_entities_for_language( entities: List[str], valid_language_entities: List[str], language: str ) -> List[str]: if not valid_language_entities: print(f"No valid entities supported for language: {language}") # raise Warning(f"No valid entities supported for language: {language}") if not entities: print(f"No entities provided for language: {language}") # raise Warning(f"No entities provided for language: {language}") filtered_entities = [ entity for entity in entities if entity in valid_language_entities ] if not filtered_entities: print(f"No relevant entities supported for language: {language}") # raise Warning(f"No relevant entities supported for language: {language}") if language != "en": gr.Info( f"Using {str(filtered_entities)} entities for local model analysis for language: {language}" ) return filtered_entities def _get_tesseract_psm(segmentation_level: str) -> int: """ Get the appropriate Tesseract PSM (Page Segmentation Mode) value based on segmentation level. Args: segmentation_level: "word" or "line" Returns: PSM value for Tesseract configuration """ if segmentation_level.lower() == "line": return 6 # Uniform block of text elif segmentation_level.lower() == "word": return 11 # Sparse text (word-level) else: print( f"Warning: Unknown segmentation level '{segmentation_level}', defaulting to word-level (PSM 11)" ) return 11 def _extract_page_number_for_vlm_log(image_name: Optional[str]) -> Optional[int]: """ Best-effort 0-based page index from an image basename for VLM log filenames. Tries end-anchored patterns first, then the last ``_page_`` occurrence anywhere in the name (e.g. ``doc.pdf_0_page_0000.png``). """ if not image_name: return None end_patterns = ( r"_page_(\d+)\.(?:png|jpg|jpeg|webp|tif|tiff)$", r"_page_(\d+)\.png$", r"_(\d+)\.png$", r"page_(\d+)\.png$", ) for pattern in end_patterns: match = re.search(pattern, image_name, re.IGNORECASE) if match: return int(match.group(1)) last_page: Optional[int] = None for match in re.finditer(r"(?i)_page_(\d+)", image_name): last_page = int(match.group(1)) return last_page def save_vlm_prompt_response( prompt: str, response_text: str, output_folder: str, model_choice: str, image_name: Optional[str] = None, page_number: Optional[int] = None, temperature: Optional[float] = None, max_new_tokens: Optional[int] = None, top_p: Optional[float] = None, model_type: str = "VLM", task_suffix: Optional[str] = None, input_tokens: Optional[int] = None, output_tokens: Optional[int] = None, image_width: Optional[int] = None, image_height: Optional[int] = None, ) -> str: """ Save VLM prompt and response to a text file for traceability. Args: prompt: Prompt sent to VLM response_text: Response text from VLM output_folder: Output folder path model_choice: Model used image_name: Optional image name (without extension) for the filename page_number: Optional 0-based page index for the filename (overrides parsing ``image_name``). Displayed in the log body as 1-based when set or parsed. temperature: Temperature used (if applicable) max_new_tokens: Max tokens used (if applicable) top_p: Top-p parameter used (if applicable) model_type: Type of model (e.g., "VLM", "Bedrock", "Inference Server", "Gemini", "Azure/OpenAI") task_suffix: Optional suffix to add to filename (e.g., "_person", "_sig") to distinguish task types input_tokens: Input token count (API usage where available; local/estimated for Transformers) output_tokens: Output token count (same) image_width: Pixel width of the image sent to the VLM (after any resize/pad in the pipeline) image_height: Pixel height of the image sent to the VLM Returns: Path to the saved file """ # Create VLM logs subfolder vlm_logs_folder = os.path.join(output_folder, "vlm_prompts_responses") os.makedirs(vlm_logs_folder, exist_ok=True) # Add task suffix to filename if provided suffix_str = f"_{task_suffix}" if task_suffix else "" effective_page: Optional[int] = page_number if effective_page is None and image_name: effective_page = _extract_page_number_for_vlm_log(image_name) # Filenames always include a page segment: _page_NNNN or _page_unknown if image_name: safe_image_name = "".join( c for c in image_name if c.isalnum() or c in (" ", "-", "_", ".") ).strip() safe_image_name = safe_image_name.replace(" ", "_") safe_image_name = safe_image_name.rsplit(".", 1)[0] if isinstance(effective_page, int): page_part = f"_page_{effective_page:04d}" else: page_part = "_page_unknown" filename = f"vlm_{safe_image_name}{page_part}{suffix_str}_{model_type.lower().replace(' ', '_')}.txt" else: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if isinstance(effective_page, int): page_part = f"_page_{effective_page:04d}" else: page_part = "_page_unknown" filename = f"vlm_{model_type.lower().replace(' ', '_')}{page_part}{suffix_str}_{timestamp}.txt" filepath = os.path.join(vlm_logs_folder, filename) # Write prompt and response to file with open(filepath, "w", encoding="utf-8") as f: f.write("=" * 80 + "\n") f.write("VLM OCR - PROMPT AND RESPONSE LOG\n") f.write("=" * 80 + "\n\n") f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") if image_name: f.write(f"Image: {image_name}\n") if isinstance(effective_page, int): f.write(f"Page: {effective_page + 1}\n") if image_width is not None and image_height is not None: f.write( f"Image input size (pixels): {image_width} x {image_height} (width x height)\n" ) elif image_width is not None or image_height is not None: f.write( f"Image input size (pixels): width={image_width}, height={image_height}\n" ) f.write( f"Input tokens: {input_tokens if input_tokens is not None else '(not reported)'}\n" ) f.write( f"Output tokens: {output_tokens if output_tokens is not None else '(not reported)'}\n" ) f.write(f"Model: {model_choice}\n") f.write(f"Model Type: {model_type}\n") if temperature is not None: f.write(f"Temperature: {temperature}\n") if max_new_tokens is not None: f.write(f"Max New Tokens: {max_new_tokens}\n") if top_p is not None: f.write(f"Top-p: {top_p}\n") f.write("\n" + "=" * 80 + "\n") f.write("PROMPT\n") f.write("=" * 80 + "\n\n") f.write(prompt) f.write("\n\n" + "=" * 80 + "\n") f.write("VLM RESPONSE\n") f.write("=" * 80 + "\n\n") f.write(response_text) f.write("\n\n" + "=" * 80 + "\n") f.write("END OF LOG\n") f.write("=" * 80 + "\n") return filepath def _exif_resolution_to_float(value: Any) -> Optional[float]: """Convert EXIF/TIFF resolution (RATIONAL) to float.""" if value is None: return None try: if hasattr(value, "numerator") and hasattr(value, "denominator"): d = float(value.denominator) if d == 0: return None return float(value.numerator) / d if isinstance(value, tuple) and len(value) == 2: d = float(value[1]) if d == 0: return None return float(value[0]) / d return float(value) except (TypeError, ValueError): return None def _best_effort_pil_dpi(image: Image.Image, fallback: float = 72.0) -> float: """ Best-effort DPI from PIL metadata before defaulting to ``fallback``. Tries, in order: ``info['dpi']``, JPEG JFIF density/unit, EXIF X/Y resolution, TIFF ``tag_v2`` resolution tags. """ def _positive_max_pair(a: Any, b: Any) -> Optional[float]: try: x = float(a) if a is not None else 0.0 y = float(b) if b is not None else 0.0 m = max(x, y) return m if m > 0 else None except (TypeError, ValueError): return None dpi_raw = image.info.get("dpi") if dpi_raw is not None: if isinstance(dpi_raw, tuple) and len(dpi_raw) >= 2: v = _positive_max_pair(dpi_raw[0], dpi_raw[1]) if v is not None: return v else: try: f = float(dpi_raw) if f > 0: return f except (TypeError, ValueError): pass jfif_unit = image.info.get("jfif_unit") jden = image.info.get("jfif_density") if jden is not None and isinstance(jden, (tuple, list)) and len(jden) >= 2: v = _positive_max_pair(jden[0], jden[1]) if v is not None: if jfif_unit == 1: return v if jfif_unit == 2: return v * 2.54 try: exif = image.getexif() if exif: xres = _exif_resolution_to_float(exif.get(282)) yres = _exif_resolution_to_float(exif.get(283)) if xres is not None and yres is not None and xres > 0 and yres > 0: d = max(xres, yres) unit = exif.get(296, 2) if unit == 2 or unit is None: return d if unit == 3: return d * 2.54 except Exception: pass try: tv = getattr(image, "tag_v2", None) if tv is not None: xres = _exif_resolution_to_float(tv.get(282)) yres = _exif_resolution_to_float(tv.get(283)) if xres is not None and yres is not None and xres > 0 and yres > 0: d = max(xres, yres) unit = tv.get(296) if unit == 2 or unit is None: return d if unit == 3: return d * 2.54 except Exception: pass # No DPI metadata found (or parse failed) across known PIL/JFIF/EXIF/TIFF fields. # Make this explicit in logs so "reported DPI" isn't mistaken for a detected value. try: w, h = image.size except Exception: w, h = None, None try: img_format = getattr(image, "format", None) except Exception: img_format = None print( "VLM image preparation: DPI metadata not found; " f"using fallback {fallback:.1f} DPI" + ( f" ({w}x{h}{', ' + str(img_format) if img_format else ''})" if w and h else "" ) ) return fallback def _save_image_with_config_dpi(image: Image.Image, path: str, **kwargs: Any) -> None: """ Write a PIL image to disk with DPI metadata from ``IMAGES_DPI`` (PNG pHYs, JPEG JFIF). Extra kwargs are forwarded to ``Image.save`` (e.g. format=, optimize=). """ _d = max(1, int(round(float(IMAGES_DPI)))) kw = dict(kwargs) kw.pop("dpi", None) image.save(path, dpi=(_d, _d), **kw) def _prepare_image_for_vlm( image: Image.Image, ocr_method: Optional[str] = None, max_image_size: Optional[int] = VLM_MAX_IMAGE_SIZE, hybrid_vlm: bool = False, ) -> Image.Image: """ Prepare image for VLM: enforce pixel count and reported DPI bounds. Scaling by factor ``s`` updates effective DPI as ``reported_dpi * s`` (same physical document size). Chooses ``s`` so that: - ``VLM_MIN_IMAGE_SIZE`` (full page) or ``VLM_HYBRID_MIN_IMAGE_SIZE`` (hybrid crops) <= width*height*s^2 <= ``max_image_size`` - ``VLM_MIN_DPI`` <= reported_dpi * s <= ``VLM_MAX_DPI`` If constraints conflict, caps (max pixels / max DPI) take precedence and a warning is printed. Args: image: PIL Image to prepare ocr_method: If it contains ``bedrock`` (case-insensitive), max pixel budget is raised to 33554432 for Bedrock VLM. max_image_size: Upper bound on total pixels (default ``VLM_MAX_IMAGE_SIZE``). hybrid_vlm: If True, use ``VLM_HYBRID_MIN_IMAGE_SIZE`` as minimum pixels; otherwise ``VLM_MIN_IMAGE_SIZE`` (whole-page VLM). Returns: Resized RGB-safe image when needed; DPI metadata updated after resize. """ if image is None: return image if ocr_method and "bedrock" in ocr_method.lower(): max_image_size = min( 33554432, VLM_MAX_IMAGE_SIZE ) # Bedrock has a specific max pixel budget - it will fail if exceeded if max_image_size is None or max_image_size <= 0: max_image_size = VLM_MAX_IMAGE_SIZE min_image_size = VLM_HYBRID_MIN_IMAGE_SIZE if hybrid_vlm else VLM_MIN_IMAGE_SIZE min_image_size = max(0, int(min_image_size)) dpi_lo = min(VLM_MIN_DPI, VLM_MAX_DPI) dpi_hi = max(VLM_MIN_DPI, VLM_MAX_DPI) width, height = image.size area = float(width * height) if area <= 0: return image current_dpi = _best_effort_pil_dpi(image, fallback=float(IMAGES_DPI)) if current_dpi <= 0: current_dpi = float(IMAGES_DPI) # Effective DPI after uniform scale s is current_dpi * s. s_min_dpi = dpi_lo / current_dpi s_max_dpi = dpi_hi / current_dpi s_min_px = math.sqrt(min_image_size / area) if min_image_size > 0 else 0.0 s_max_px = math.sqrt(max_image_size / area) if max_image_size > 0 else float("inf") s_lo = max(s_min_dpi, s_min_px) s_hi = min(s_max_dpi, s_max_px) if s_lo > s_hi: print( f"VLM image preparation warning: constraints conflict " f"(DPI {dpi_lo:.1f}-{dpi_hi:.1f}, pixels {min_image_size}-{max_image_size}, " f"reported DPI {current_dpi:.1f}, {width}x{height}). " f"Using scale {s_hi:.4f} (capping size/DPI)." ) s = s_hi else: s = 1.0 if s < s_lo: s = s_lo elif s > s_hi: s = s_hi if abs(s - 1.0) < 1e-6: return image new_w = max(1, int(round(width * s))) new_h = max(1, int(round(height * s))) new_pixels = new_w * new_h achieved_dpi = current_dpi * s metadata_dpi = min(dpi_hi, max(dpi_lo, achieved_dpi)) if abs(s - 1.0) > 0.02: print( f"VLM image preparation: {width}x{height} ({int(area):,} px, ~{current_dpi:.1f} DPI) " f"-> {new_w}x{new_h} (achieved ~{achieved_dpi:.1f} DPI, {new_pixels:,} px, config requirement ~{metadata_dpi:.1f} DPI), scale {s:.4f}" ) resample = Image.Resampling.LANCZOS if s < 1.0 else Image.Resampling.BICUBIC image = image.resize((new_w, new_h), resample) try: image.info["dpi"] = (metadata_dpi, metadata_dpi) except Exception: pass return image def _pad_image_for_vlm_aspect_ratio( image: Image.Image, max_aspect: Optional[float] = None, ) -> Image.Image: """ Pad image so aspect ratio max(w/h, h/w) <= max_aspect (default: ``VLM_MAX_ASPECT_RATIO``). Used for Bedrock, inference-server, Gemini, Azure/OpenAI, and local transformers VLM to avoid API or model issues on very long/thin hybrid crops. Returns RGB image. """ if max_aspect is None: max_aspect = VLM_MAX_ASPECT_RATIO if image is None: return image try: w, h = image.size if w < 1 or h < 1: return image.convert("RGB") if image.mode != "RGB" else image current = max(w / float(h), h / float(w)) if current <= max_aspect: return image.convert("RGB") if image.mode != "RGB" else image img = image.convert("RGB") if image.mode != "RGB" else image if w >= h: new_h = max(int(math.ceil(w / max_aspect)), 1) if new_h != h: print( f"VLM aspect ratio padding: width >= height. Original size: {w}x{h}. New size: {w}x{new_h}. Applied extra padding to height to achieve aspect ratio <= {max_aspect}." ) new_w, new_h = w, new_h else: new_w = max(int(math.ceil(h / max_aspect)), 1) if new_w != w: print( f"VLM aspect ratio padding: height > width. Original size: {w}x{h}. New size: {new_w}x{h}. Applied extra padding to width to achieve aspect ratio <= {max_aspect}." ) new_w, new_h = new_w, h canvas = Image.new("RGB", (new_w, new_h), color=(255, 255, 255)) scale = min(new_w / float(w), new_h / float(h)) if scale < 1.0: rw = max(1, int(round(w * scale))) rh = max(1, int(round(h * scale))) paste_img = img.resize((rw, rh), Image.Resampling.LANCZOS) else: paste_img = img rw, rh = w, h ox = max((new_w - rw) // 2, 0) oy = max((new_h - rh) // 2, 0) canvas.paste(paste_img, (ox, oy)) return canvas except Exception: return image.convert("RGB") if image.mode != "RGB" else image def _prepare_hybrid_line_crop_for_vlm(image: Image.Image) -> Image.Image: """ Resize/DPI-budget and aspect-pad a line crop for hybrid local VLM or hybrid inference-server. Matches the image pipeline in ``_vlm_ocr_predict`` immediately before ``extract_text_from_image_vlm``. Caller should supply an RGB image (e.g. after line crop). """ image = _prepare_image_for_vlm(image, hybrid_vlm=True) try: image = _pad_image_for_vlm_aspect_ratio(image) except Exception: pass return image def _call_inference_server_vlm_api( image: Image.Image, prompt: str, api_url: str = None, model_name: str = None, max_new_tokens: int = None, temperature: float = None, top_p: float = None, top_k: int = None, repetition_penalty: float = None, timeout: int = None, stream: bool = VLM_DEFAULT_STREAM, seed: int = None, do_sample: bool = None, min_p: float = None, presence_penalty: float = None, use_llama_swap: bool = USE_LLAMA_SWAP, disable_thinking: bool = INFERENCE_SERVER_DISABLE_THINKING, apply_aspect_ratio_padding: bool = True, ) -> Tuple[str, int, int, int, int]: """ Calls a inference-server API endpoint with an image and text prompt. This function converts a PIL Image to base64 and sends it to the inference-server API endpoint using the OpenAI-compatible chat completions format. Args: image: PIL Image to process prompt: Text prompt for the VLM api_url: Base URL of the inference-server API (defaults to INFERENCE_SERVER_API_URL from config) model_name: Optional model name to use (defaults to INFERENCE_SERVER_MODEL_NAME from config) max_new_tokens: Maximum number of tokens to generate temperature: Sampling temperature top_p: Nucleus sampling parameter top_k: Top-k sampling parameter repetition_penalty: Penalty for token repetition timeout: Request timeout in seconds (defaults to INFERENCE_SERVER_TIMEOUT from config) stream: Whether to stream the response seed: Random seed for generation do_sample: If True, use sampling (do_sample=True). If False, use greedy decoding (do_sample=False). min_p: Minimum probability threshold for token sampling. presence_penalty: Penalty for token presence. use_llama_swap: Whether to use llama-swap for the model (defaults to USE_LLAMA_SWAP from config). If True and model_name is provided, the model name will be included in the payload. disable_thinking: When True, adds chat_template_kwargs={"enable_thinking": False} to the request payload. This is the vLLM-native equivalent of appending in the local transformers path (VLM_DISABLE_QWEN3_5_THINKING). Defaults to INFERENCE_SERVER_DISABLE_THINKING from config. apply_aspect_ratio_padding: When False, send the image unchanged (caller already ran ``_pad_image_for_vlm_aspect_ratio``). Full-page callers keep the default True. Returns: Tuple of response text, input tokens, output tokens, and image width/height in pixels (as encoded for the API, including padding when applied). On success, also prints prompt/completion token counts and output tok/s to stdout. Raises: ConnectionError: If the API request fails ValueError: If the response format is invalid """ if api_url is None: api_url = INFERENCE_SERVER_API_URL if model_name is None: model_name = ( INFERENCE_SERVER_MODEL_NAME if INFERENCE_SERVER_MODEL_NAME else None ) if timeout is None: timeout = INFERENCE_SERVER_TIMEOUT # Pad image so aspect ratio <= VLM_MAX_ASPECT_RATIO; hybrid crops can be very long/thin if apply_aspect_ratio_padding: try: image = _pad_image_for_vlm_aspect_ratio(image) except Exception as e: print( f"Warning: could not pad image for inference-server VLM aspect ratio: {e}" ) # Convert PIL Image to base64 buffer = io.BytesIO() image.save(buffer, format="PNG") image_bytes = buffer.getvalue() image_base64 = base64.b64encode(image_bytes).decode("utf-8") # Prepare the request payload in OpenAI-compatible format messages = [ { "role": "user", "content": [ { "type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}, }, {"type": "text", "text": prompt}, ], } ] payload = { "messages": messages, "stream": stream, } # Add model name if specified and use llama-swap if model_name and use_llama_swap: payload["model"] = model_name # Disable thinking for Qwen3/Qwen3.5 models served by vLLM. vLLM applies the chat template # server-side and honours enable_thinking=False via chat_template_kwargs, which is the exact # server-side equivalent of appending in the local transformers path. if disable_thinking: payload["chat_template_kwargs"] = {"enable_thinking": False} if do_sample is not None: payload["do_sample"] = do_sample if temperature is not None: payload["temperature"] = temperature if top_p is not None: payload["top_p"] = top_p if min_p is not None: payload["min_p"] = min_p if top_k is not None: payload["top_k"] = top_k if repetition_penalty is not None: payload["repeat_penalty"] = repetition_penalty if presence_penalty is not None: payload["presence_penalty"] = presence_penalty if max_new_tokens is not None: payload["max_tokens"] = max_new_tokens if seed is not None: payload["seed"] = seed # Handle deterministic (greedy) generation if do_sample is False: # Greedy decoding (deterministic): always pick the highest probability token # This emulates transformers' do_sample=False behavior payload["temperature"] = 0 # Temperature=0 makes it deterministic payload["top_k"] = 1 # Only consider top 1 token (greedy) payload["top_p"] = 1.0 # Consider all tokens (but top_k=1 overrides this) payload["min_p"] = 0.0 # Minimum probability threshold for token sampling. payload["presence_penalty"] = 1.0 # Penalty for token presence. payload["repeat_penalty"] = 1.0 # No penalty for deterministic endpoint = f"{api_url}/v1/chat/completions" # Retry logic: try up to 5 times for connection errors max_retries = 5 retry_delay = 2 # seconds between retries for attempt in range(max_retries): try: if stream: # Handle streaming response response = requests.post( endpoint, json=payload, headers={"Content-Type": "application/json"}, stream=True, timeout=timeout, ) response.raise_for_status() stream_start = time.perf_counter() # Some OpenAI-compatible servers stream *cumulative* delta.content (full text # generated so far on each chunk). Printing every chunk reprints completed lines. # Others stream incremental tokens only. Support both. accumulated_response = "" # Qwen3 / vLLM --reasoning-parser: assistant text may stream only in # delta.reasoning_content (or reasoning) while content stays empty. Non-streaming # responses are handled by _extract_choice_message_text; mirror that here. accumulated_reasoning = "" # Track only the current in-progress line (since the last newline) so GUI line # reporting doesn't repeatedly emit the entire response. line_buffer = "" accumulated_response_line = "" output_tokens = 0 final_chunk = None for line in response.iter_lines(): if not line: # Skip empty lines continue line = line.decode("utf-8") if line.startswith("data: "): data = line[6:] # Remove 'data: ' prefix if data.strip() == "[DONE]": break try: chunk = json.loads(data) # Store the last chunk in case it contains usage info final_chunk = chunk if "choices" in chunk and len(chunk["choices"]) > 0: delta = chunk["choices"][0].get("delta", {}) token = delta.get("content") or "" reason_piece = delta.get("reasoning_content") if not isinstance(reason_piece, str): reason_piece = "" if not reason_piece: r_alt = delta.get("reasoning") reason_piece = ( r_alt if isinstance(r_alt, str) else "" ) if reason_piece: if reason_piece == accumulated_reasoning: pass elif ( accumulated_reasoning and reason_piece.startswith( accumulated_reasoning ) ): accumulated_reasoning = reason_piece else: accumulated_reasoning += reason_piece output_tokens += 1 if not token: continue if token == accumulated_response: continue if accumulated_response and token.startswith( accumulated_response ): new_part = token[len(accumulated_response) :] if new_part: print(new_part, end="", flush=True) accumulated_response = token else: print(token, end="", flush=True) accumulated_response += token output_tokens += 1 # Maintain line-only buffer for GUI reporting. try: # Prefer the delta we just received; supports both cumulative and incremental streams. if accumulated_response and token.startswith( accumulated_response ): # This is a cumulative chunk; line_buffer should only get the new part. line_buffer += ( new_part if "new_part" in locals() else "" ) else: # This is an incremental token (or reset); append the token to current line. line_buffer += token except Exception: # If anything is inconsistent, fall back to using the full accumulated response. line_buffer = accumulated_response if "\n" in line_buffer: parts = line_buffer.split("\n") complete_lines = parts[:-1] line_buffer = parts[-1] if parts else "" accumulated_response_line = line_buffer if REPORT_VLM_OUTPUTS_TO_GUI: for _ln in complete_lines: if _ln.strip(): try: gr.Info(_ln, duration=2) except Exception: pass else: accumulated_response_line = line_buffer except json.JSONDecodeError: continue print() # newline after stream finishes _content_stripped = accumulated_response.strip() text = ( accumulated_response if _content_stripped else accumulated_reasoning.strip() ) stream_elapsed_s = time.perf_counter() - stream_start # Try to extract token usage from final chunk if available input_tokens = 0 if final_chunk and "usage" in final_chunk: usage = final_chunk["usage"] input_tokens = usage.get("prompt_tokens", 0) # Use the actual output tokens from usage if available, otherwise use our count output_tokens_from_usage = usage.get("completion_tokens", 0) if output_tokens_from_usage > 0: output_tokens = output_tokens_from_usage else: # Estimate input tokens based on prompt length and image # Rough approximation: prompt tokens + image tokens (estimate based on image size) prompt_word_count = len(prompt.split()) # Estimate image tokens: roughly 1 token per 100 pixels (very rough approximation) image_tokens_estimate = max( 100, (image.size[0] * image.size[1]) // 100 ) input_tokens = prompt_word_count + image_tokens_estimate if stream_elapsed_s > 0 and output_tokens > 0: gen_tok_s = output_tokens / stream_elapsed_s print( f"Inference-server VLM: prompt_tokens={input_tokens}, " f"completion_tokens={output_tokens}, speed={gen_tok_s:.2f} tok/s", flush=True, ) else: print( f"Inference-server VLM: prompt_tokens={input_tokens}, " f"completion_tokens={output_tokens}, speed=n/a", flush=True, ) iw, ih = image.size return text, input_tokens, output_tokens, iw, ih else: # Handle non-streaming response req_start = time.perf_counter() response = requests.post( endpoint, json=payload, headers={"Content-Type": "application/json"}, timeout=timeout, ) response.raise_for_status() result = response.json() # Ensure the response has the expected format if "choices" not in result or len(result["choices"]) == 0: raise ValueError( "Invalid response format from inference-server: no choices found" ) choice = result["choices"][0] content = _extract_choice_message_text(choice) if not (content and str(content).strip()): raise ValueError( "Invalid response format from inference-server: no content in message" ) # Extract token usage from response input_tokens = 0 output_tokens = 0 if "usage" in result: usage = result["usage"] input_tokens = usage.get("prompt_tokens", 0) output_tokens = usage.get("completion_tokens", 0) req_elapsed_s = time.perf_counter() - req_start if req_elapsed_s > 0 and output_tokens > 0: gen_tok_s = output_tokens / req_elapsed_s print( f"Inference-server VLM: prompt_tokens={input_tokens}, " f"completion_tokens={output_tokens}, speed={gen_tok_s:.2f} tok/s", flush=True, ) else: print( f"Inference-server VLM: prompt_tokens={input_tokens}, " f"completion_tokens={output_tokens}, speed=n/a", flush=True, ) iw, ih = image.size return content, input_tokens, output_tokens, iw, ih except ( requests.exceptions.RequestException, requests.exceptions.HTTPError, ) as e: # Retry on connection errors or HTTP errors (like 500 Server Error) if attempt < max_retries - 1: print( f"Inference-server VLM API call failed (attempt {attempt + 1}/{max_retries}): {str(e)}" ) print(f"Retrying in {retry_delay} seconds...") time.sleep(retry_delay) retry_delay *= 2 # Exponential backoff continue else: # Final attempt failed, raise the error raise ConnectionError( f"Failed to connect to inference-server at {api_url} after {max_retries} attempts: {str(e)}" ) except json.JSONDecodeError as e: # Don't retry on JSON decode errors - these are likely permanent issues raise ValueError(f"Invalid JSON response from inference-server: {str(e)}") except Exception as e: # Don't retry on other exceptions - these are likely permanent issues raise RuntimeError(f"Error calling inference-server API: {str(e)}") def _call_bedrock_vlm_api( image: Image.Image, prompt: str, model_choice: str = None, bedrock_runtime=None, max_new_tokens: int = None, temperature: float = None, top_p: float = None, timeout: int = 60, max_retries: int = 5, retry_delay_seconds: float = 2.0, ) -> Tuple[str, int, int, int, int]: """ Calls AWS Bedrock API with an image and text prompt for vision models. Args: image: PIL Image to process prompt: Text prompt for the VLM model_choice: Bedrock model ID (e.g., "anthropic.claude-3-5-sonnet-20241022-v2:0") bedrock_runtime: boto3 Bedrock runtime client max_new_tokens: Maximum number of tokens to generate temperature: Sampling temperature top_p: Nucleus sampling parameter timeout: Request timeout in seconds max_retries: Maximum number of retry attempts on failure (default 5) retry_delay_seconds: Delay in seconds between retries (default 2.0) Returns: Response text, input/output tokens, and image width/height in pixels (after aspect-ratio padding sent to the API). Raises: ConnectionError: If the API request fails after all retries ValueError: If the response format is invalid """ if bedrock_runtime is None: raise ValueError("bedrock_runtime client is required for Bedrock VLM calls") if model_choice is None: raise ValueError("model_choice is required for Bedrock VLM calls") # Bedrock Converse API requires image aspect ratio <= 20:1. Pad to VLM_MAX_ASPECT_RATIO first. try: image = _pad_image_for_vlm_aspect_ratio(image) except Exception as aspect_error: print( f"Warning: could not adjust image aspect ratio for Bedrock VLM: {aspect_error}" ) # Final safeguard: never send aspect > 10:1 (AWS Converse limit) try: w, h = image.size if w > 0 and h > 0: aspect = max(w / float(h), h / float(w)) if aspect > 10.0: image = _pad_image_for_vlm_aspect_ratio(image, max_aspect=10.0) print( f"Bedrock VLM: re-padded image to satisfy aspect ratio (was {aspect:.1f}:1)." ) except Exception: pass # Encode the (possibly padded) image and send to Bedrock buffer = io.BytesIO() image.save(buffer, format="PNG") image_bytes = buffer.getvalue() base64.b64encode(image_bytes).decode("utf-8") # Prepare messages for Bedrock converse API # Bedrock supports images in the content array messages = [ { "role": "user", "content": [ {"image": {"format": "png", "source": {"bytes": image_bytes}}}, {"text": prompt}, ], } ] # Build inference config inference_config = { "maxTokens": max_new_tokens if max_new_tokens is not None else 4096, } if temperature is not None: inference_config["temperature"] = temperature if top_p is not None: inference_config["topP"] = top_p last_error = None for attempt in range(1, max_retries + 1): try: # Call Bedrock converse API api_response = bedrock_runtime.converse( modelId=model_choice, messages=messages, inferenceConfig=inference_config, ) # Extract response text output_message = api_response["output"]["message"] if "content" in output_message and len(output_message["content"]) > 0: # Handle reasoning content if present if "reasoningContent" in output_message["content"][0]: # Extract the output text (skip reasoning) if len(output_message["content"]) > 1: text = output_message["content"][1]["text"] else: text = "" else: text = output_message["content"][0]["text"] else: raise ValueError("No content in Bedrock response") # Extract token usage from API response input_tokens = 0 output_tokens = 0 if "usage" in api_response: usage = api_response["usage"] input_tokens = usage.get("inputTokens", 0) output_tokens = usage.get("outputTokens", 0) iw, ih = image.size return text, input_tokens, output_tokens, iw, ih except Exception as e: last_error = e if attempt < max_retries: print( f"Bedrock API attempt {attempt}/{max_retries} failed: {e}. " f"Retrying in {retry_delay_seconds}s..." ) time.sleep(retry_delay_seconds) else: raise ConnectionError( f"Failed to call Bedrock API after {max_retries} attempts: {str(last_error)}" ) from last_error def _call_gemini_vlm_api( image: Image.Image, prompt: str, client=None, config=None, model_choice: str = None, max_new_tokens: int = None, temperature: float = None, timeout: int = 60, ) -> Tuple[str, int, int]: """ Calls Gemini API with an image and text prompt for vision models. Args: image: PIL Image to process prompt: Text prompt for the VLM client: Gemini ai.Client instance config: Gemini types.GenerateContentConfig instance model_choice: Gemini model name (e.g., "gemini-1.5-pro") max_new_tokens: Maximum number of tokens to generate temperature: Sampling temperature timeout: Request timeout in seconds Returns: Tuple[str, int, int]: The generated text response, input tokens, output tokens Raises: ConnectionError: If the API request fails ValueError: If the response format is invalid """ if client is None: raise ValueError("Gemini client is required for Gemini VLM calls") if model_choice is None: raise ValueError("model_choice is required for Gemini VLM calls") try: image = _pad_image_for_vlm_aspect_ratio(image) except Exception as aspect_error: print( f"Warning: could not adjust image aspect ratio for Gemini VLM: {aspect_error}" ) # Convert PIL Image to base64 buffer = io.BytesIO() image.save(buffer, format="PNG") image_bytes = buffer.getvalue() base64.b64encode(image_bytes).decode("utf-8") # Prepare content for Gemini API # Gemini expects content as a list with parts containing image and text try: # Use the client to generate content with image # Gemini API expects the image as part of the content try: import google.genai.types as types except ImportError: raise ImportError( "Google GenAI types not available. Please install google-genai package." ) # Create content with image and text # For Gemini, we can pass image bytes directly or use inline_data parts = [ types.Part.from_bytes(data=image_bytes, mime_type="image/png"), types.Part.from_text(text=prompt), ] # Update config if needed if config is None: config = types.GenerateContentConfig( temperature=temperature if temperature is not None else 0.7, max_output_tokens=( max_new_tokens if max_new_tokens is not None else 4096 ), ) else: # Update existing config if temperature is not None: config.temperature = temperature if max_new_tokens is not None: config.max_output_tokens = max_new_tokens response = client.models.generate_content( model=model_choice, contents=parts, config=config ) # Extract text from response text = "" if hasattr(response, "text"): text = response.text elif hasattr(response, "candidates") and len(response.candidates) > 0: if hasattr(response.candidates[0], "content"): if hasattr(response.candidates[0].content, "parts"): text_parts = [] for part in response.candidates[0].content.parts: if hasattr(part, "text"): text_parts.append(part.text) text = "".join(text_parts) if not text: raise ValueError("No text content in Gemini response") # Extract token usage from response input_tokens = 0 output_tokens = 0 try: if hasattr(response, "usage_metadata"): usage = response.usage_metadata if hasattr(usage, "prompt_token_count"): input_tokens = usage.prompt_token_count if hasattr(usage, "candidates_token_count"): output_tokens = usage.candidates_token_count except Exception: pass # Token usage not available, return 0 return text, input_tokens, output_tokens except Exception as e: raise ConnectionError(f"Failed to call Gemini API: {str(e)}") def _call_azure_openai_vlm_api( image: Image.Image, prompt: str, client=None, model_choice: str = None, max_new_tokens: int = None, temperature: float = None, timeout: int = 60, ) -> Tuple[str, int, int]: """ Calls Azure/OpenAI API with an image and text prompt for vision models. Args: image: PIL Image to process prompt: Text prompt for the VLM client: OpenAI client instance model_choice: Model name (e.g., "gpt-4o", "gpt-4-vision-preview") max_new_tokens: Maximum number of tokens to generate temperature: Sampling temperature timeout: Request timeout in seconds Returns: Tuple[str, int, int]: The generated text response, input tokens, output tokens Raises: ConnectionError: If the API request fails ValueError: If the response format is invalid """ if client is None: raise ValueError("OpenAI client is required for Azure/OpenAI VLM calls") if model_choice is None: raise ValueError("model_choice is required for Azure/OpenAI VLM calls") try: image = _pad_image_for_vlm_aspect_ratio(image) except Exception as aspect_error: print( f"Warning: could not adjust image aspect ratio for Azure/OpenAI VLM: {aspect_error}" ) # Convert PIL Image to base64 buffer = io.BytesIO() image.save(buffer, format="PNG") image_bytes = buffer.getvalue() image_base64 = base64.b64encode(image_bytes).decode("utf-8") # Prepare messages in OpenAI format messages = [ { "role": "user", "content": [ { "type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}, }, {"type": "text", "text": prompt}, ], } ] try: # Call OpenAI chat completions API response = client.chat.completions.create( model=model_choice, messages=messages, temperature=temperature if temperature is not None else 0.7, max_completion_tokens=( max_new_tokens if max_new_tokens is not None else 4096 ), ) # Extract text from response text = "" if response.choices and len(response.choices) > 0: message = response.choices[0].message if hasattr(message, "content") and message.content: text = message.content else: raise ValueError("No content in OpenAI response") else: raise ValueError("No choices in OpenAI response") # Extract token usage from response input_tokens = 0 output_tokens = 0 try: if hasattr(response, "usage"): usage = response.usage if hasattr(usage, "prompt_tokens"): input_tokens = usage.prompt_tokens if hasattr(usage, "completion_tokens"): output_tokens = usage.completion_tokens except Exception: pass # Token usage not available, return 0 return text, input_tokens, output_tokens except Exception as e: raise ConnectionError(f"Failed to call Azure/OpenAI API: {str(e)}") def _repair_vlm_json_common_quote_issues(s: str) -> str: """ Best-effort repair for minor JSON issues seen in VLM outputs. Common cases: - Stray quote after numeric conf: {"conf": 0.85"} -> {"conf": 0.85} - Conf glued to extra text: {"conf": 0.85"some text..."} -> {"conf": 0.85} - A second trailing "conf : .6}" fragment appended inside an object """ if not s or not isinstance(s, str): return s out = s # Remove a stray quote (and any non-delimiter junk) immediately after numeric conf/confidence. # Keep the numeric value and let json.loads succeed. out = re.sub( r'("conf(?:idence)?"\s*:\s*)(-?\d+(?:\.\d+)?)(?:"[^,}\]]*)', r"\1\2", out, ) # Drop malformed trailing fragments like: , conf : .6}" # (unquoted key, often followed by extra braces/quotes). out = re.sub( r",\s*conf\s*:\s*-?(?:\d+(?:\.\d+)?|\.\d+)\s*\}?\s*\"?\s*", "", out, flags=re.IGNORECASE, ) # If the model accidentally emits single quotes around keys/strings, prefer leaving as-is # (other code paths already rely on strict JSON). Avoid aggressive rewriting here. return out def _best_effort_extract_text_conf_from_messy_jsonish( raw: str, ) -> Optional[Dict[str, Any]]: """ Last-resort extractor for hybrid VLM OCR single-line payloads when JSON is too broken. Pulls the first `text` string and the first numeric `conf/confidence` value. """ if not raw or not isinstance(raw, str): return None s = raw.strip() # Extract text (prefer quoted JSON-like "text": "...") text_match = re.search( r'"text"\s*:\s*"(?P(?:[^"\\\\]|\\\\.)*)"', s, flags=re.IGNORECASE, ) if not text_match: text_match = re.search( r"\btext\b\s*[:=]\s*\"(?P(?:[^\"\\\\]|\\\\.)*)\"", s, flags=re.IGNORECASE, ) if not text_match: return None text_val = text_match.group("text") try: # Unescape common sequences text_val = bytes(text_val, "utf-8").decode("unicode_escape") except Exception: pass text_val = str(text_val).strip() if not text_val: return None # Extract first confidence number after conf/confidence (tolerate unquoted key) conf_match = re.search( r"(?:\"?conf(?:idence)?\"?)\s*:\s*(?P-?(?:\d+(?:\.\d+)?|\.\d+))", s, flags=re.IGNORECASE, ) out: Dict[str, Any] = {"text": text_val} if conf_match: out["confidence"] = conf_match.group("conf") return out def _extract_last_text_dict_from_vlm_response(raw: str) -> Optional[Dict[str, Any]]: """ Extract the last JSON object that contains a "text" key from a VLM response. Handles thinking blocks (e.g. ...) that may contain multiple dicts; the final answer is assumed to be the last valid dict in correct format. Returns: The parsed dict with "text" (and optionally "confidence"), or None if none found. """ if not raw or not isinstance(raw, str): return None last_valid = None i = 0 while i < len(raw): start = raw.find("{", i) if start == -1: break depth = 1 j = start + 1 while j < len(raw) and depth > 0: if raw[j] == "{": depth += 1 elif raw[j] == "}": depth -= 1 j += 1 if depth != 0: i = start + 1 continue snippet = raw[start:j] try: snippet = _repair_vlm_json_common_quote_issues(snippet) obj = json.loads(snippet) if isinstance(obj, dict): norm = _normalize_single_line_text_dict(obj) if norm is not None: last_valid = norm except (json.JSONDecodeError, TypeError): fallback = _best_effort_extract_text_conf_from_messy_jsonish(snippet) if fallback is not None: norm = _normalize_single_line_text_dict(fallback) if norm is not None: last_valid = norm i = j return last_valid def _extract_and_combine_text_dicts_from_vlm_response( raw: str, ) -> Optional[Dict[str, Any]]: """ Extract all JSON objects that contain a "text" key from a VLM response, then combine them. If the VLM returns each word in its own dict (e.g. [{"text": "Hello", "confidence": 0.9}, ...]), the text from each entry is joined with spaces and confidence values are averaged. Returns: A single dict with "text" (combined) and "confidence" (average), or None if no valid dict found. """ if not raw or not isinstance(raw, str): return None raw = strip_vlm_thinking_tags(raw) collected = [] i = 0 while i < len(raw): start = raw.find("{", i) if start == -1: break depth = 1 j = start + 1 while j < len(raw) and depth > 0: if raw[j] == "{": depth += 1 elif raw[j] == "}": depth -= 1 j += 1 if depth != 0: i = start + 1 continue snippet = raw[start:j] try: snippet = _repair_vlm_json_common_quote_issues(snippet) obj = json.loads(snippet) if isinstance(obj, dict): norm = _normalize_single_line_text_dict(obj) if norm is not None: collected.append(norm) except (json.JSONDecodeError, TypeError): fallback = _best_effort_extract_text_conf_from_messy_jsonish(snippet) if fallback is not None: norm = _normalize_single_line_text_dict(fallback) if norm is not None: collected.append(norm) i = j if not collected: return None if len(collected) == 1: return collected[0] # Multiple entries: combine text and average confidence texts = [] confidences = [] for entry in collected: t = entry.get("text") if t is not None and isinstance(t, str) and t.strip(): texts.append(t.strip()) c = entry.get("confidence", entry.get("conf")) if c is not None: try: confidences.append(float(c)) except (TypeError, ValueError): pass combined_text = " ".join(texts) avg_confidence = sum(confidences) / len(confidences) if confidences else 1.0 return {"text": combined_text, "confidence": avg_confidence} def _vlm_ocr_predict( image: Image.Image, prompt: str = model_default_prompt, ) -> Dict[str, Any]: """ VLM OCR prediction function that mimics PaddleOCR's interface. Args: image: PIL Image to process prompt: Text prompt for the VLM Returns: Dictionary in PaddleOCR format with 'rec_texts' and 'rec_scores' """ try: # Validate image exists and is not None if image is None: print("VLM OCR error: Image is None") return {"rec_texts": [], "rec_scores": []} # Validate image has valid size (at least 10x10 pixels) try: width, height = image.size if width < 10 or height < 10: print( f"VLM OCR error: Image is too small ({width}x{height} pixels). Minimum size is 10x10." ) return {"rec_texts": [], "rec_scores": []} except Exception as size_error: print(f"VLM OCR error: Could not get image size: {size_error}") return {"rec_texts": [], "rec_scores": []} # Ensure image is in RGB mode (convert if needed) try: if image.mode != "RGB": # print(f"VLM OCR: Converting image from {image.mode} to RGB mode") image = image.convert("RGB") # Update width/height after conversion (should be same, but ensure consistency) width, height = image.size except Exception as convert_error: print(f"VLM OCR error: Could not convert image to RGB: {convert_error}") return {"rec_texts": [], "rec_scores": []} # Same pipeline as hybrid inference-server line crops try: image = _prepare_hybrid_line_crop_for_vlm(image) width, height = image.size except Exception as prep_error: print(f"VLM OCR error: Could not prepare image for VLM: {prep_error}") return {"rec_texts": [], "rec_scores": []} # Use the VLM to extract text # Pass None for parameters to prioritize model-specific defaults from run_vlm.py # If model defaults are not available, general defaults will be used (matching current values) extracted_text, _, _ = extract_text_from_image_vlm( text=prompt, image=image, max_new_tokens=None, # Use model default if available, otherwise MAX_NEW_TOKENS from config temperature=None, # Use model default if available, otherwise 0.7 top_p=None, # Use model default if available, otherwise 0.9 min_p=None, # Use model default if available, otherwise 0.0 top_k=None, # Use model default if available, otherwise 50 repetition_penalty=None, # Use model default if available, otherwise 1.3 presence_penalty=None, # Use model default if available, otherwise None (only supported by Qwen3-VL models) ) # Check if extracted_text is None or empty if extracted_text is None: # print("VLM OCR warning: extract_text_from_image_vlm returned None") return {"rec_texts": [], "rec_scores": []} if not isinstance(extracted_text, str): # print(f"VLM OCR warning: extract_text_from_image_vlm returned unexpected type: {type(extracted_text)}") return {"rec_texts": [], "rec_scores": []} if not extracted_text.strip(): # print("VLM OCR warning: Extracted text is empty after stripping") return {"rec_texts": [], "rec_scores": []} # Parse VLM response: expect dictionary format {"text": "...", "confidence": ...} # If VLM returns multiple dicts (e.g. one per word), combine text and average confidence parsed = _extract_and_combine_text_dicts_from_vlm_response(extracted_text) if parsed is None: return {"rec_texts": [], "rec_scores": []} text_content = parsed.get("text") confidence = parsed.get("confidence") if text_content is None or not isinstance(text_content, str): return {"rec_texts": [], "rec_scores": []} # Clamp confidence to [0, 1]; default 1.0 if missing or invalid try: score = float(confidence) if confidence is not None else 1.0 score = max(0.0, min(1.0, score)) except (TypeError, ValueError): score = 1.0 cleaned_text = re.sub(r"[\r\n]+", " ", text_content).strip() words = cleaned_text.split() # Enforce output length below HYBRID_OCR_MAX_WORDS (truncate if over) if len(words) > HYBRID_OCR_MAX_WORDS: words = words[:HYBRID_OCR_MAX_WORDS] result = { "rec_texts": words, "rec_scores": [score] * len(words), } return result except Exception: # print(f"VLM OCR error: {e}") # print(f"VLM OCR error traceback: {traceback.format_exc()}") return {"rec_texts": [], "rec_scores": []} @spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME) def _process_page_result_with_hybrid_vlm_ocr( page_results: list, image: Image.Image, img_width: int, img_height: int, input_image_width: int, input_image_height: int, confidence_threshold: float, image_name: str, output_folder: str, padding: int = 0, ): """ Processes OCR page results using a hybrid system that combines PaddleOCR for initial recognition and VLM for low-confidence lines. When PaddleOCR's recognition confidence for a detected line is below the specified threshold, the line is re-processed using a higher-quality (but slower) VLM model and the result is used to replace the low-confidence recognition. Results are kept in PaddleOCR's standard output format for downstream compatibility. Args: page_results (list): The list of page result dicts from PaddleOCR to process. Each dict should contain keys like 'rec_texts', 'rec_scores', 'rec_polys', and optionally 'image_width', 'image_height', and 'rec_models'. image (PIL.Image.Image): The PIL Image object of the full page to allow line cropping. img_width (int): The width of the (possibly preprocessed) image in pixels. img_height (int): The height of the (possibly preprocessed) image in pixels. input_image_width (int): The original image width (before any resizing/preprocessing). input_image_height (int): The original image height (before any resizing/preprocessing). confidence_threshold (float): Lines recognized by PaddleOCR with confidence lower than this threshold will be replaced using the VLM. image_name (str): The name of the source image, used for logging/debugging. output_folder (str): The output folder path for saving example images. padding (int): Padding to add around line crops. Returns: Modified page_results with VLM replacements for low-confidence lines. """ if len(page_results) > 1: print( f"Hybrid Paddle+VLM: PaddleOCR returned {len(page_results)} result dicts for one image " f"({image_name!r}); applying line-level VLM re-OCR only to the first dict to avoid duplicate VLM calls." ) _hybrid_page_iter = page_results[:1] if len(page_results) > 1 else page_results def _normalize_paddle_result_lists(rec_texts, rec_scores, rec_polys): """ Normalizes PaddleOCR result lists to ensure they all have the same length. Pads missing entries with appropriate defaults: - rec_texts: empty string "" - rec_scores: 0.0 (low confidence) - rec_polys: empty list [] Args: rec_texts: List of recognized text strings rec_scores: List of confidence scores rec_polys: List of bounding box polygons Returns: Tuple of (normalized_rec_texts, normalized_rec_scores, normalized_rec_polys, max_length) """ len_texts = len(rec_texts) len_scores = len(rec_scores) len_polys = len(rec_polys) max_length = max(len_texts, len_scores, len_polys) # Only normalize if there's a mismatch if max_length > 0 and ( len_texts != max_length or len_scores != max_length or len_polys != max_length ): print( f"Warning: List length mismatch detected - rec_texts: {len_texts}, " f"rec_scores: {len_scores}, rec_polys: {len_polys}. " f"Padding to length {max_length}." ) # Pad rec_texts if len_texts < max_length: rec_texts = list(rec_texts) + [""] * (max_length - len_texts) # Pad rec_scores if len_scores < max_length: rec_scores = list(rec_scores) + [0.0] * (max_length - len_scores) # Pad rec_polys if len_polys < max_length: rec_polys = list(rec_polys) + [[]] * (max_length - len_polys) return rec_texts, rec_scores, rec_polys, max_length # Helper function to create safe filename (inlined to avoid needing instance_self) def _create_safe_filename_with_confidence( original_text: str, new_text: str, conf: int, new_conf: int, ocr_type: str = "OCR", ) -> str: """Creates a safe filename using confidence values when text sanitization fails.""" # Helper to sanitize text similar to _sanitize_filename def _sanitize_text_for_filename( text: str, max_length: int = 20, fallback_prefix: str = "unknown_text", ) -> str: """Sanitizes text for use in filenames.""" sanitized = safe_sanitize_text(text) # Remove leading/trailing underscores and spaces sanitized = sanitized.strip("_ ") # If empty after sanitization, use a default value if not sanitized: sanitized = fallback_prefix # Limit to max_length characters if len(sanitized) > max_length: sanitized = sanitized[:max_length] sanitized = sanitized.rstrip("_") # Final check: if still empty or too short, use fallback if not sanitized or len(sanitized) < 3: sanitized = fallback_prefix return sanitized # Try to sanitize both texts safe_original = _sanitize_text_for_filename( original_text, max_length=15, fallback_prefix=f"orig_conf_{conf}" ) safe_new = _sanitize_text_for_filename( new_text, max_length=15, fallback_prefix=f"new_conf_{new_conf}" ) # If both sanitizations resulted in fallback names, create a confidence-based name if safe_original.startswith("orig_conf") and safe_new.startswith("new_conf"): return f"{ocr_type}_conf_{conf}_to_conf_{new_conf}" return f"{safe_original}_conf_{conf}_to_{safe_new}_conf_{new_conf}" # Process each page result in paddle_results (see _hybrid_page_iter when len > 1) for page_result in _hybrid_page_iter: # Extract text recognition results from the paddle format rec_texts = page_result.get("rec_texts", list()) rec_scores = page_result.get("rec_scores", list()) rec_polys = page_result.get("rec_polys", list()) # Normalize lists to ensure they all have the same length rec_texts, rec_scores, rec_polys, num_lines = _normalize_paddle_result_lists( rec_texts, rec_scores, rec_polys ) # Update page_result with normalized lists page_result["rec_texts"] = rec_texts page_result["rec_scores"] = rec_scores page_result["rec_polys"] = rec_polys # Initialize rec_models list with "Paddle" as default for all lines if ( "rec_models" not in page_result or len(page_result.get("rec_models", [])) != num_lines ): rec_models = ["Paddle"] * num_lines page_result["rec_models"] = rec_models else: rec_models = page_result["rec_models"] # Since we're using the exact image PaddleOCR processed, coordinates are directly in image space # No coordinate conversion needed - coordinates match the image dimensions exactly # Process each line # print(f"Processing {num_lines} lines from PaddleOCR results...") for i in range(num_lines): line_text = rec_texts[i] line_conf = float(rec_scores[i]) * 100 # Convert to percentage bounding_box = rec_polys[i] # Skip if bounding box is empty (from padding) # Handle numpy arrays, lists, and None values safely if bounding_box is None: continue # Convert to list first to handle numpy arrays safely if hasattr(bounding_box, "tolist"): box = bounding_box.tolist() else: box = bounding_box # Check if box is empty (handles both list and numpy array cases) if not box or (isinstance(box, list) and len(box) == 0): continue # Skip empty lines if not line_text.strip(): continue # Convert polygon to bounding box x_coords = [p[0] for p in box] y_coords = [p[1] for p in box] line_left_paddle = float(min(x_coords)) line_top_paddle = float(min(y_coords)) line_right_paddle = float(max(x_coords)) line_bottom_paddle = float(max(y_coords)) line_width_paddle = line_right_paddle - line_left_paddle line_height_paddle = line_bottom_paddle - line_top_paddle # Since we're using the exact image PaddleOCR processed, coordinates are already in image space # No conversion needed - use coordinates directly line_left = line_left_paddle line_top = line_top_paddle line_width = line_width_paddle line_height = line_height_paddle # Initialize model as PaddleOCR (default) # Count words in PaddleOCR output paddle_words = line_text.split() paddle_word_count = len(paddle_words) # If confidence is low, use VLM for a second opinion if line_conf <= confidence_threshold: # Ensure minimum line height for VLM processing # If line_height is too small, use a minimum height based on typical text line height min_line_height = max( line_height, 20 ) # Minimum 20 pixels for text line # Calculate crop coordinates with padding # Convert floats to integers and apply padding, clamping to image bounds crop_left = max(0, int(round(line_left - padding))) crop_top = max(0, int(round(line_top - padding))) crop_right = min( img_width, int(round(line_left + line_width + padding)) ) crop_bottom = min( img_height, int(round(line_top + min_line_height + padding)) ) # Ensure crop dimensions are valid if crop_right <= crop_left or crop_bottom <= crop_top: # Invalid crop, keep original PaddleOCR result continue # Crop the line image cropped_image = image.crop( (crop_left, crop_top, crop_right, crop_bottom) ) # Check if cropped image is too small for VLM processing crop_width = crop_right - crop_left crop_height = crop_bottom - crop_top if crop_width < 10 or crop_height < 10: continue # Ensure cropped image is in RGB mode before passing to VLM if cropped_image.mode != "RGB": cropped_image = cropped_image.convert("RGB") # Save input image for debugging if environment variable is set if SAVE_VLM_INPUT_IMAGES: try: vlm_debug_dir = os.path.join( output_folder, "hybrid_paddle_vlm_visualisations/hybrid_analysis_input_images", ) os.makedirs(vlm_debug_dir, exist_ok=True) line_text_safe = safe_sanitize_text(line_text) line_text_shortened = line_text_safe[:20] image_name_safe = safe_sanitize_text(image_name) image_name_shortened = image_name_safe[:20] filename = f"{image_name_shortened}_{line_text_shortened}_hybrid_analysis_input_image.png" filepath = os.path.join(vlm_debug_dir, filename) _save_image_with_config_dpi(cropped_image, filepath) # print(f"Saved VLM input image to: {filepath}") except Exception as save_error: print(f"Warning: Could not save VLM input image: {save_error}") # Use VLM for OCR on this line with error handling vlm_result = None vlm_rec_texts = [] vlm_rec_scores = [] try: vlm_result = _vlm_ocr_predict(cropped_image) vlm_rec_texts = ( vlm_result.get("rec_texts", []) if vlm_result else [] ) vlm_rec_scores = ( vlm_result.get("rec_scores", []) if vlm_result else [] ) except Exception: # Ensure we keep original PaddleOCR result on error vlm_rec_texts = [] vlm_rec_scores = [] if vlm_rec_texts and vlm_rec_scores: # Combine VLM words into a single text string vlm_text = " ".join(vlm_rec_texts) vlm_word_count = len(vlm_rec_texts) vlm_conf = float( np.median(vlm_rec_scores) ) # Keep as 0-1 range for paddle format # Only replace if word counts match word_count_allowed_difference = 4 if ( vlm_word_count - paddle_word_count <= word_count_allowed_difference and vlm_word_count - paddle_word_count >= -word_count_allowed_difference ): text_output = f" Re-OCR'd line: '{line_text}' (conf: {line_conf:.1f}, words: {paddle_word_count}) " text_output += f"-> '{vlm_text}' (conf: {vlm_conf*100:.1f}, words: {vlm_word_count}) [VLM]" print(text_output) if REPORT_VLM_OUTPUTS_TO_GUI: try: gr.Info(text_output, duration=2) except Exception: # gr.Info may not be available in worker process, ignore pass # For exporting example image comparisons safe_filename = _create_safe_filename_with_confidence( line_text, vlm_text, int(line_conf), int(vlm_conf * 100), "VLM", ) if SAVE_EXAMPLE_HYBRID_IMAGES: # Normalize and validate image_name to prevent path traversal attacks normalized_image_name = os.path.normpath( image_name + "_hybrid_paddle_vlm" ) if ( ".." in normalized_image_name or "/" in normalized_image_name or "\\" in normalized_image_name ): normalized_image_name = "safe_image" hybrid_ocr_examples_folder = ( output_folder + f"/hybrid_ocr_examples/{normalized_image_name}" ) # Validate the constructed path is safe if not validate_folder_containment( hybrid_ocr_examples_folder, OUTPUT_FOLDER ): raise ValueError( f"Unsafe hybrid_ocr_examples folder path: {hybrid_ocr_examples_folder}" ) if not os.path.exists(hybrid_ocr_examples_folder): os.makedirs(hybrid_ocr_examples_folder) output_image_path = ( hybrid_ocr_examples_folder + f"/{safe_filename}.png" ) # print(f"Saving example image to {output_image_path}") _save_image_with_config_dpi( cropped_image, output_image_path ) # Replace with VLM result in paddle_results format # Update rec_texts, rec_scores, and rec_models for this line rec_texts[i] = vlm_text rec_scores[i] = vlm_conf rec_models[i] = "VLM" # Ensure page_result is updated with the modified rec_models list page_result["rec_models"] = rec_models else: print( f" Line: '{line_text}' (conf: {line_conf:.1f}, words: {paddle_word_count}) -> " f"VLM result '{vlm_text}' (conf: {vlm_conf*100:.1f}, words: {vlm_word_count}) " f"word count mismatch. Keeping PaddleOCR result." ) else: # VLM returned empty or no results - keep original PaddleOCR result if line_conf <= confidence_threshold: pass # Debug: Print summary of model labels before returning for page_idx, page_result in enumerate(page_results): rec_models = page_result.get("rec_models", []) sum(1 for m in rec_models if m == "VLM") sum(1 for m in rec_models if m == "Paddle") return page_results def _convert_single_line_to_word_level_standalone( line_text: str, line_left: int, line_top: int, line_width: int, line_height: int, line_conf: float, image: Image.Image, image_width: int, image_height: int, output_folder: str, image_name: str = None, line_model: str = "Bedrock VLM", ) -> Dict[str, List]: """ Converts a single line (text + line bbox) to word-level bounding boxes using AdaptiveSegmenter. Used by the hybrid Textract+Bedrock path when word count differs by <= MAX_WORD_COUNT_DIFF_FOR_LINE_DERIVED_WORDS so we keep Bedrock text but derive word boxes from the line. Returns dict with keys "text", "left", "top", "width", "height", "conf", "model" (all lists, coordinates in full image space). """ output = { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } if not (line_text or "").strip(): return output if image is None or output_folder is None: return output if hasattr(image, "size"): image_np = np.array(image) if len(image_np.shape) == 3: image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) elif len(image_np.shape) == 2: image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR) else: image_np = image.copy() if len(image_np.shape) == 2: image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR) actual_height, actual_width = image_np.shape[:2] if actual_width != image_width or actual_height != image_height: image_width = actual_width image_height = actual_height line_left = int(max(0, min(line_left, image_width - 1))) line_top = int(max(0, min(line_top, image_height - 1))) line_width = int(max(1, min(line_width, image_width - line_left))) line_height = int(max(1, min(line_height, image_height - line_top))) if line_left >= image_width or line_top >= image_height: return output if line_left + line_width > image_width: line_width = image_width - line_left if line_top + line_height > image_height: line_height = image_height - line_top if line_width <= 0 or line_height <= 0: return output try: line_image = image_np[ line_top : line_top + line_height, line_left : line_left + line_width, ] except IndexError: return output if line_image is None or line_image.size == 0 or len(line_image.shape) < 2: return output conf_val = line_conf if isinstance(line_conf, (int, float)) else 100 try: conf_val = max(0, min(100, float(conf_val))) except (TypeError, ValueError): conf_val = 100 single_line_data = { "text": [line_text], "left": [0], "top": [0], "width": [line_width], "height": [line_height], "conf": [conf_val], "line": [0], } segmenter = AdaptiveSegmenter(output_folder=output_folder) try: word_output, _ = segmenter.segment( single_line_data, line_image, image_name=image_name ) except Exception: word_output = None if not word_output or not word_output.get("text"): words = line_text.split() if words: num_chars = len("".join(words)) num_spaces = len(words) - 1 char_space_ratio = 2.0 denom = (num_chars * char_space_ratio + num_spaces) or 1 estimated_space_width = line_width / denom avg_char_width = estimated_space_width * char_space_ratio current_left = 0 for word in words: word_width = len(word) * avg_char_width clamped_left = max(0, min(current_left, line_width)) clamped_width = max(0, min(word_width, line_width - clamped_left)) output["text"].append(word) output["left"].append(line_left + clamped_left) output["top"].append(line_top) output["width"].append(clamped_width) output["height"].append(line_height) output["conf"].append(conf_val) output["model"].append(line_model) current_left += word_width + estimated_space_width return output for j in range(len(word_output["text"])): output["text"].append(word_output["text"][j]) output["left"].append(line_left + word_output["left"][j]) output["top"].append(line_top + word_output["top"][j]) output["width"].append(word_output["width"][j]) output["height"].append(word_output["height"][j]) output["conf"].append( word_output["conf"][j] if j < len(word_output.get("conf") or []) else conf_val ) output["model"].append(line_model) return output def _process_textract_page_with_hybrid_bedrock_vlm( page_line_level_ocr_results: Dict[str, Any], page_line_level_ocr_results_with_words: Dict[str, Any], image: Image.Image, img_width: int, img_height: int, confidence_threshold: float, padding: int, bedrock_runtime: Any, model_choice: str, output_folder: str, image_name: str, ) -> Tuple[Dict[str, Any], int, int, str]: """ For a single page's Textract results, re-run Bedrock VLM on lines whose line-level confidence is below the threshold. Uses the actual line-level page OCR object (page_line_level_ocr_results) for confidence and bbox; the ocr results with words object is updated only at the end when mapping back corrected text/confidence/words for successfully re-OCR'd lines. Returns: Tuple of (page_line_level_ocr_results_with_words, vlm_input_tokens, vlm_output_tokens, vlm_model_name) for usage logging. """ _empty_return = (page_line_level_ocr_results_with_words, 0, 0, model_choice or "") if image is None or not page_line_level_ocr_results_with_words: print("Image is None or no page line level OCR results with words found") return _empty_return results = page_line_level_ocr_results_with_words.get("results") or {} if not results: print("No results found") return _empty_return if bedrock_runtime is None or not model_choice: print("Bedrock runtime is None or model choice is not set") print(f"Bedrock runtime: {bedrock_runtime}") print(f"Model choice: {model_choice}") return _empty_return line_level_results = page_line_level_ocr_results.get("results") or [] if not line_level_results: return _empty_return # Build line-level items from the actual line-level OCR (OCRResult list) # Match by result.line -> key "text_line_{line}" in the with_words dict line_level_items = [] for result in line_level_results: line_num = ( getattr(result, "line", None) if hasattr(result, "line") else result.get("line") if isinstance(result, dict) else None ) if line_num is None or line_num < 1: continue key = f"text_line_{line_num}" if key not in results: continue if isinstance(result, dict): conf = result.get("conf", result.get("confidence")) else: conf = getattr(result, "conf", None) if conf is None: conf = 0 try: line_conf = float(conf) except (TypeError, ValueError): line_conf = 0 if isinstance(result, dict): left = result.get("left", 0) top = result.get("top", 0) w = result.get("width", 0) h = result.get("height", 0) else: left = getattr(result, "left", 0) top = getattr(result, "top", 0) w = getattr(result, "width", 0) h = getattr(result, "height", 0) bbox = (left, top, left + w, top + h) line_level_items.append((key, line_conf, bbox)) # Ensure RGB for Bedrock if image.mode != "RGB": image = image.convert("RGB") # Optional: folder and log file for saving example images and prompt/response (when config set) save_examples = SAVE_TEXTRACT_BEDROCK_HYBRID_EXAMPLES hybrid_examples_folder = None inference_log_path = None if save_examples and output_folder and image_name: normalized_image_name = os.path.normpath( image_name + "_textract_bedrock_hybrid" ) if ( ".." in normalized_image_name or "/" in normalized_image_name or "\\" in normalized_image_name ): normalized_image_name = "safe_image" hybrid_examples_folder = os.path.join( output_folder, "textract_bedrock_hybrid_examples", normalized_image_name ) if validate_folder_containment(hybrid_examples_folder, OUTPUT_FOLDER): if not os.path.exists(hybrid_examples_folder): os.makedirs(hybrid_examples_folder) page_no = page_line_level_ocr_results_with_words.get("page", "?") inference_log_path = os.path.join( hybrid_examples_folder, f"page_{page_no}_inference_log.jsonl" ) else: save_examples = False # Build list of (key, line_conf, bbox, cropped) for lines below threshold tasks = [] for key, line_conf, bbox in line_level_items: if line_conf > confidence_threshold: continue left, top, right, bottom = bbox crop_left = max(0, int(left) - padding) crop_top = max(0, int(top) - padding) crop_right = min(img_width, int(right) + padding) crop_bottom = min(img_height, int(bottom) + padding) if crop_right <= crop_left or crop_bottom <= crop_top: continue cropped = image.crop((crop_left, crop_top, crop_right, crop_bottom)) if cropped.size[0] < 10 or cropped.size[1] < 10: continue tasks.append((key, line_conf, bbox, cropped)) def _run_one_line_vlm( task: Tuple[str, float, Tuple, Image.Image], ) -> Dict[str, Any]: """Run Bedrock VLM on one line crop. Returns dict with key, line_conf, bbox, cropped, and either vlm_result or error.""" key, line_conf, bbox, cropped = task prompt_used = model_default_prompt try: vlm_result = _bedrock_vlm_ocr_predict( cropped, model_choice=model_choice, bedrock_runtime=bedrock_runtime, return_prompt_and_response=save_examples, ) if save_examples: prompt_used = vlm_result.get("prompt", prompt_used) return { "key": key, "line_conf": line_conf, "bbox": bbox, "cropped": cropped, "vlm_result": vlm_result, "prompt_used": prompt_used, "raw_response": ( vlm_result.get("raw_response") if save_examples else None ), "error": None, } except Exception as e: return { "key": key, "line_conf": line_conf, "bbox": bbox, "cropped": cropped, "vlm_result": None, "prompt_used": prompt_used, "raw_response": None, "error": str(e), } # Run VLM inference in parallel for all low-confidence lines hybrid_vlm_input_tokens = 0 hybrid_vlm_output_tokens = 0 updates = [] vlm_results_list = [] if tasks: max_workers_hybrid = min(MAX_WORKERS, len(tasks)) with ThreadPoolExecutor(max_workers=max_workers_hybrid) as executor: vlm_results_list = list(executor.map(_run_one_line_vlm, tasks)) # Process each VLM result (post-processing and optional logging on main thread) for res in vlm_results_list: key = res["key"] line_conf = res["line_conf"] bbox = res["bbox"] cropped = res["cropped"] prompt_used = res["prompt_used"] raw_response = res["raw_response"] if res["error"] is not None: print(f"Hybrid Textract-Bedrock VLM failed for line {key}: {res['error']}") if save_examples and hybrid_examples_folder and inference_log_path: try: safe_name = ( f"{safe_sanitize_text(key)}_conf_{int(line_conf)}_error.png" ) crop_path = os.path.join(hybrid_examples_folder, safe_name) _save_image_with_config_dpi(cropped, crop_path) log_entry = { "key": key, "line_conf": line_conf, "prompt": prompt_used, "error": res["error"], "raw_response": None, } with open(inference_log_path, "a", encoding="utf-8") as log_f: log_f.write(json.dumps(log_entry, ensure_ascii=False) + "\n") except Exception as save_err: print(f"Could not save hybrid example for {key}: {save_err}") continue vlm_result = res["vlm_result"] hybrid_vlm_input_tokens += vlm_result.get("vlm_input_tokens", 0) hybrid_vlm_output_tokens += vlm_result.get("vlm_output_tokens", 0) rec_texts = vlm_result.get("rec_texts", []) rec_scores = vlm_result.get("rec_scores", []) if not rec_texts or not rec_scores: if save_examples and hybrid_examples_folder and inference_log_path: try: safe_name = f"{safe_sanitize_text(key)}_conf_{int(line_conf)}.png" crop_path = os.path.join(hybrid_examples_folder, safe_name) _save_image_with_config_dpi(cropped, crop_path) log_entry = { "key": key, "line_conf": line_conf, "prompt": prompt_used, "raw_response": raw_response, "error": None, "parsed_rec_texts": [], } with open(inference_log_path, "a", encoding="utf-8") as log_f: log_f.write(json.dumps(log_entry, ensure_ascii=False) + "\n") except Exception as save_err: print(f"Could not save hybrid example for {key}: {save_err}") continue if save_examples and hybrid_examples_folder and inference_log_path: try: safe_name = f"{safe_sanitize_text(key)}_conf_{int(line_conf)}.png" crop_path = os.path.join(hybrid_examples_folder, safe_name) _save_image_with_config_dpi(cropped, crop_path) log_entry = { "key": key, "line_conf": line_conf, "prompt": prompt_used, "raw_response": raw_response, "error": None, "parsed_rec_texts": rec_texts, "parsed_rec_scores": rec_scores, } with open(inference_log_path, "a", encoding="utf-8") as log_f: log_f.write(json.dumps(log_entry, ensure_ascii=False) + "\n") except Exception as save_err: print(f"Could not save hybrid example for {key}: {save_err}") # Textract may have split punctuation into separate words (SPLIT_PUNCTUATION_FROM_WORDS), # while the VLM returns punctuation attached (e.g. "ACTION."). Expand VLM words to match # original word count by splitting trailing punctuation when the next original word is punct-only. line_data_for_words = results.get(key) or {} original_words_list = line_data_for_words.get("words") or [] original_word_count = len(original_words_list) expanded_texts = [] expanded_scores = [] j = 0 i = 0 while i < original_word_count and j < len(rec_texts): word = rec_texts[j] score = rec_scores[j] if j < len(rec_scores) else rec_scores[-1] next_orig = ( original_words_list[i + 1] if i + 1 < original_word_count else None ) next_orig_text = ( (next_orig.get("text", "") or "") if isinstance(next_orig, dict) else "" ) next_is_punct_only = bool(next_orig_text) and not re.search( r"[\w]", next_orig_text ) word_has_trailing_punct = bool(word) and bool(re.search(r"[^\w\s]$", word)) if next_is_punct_only and word_has_trailing_punct and len(word) > 1: match = re.search(r"^(.*?)([^\w\s]+)$", word) if match: main, trail = match.group(1), match.group(2) expanded_texts.append(main) expanded_scores.append(score) expanded_texts.append(trail) expanded_scores.append(score) i += 2 j += 1 else: expanded_texts.append(word) expanded_scores.append(score) i += 1 j += 1 else: expanded_texts.append(word) expanded_scores.append(score) i += 1 j += 1 same_word_count = len(expanded_texts) == original_word_count and j == len( rec_texts ) if same_word_count: rec_texts = expanded_texts rec_scores = expanded_scores new_text = " ".join(rec_texts) # rec_scores from Bedrock are 0-1; store as 0-100 to match Textract # Use original word-level bounding boxes so replacement is shown per word, not per line new_words = [] for i, txt in enumerate(rec_texts): score = rec_scores[i] if i < len(rec_scores) else rec_scores[-1] try: sc = float(score) if 0 <= sc <= 1: sc = sc * 100 sc = max(0, min(100, sc)) except (TypeError, ValueError): sc = 100 # Retain original word bounding box when available (word-level replacement boxes) word_bbox = bbox if i < len(original_words_list): orig_word = original_words_list[i] if isinstance(orig_word, dict): orig_bbox = orig_word.get("bounding_box") if isinstance(orig_bbox, (list, tuple)) and len(orig_bbox) == 4: word_bbox = orig_bbox new_words.append( { "text": txt, "confidence": round(sc, 0), "bounding_box": word_bbox, "model": "Bedrock VLM", } ) avg_conf = (sum(rec_scores) / len(rec_scores)) * 100 if rec_scores else 100 avg_conf = max(0, min(100, avg_conf)) # Accept VLM result if: (1) word count matches and conf > 50, or # (2) word count differs by <= MAX_WORD_COUNT_DIFF and conf > 50 — then use # Bedrock text + Textract line bbox and derive word-level boxes via line-to-word. vlm_conf_above_50 = avg_conf > 50 word_count_diff = abs(original_word_count - len(rec_texts)) use_line_derived_words = ( not same_word_count and vlm_conf_above_50 and word_count_diff <= MAX_WORD_COUNT_DIFF_FOR_LINE_DERIVED_WORDS ) if same_word_count and vlm_conf_above_50: updates.append((key, new_text, avg_conf, new_words, line_conf)) elif use_line_derived_words: # Use Bedrock text and Textract line bbox; derive word-level boxes. left, top, right, bottom = bbox line_w = max(1, int(right) - int(left)) line_h = max(1, int(bottom) - int(top)) word_level = _convert_single_line_to_word_level_standalone( new_text, int(left), int(top), line_w, line_h, avg_conf, image, img_width, img_height, output_folder, image_name=image_name, line_model="Bedrock VLM", ) derived_words = [] for idx in range(len(word_level.get("text") or [])): left = word_level["left"][idx] top = word_level["top"][idx] width = word_level["width"][idx] height = word_level["height"][idx] confidence = ( word_level["conf"][idx] if idx < len(word_level.get("conf") or []) else avg_conf ) try: confidence = max(0, min(100, float(confidence))) except (TypeError, ValueError): confidence = avg_conf derived_words.append( { "text": word_level["text"][idx], "confidence": round(confidence, 0), "bounding_box": (left, top, left + width, top + height), "model": "Bedrock VLM", } ) if derived_words: updates.append((key, new_text, avg_conf, derived_words, line_conf)) else: print( f" Skipping VLM result for {key}: line-to-word returned no words (original={original_word_count}, VLM={len(rec_texts)}). Keeping Textract." ) else: if not same_word_count and not use_line_derived_words: print( f" Skipping VLM result for {key}: word count mismatch (original={original_word_count}, VLM={len(expanded_texts)}). Keeping Textract." ) elif not vlm_conf_above_50: print( f" Skipping VLM result for {key}: VLM confidence {avg_conf:.0f} not above 50. Keeping Textract." ) # Map back into ocr results with words and update line-level OCRResult objects line_level_results = page_line_level_ocr_results.get("results") or [] for key, new_text, avg_conf, new_words, line_conf in updates: line_data = results.get(key) if line_data is not None and isinstance(line_data, dict): line_data["text"] = new_text line_data["confidence"] = round(avg_conf, 0) line_data["words"] = new_words line_data["model"] = "Bedrock VLM" # Update corresponding line-level OCRResult (by line number from key "text_line_N") try: line_num = int(key.replace("text_line_", "")) idx = line_num - 1 if 0 <= idx < len(line_level_results): line_result = line_level_results[idx] if hasattr(line_result, "text"): line_result.text = new_text line_result.conf = round(avg_conf, 0) line_result.model = "Bedrock VLM" elif isinstance(line_result, dict): line_result["text"] = new_text line_result["conf"] = round(avg_conf, 0) line_result["model"] = "Bedrock VLM" except (ValueError, TypeError): pass print( f" Re-OCR'd line (Textract conf: {line_conf:.0f}) -> '{new_text}' (conf: {avg_conf:.0f}) [Bedrock VLM]" ) if len(updates) == 0: page_no = page_line_level_ocr_results_with_words.get("page", "?") print( f" Hybrid Textract + Bedrock VLM: no lines on page {page_no} met the low-confidence criteria (threshold={confidence_threshold:.0f}); no Bedrock VLM inference run for this page." ) return ( page_line_level_ocr_results_with_words, hybrid_vlm_input_tokens, hybrid_vlm_output_tokens, model_choice or "", ) def _inference_server_ocr_predict( image: Image.Image, prompt: str = model_default_prompt, max_retries: int = 5, model_name: str = None, image_hybrid_line_prepared: bool = False, ) -> Dict[str, Any]: """ Inference-server OCR prediction function that mimics PaddleOCR's interface. Calls an external inference-server API instead of a local model. Args: image: PIL Image to process prompt: Text prompt for the VLM max_retries: Maximum number of retry attempts for API calls (default: 5) model_name: Name of the inference-server model to use image_hybrid_line_prepared: If True, ``image`` was already processed with ``_prepare_hybrid_line_crop_for_vlm`` (hybrid Paddle + inference-server path). Returns: Dictionary in PaddleOCR format with 'rec_texts' and 'rec_scores' Raises: Exception: If all retry attempts fail after max_retries attempts """ try: # Validate image exists and is not None if image is None: print("Inference-server OCR error: Image is None") return {"rec_texts": [], "rec_scores": []} # Validate image has valid size (at least 10x10 pixels) try: width, height = image.size if width < 10 or height < 10: print( f"Inference-server OCR error: Image is too small ({width}x{height} pixels). Minimum size is 10x10." ) return {"rec_texts": [], "rec_scores": []} except Exception as size_error: print(f"Inference-server OCR error: Could not get image size: {size_error}") return {"rec_texts": [], "rec_scores": []} # Ensure image is in RGB mode (convert if needed) try: if image.mode != "RGB": image = image.convert("RGB") width, height = image.size except Exception as convert_error: print( f"Inference-server OCR error: Could not convert image to RGB: {convert_error}" ) return {"rec_texts": [], "rec_scores": []} if not image_hybrid_line_prepared: # Check and resize image if it exceeds maximum size or DPI limits try: image = _prepare_image_for_vlm(image, hybrid_vlm=True) width, height = image.size except Exception as prep_error: print( f"Inference-server OCR error: Could not prepare image for VLM: {prep_error}" ) return {"rec_texts": [], "rec_scores": []} # Use the inference-server API to extract text with retry logic extracted_text = None for attempt in range(1, max_retries + 1): try: # Determine model_name: use provided parameter, then DEFAULT_INFERENCE_SERVER_VLM_MODEL, then INFERENCE_SERVER_MODEL_NAME final_model_name = model_name if final_model_name is None or final_model_name == "": final_model_name = ( DEFAULT_INFERENCE_SERVER_VLM_MODEL if DEFAULT_INFERENCE_SERVER_VLM_MODEL else None ) if final_model_name is None or final_model_name == "": final_model_name = ( INFERENCE_SERVER_MODEL_NAME if INFERENCE_SERVER_MODEL_NAME else None ) extracted_text, _vlm_input_tokens, _vlm_output_tokens, _, _ = ( _call_inference_server_vlm_api( image=image, prompt=prompt, model_name=final_model_name, max_new_tokens=HYBRID_OCR_MAX_NEW_TOKENS, temperature=None, top_p=None, top_k=None, repetition_penalty=None, seed=None, do_sample=model_default_do_sample, min_p=None, presence_penalty=None, use_llama_swap=USE_LLAMA_SWAP, apply_aspect_ratio_padding=not image_hybrid_line_prepared, ) ) # If we get here, the API call succeeded break except Exception as api_error: print( f"Inference-server OCR retry attempt {attempt}/{max_retries} failed: {api_error}" ) if attempt == max_retries: # All retries exhausted, raise the exception raise Exception( f"Inference-server OCR failed after {max_retries} attempts. Last error: {str(api_error)}" ) from api_error # Continue to next retry attempt # Check if extracted_text is None or empty if extracted_text is None: return {"rec_texts": [], "rec_scores": []} if not isinstance(extracted_text, str): return {"rec_texts": [], "rec_scores": []} if extracted_text.strip(): # Try to parse VLM/LLM response for {"text": "...", "confidence": ...} or "conf" # If multiple dicts (e.g. one per word), combine text and average confidence parsed = _extract_and_combine_text_dicts_from_vlm_response(extracted_text) if parsed is not None and isinstance(parsed.get("text"), str): text_content = parsed.get("text") # Prefer "confidence", fallback to "conf" (VLM may use either) confidence = parsed.get("confidence", parsed.get("conf")) try: score = float(confidence) if confidence is not None else 1.0 # Normalise: if > 1 assume percentage (0–100), else 0–1 if score > 1.0: score = score / 100.0 score = max(0.0, min(1.0, score)) except (TypeError, ValueError): score = 1.0 cleaned_text = re.sub(r"[\r\n]+", " ", text_content).strip() else: # No parseable dict: use raw text and default confidence cleaned_text = re.sub(r"[\r\n]+", " ", extracted_text).strip() score = 1.0 # Split into words for compatibility with PaddleOCR format words = cleaned_text.split() # If text has more words than the line-level limit, assume something went wrong and skip it if len(words) > HYBRID_OCR_MAX_WORDS: print( f"Inference-server OCR warning: Extracted text has {len(words)} words, which exceeds the {HYBRID_OCR_MAX_WORDS} word limit. Skipping." ) return {"rec_texts": [], "rec_scores": []} # Create PaddleOCR-compatible result; use VLM/LLM confidence when available result = { "rec_texts": words, "rec_scores": [score] * len(words), } return result else: return {"rec_texts": [], "rec_scores": []} except Exception as e: # Re-raise if it's the retry exhaustion exception if "failed after" in str(e) and "attempts" in str(e): raise # Otherwise, handle other exceptions as before print(f"Inference-server OCR error: {e}") import traceback print(f"Inference-server OCR error traceback: {traceback.format_exc()}") return {"rec_texts": [], "rec_scores": []} def _bedrock_vlm_ocr_predict( image: Image.Image, prompt: str = model_default_prompt, model_choice: str = None, bedrock_runtime=None, max_retries: int = 10, return_prompt_and_response: bool = False, ) -> Dict[str, Any]: """ Bedrock VLM OCR prediction function that mimics PaddleOCR's interface. Args: image: PIL Image to process prompt: Text prompt for the VLM model_choice: Bedrock model ID bedrock_runtime: boto3 Bedrock runtime client max_retries: Maximum number of retry attempts for API calls (default: 5) return_prompt_and_response: If True, add "prompt" and "raw_response" to the returned dict for logging (raw_response is the raw API text before parsing). Returns: Dictionary in PaddleOCR format with 'rec_texts' and 'rec_scores' (and optionally 'prompt', 'raw_response' when return_prompt_and_response is True). """ extracted_text = None vlm_input_tokens_used = 0 vlm_output_tokens_used = 0 def _add_prompt_response(d: Dict[str, Any]) -> Dict[str, Any]: if return_prompt_and_response: d["prompt"] = prompt d["raw_response"] = extracted_text d["vlm_input_tokens"] = vlm_input_tokens_used d["vlm_output_tokens"] = vlm_output_tokens_used return d try: # Validate image exists and is not None if image is None: print("Bedrock VLM OCR error: Image is None") return _add_prompt_response({"rec_texts": [], "rec_scores": []}) # Validate image has valid size (at least 10x10 pixels) try: width, height = image.size if width < 10 or height < 10: print( f"Bedrock VLM OCR error: Image is too small ({width}x{height} pixels). Minimum size is 10x10." ) return _add_prompt_response({"rec_texts": [], "rec_scores": []}) except Exception as size_error: print(f"Bedrock VLM OCR error: Could not get image size: {size_error}") return _add_prompt_response({"rec_texts": [], "rec_scores": []}) # Ensure image is in RGB mode (convert if needed) try: if image.mode != "RGB": image = image.convert("RGB") width, height = image.size except Exception as convert_error: print( f"Bedrock VLM OCR error: Could not convert image to RGB: {convert_error}" ) return _add_prompt_response({"rec_texts": [], "rec_scores": []}) # Check and resize image if it exceeds maximum size or DPI limits # Skip resizing for AWS Bedrock VLM OCR try: from tools.config import BEDROCK_VLM_TEXT_EXTRACT_OPTION image = _prepare_image_for_vlm( image, ocr_method=BEDROCK_VLM_TEXT_EXTRACT_OPTION, hybrid_vlm=True, ) width, height = image.size except Exception as prep_error: print( f"Bedrock VLM OCR error: Could not prepare image for VLM: {prep_error}" ) return _add_prompt_response({"rec_texts": [], "rec_scores": []}) # Use the Bedrock API to extract text with retry logic for attempt in range(1, max_retries + 1): try: extracted_text, _vlm_input_tokens, _vlm_output_tokens, _, _ = ( _call_bedrock_vlm_api( image=image, prompt=prompt, model_choice=model_choice, bedrock_runtime=bedrock_runtime, max_new_tokens=MAX_NEW_TOKENS, temperature=model_default_temperature, top_p=model_default_top_p, ) ) vlm_input_tokens_used = _vlm_input_tokens vlm_output_tokens_used = _vlm_output_tokens # If we get here, the API call succeeded break except Exception as api_error: print( f"Bedrock VLM OCR retry attempt {attempt}/{max_retries} failed: {api_error}" ) if attempt == max_retries: raise Exception( f"Bedrock VLM OCR failed after {max_retries} attempts. Last error: {str(api_error)}" ) from api_error # Check if extracted_text is None or empty if extracted_text is None: return _add_prompt_response({"rec_texts": [], "rec_scores": []}) if not isinstance(extracted_text, str): return _add_prompt_response({"rec_texts": [], "rec_scores": []}) if extracted_text.strip(): # If Bedrock returns multiple dicts (e.g. one per word) {"text": "...", "confidence": ...}, combine and average (same as local VLM / inference server) parsed = _extract_and_combine_text_dicts_from_vlm_response(extracted_text) if parsed is not None and isinstance(parsed.get("text"), str): text_content = parsed.get("text", "").strip() conf = parsed.get("confidence", parsed.get("conf")) try: score = float(conf) if conf is not None else 1.0 if score > 1.0: score = score / 100.0 score = max(0.0, min(1.0, score)) except (TypeError, ValueError): score = 1.0 if text_content: words = re.sub(r"[\r\n]+", " ", text_content).strip().split() if len(words) <= HYBRID_OCR_MAX_WORDS: return _add_prompt_response( { "rec_texts": words, "rec_scores": [score] * len(words), } ) # Reject parsed result with empty text or zero confidence (e.g. {"text": "", "conf": 0.0}) if not text_content or score <= 0.0: return _add_prompt_response({"rec_texts": [], "rec_scores": []}) # Try to parse VLM JSON response [{'bbox': [...], 'text': '...', 'conf': 0-1}, ...] lines_data = None text = extracted_text.strip() try: text = _fix_malformed_bbox_in_json_string(text) except Exception: pass try: lines_data = json.loads(text) except json.JSONDecodeError: pass if lines_data is None: json_match = re.search(r"```(?:json)?\s*(\[.*?\])", text, re.DOTALL) if json_match: try: lines_data = json.loads(json_match.group(1)) except json.JSONDecodeError: pass if lines_data is None and "[" in text: start_idx = text.find("[") bracket_count = 0 end_idx = start_idx for i in range(start_idx, len(text)): if text[i] == "[": bracket_count += 1 elif text[i] == "]": bracket_count -= 1 if bracket_count == 0: end_idx = i break if end_idx > start_idx: try: lines_data = json.loads(text[start_idx : end_idx + 1]) except json.JSONDecodeError: pass if lines_data is None: try: python_data = ast.literal_eval(text) if isinstance(python_data, list): lines_data = python_data except Exception: pass if isinstance(lines_data, list) and len(lines_data) > 0: rec_texts = [] rec_scores = [] for line_item in lines_data: if not isinstance(line_item, dict): continue line_text = line_item.get("text_content") or line_item.get( "text", "" ) if line_text is None: line_text = "" line_text = str(line_text).strip() if not line_text: continue conf = line_item.get("confidence", line_item.get("conf")) try: score = float(conf) if conf is not None else 1.0 if score > 1.0: score = score / 100.0 score = max(0.0, min(1.0, score)) except (TypeError, ValueError): score = 1.0 rec_texts.append(line_text) rec_scores.append(score) if rec_texts: return _add_prompt_response( {"rec_texts": rec_texts, "rec_scores": rec_scores} ) # Fallback: treat response as plain text (e.g. different prompt) cleaned_text = re.sub(r"[\r\n]+", " ", extracted_text) cleaned_text = cleaned_text.strip() words = cleaned_text.split() if len(words) > HYBRID_OCR_MAX_WORDS: print( f"Bedrock VLM OCR warning: Extracted text has {len(words)} words, which exceeds the {HYBRID_OCR_MAX_WORDS} word limit. Skipping." ) return _add_prompt_response({"rec_texts": [], "rec_scores": []}) result = { "rec_texts": words, "rec_scores": [1.0] * len(words), } return _add_prompt_response(result) else: return _add_prompt_response({"rec_texts": [], "rec_scores": []}) except Exception as e: print(f"Bedrock VLM OCR error: {e}") import traceback print(f"Bedrock VLM OCR error traceback: {traceback.format_exc()}") return _add_prompt_response({"rec_texts": [], "rec_scores": []}) def _gemini_vlm_ocr_predict( image: Image.Image, prompt: str = model_default_prompt, model_choice: str = None, client=None, config=None, max_retries: int = 5, ) -> Dict[str, Any]: """ Gemini VLM OCR prediction function that mimics PaddleOCR's interface. Args: image: PIL Image to process prompt: Text prompt for the VLM model_choice: Gemini model name client: Gemini ai.Client instance config: Gemini types.GenerateContentConfig instance max_retries: Maximum number of retry attempts for API calls (default: 5) Returns: Dictionary in PaddleOCR format with 'rec_texts' and 'rec_scores' """ try: # Validate image exists and is not None if image is None: print("Gemini VLM OCR error: Image is None") return {"rec_texts": [], "rec_scores": []} # Validate image has valid size (at least 10x10 pixels) try: width, height = image.size if width < 10 or height < 10: print( f"Gemini VLM OCR error: Image is too small ({width}x{height} pixels). Minimum size is 10x10." ) return {"rec_texts": [], "rec_scores": []} except Exception as size_error: print(f"Gemini VLM OCR error: Could not get image size: {size_error}") return {"rec_texts": [], "rec_scores": []} # Ensure image is in RGB mode (convert if needed) try: if image.mode != "RGB": image = image.convert("RGB") width, height = image.size except Exception as convert_error: print( f"Gemini VLM OCR error: Could not convert image to RGB: {convert_error}" ) return {"rec_texts": [], "rec_scores": []} # Check and resize image if it exceeds maximum size or DPI limits try: image = _prepare_image_for_vlm(image, hybrid_vlm=True) width, height = image.size except Exception as prep_error: print( f"Gemini VLM OCR error: Could not prepare image for VLM: {prep_error}" ) return {"rec_texts": [], "rec_scores": []} # Use the Gemini API to extract text with retry logic extracted_text = None for attempt in range(1, max_retries + 1): try: extracted_text, _, _ = _call_gemini_vlm_api( image=image, prompt=prompt, client=client, config=config, model_choice=model_choice, max_new_tokens=MAX_NEW_TOKENS, temperature=model_default_temperature, ) # If we get here, the API call succeeded break except Exception as api_error: print( f"Gemini VLM OCR retry attempt {attempt}/{max_retries} failed: {api_error}" ) if attempt == max_retries: raise Exception( f"Gemini VLM OCR failed after {max_retries} attempts. Last error: {str(api_error)}" ) from api_error # Check if extracted_text is None or empty if extracted_text is None: return {"rec_texts": [], "rec_scores": []} if not isinstance(extracted_text, str): return {"rec_texts": [], "rec_scores": []} if extracted_text.strip(): # Try to parse VLM JSON response [{'bbox': [...], 'text': '...', 'conf': 0-1}, ...] lines_data = None text = extracted_text.strip() try: text = _fix_malformed_bbox_in_json_string(text) except Exception: pass try: lines_data = json.loads(text) except json.JSONDecodeError: pass if lines_data is None: json_match = re.search(r"```(?:json)?\s*(\[.*?\])", text, re.DOTALL) if json_match: try: lines_data = json.loads(json_match.group(1)) except json.JSONDecodeError: pass if lines_data is None and "[" in text: start_idx = text.find("[") bracket_count = 0 end_idx = start_idx for i in range(start_idx, len(text)): if text[i] == "[": bracket_count += 1 elif text[i] == "]": bracket_count -= 1 if bracket_count == 0: end_idx = i break if end_idx > start_idx: try: lines_data = json.loads(text[start_idx : end_idx + 1]) except json.JSONDecodeError: pass if lines_data is None: try: python_data = ast.literal_eval(text) if isinstance(python_data, list): lines_data = python_data except Exception: pass if isinstance(lines_data, list) and len(lines_data) > 0: rec_texts = [] rec_scores = [] for line_item in lines_data: if not isinstance(line_item, dict): continue line_text = line_item.get("text_content") or line_item.get( "text", "" ) if line_text is None: line_text = "" line_text = str(line_text).strip() if not line_text: continue conf = line_item.get("confidence", line_item.get("conf")) try: score = float(conf) if conf is not None else 1.0 if score > 1.0: score = score / 100.0 score = max(0.0, min(1.0, score)) except (TypeError, ValueError): score = 1.0 rec_texts.append(line_text) rec_scores.append(score) if rec_texts: return {"rec_texts": rec_texts, "rec_scores": rec_scores} # Fallback: treat response as plain text (e.g. different prompt) cleaned_text = re.sub(r"[\r\n]+", " ", extracted_text) cleaned_text = cleaned_text.strip() words = cleaned_text.split() if len(words) > HYBRID_OCR_MAX_WORDS: print( f"Gemini VLM OCR warning: Extracted text has {len(words)} words, which exceeds the {HYBRID_OCR_MAX_WORDS} word limit. Skipping." ) return {"rec_texts": [], "rec_scores": []} result = { "rec_texts": words, "rec_scores": [1.0] * len(words), } return result else: return {"rec_texts": [], "rec_scores": []} except Exception as e: print(f"Gemini VLM OCR error: {e}") import traceback print(f"Gemini VLM OCR error traceback: {traceback.format_exc()}") return {"rec_texts": [], "rec_scores": []} def _azure_openai_vlm_ocr_predict( image: Image.Image, prompt: str = model_default_prompt, model_choice: str = None, client=None, max_retries: int = 5, ) -> Dict[str, Any]: """ Azure/OpenAI VLM OCR prediction function that mimics PaddleOCR's interface. Args: image: PIL Image to process prompt: Text prompt for the VLM model_choice: Model name (e.g., "gpt-4o", "gpt-4-vision-preview") client: OpenAI client instance max_retries: Maximum number of retry attempts for API calls (default: 5) Returns: Dictionary in PaddleOCR format with 'rec_texts' and 'rec_scores' """ try: # Validate image exists and is not None if image is None: print("Azure/OpenAI VLM OCR error: Image is None") return {"rec_texts": [], "rec_scores": []} # Validate image has valid size (at least 10x10 pixels) try: width, height = image.size if width < 10 or height < 10: print( f"Azure/OpenAI VLM OCR error: Image is too small ({width}x{height} pixels). Minimum size is 10x10." ) return {"rec_texts": [], "rec_scores": []} except Exception as size_error: print(f"Azure/OpenAI VLM OCR error: Could not get image size: {size_error}") return {"rec_texts": [], "rec_scores": []} # Ensure image is in RGB mode (convert if needed) try: if image.mode != "RGB": image = image.convert("RGB") width, height = image.size except Exception as convert_error: print( f"Azure/OpenAI VLM OCR error: Could not convert image to RGB: {convert_error}" ) return {"rec_texts": [], "rec_scores": []} # Check and resize image if it exceeds maximum size or DPI limits try: image = _prepare_image_for_vlm(image, hybrid_vlm=True) width, height = image.size except Exception as prep_error: print( f"Azure/OpenAI VLM OCR error: Could not prepare image for VLM: {prep_error}" ) return {"rec_texts": [], "rec_scores": []} # Use the Azure/OpenAI API to extract text with retry logic extracted_text = None for attempt in range(1, max_retries + 1): try: extracted_text, _, _ = _call_azure_openai_vlm_api( image=image, prompt=prompt, client=client, model_choice=model_choice, max_new_tokens=MAX_NEW_TOKENS, temperature=model_default_temperature, ) # If we get here, the API call succeeded break except Exception as api_error: print( f"Azure/OpenAI VLM OCR retry attempt {attempt}/{max_retries} failed: {api_error}" ) if attempt == max_retries: raise Exception( f"Azure/OpenAI VLM OCR failed after {max_retries} attempts. Last error: {str(api_error)}" ) from api_error # Check if extracted_text is None or empty if extracted_text is None: return {"rec_texts": [], "rec_scores": []} if not isinstance(extracted_text, str): return {"rec_texts": [], "rec_scores": []} if extracted_text.strip(): # Try to parse VLM JSON response [{'bbox': [...], 'text': '...', 'conf': 0-1}, ...] lines_data = None text = extracted_text.strip() try: text = _fix_malformed_bbox_in_json_string(text) except Exception: pass try: lines_data = json.loads(text) except json.JSONDecodeError: pass if lines_data is None: json_match = re.search(r"```(?:json)?\s*(\[.*?\])", text, re.DOTALL) if json_match: try: lines_data = json.loads(json_match.group(1)) except json.JSONDecodeError: pass if lines_data is None and "[" in text: start_idx = text.find("[") bracket_count = 0 end_idx = start_idx for i in range(start_idx, len(text)): if text[i] == "[": bracket_count += 1 elif text[i] == "]": bracket_count -= 1 if bracket_count == 0: end_idx = i break if end_idx > start_idx: try: lines_data = json.loads(text[start_idx : end_idx + 1]) except json.JSONDecodeError: pass if lines_data is None: try: python_data = ast.literal_eval(text) if isinstance(python_data, list): lines_data = python_data except Exception: pass if isinstance(lines_data, list) and len(lines_data) > 0: rec_texts = [] rec_scores = [] for line_item in lines_data: if not isinstance(line_item, dict): continue line_text = line_item.get("text_content") or line_item.get( "text", "" ) if line_text is None: line_text = "" line_text = str(line_text).strip() if not line_text: continue conf = line_item.get("confidence", line_item.get("conf")) try: score = float(conf) if conf is not None else 1.0 if score > 1.0: score = score / 100.0 score = max(0.0, min(1.0, score)) except (TypeError, ValueError): score = 1.0 rec_texts.append(line_text) rec_scores.append(score) if rec_texts: return {"rec_texts": rec_texts, "rec_scores": rec_scores} # Fallback: treat response as plain text (e.g. different prompt) cleaned_text = re.sub(r"[\r\n]+", " ", extracted_text) cleaned_text = cleaned_text.strip() words = cleaned_text.split() if len(words) > HYBRID_OCR_MAX_WORDS: print( f"Azure/OpenAI VLM OCR warning: Extracted text has {len(words)} words, which exceeds the {HYBRID_OCR_MAX_WORDS} word limit. Skipping." ) return {"rec_texts": [], "rec_scores": []} result = { "rec_texts": words, "rec_scores": [1.0] * len(words), } return result else: return {"rec_texts": [], "rec_scores": []} except Exception as e: print(f"Azure/OpenAI VLM OCR error: {e}") import traceback print(f"Azure/OpenAI VLM OCR error traceback: {traceback.format_exc()}") return {"rec_texts": [], "rec_scores": []} def plot_text_bounding_boxes( image: Image.Image, bounding_boxes: List[Dict], image_name: str = "initial_vlm_output_bounding_boxes.png", image_folder: str = "inference_server_visualisations", output_folder: str = OUTPUT_FOLDER, task_type: str = "ocr", ): """ Plots bounding boxes on an image with markers for each a name, using PIL, normalised coordinates, and different colors. Args: image: The PIL Image object. bounding_boxes: A list of bounding boxes containing the name of the object and their positions in normalized [y1 x1 y2 x2] format. image_name: The name of the image for debugging. image_folder: The folder name (relative to output_folder) where the image will be saved. output_folder: The folder where the image will be saved. task_type: The type of task the bounding boxes are for ("ocr", "person", "signature"). """ # Load the image img = image width, height = img.size # Create a drawing object draw = ImageDraw.Draw(img) # Parsing out the markdown fencing bbox_list = _parse_vlm_bbox_dict_list(bounding_boxes) font = ImageFont.load_default() # Iterate over the bounding boxes for i, bbox_dict in enumerate(bbox_list): color = "green" # Extract the bounding box coordinates (preserve the original dict for text extraction) if "bb" in bbox_dict: bbox_coords = bbox_dict["bb"] elif "bbox" in bbox_dict: bbox_coords = bbox_dict["bbox"] elif "bbox_2d" in bbox_dict: bbox_coords = bbox_dict["bbox_2d"] else: # Skip if no valid bbox found continue # Ensure bbox_coords is a list with 4 elements if not isinstance(bbox_coords, list) or len(bbox_coords) != 4: # Try to fix malformed bbox fixed_bbox = _fix_malformed_bbox(bbox_coords) if fixed_bbox is not None: bbox_coords = fixed_bbox else: continue # Convert normalized coordinates to absolute coordinates abs_y1 = int(bbox_coords[1] / 999 * height) abs_x1 = int(bbox_coords[0] / 999 * width) abs_y2 = int(bbox_coords[3] / 999 * height) abs_x2 = int(bbox_coords[2] / 999 * width) if abs_x1 > abs_x2: abs_x1, abs_x2 = abs_x2, abs_x1 if abs_y1 > abs_y2: abs_y1, abs_y2 = abs_y2, abs_y1 # Draw the bounding box draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=1) # Draw the text - extract from the original dictionary, not the coordinates text_to_draw = "No text" if "text" in bbox_dict: text_to_draw = bbox_dict["text"] elif "text_content" in bbox_dict: text_to_draw = bbox_dict["text_content"] draw.text((abs_x1, abs_y2), text_to_draw, fill=color, font=font) try: debug_dir = os.path.join( output_folder, image_folder, ) # Security: Validate that the constructed path is safe normalized_debug_dir = os.path.normpath(os.path.abspath(debug_dir)) if not validate_folder_containment(normalized_debug_dir, OUTPUT_FOLDER): raise ValueError( f"Unsafe image folder path: {debug_dir}. Must be contained within {OUTPUT_FOLDER}" ) os.makedirs(normalized_debug_dir, exist_ok=True) # Increment the number at the end of image_name before .png # This converts zero-indexed input to one-indexed output incremented_image_name = image_name if image_name.endswith(".png"): # Find the number pattern at the end before .png # Matches patterns like: _0.png, _00.png, 0.png, 00.png, etc. pattern = r"(\d+)(\.png)$" match = re.search(pattern, image_name) if match: number_str = match.group(1) number = int(number_str) incremented_number = number + 1 # Preserve the same number of digits (padding with zeros if needed) incremented_str = str(incremented_number).zfill(len(number_str)) incremented_image_name = re.sub( pattern, lambda m: incremented_str + m.group(2), image_name ) image_name_safe = safe_sanitize_text(incremented_image_name) image_name_shortened = image_name_safe[:50] task_type_suffix = f"_{task_type}" if task_type != "ocr" else "" filename = ( f"{image_name_shortened}_initial_bounding_box_output{task_type_suffix}.png" ) filepath = os.path.join(normalized_debug_dir, filename) _save_image_with_config_dpi(img, filepath) except Exception as e: print(f"Error saving image with bounding boxes: {e}") def parse_json(json_output): # Parsing out the markdown fencing and Qwen thinking tags if not isinstance(json_output, str): return json_output json_output = strip_vlm_thinking_tags(json_output) lines = json_output.splitlines() for i, line in enumerate(lines): if line == "```json": json_output = "\n".join( lines[i + 1 :] ) # Remove everything before "```json" json_output = json_output.split("```")[ 0 ] # Remove everything after the closing "```" break # Exit the loop once "```json" is found return json_output def _parse_vlm_bbox_dict_list(bounding_boxes: str) -> List[Dict]: """Parse a VLM bbox JSON/list response, ignoring thinking tags and extra prose.""" if not bounding_boxes or not isinstance(bounding_boxes, str): return [] cleaned = parse_json(bounding_boxes) cleaned = _preprocess_vlm_ocr_json_string(cleaned) if not cleaned: return [] for candidate in (cleaned, extract_balanced_json_array(cleaned)): if not candidate: continue try: data = json.loads(candidate) if isinstance(data, list): return data if isinstance(data, dict): return [data] except json.JSONDecodeError: pass try: data = ast.literal_eval(candidate) if isinstance(data, list): return data if isinstance(data, dict): return [data] except Exception: pass return [] def _fix_malformed_bbox_in_json_string(json_string): """ Fixes malformed bounding box values in a JSON string before parsing. Handles cases like: - "bb": "779, 767, 874, 789], "text" (missing opening bracket, missing closing quote) - "bb": "[779, 767, 874, 789]" (stringified array) - "bb": "779, 767, 874, 789" (no brackets) Args: json_string: The raw JSON string that may contain malformed bbox values Returns: str: The JSON string with malformed bbox values fixed """ import re # Pattern 1: Match malformed bbox like: "bb": "779, 767, 874, 789], "text" # The issue: missing opening bracket, missing closing quote after the bracket # Matches: "bb": " followed by numbers, ], then , " pattern1 = ( r'("(?:bb|bbox|bbox_2d)"\s*:\s*)"(\d+\s*,\s*\d+\s*,\s*\d+\s*,\s*\d+)\]\s*,\s*"' ) def fix_bbox_match1(match): key_part = match.group(1) # "bb": " bbox_str = match.group(2) # "779, 767, 874, 789" # Format as proper JSON array (no quotes around it) fixed_bbox = "[" + bbox_str.strip() + "]" # Return the fixed version: "bb": [779, 767, 874, 789], " return key_part + fixed_bbox + ', "' # Pattern 2: Match malformed bbox like: "bb": "779, 767, 874, 789]" # Missing opening bracket, but has closing quote pattern2 = r'("(?:bb|bbox|bbox_2d)"\s*:\s*)"(\d+\s*,\s*\d+\s*,\s*\d+\s*,\s*\d+)\]"' def fix_bbox_match2(match): key_part = match.group(1) bbox_str = match.group(2) fixed_bbox = "[" + bbox_str.strip() + "]" return key_part + fixed_bbox + '"' # Pattern 3: Match malformed bbox like: "bb": "779, 767, 874, 789] (end of object, no quote) pattern3 = ( r'("(?:bb|bbox|bbox_2d)"\s*:\s*)"(\d+\s*,\s*\d+\s*,\s*\d+\s*,\s*\d+)\]\s*\}' ) def fix_bbox_match3(match): key_part = match.group(1) bbox_str = match.group(2) fixed_bbox = "[" + bbox_str.strip() + "]" return key_part + fixed_bbox + "}" # Apply the fixes in order fixed_json = re.sub(pattern1, fix_bbox_match1, json_string) fixed_json = re.sub(pattern2, fix_bbox_match2, fixed_json) fixed_json = re.sub(pattern3, fix_bbox_match3, fixed_json) return fixed_json def _repair_vlm_json_stray_coordinate_strings( json_string: str, default_text: str = "[UNKNOWN]", default_conf: float = 0.9, ) -> str: """ Fix invalid JSON where the VLM repeats bbox coords as a lone quoted string instead of \"text\" / \"conf\", e.g.: {\"bbox\": [870, 290, 913, 316], \"870, 290, 913, 316\"} """ if not json_string or not default_text: return json_string text_lit = json.dumps(str(default_text)) conf_lit = json.dumps(float(default_conf)) # After a closing `]` (end of bbox array), comma, quoted string of four ints only pattern = re.compile(r'(\])\s*,\s*"(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)"') def _repl(m): return f'{m.group(1)}, "text": {text_lit}, "conf": {conf_lit}' prev = None s = json_string while prev != s: prev = s s = pattern.sub(_repl, s) return s def _repair_vlm_json_missing_text_key_after_bbox(json_string: str) -> str: """ Fix invalid JSON where the VLM omits the \"text\" key before the OCR string. Broken (not valid JSON — bare string where a key:value pair is required): {\"bbox\": [42, 70, 223, 115], \"Method\", \"conf\": 0.882} Fixed: {\"bbox\": [42, 70, 223, 115], \"text\": \"Method\", \"conf\": 0.882} Common Bedrock/VLM mistake after a correct first line or two. """ if not json_string or not isinstance(json_string, str): return json_string # After closing bbox array `]`, comma, a quoted string, comma, then "conf": # Insert "text": before that string (only when the string is not already "text"). pattern = re.compile( r'\]\s*,\s*"((?:[^"\\]|\\.)*)"\s*,\s*"conf"\s*:', ) def _repl(m) -> str: inner = m.group(1) text_lit = json.dumps(inner) return f'], "text": {text_lit}, "conf":' prev = None s = json_string while prev != s: prev = s s = pattern.sub(_repl, s) return s def _preprocess_vlm_ocr_json_string( raw: Optional[str], implied_label: Optional[str] = None, ) -> str: """Chain bbox fixes and stray-coordinate repair before json.loads.""" if not raw or not isinstance(raw, str): return "" s = strip_vlm_thinking_tags(raw.strip()) s = _fix_malformed_bbox_in_json_string(s) label = implied_label if implied_label else "[UNKNOWN]" s = _repair_vlm_json_stray_coordinate_strings(s, default_text=label) s = _repair_vlm_json_missing_text_key_after_bbox(s) s = _repair_vlm_json_common_quote_issues(s) return s CUSTOM_VLM_CANONICAL_LABELS = frozenset({"[FACE]", "[SIGNATURE]"}) def _get_vlm_item_conf_field(item: dict): """Best-effort confidence field from common or fuzzy VLM keys.""" if not isinstance(item, dict): return None for k in ("confidence", "conf", "confidence_level", "confidence_score"): if item.get(k) is not None: return item.get(k) # Fuzzy match: any key containing "conf" (covers confidence_level, conf_score, etc.) for key, val in item.items(): if val is None: continue lk = str(key).lower() if "conf" in lk: return val return None def _extract_vlm_line_text(item: dict) -> str: """ Best-effort string for general OCR line items when the model uses alternate keys. Order avoids grabbing non-OCR fields like 'label' used for classes. """ for key in ("text", "text_content", "content", "transcription"): val = item.get(key) if val is None: continue if isinstance(val, str): s = val.strip() if s: return s else: s = str(val).strip() if s: return s # Fuzzy match: accept any key that contains "text", but avoid obvious non-text fields. for key, val in item.items(): if val is None: continue lk = str(key).lower() if "text" not in lk: continue if lk in ("context", "texture", "text_direction"): continue if "label" in lk: continue if isinstance(val, str): s = val.strip() if s: return s else: s = str(val).strip() if s: return s return "" def _get_vlm_item_bbox_field(item: dict): """Raw bbox value from common VLM keys (may be list or malformed string).""" if not isinstance(item, dict): return None if item.get("bbox_2d") is not None: return item.get("bbox_2d") if item.get("bbox") is not None: return item.get("bbox") if item.get("bb") is not None: return item.get("bb") # Fuzzy match: any key containing bbox/bounding_box. for key, val in item.items(): if val is None: continue lk = str(key).lower() if "bbox" in lk or "bounding_box" in lk or "boundingbox" in lk: return val return None def _normalize_single_line_text_dict(obj: Dict[str, Any]) -> Optional[Dict[str, Any]]: """ Map content/conf aliases to text/confidence for hybrid / single-line VLM dicts. """ if not isinstance(obj, dict): return None text = obj.get("text") if text is None or (isinstance(text, str) and not text.strip()): for alt in ("content", "transcription", "text_content"): v = obj.get(alt) if v is not None: if isinstance(v, str) and v.strip(): text = v.strip() break if not isinstance(v, str) and str(v).strip(): text = str(v).strip() break if text is None: return None if not isinstance(text, str): text = str(text) conf = _get_vlm_item_conf_field(obj) out = {"text": text} if conf is not None: out["confidence"] = conf return out def _fix_malformed_bbox(bbox): """ Attempts to fix malformed bounding box values. Handles cases where bbox is: - A string like "779, 767, 874, 789]" (missing opening bracket) - A string like "[779, 767, 874, 789]" (should be parsed) - A string like "779, 767, 874, 789" (no brackets at all) - Already a valid list (returns as-is) Args: bbox: The bounding box value (could be list, string, or other) Returns: list: A list of 4 numbers [x1, y1, x2, y2], or None if parsing fails """ # If it's already a valid list, return it if isinstance(bbox, list) and len(bbox) == 4: return bbox # If it's not a string, we can't fix it if not isinstance(bbox, str): return None try: # Remove any leading/trailing whitespace bbox_str = bbox.strip() # Remove quotes if present if bbox_str.startswith('"') and bbox_str.endswith('"'): bbox_str = bbox_str[1:-1] elif bbox_str.startswith("'") and bbox_str.endswith("'"): bbox_str = bbox_str[1:-1] # Try to extract numbers from various formats # Pattern 1: "779, 767, 874, 789]" (missing opening bracket) # Pattern 2: "[779, 767, 874, 789]" (has brackets) # Pattern 3: "779, 767, 874, 789" (no brackets) # Remove brackets if present if bbox_str.startswith("["): bbox_str = bbox_str[1:] if bbox_str.endswith("]"): bbox_str = bbox_str[:-1] # Split by comma and extract numbers parts = [part.strip() for part in bbox_str.split(",")] if len(parts) != 4: return None # Convert each part to float coords = [] for part in parts: try: coords.append(float(part)) except (ValueError, TypeError): return None return coords except Exception: return None def _parse_vlm_line_item_to_geometry( line_item: dict, implied_label: Optional[str], warn_prefix: str, ) -> Optional[Tuple[str, List[float], float]]: """ Parse one VLM JSON line object into text, xyxy floats, and raw confidence. For person/signature passes (implied_label in CUSTOM_VLM_CANONICAL_LABELS), text is always the canonical label when the bbox is valid; model text keys are ignored. """ if not isinstance(line_item, dict): return None canon = None if implied_label and str(implied_label).strip() in CUSTOM_VLM_CANONICAL_LABELS: canon = str(implied_label).strip() raw_bbox = _get_vlm_item_bbox_field(line_item) if raw_bbox is None: raw_bbox = [] fixed_bbox = _fix_malformed_bbox(raw_bbox) if fixed_bbox is not None: bbox = fixed_bbox if not isinstance(raw_bbox, list) or len(raw_bbox) != 4: dbg_txt = canon or _extract_vlm_line_text(line_item) or "?" print( f"{warn_prefix}: Fixed malformed bbox for line '{dbg_txt[:50]}': " f"{raw_bbox} -> {fixed_bbox}" ) elif isinstance(raw_bbox, list) and len(raw_bbox) == 4: bbox = raw_bbox else: dbg_txt = canon or _extract_vlm_line_text(line_item) or "?" print( f"{warn_prefix} warning: Invalid bbox format for line '{dbg_txt[:50]}': {raw_bbox}" ) return None try: x1 = float(bbox[0]) y1 = float(bbox[1]) x2 = float(bbox[2]) y2 = float(bbox[3]) except (ValueError, TypeError): dbg_txt = canon or _extract_vlm_line_text(line_item) or "?" print( f"{warn_prefix} warning: Invalid bbox coordinates for line '{dbg_txt[:50]}': {bbox}" ) return None if x2 <= x1 or y2 <= y1: dbg_txt = canon or _extract_vlm_line_text(line_item) or "?" print( f"{warn_prefix} warning: Invalid bbox dimensions for line '{dbg_txt[:50]}': {bbox}" ) return None if canon: text = canon conf_raw = _get_vlm_item_conf_field(line_item) if conf_raw is None: conf_raw = 0.9 else: text = _extract_vlm_line_text(line_item) if not text: return None conf_raw = _get_vlm_item_conf_field(line_item) if conf_raw is None: conf_raw = 100 try: confidence = float(conf_raw) except (TypeError, ValueError): confidence = 0.9 if canon else 100.0 return (text, [x1, y1, x2, y2], confidence) def _vlm_page_ocr_predict( image: Image.Image, image_name: str = "vlm_page_ocr_input_image.png", normalised_coords_range: Optional[int] = 999, output_folder: str = OUTPUT_FOLDER, detect_people_only: bool = False, detect_signatures_only: bool = False, progress: Optional[gr.Progress] = gr.Progress(), page_index_0: Optional[int] = None, ) -> Tuple[Dict[str, List], int, int, str]: """ VLM page-level OCR prediction that returns structured line-level results with bounding boxes. Args: image: PIL Image to process (full page) image_name: Name of the image for debugging normalised_coords_range: If set, bounding boxes are assumed to be in normalized coordinates from 0 to this value (e.g., 999 as used in the full-page VLM prompt). Coordinates will be rescaled to match the processed image size. If None, coordinates are assumed to be in absolute pixel coordinates. output_folder: The folder where output images will be saved Returns: Dictionary with 'text', 'left', 'top', 'width', 'height', 'conf', 'model' keys matching the format expected by perform_ocr """ try: # Validate image exists and is not None if image is None: print("VLM page OCR error: Image is None") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } # Validate image has valid size (at least 10x10 pixels) try: width, height = image.size if width < 10 or height < 10: print( f"VLM page OCR error: Image is too small ({width}x{height} pixels). Minimum size is 10x10." ) return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } except Exception as size_error: print(f"VLM page OCR error: Could not get image size: {size_error}") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } # Ensure image is in RGB mode (convert if needed) try: if image.mode != "RGB": image = image.convert("RGB") width, height = image.size except Exception as convert_error: print( f"VLM page OCR error: Could not convert image to RGB: {convert_error}" ) return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } # Check and resize image if it exceeds maximum size or DPI limits scale_x = 1.0 scale_y = 1.0 try: original_width, original_height = image.size processed_image = _prepare_image_for_vlm(image) # Pad so aspect ratio <= VLM_MAX_ASPECT_RATIO for hybrid/long pages processed_image = _pad_image_for_vlm_aspect_ratio(processed_image) processed_width, processed_height = processed_image.size # Use float division to avoid rounding errors scale_x = ( float(original_width) / float(processed_width) if processed_width > 0 else 1.0 ) scale_y = ( float(original_height) / float(processed_height) if processed_height > 0 else 1.0 ) # Debug: print scale factors to verify if scale_x != 1.0 or scale_y != 1.0: print(f"Scale factors: x={scale_x:.6f}, y={scale_y:.6f}") print( f"Original: {original_width}x{original_height}, Processed: {processed_width}x{processed_height}" ) except Exception as prep_error: print(f"VLM page OCR error: Could not prepare image for VLM: {prep_error}") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } # Save input image for debugging if environment variable is set if SAVE_VLM_INPUT_IMAGES: try: vlm_debug_dir = os.path.join( output_folder, "vlm_visualisations/vlm_input_images", ) os.makedirs(vlm_debug_dir, exist_ok=True) # Increment the number at the end of image_name before .png # This converts zero-indexed input to one-indexed output incremented_image_name = image_name if image_name.endswith(".png"): # Find the number pattern at the end before .png # Matches patterns like: _0.png, _00.png, 0.png, 00.png, etc. pattern = r"(\d+)(\.png)$" match = re.search(pattern, image_name) if match: number_str = match.group(1) number = int(number_str) incremented_number = number + 1 # Preserve the same number of digits (padding with zeros if needed) incremented_str = str(incremented_number).zfill(len(number_str)) incremented_image_name = re.sub( pattern, lambda m: incremented_str + m.group(2), image_name ) image_name_safe = safe_sanitize_text(incremented_image_name) image_name_shortened = image_name_safe[:50] filename = f"{image_name_shortened}_vlm_page_input_image.png" filepath = os.path.join(vlm_debug_dir, filename) _save_image_with_config_dpi(processed_image, filepath) # print(f"Saved VLM input image to: {filepath}") except Exception as save_error: print(f"Warning: Could not save VLM input image: {save_error}") # Create prompt that requests structured JSON output with bounding boxes if detect_people_only: # progress(0.5, "Detecting faces on page...") # print("Detecting faces on page...") prompt = full_page_ocr_people_vlm_prompt task_type = "face" elif detect_signatures_only: # progress(0.5, "Detecting signatures on page...") # print("Detecting signatures on page...") prompt = full_page_ocr_signature_vlm_prompt task_type = "signature" else: prompt = full_page_ocr_vlm_prompt task_type = "ocr" # Use the VLM to extract structured text # Pass explicit model_default_* values for consistency with _inference_server_page_ocr_predict extracted_text, vlm_input_tokens, vlm_output_tokens = ( extract_text_from_image_vlm( text=prompt, image=processed_image, max_new_tokens=model_default_max_new_tokens, temperature=model_default_temperature, top_p=model_default_top_p, min_p=model_default_min_p, top_k=model_default_top_k, repetition_penalty=model_default_repetition_penalty, presence_penalty=model_default_presence_penalty, seed=model_default_seed, do_sample=model_default_do_sample, ) ) # Save prompt and response to file if extracted_text and isinstance(extracted_text, str) and output_folder: try: # Determine task suffix based on detection type task_suffix = None if detect_people_only: task_suffix = "face" elif detect_signatures_only: task_suffix = "sig" # Get model name for logging vlm_model_name = ( SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL if SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL else "VLM" ) saved_file = save_vlm_prompt_response( prompt=prompt, response_text=extracted_text, output_folder=output_folder, model_choice=vlm_model_name, image_name=image_name, page_number=page_index_0, temperature=model_default_temperature, max_new_tokens=model_default_max_new_tokens, top_p=model_default_top_p, model_type="VLM", task_suffix=task_suffix, input_tokens=vlm_input_tokens, output_tokens=vlm_output_tokens, image_width=processed_image.size[0], image_height=processed_image.size[1], ) print(f"Saved VLM prompt/response to: {saved_file}") except Exception as save_error: print(f"Warning: Could not save VLM prompt/response: {save_error}") # Check if extracted_text is None or empty if extracted_text is None or not isinstance(extracted_text, str): print( "VLM page OCR warning: extract_text_from_image_vlm returned None or invalid type" ) return ( { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], }, 0, 0, ( SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL if SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL else "VLM" ), ) # Try to parse JSON from the response # The VLM might return JSON wrapped in markdown code blocks or with extra text extracted_text = extracted_text.strip() # Fix malformed bounding box values in the JSON string before parsing # This handles cases like: "bb": "779, 767, 874, 789], _inf_implied = None if detect_people_only: _inf_implied = "[FACE]" elif detect_signatures_only: _inf_implied = "[SIGNATURE]" extracted_text = _preprocess_vlm_ocr_json_string( extracted_text, implied_label=_inf_implied ) lines_data = None # First, try to parse the entire response as JSON try: lines_data = json.loads(extracted_text) except json.JSONDecodeError: pass # If that fails, try to extract JSON from markdown code blocks if lines_data is None: json_match = re.search( r"```(?:json)?\s*(\[.*?\])", extracted_text, re.DOTALL ) if json_match: try: lines_data = json.loads(json_match.group(1)) except json.JSONDecodeError: pass # If that fails, try to find JSON array in the text (more lenient) if lines_data is None: # Try to find array starting with [ and ending with ] # This is a simple approach - look for balanced brackets start_idx = extracted_text.find("[") if start_idx >= 0: bracket_count = 0 end_idx = start_idx for i in range(start_idx, len(extracted_text)): if extracted_text[i] == "[": bracket_count += 1 elif extracted_text[i] == "]": bracket_count -= 1 if bracket_count == 0: end_idx = i break if end_idx > start_idx: try: lines_data = json.loads(extracted_text[start_idx : end_idx + 1]) except json.JSONDecodeError: pass # If that fails, try parsing multiple JSON arrays (may span multiple lines) # This handles cases where the response has multiple JSON arrays separated by newlines # Each array might be on a single line or span multiple lines if lines_data is None: try: combined_data = [] # Find all JSON arrays in the text (they may span multiple lines) # This approach handles both single-line and multi-line arrays text = extracted_text while True: start_idx = text.find("[") if start_idx < 0: break # Find the matching closing bracket bracket_count = 0 end_idx = start_idx for i in range(start_idx, len(text)): if text[i] == "[": bracket_count += 1 elif text[i] == "]": bracket_count -= 1 if bracket_count == 0: end_idx = i break if end_idx > start_idx: try: array_str = text[start_idx : end_idx + 1] array_data = json.loads(array_str) if isinstance(array_data, list): combined_data.extend(array_data) except json.JSONDecodeError: pass # Move past this array to find the next one text = text[end_idx + 1 :] if combined_data: lines_data = combined_data except Exception: pass # If that fails, try to interpret the response as a Python literal (handles single-quoted lists/dicts) if lines_data is None: try: python_data = ast.literal_eval(extracted_text) if isinstance(python_data, list): lines_data = python_data except Exception: pass # Final attempt: try to parse as-is if lines_data is None: try: lines_data = json.loads(extracted_text) except json.JSONDecodeError: pass # If we still couldn't parse JSON, return empty results if lines_data is None: print("VLM page OCR error: Could not parse JSON response") print( f"Response text: {extracted_text[:500]}" ) # Print first 500 chars for debugging return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } if isinstance(lines_data, dict): lines_data = [lines_data] elif not isinstance(lines_data, list): print(f"VLM page OCR error: Expected list, got {type(lines_data)}") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } if SAVE_VLM_INPUT_IMAGES: try: plot_text_bounding_boxes( processed_image, extracted_text, image_name=image_name, image_folder="vlm_visualisations", output_folder=output_folder, task_type=task_type, ) except Exception as viz_error: print(f"Warning: VLM bbox visualization failed: {viz_error}") # Store a copy of the processed image for debug visualization (before rescaling) # IMPORTANT: This must be the EXACT same image that was sent to the API processed_image_for_debug = ( processed_image.copy() if SAVE_VLM_INPUT_IMAGES else None ) # Collect all valid bounding boxes before rescaling for debug visualization pre_scaled_boxes = [] # Convert VLM results to expected format result = { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } for line_item in lines_data: parsed = _parse_vlm_line_item_to_geometry( line_item, _inf_implied, "VLM page OCR", ) if parsed is None: continue text, bbox_xyxy, confidence = parsed x1, y1, x2, y2 = bbox_xyxy # If coordinates are normalized (0 to normalised_coords_range), rescale directly to processed image dimensions # This matches the ocr.ipynb approach: direct normalization to image size using /999 * dimension # ocr.ipynb uses: abs_x1 = int(bounding_box["bbox_2d"][0]/999 * width) # abs_y1 = int(bounding_box["bbox_2d"][1]/999 * height) if normalised_coords_range is not None and normalised_coords_range > 0: # Direct normalization: match ocr.ipynb approach exactly # Formula: (coord / normalised_coords_range) * image_dimension # Note: ocr.ipynb uses 999, but we allow configurable range x1 = (x1 / float(normalised_coords_range)) * processed_width y1 = (y1 / float(normalised_coords_range)) * processed_height x2 = (x2 / float(normalised_coords_range)) * processed_width y2 = (y2 / float(normalised_coords_range)) * processed_height # Store bounding box after normalization (if applied) but before rescaling to original image space if processed_image_for_debug is not None: pre_scaled_boxes.append({"bbox": (x1, y1, x2, y2), "text": text}) # Step 3: Scale coordinates back to original image space if image was resized if scale_x != 1.0 or scale_y != 1.0: x1 = x1 * scale_x y1 = y1 * scale_y x2 = x2 * scale_x y2 = y2 * scale_y # Convert from (x1, y1, x2, y2) to (left, top, width, height) left = int(round(x1)) top = int(round(y1)) width = int(round(x2 - x1)) height = int(round(y2 - y1)) # Ensure confidence is in valid range (0-100). VLM may return 0-1; scale to 0-100. try: confidence = float(confidence) if 0 <= confidence <= 1: confidence = confidence * 100 confidence = max(0, min(100, confidence)) # Clamp to 0-100 except (ValueError, TypeError): confidence = 100 # Default if invalid result["text"].append( clean_unicode_text(text, preserve_international_scripts=True) ) result["left"].append(left) result["top"].append(top) result["width"].append(width) result["height"].append(height) result["conf"].append(int(round(confidence))) result["model"].append("VLM") # Get model name for tracking vlm_model_name = ( SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL if SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL else "VLM" ) return result, vlm_input_tokens, vlm_output_tokens, vlm_model_name except Exception as e: print(f"VLM page OCR error: {e}") import traceback print(f"VLM page OCR error traceback: {traceback.format_exc()}") return ( { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], }, 0, 0, "VLM", ) def _inference_server_page_ocr_predict( image: Image.Image, image_name: str = "inference_server_page_ocr_input_image.png", normalised_coords_range: Optional[int] = 999, output_folder: str = OUTPUT_FOLDER, detect_people_only: bool = False, detect_signatures_only: bool = False, progress: Optional[gr.Progress] = gr.Progress(), model_name: str = None, page_index_0: Optional[int] = None, ) -> Tuple[Dict[str, List], int, int, str]: """ Inference-server page-level OCR prediction that returns structured line-level results with bounding boxes. Calls an external inference-server API instead of a local model. Args: image: PIL Image to process (full page) image_name: Name of the image for debugging normalised_coords_range: If set, bounding boxes are assumed to be in normalized coordinates from 0 to this value (e.g., 999 as used in the full-page VLM prompt). Coordinates will be rescaled to match the processed image size. If None, coordinates are assumed to be in absolute pixel coordinates. output_folder: The folder where output images will be saved Returns: Dictionary with 'text', 'left', 'top', 'width', 'height', 'conf', 'model' keys matching the format expected by perform_ocr """ try: def _empty_inference_server_page_result( resolved_name: Optional[str] = None, ) -> Tuple[Dict[str, List], int, int, str]: """Always return (ocr_dict, in_tokens, out_tokens, model_name) for perform_ocr.""" nm = resolved_name if nm is None or nm == "": nm = ( DEFAULT_INFERENCE_SERVER_VLM_MODEL if DEFAULT_INFERENCE_SERVER_VLM_MODEL else None ) if nm is None or nm == "": nm = ( INFERENCE_SERVER_MODEL_NAME if INFERENCE_SERVER_MODEL_NAME else None ) if nm is None or nm == "": nm = "Inference Server" return ( { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], }, 0, 0, nm, ) # Validate image exists and is not None if image is None: print("Inference-server page OCR error: Image is None") return _empty_inference_server_page_result() # Validate image has valid size (at least 10x10 pixels) try: width, height = image.size if width < 10 or height < 10: print( f"Inference-server page OCR error: Image is too small ({width}x{height} pixels). Minimum size is 10x10." ) return _empty_inference_server_page_result() except Exception as size_error: print( f"Inference-server page OCR error: Could not get image size: {size_error}" ) return _empty_inference_server_page_result() # Ensure image is in RGB mode (convert if needed) try: if image.mode != "RGB": image = image.convert("RGB") width, height = image.size except Exception as convert_error: print( f"Inference-server page OCR error: Could not convert image to RGB: {convert_error}" ) return _empty_inference_server_page_result() # Check and resize image if it exceeds maximum size or DPI limits scale_x = 1.0 scale_y = 1.0 # In _inference_server_page_ocr_predict, around line 1465-1471: try: original_width, original_height = image.size processed_image = _prepare_image_for_vlm(image) processed_width, processed_height = processed_image.size # Use float division to avoid rounding errors scale_x = ( float(original_width) / float(processed_width) if processed_width > 0 else 1.0 ) scale_y = ( float(original_height) / float(processed_height) if processed_height > 0 else 1.0 ) # Debug: print scale factors to verify if scale_x != 1.0 or scale_y != 1.0: print(f"Scale factors: x={scale_x:.6f}, y={scale_y:.6f}") print( f"Original: {original_width}x{original_height}, Processed: {processed_width}x{processed_height}" ) except Exception as prep_error: print( f"Inference-server page OCR error: Could not prepare image for VLM: {prep_error}" ) return _empty_inference_server_page_result() # Save input image for debugging if environment variable is set if SAVE_VLM_INPUT_IMAGES: try: vlm_debug_dir = os.path.join( output_folder, "inference_server_visualisations/vlm_input_images", ) os.makedirs(vlm_debug_dir, exist_ok=True) # Increment the number at the end of image_name before .png # This converts zero-indexed input to one-indexed output incremented_image_name = image_name if image_name.endswith(".png"): # Find the number pattern at the end before .png # Matches patterns like: _0.png, _00.png, 0.png, 00.png, etc. pattern = r"(\d+)(\.png)$" match = re.search(pattern, image_name) if match: number_str = match.group(1) number = int(number_str) incremented_number = number + 1 # Preserve the same number of digits (padding with zeros if needed) incremented_str = str(incremented_number).zfill(len(number_str)) incremented_image_name = re.sub( pattern, lambda m: incremented_str + m.group(2), image_name ) image_name_safe = safe_sanitize_text(incremented_image_name) image_name_shortened = image_name_safe[:50] filename = ( f"{image_name_shortened}_inference_server_page_input_image.png" ) filepath = os.path.join(vlm_debug_dir, filename) print(f"Saving inference-server input image to: {filename}") _save_image_with_config_dpi(processed_image, filepath) # print(f"Saved VLM input image to: {filepath}") except Exception as save_error: print(f"Warning: Could not save VLM input image: {save_error}") # Create prompt that requests structured JSON output with bounding boxes if detect_people_only: # progress(0.5, "Detecting faces on page...") # print("Detecting faces on page...") prompt = full_page_ocr_people_vlm_prompt task_type = "face" elif detect_signatures_only: # progress(0.5, "Detecting signatures on page...") # print("Detecting signatures on page...") prompt = full_page_ocr_signature_vlm_prompt task_type = "signature" else: prompt = full_page_ocr_vlm_prompt task_type = "ocr" # Use the inference-server API to extract structured text # Note: processed_width and processed_height were already captured on line 1921 # after _prepare_image_for_vlm, so we use those values for normalization # Determine model_name: use provided parameter, then DEFAULT_INFERENCE_SERVER_VLM_MODEL, then INFERENCE_SERVER_MODEL_NAME final_model_name = model_name if final_model_name is None or final_model_name == "": final_model_name = ( DEFAULT_INFERENCE_SERVER_VLM_MODEL if DEFAULT_INFERENCE_SERVER_VLM_MODEL else None ) if final_model_name is None or final_model_name == "": final_model_name = ( INFERENCE_SERVER_MODEL_NAME if INFERENCE_SERVER_MODEL_NAME else None ) ( extracted_text, vlm_input_tokens, vlm_output_tokens, vlm_sent_w, vlm_sent_h, ) = _call_inference_server_vlm_api( image=processed_image, prompt=prompt, model_name=final_model_name, max_new_tokens=model_default_max_new_tokens, temperature=None, top_p=None, top_k=None, repetition_penalty=None, seed=None, do_sample=model_default_do_sample, min_p=None, presence_penalty=None, use_llama_swap=USE_LLAMA_SWAP, ) # Save prompt and response to file if extracted_text and isinstance(extracted_text, str) and output_folder: try: # Determine task suffix based on detection type task_suffix = None if detect_people_only: task_suffix = "face" elif detect_signatures_only: task_suffix = "sig" saved_file = save_vlm_prompt_response( prompt=prompt, response_text=extracted_text, output_folder=output_folder, model_choice=final_model_name or "unknown", image_name=image_name, page_number=page_index_0, temperature=model_default_temperature, max_new_tokens=model_default_max_new_tokens, top_p=model_default_top_p, model_type="Inference Server", task_suffix=task_suffix, input_tokens=vlm_input_tokens, output_tokens=vlm_output_tokens, image_width=vlm_sent_w, image_height=vlm_sent_h, ) print(f"Saved inference-server VLM prompt/response to: {saved_file}") except Exception as save_error: print( f"Warning: Could not save inference-server VLM prompt/response: {save_error}" ) # Check if extracted_text is None or empty if extracted_text is None or not isinstance(extracted_text, str): print( "Inference-server page OCR warning: API returned None or invalid type" ) return ( { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], }, 0, 0, final_model_name or "Inference Server", ) # Try to parse JSON from the response # The API might return JSON wrapped in markdown code blocks or with extra text extracted_text = extracted_text.strip() # Fix malformed bounding box values in the JSON string before parsing # This handles cases like: "bb": "779, 767, 874, 789], _inf_server_implied = None if detect_people_only: _inf_server_implied = "[FACE]" elif detect_signatures_only: _inf_server_implied = "[SIGNATURE]" extracted_text = _preprocess_vlm_ocr_json_string( extracted_text, implied_label=_inf_server_implied ) lines_data = None # First, try to parse the entire response as JSON try: lines_data = json.loads(extracted_text) except json.JSONDecodeError: pass # If that fails, try to extract JSON from markdown code blocks if lines_data is None: json_match = re.search( r"```(?:json)?\s*(\[.*?\])", extracted_text, re.DOTALL ) if json_match: try: lines_data = json.loads(json_match.group(1)) except json.JSONDecodeError: pass # If that fails, try to find JSON array in the text (more lenient) if lines_data is None: # Try to find array starting with [ and ending with ] start_idx = extracted_text.find("[") if start_idx >= 0: bracket_count = 0 end_idx = start_idx for i in range(start_idx, len(extracted_text)): if extracted_text[i] == "[": bracket_count += 1 elif extracted_text[i] == "]": bracket_count -= 1 if bracket_count == 0: end_idx = i break if end_idx > start_idx: try: lines_data = json.loads(extracted_text[start_idx : end_idx + 1]) except json.JSONDecodeError: pass # If that fails, try parsing multiple JSON arrays (may span multiple lines) # This handles cases where the response has multiple JSON arrays separated by newlines # Each array might be on a single line or span multiple lines if lines_data is None: try: combined_data = [] # Find all JSON arrays in the text (they may span multiple lines) # This approach handles both single-line and multi-line arrays text = extracted_text while True: start_idx = text.find("[") if start_idx < 0: break # Find the matching closing bracket bracket_count = 0 end_idx = start_idx for i in range(start_idx, len(text)): if text[i] == "[": bracket_count += 1 elif text[i] == "]": bracket_count -= 1 if bracket_count == 0: end_idx = i break if end_idx > start_idx: try: array_str = text[start_idx : end_idx + 1] array_data = json.loads(array_str) if isinstance(array_data, list): combined_data.extend(array_data) except json.JSONDecodeError: pass # Move past this array to find the next one text = text[end_idx + 1 :] if combined_data: lines_data = combined_data except Exception: pass # If that fails, try to interpret the response as a Python literal (handles single-quoted lists/dicts) if lines_data is None: try: python_data = ast.literal_eval(extracted_text) if isinstance(python_data, list): lines_data = python_data except Exception: pass # Final attempt: try to parse as-is if lines_data is None: try: lines_data = json.loads(extracted_text) except json.JSONDecodeError: pass # If we still couldn't parse JSON, return empty results if lines_data is None: print("Inference-server page OCR error: Could not parse JSON response") print( f"Response text: {extracted_text[:500]}" ) # Print first 500 chars for debugging return _empty_inference_server_page_result(final_model_name) if isinstance(lines_data, dict): lines_data = [lines_data] elif not isinstance(lines_data, list): print( f"Inference-server page OCR error: Expected list, got {type(lines_data)}" ) return _empty_inference_server_page_result(final_model_name) if SAVE_VLM_INPUT_IMAGES: plot_text_bounding_boxes( processed_image, extracted_text, image_name=image_name, image_folder="inference_server_visualisations", output_folder=output_folder, task_type=task_type, ) # Store a copy of the processed image for debug visualization (before rescaling) # IMPORTANT: This must be the EXACT same image that was sent to the API processed_image_for_debug = ( processed_image.copy() if SAVE_VLM_INPUT_IMAGES else None ) # Collect all valid bounding boxes before rescaling for debug visualization pre_scaled_boxes = [] # Convert API results to expected format result = { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } for line_item in lines_data: parsed = _parse_vlm_line_item_to_geometry( line_item, _inf_server_implied, "Inference-server page OCR", ) if parsed is None: continue text, bbox_xyxy, confidence = parsed x1, y1, x2, y2 = bbox_xyxy # If coordinates are normalized (0 to normalised_coords_range), rescale directly to processed image dimensions if normalised_coords_range is not None and normalised_coords_range > 0: # Formula: (coord / normalised_coords_range) * image_dimension (e.g. 999 from full-page VLM prompt) x1 = (x1 / float(normalised_coords_range)) * processed_width y1 = (y1 / float(normalised_coords_range)) * processed_height x2 = (x2 / float(normalised_coords_range)) * processed_width y2 = (y2 / float(normalised_coords_range)) * processed_height # Store bounding box after normalization (if applied) but before rescaling to original image space if processed_image_for_debug is not None: pre_scaled_boxes.append({"bbox": (x1, y1, x2, y2), "text": text}) # Step 3: Scale coordinates back to original image space if image was resized if scale_x != 1.0 or scale_y != 1.0: x1 = x1 * scale_x y1 = y1 * scale_y x2 = x2 * scale_x y2 = y2 * scale_y # Convert from (x1, y1, x2, y2) to (left, top, width, height) left = int(round(x1)) top = int(round(y1)) width = int(round(x2 - x1)) height = int(round(y2 - y1)) # Ensure confidence is in valid range (0-100). VLM may return 0-1; scale to 0-100. try: confidence = float(confidence) if 0 <= confidence <= 1: confidence = confidence * 100 confidence = max(0, min(100, confidence)) # Clamp to 0-100 except (ValueError, TypeError): confidence = 50 # Default if invalid result["text"].append( clean_unicode_text(text, preserve_international_scripts=True) ) result["left"].append(left) result["top"].append(top) result["width"].append(width) result["height"].append(height) result["conf"].append(int(round(confidence))) result["model"].append("Inference Server") # Get model name for tracking vlm_model_name = final_model_name or "Inference Server" return result, vlm_input_tokens, vlm_output_tokens, vlm_model_name except Exception as e: print(f"Inference-server page OCR error: {e}") import traceback print(f"Inference-server page OCR error traceback: {traceback.format_exc()}") # Determine model name for error case error_model_name = ( model_name or DEFAULT_INFERENCE_SERVER_VLM_MODEL or INFERENCE_SERVER_MODEL_NAME or "Inference Server" ) return ( { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], }, 0, 0, error_model_name, ) def _parse_vlm_page_ocr_response( extracted_text: str, processed_image: Image.Image, processed_width: int, processed_height: int, scale_x: float, scale_y: float, normalised_coords_range: Optional[int], model_name: str = "Cloud VLM", implied_label: Optional[str] = None, ) -> Dict[str, List]: """ Helper function to parse VLM page OCR response and convert to expected format. Shared by all cloud VLM page OCR functions. Args: extracted_text: Raw text response from VLM processed_image: The processed image that was sent to the VLM processed_width: Width of processed image processed_height: Height of processed image scale_x: Scale factor for x coordinates (original/processed) scale_y: Scale factor for y coordinates (original/processed) normalised_coords_range: If set, bounding boxes are in normalized coordinates (0 to this value) model_name: Name of the model for the 'model' field in results implied_label: When set (e.g. \"[FACE]\" for face pass), used to repair malformed JSON where the model omits \"text\", and to fill missing text on dict entries that only have bbox. Returns: Dictionary with 'text', 'left', 'top', 'width', 'height', 'conf', 'model' keys """ # Use actual image dimensions to ensure consistency (in case image was modified) actual_width, actual_height = processed_image.size if actual_width != processed_width or actual_height != processed_height: print( f"{model_name} page OCR warning: Image dimensions mismatch. " f"Expected {processed_width}x{processed_height}, got {actual_width}x{actual_height}. " f"Using actual dimensions." ) processed_width = actual_width processed_height = actual_height extracted_text = _preprocess_vlm_ocr_json_string( extracted_text, implied_label=implied_label ) lines_data = None # Try various JSON parsing strategies (same as _vlm_page_ocr_predict) try: lines_data = json.loads(extracted_text) except json.JSONDecodeError: pass if lines_data is None: json_match = re.search(r"```(?:json)?\s*(\[.*?\])", extracted_text, re.DOTALL) if json_match: try: lines_data = json.loads(json_match.group(1)) except json.JSONDecodeError: pass if lines_data is None: start_idx = extracted_text.find("[") if start_idx >= 0: bracket_count = 0 end_idx = start_idx for i in range(start_idx, len(extracted_text)): if extracted_text[i] == "[": bracket_count += 1 elif extracted_text[i] == "]": bracket_count -= 1 if bracket_count == 0: end_idx = i break if end_idx > start_idx: try: lines_data = json.loads(extracted_text[start_idx : end_idx + 1]) except json.JSONDecodeError: pass if lines_data is None: try: python_data = ast.literal_eval(extracted_text) if isinstance(python_data, list): lines_data = python_data elif isinstance(python_data, dict): lines_data = python_data except Exception: pass if lines_data is None: print(f"{model_name} page OCR error: Could not parse JSON response") print(f"Response text: {extracted_text[:500]}") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } if isinstance(lines_data, dict): lines_data = [lines_data] elif not isinstance(lines_data, list): print(f"{model_name} page OCR error: Expected list, got {type(lines_data)}") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } # Convert VLM results to expected format result = { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } for line_item in lines_data: parsed = _parse_vlm_line_item_to_geometry( line_item, implied_label, f"{model_name} page OCR", ) if parsed is None: continue text, bbox_xyxy, confidence = parsed x1, y1, x2, y2 = bbox_xyxy if normalised_coords_range is not None and normalised_coords_range > 0: x1 = (x1 / float(normalised_coords_range)) * processed_width y1 = (y1 / float(normalised_coords_range)) * processed_height x2 = (x2 / float(normalised_coords_range)) * processed_width y2 = (y2 / float(normalised_coords_range)) * processed_height if scale_x != 1.0 or scale_y != 1.0: x1 = x1 * scale_x y1 = y1 * scale_y x2 = x2 * scale_x y2 = y2 * scale_y left = int(round(x1)) top = int(round(y1)) width = int(round(x2 - x1)) height = int(round(y2 - y1)) try: confidence = float(confidence) if 0 <= confidence <= 1: confidence = confidence * 100 confidence = max(0, min(100, confidence)) except (ValueError, TypeError): confidence = 100 result["text"].append( clean_unicode_text(text, preserve_international_scripts=True) ) result["left"].append(left) result["top"].append(top) result["width"].append(width) result["height"].append(height) result["conf"].append(int(round(confidence))) result["model"].append(model_name) return result def _bedrock_page_ocr_predict( image: Image.Image, image_name: str = "bedrock_page_ocr_input_image.png", normalised_coords_range: Optional[int] = None, output_folder: str = OUTPUT_FOLDER, detect_people_only: bool = False, detect_signatures_only: bool = False, progress: Optional[gr.Progress] = gr.Progress(), model_choice: str = None, bedrock_runtime=None, page_index_0: Optional[int] = None, ) -> Tuple[Dict[str, List], int, int, str]: """ Bedrock page-level OCR prediction that returns structured line-level results with bounding boxes. Args: image: PIL Image to process (full page) image_name: Name of the image for debugging normalised_coords_range: If set, bounding boxes are assumed to be in normalized coordinates from 0 to this value (e.g., 999 as used in the full-page VLM prompt). Coordinates will be rescaled to match the processed image size. If None, coordinates are assumed to be in absolute pixel coordinates. output_folder: The folder where output images will be saved detect_people_only: If True, only detect people in images detect_signatures_only: If True, only detect signatures in images progress: Gradio progress tracker model_choice: Bedrock model ID bedrock_runtime: boto3 Bedrock runtime client Returns: Dictionary with 'text', 'left', 'top', 'width', 'height', 'conf', 'model' keys """ try: if image is None: print("Bedrock page OCR error: Image is None") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } try: width, height = image.size if width < 10 or height < 10: print( f"Bedrock page OCR error: Image is too small ({width}x{height} pixels)." ) return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } except Exception as size_error: print(f"Bedrock page OCR error: Could not get image size: {size_error}") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } try: if image.mode != "RGB": image = image.convert("RGB") except Exception as convert_error: print( f"Bedrock page OCR error: Could not convert image to RGB: {convert_error}" ) return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } # Same preparation as other full-page VLMs: min/max pixels + DPI bounds; Bedrock # gets a higher max pixel budget via ocr_method. scale_x/scale_y map boxes to original. scale_x = 1.0 scale_y = 1.0 try: original_width, original_height = image.size processed_image = _prepare_image_for_vlm( image, ocr_method="AWS Bedrock VLM page OCR", hybrid_vlm=False, ) processed_width, processed_height = processed_image.size scale_x = ( float(original_width) / float(processed_width) if processed_width > 0 else 1.0 ) scale_y = ( float(original_height) / float(processed_height) if processed_height > 0 else 1.0 ) except Exception as prep_error: print(f"Bedrock page OCR error: Could not prepare image: {prep_error}") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } # Create prompt if detect_people_only: # progress(0.5, "Detecting faces on page...") # print("Detecting faces on page...") prompt = full_page_ocr_people_vlm_prompt elif detect_signatures_only: # progress(0.5, "Detecting signatures on page...") # print("Detecting signatures on page...") prompt = full_page_ocr_signature_vlm_prompt else: prompt = full_page_ocr_vlm_prompt # Save input image for debugging if environment variable is set if SAVE_VLM_INPUT_IMAGES: try: vlm_debug_dir = os.path.join( output_folder, "bedrock_visualisations/vlm_input_images", ) os.makedirs(vlm_debug_dir, exist_ok=True) # Increment the number at the end of image_name before .png # This converts zero-indexed input to one-indexed output incremented_image_name = image_name if image_name.endswith(".png"): # Find the number pattern at the end before .png # Matches patterns like: _0.png, _00.png, 0.png, 00.png, etc. pattern = r"(\d+)(\.png)$" match = re.search(pattern, image_name) if match: number_str = match.group(1) number = int(number_str) incremented_number = number + 1 # Preserve the same number of digits (padding with zeros if needed) incremented_str = str(incremented_number).zfill(len(number_str)) incremented_image_name = re.sub( pattern, lambda m: incremented_str + m.group(2), image_name ) image_name_safe = safe_sanitize_text(incremented_image_name) # Extract page number from image_name if present (e.g., "file_1.png" -> "1") # Look for patterns like "_1.png", "_01.png", "_page_1.png", etc. page_number = None page_patterns = [ r"_page_(\d+)\.png$", # _page_1.png r"_(\d+)\.png$", # _1.png, _01.png r"page_(\d+)\.png$", # page_1.png ] for pattern in page_patterns: match = re.search(pattern, incremented_image_name, re.IGNORECASE) if match: page_number = match.group(1) break # Use longer name limit to preserve page numbers, but still truncate if very long # Remove .png extension before truncating to preserve more of the name image_name_no_ext = image_name_safe.replace(".png", "").replace( ".PNG", "" ) if len(image_name_no_ext) > 100: image_name_shortened = image_name_no_ext[:100] else: image_name_shortened = image_name_no_ext # Construct filename with page number if found if page_number: filename = ( f"{image_name_shortened}_page_{page_number}_bedrock_input.png" ) else: filename = f"{image_name_shortened}_bedrock_page_input_image.png" filepath = os.path.join(vlm_debug_dir, filename) print(f"Saving Bedrock VLM input image to: {filename}") _save_image_with_config_dpi(processed_image, filepath) # print(f"Saved Bedrock VLM input image to: {filepath}") except Exception as save_error: print(f"Warning: Could not save Bedrock VLM input image: {save_error}") # Call Bedrock API ( extracted_text, vlm_input_tokens, vlm_output_tokens, vlm_sent_w, vlm_sent_h, ) = _call_bedrock_vlm_api( image=processed_image, prompt=prompt, model_choice=model_choice, bedrock_runtime=bedrock_runtime, max_new_tokens=model_default_max_new_tokens, temperature=model_default_temperature, top_p=model_default_top_p, ) # Save prompt and response to file (including when response is empty, e.g. no faces/signatures) if extracted_text is not None and output_folder: try: # Determine task suffix based on detection type task_suffix = None if detect_people_only: task_suffix = "face" elif detect_signatures_only: task_suffix = "sig" response_str = ( extracted_text if isinstance(extracted_text, str) else str(extracted_text or "") ) saved_file = save_vlm_prompt_response( prompt=prompt, response_text=response_str, output_folder=output_folder, model_choice=model_choice or "unknown", image_name=image_name, page_number=page_index_0, temperature=model_default_temperature, max_new_tokens=model_default_max_new_tokens, top_p=model_default_top_p, model_type="Bedrock", task_suffix=task_suffix, input_tokens=vlm_input_tokens, output_tokens=vlm_output_tokens, image_width=vlm_sent_w, image_height=vlm_sent_h, ) print(f"Saved Bedrock VLM prompt/response to: {saved_file}") except Exception as save_error: print( f"Warning: Could not save Bedrock VLM prompt/response: {save_error}" ) if extracted_text is None or not isinstance(extracted_text, str): print("Bedrock page OCR warning: No valid response") return ( { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], }, 0, 0, model_choice or "Bedrock", ) _bedrock_implied_label = None if detect_people_only: _bedrock_implied_label = "[FACE]" elif detect_signatures_only: _bedrock_implied_label = "[SIGNATURE]" # Plot bounding boxes from VLM response if enabled if SAVE_VLM_INPUT_IMAGES: try: # Determine task type based on prompt task_type = "ocr" if detect_people_only: task_type = "face" elif detect_signatures_only: task_type = "signature" _viz_json = _preprocess_vlm_ocr_json_string( extracted_text, implied_label=_bedrock_implied_label ) plot_text_bounding_boxes( processed_image, _viz_json, image_name=image_name, image_folder="bedrock_visualisations", output_folder=output_folder, task_type=task_type, ) except Exception as plot_error: print( f"Warning: Could not plot Bedrock VLM bounding boxes: {plot_error}" ) # Parse response using shared helper result = _parse_vlm_page_ocr_response( extracted_text=extracted_text, processed_image=processed_image, processed_width=processed_width, processed_height=processed_height, scale_x=scale_x, scale_y=scale_y, normalised_coords_range=normalised_coords_range, model_name="Bedrock", implied_label=_bedrock_implied_label, ) # Get model name for tracking vlm_model_name = model_choice or "Bedrock" return result, vlm_input_tokens, vlm_output_tokens, vlm_model_name except Exception as e: print(f"Bedrock page OCR error: {e}") import traceback print(f"Bedrock page OCR error traceback: {traceback.format_exc()}") return ( { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], }, 0, 0, model_choice or "Bedrock", ) def _gemini_page_ocr_predict( image: Image.Image, image_name: str = "gemini_page_ocr_input_image.png", normalised_coords_range: Optional[int] = None, output_folder: str = OUTPUT_FOLDER, detect_people_only: bool = False, detect_signatures_only: bool = False, progress: Optional[gr.Progress] = gr.Progress(), model_choice: str = None, client=None, config=None, page_index_0: Optional[int] = None, ) -> Tuple[Dict[str, List], int, int, str]: """ Gemini page-level OCR prediction that returns structured line-level results with bounding boxes. Args: image: PIL Image to process (full page) image_name: Name of the image for debugging normalised_coords_range: If set, bounding boxes are assumed to be in normalized coordinates output_folder: The folder where output images will be saved detect_people_only: If True, only detect people in images detect_signatures_only: If True, only detect signatures in images progress: Gradio progress tracker model_choice: Gemini model name client: Gemini ai.Client instance config: Gemini types.GenerateContentConfig instance Returns: Dictionary with 'text', 'left', 'top', 'width', 'height', 'conf', 'model' keys """ try: if image is None: print("Gemini page OCR error: Image is None") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } try: width, height = image.size if width < 10 or height < 10: print( f"Gemini page OCR error: Image is too small ({width}x{height} pixels)." ) return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } except Exception as size_error: print(f"Gemini page OCR error: Could not get image size: {size_error}") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } try: if image.mode != "RGB": image = image.convert("RGB") except Exception as convert_error: print( f"Gemini page OCR error: Could not convert image to RGB: {convert_error}" ) return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } scale_x = 1.0 scale_y = 1.0 try: original_width, original_height = image.size processed_image = _prepare_image_for_vlm(image) processed_width, processed_height = processed_image.size scale_x = ( float(original_width) / float(processed_width) if processed_width > 0 else 1.0 ) scale_y = ( float(original_height) / float(processed_height) if processed_height > 0 else 1.0 ) except Exception as prep_error: print(f"Gemini page OCR error: Could not prepare image: {prep_error}") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } # Create prompt if detect_people_only: # progress(0.5, "Detecting faces on page...") # print("Detecting faces on page...") prompt = full_page_ocr_people_vlm_prompt elif detect_signatures_only: # progress(0.5, "Detecting signatures on page...") # print("Detecting signatures on page...") prompt = full_page_ocr_signature_vlm_prompt else: prompt = full_page_ocr_vlm_prompt # Call Gemini API extracted_text, vlm_input_tokens, vlm_output_tokens = _call_gemini_vlm_api( image=processed_image, prompt=prompt, client=client, config=config, model_choice=model_choice, max_new_tokens=model_default_max_new_tokens, temperature=model_default_temperature, ) # Save prompt and response to file (including when response is empty, e.g. no faces/signatures) if extracted_text is not None and output_folder: try: # Determine task suffix based on detection type task_suffix = None if detect_people_only: task_suffix = "face" elif detect_signatures_only: task_suffix = "sig" response_str = ( extracted_text if isinstance(extracted_text, str) else str(extracted_text or "") ) saved_file = save_vlm_prompt_response( prompt=prompt, response_text=response_str, output_folder=output_folder, model_choice=model_choice or "unknown", image_name=image_name, page_number=page_index_0, temperature=model_default_temperature, max_new_tokens=model_default_max_new_tokens, model_type="Gemini", task_suffix=task_suffix, input_tokens=vlm_input_tokens, output_tokens=vlm_output_tokens, image_width=processed_image.size[0], image_height=processed_image.size[1], ) print(f"Saved Gemini VLM prompt/response to: {saved_file}") except Exception as save_error: print( f"Warning: Could not save Gemini VLM prompt/response: {save_error}" ) if extracted_text is None or not isinstance(extracted_text, str): print("Gemini page OCR warning: No valid response") return ( { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], }, 0, 0, model_choice or "Gemini", ) _gem_implied = None if detect_people_only: _gem_implied = "[FACE]" elif detect_signatures_only: _gem_implied = "[SIGNATURE]" # Parse response using shared helper result = _parse_vlm_page_ocr_response( extracted_text=extracted_text, processed_image=processed_image, processed_width=processed_width, processed_height=processed_height, scale_x=scale_x, scale_y=scale_y, normalised_coords_range=normalised_coords_range, model_name="Gemini", implied_label=_gem_implied, ) # Get model name for tracking vlm_model_name = model_choice or "Gemini" return result, vlm_input_tokens, vlm_output_tokens, vlm_model_name except Exception as e: print(f"Gemini page OCR error: {e}") import traceback print(f"Gemini page OCR error traceback: {traceback.format_exc()}") return ( { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], }, 0, 0, model_choice or "Gemini", ) def _azure_openai_page_ocr_predict( image: Image.Image, image_name: str = "azure_openai_page_ocr_input_image.png", normalised_coords_range: Optional[int] = None, output_folder: str = OUTPUT_FOLDER, detect_people_only: bool = False, detect_signatures_only: bool = False, progress: Optional[gr.Progress] = gr.Progress(), model_choice: str = None, client=None, page_index_0: Optional[int] = None, ) -> Tuple[Dict[str, List], int, int, str]: """ Azure/OpenAI page-level OCR prediction that returns structured line-level results with bounding boxes. Args: image: PIL Image to process (full page) image_name: Name of the image for debugging normalised_coords_range: If set, bounding boxes are assumed to be in normalized coordinates output_folder: The folder where output images will be saved detect_people_only: If True, only detect people in images detect_signatures_only: If True, only detect signatures in images progress: Gradio progress tracker model_choice: Model name (e.g., "gpt-4o", "gpt-4-vision-preview") client: OpenAI client instance Returns: Dictionary with 'text', 'left', 'top', 'width', 'height', 'conf', 'model' keys """ try: if image is None: print("Azure/OpenAI page OCR error: Image is None") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } try: width, height = image.size if width < 10 or height < 10: print( f"Azure/OpenAI page OCR error: Image is too small ({width}x{height} pixels)." ) return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } except Exception as size_error: print( f"Azure/OpenAI page OCR error: Could not get image size: {size_error}" ) return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } try: if image.mode != "RGB": image = image.convert("RGB") except Exception as convert_error: print( f"Azure/OpenAI page OCR error: Could not convert image to RGB: {convert_error}" ) return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } scale_x = 1.0 scale_y = 1.0 try: original_width, original_height = image.size processed_image = _prepare_image_for_vlm(image) processed_width, processed_height = processed_image.size scale_x = ( float(original_width) / float(processed_width) if processed_width > 0 else 1.0 ) scale_y = ( float(original_height) / float(processed_height) if processed_height > 0 else 1.0 ) except Exception as prep_error: print(f"Azure/OpenAI page OCR error: Could not prepare image: {prep_error}") return { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], } # Create prompt if detect_people_only: # progress(0.5, "Detecting faces on page...") # print("Detecting faces on page...") prompt = full_page_ocr_people_vlm_prompt elif detect_signatures_only: # progress(0.5, "Detecting signatures on page...") # print("Detecting signatures on page...") prompt = full_page_ocr_signature_vlm_prompt else: prompt = full_page_ocr_vlm_prompt # Call Azure/OpenAI API extracted_text, vlm_input_tokens, vlm_output_tokens = ( _call_azure_openai_vlm_api( image=processed_image, prompt=prompt, client=client, model_choice=model_choice, max_new_tokens=model_default_max_new_tokens, temperature=model_default_temperature, ) ) # Save prompt and response to file (including when response is empty, e.g. no faces/signatures) if extracted_text is not None and output_folder: try: # Determine task suffix based on detection type task_suffix = None if detect_people_only: task_suffix = "face" elif detect_signatures_only: task_suffix = "sig" response_str = ( extracted_text if isinstance(extracted_text, str) else str(extracted_text or "") ) saved_file = save_vlm_prompt_response( prompt=prompt, response_text=response_str, output_folder=output_folder, model_choice=model_choice or "unknown", image_name=image_name, page_number=page_index_0, temperature=model_default_temperature, max_new_tokens=model_default_max_new_tokens, model_type="Azure/OpenAI", task_suffix=task_suffix, input_tokens=vlm_input_tokens, output_tokens=vlm_output_tokens, image_width=processed_image.size[0], image_height=processed_image.size[1], ) print(f"Saved Azure/OpenAI VLM prompt/response to: {saved_file}") except Exception as save_error: print( f"Warning: Could not save Azure/OpenAI VLM prompt/response: {save_error}" ) if extracted_text is None or not isinstance(extracted_text, str): print("Azure/OpenAI page OCR warning: No valid response") return ( { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], }, 0, 0, model_choice or "Azure/OpenAI", ) _azure_implied = None if detect_people_only: _azure_implied = "[FACE]" elif detect_signatures_only: _azure_implied = "[SIGNATURE]" # Parse response using shared helper result = _parse_vlm_page_ocr_response( extracted_text=extracted_text, processed_image=processed_image, processed_width=processed_width, processed_height=processed_height, scale_x=scale_x, scale_y=scale_y, normalised_coords_range=normalised_coords_range, model_name="Azure/OpenAI", implied_label=_azure_implied, ) # Get model name for tracking vlm_model_name = model_choice or "Azure/OpenAI" return result, vlm_input_tokens, vlm_output_tokens, vlm_model_name except Exception as e: print(f"Azure/OpenAI page OCR error: {e}") import traceback print(f"Azure/OpenAI page OCR error traceback: {traceback.format_exc()}") return ( { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], }, 0, 0, model_choice or "Azure/OpenAI", ) class CustomImageAnalyzerEngine: def __init__( self, analyzer_engine: Optional[AnalyzerEngine] = None, ocr_engine: str = "tesseract", tesseract_config: Optional[str] = None, paddle_kwargs: Optional[Dict[str, Any]] = None, image_preprocessor: Optional[ImagePreprocessor] = None, language: Optional[str] = DEFAULT_LANGUAGE, output_folder: str = OUTPUT_FOLDER, save_page_ocr_visualisations: bool = SAVE_PAGE_OCR_VISUALISATIONS, ): """ Initializes the CustomImageAnalyzerEngine. :param ocr_engine: The OCR engine to use ("tesseract", "paddle", "vlm", "hybrid-paddle", "hybrid-vlm", "hybrid-paddle-vlm", "hybrid-paddle-inference-server", or "inference-server"). :param analyzer_engine: The Presidio AnalyzerEngine instance. :param tesseract_config: Configuration string for Tesseract. If None, uses TESSERACT_SEGMENTATION_LEVEL config. :param paddle_kwargs: Dictionary of keyword arguments for PaddleOCR constructor. :param image_preprocessor: Optional image preprocessor. :param language: Preferred OCR language (e.g., "en", "fr", "de"). Defaults to DEFAULT_LANGUAGE. :param output_folder: The folder to save the output images to. """ if ocr_engine not in LOCAL_OCR_MODEL_OPTIONS: raise ValueError( f"ocr_engine must be one of the following: {LOCAL_OCR_MODEL_OPTIONS}" ) self.ocr_engine = ocr_engine # Language setup self.language = language or DEFAULT_LANGUAGE or "en" self.tesseract_lang = _tesseract_lang_code(self.language) self.paddle_lang = _paddle_lang_code(self.language) # Security: Validate and normalize output_folder at construction time # This ensures the object is always in a secure state and prevents # any future code from accidentally using an untrusted directory normalized_output_folder = os.path.normpath(os.path.abspath(output_folder)) if not validate_folder_containment(normalized_output_folder, OUTPUT_FOLDER): raise ValueError( f"Unsafe output folder path: {output_folder}. Must be contained within {OUTPUT_FOLDER}" ) self.output_folder = normalized_output_folder self.save_page_ocr_visualisations = bool(save_page_ocr_visualisations) if ( self.ocr_engine == "paddle" or self.ocr_engine == "hybrid-paddle" or self.ocr_engine == "hybrid-paddle-vlm" or self.ocr_engine == "hybrid-paddle-inference-server" ): if paddle_kwargs is None: paddle_kwargs = _default_paddle_kwargs(self.paddle_lang) else: paddle_kwargs = dict(paddle_kwargs) paddle_kwargs.setdefault("lang", self.paddle_lang) paddle_kwargs = _finalize_paddle_kwargs(paddle_kwargs) if SPACES_ZERO_GPU: register_module_paddle_kwargs(paddle_kwargs) self.paddle_ocr = None print( "ZeroGPU: PaddleOCR will load on GPU inside paddle_predict " f"(device={paddle_kwargs.get('device')!r}, " f"engine={paddle_kwargs.get('engine')!r})" ) else: try: self.paddle_ocr = get_or_create_module_paddle_ocr(paddle_kwargs) except Exception as e: if ( "WinError 127" in str(e) or "could not be found" in str(e).lower() or "dll" in str(e).lower() ): print( f"Warning: GPU initialization failed (likely missing CUDA/cuDNN dependencies): {e}" ) print("PaddleOCR will not be available. To fix GPU issues:") print("1. Install Visual C++ Redistributables (latest version)") print("2. Ensure CUDA runtime libraries are in your PATH") print( "3. Or reinstall paddlepaddle CPU version: pip install paddlepaddle" ) raise elif self.ocr_engine == "hybrid-vlm": # VLM-based hybrid OCR - no additional initialization needed # VLM weights load at import if LOAD_TRANSFORMERS_VLM_MODEL_AT_START=True, else on first VLM call print( f"Initializing hybrid VLM OCR with model: {SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL}" ) self.paddle_ocr = None # Not using PaddleOCR elif self.ocr_engine == "vlm": # VLM page-level OCR - no additional initialization needed # VLM weights load at import if LOAD_TRANSFORMERS_VLM_MODEL_AT_START=True, else on first VLM call print( f"Initializing VLM OCR with model: {SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL}" ) self.paddle_ocr = None # Not using PaddleOCR if self.ocr_engine == "hybrid-paddle-vlm": # Hybrid PaddleOCR + VLM - requires both PaddleOCR and VLM # VLM weights load at import if LOAD_TRANSFORMERS_VLM_MODEL_AT_START=True, else on first VLM call print( f"Initializing hybrid PaddleOCR + VLM OCR with model: {SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL}" ) if self.ocr_engine == "hybrid-paddle-inference-server": # Hybrid PaddleOCR + Inference-server - requires both PaddleOCR and inference-server API print("Initializing hybrid PaddleOCR + Inference-server OCR") if not analyzer_engine: analyzer_engine = AnalyzerEngine() self.analyzer_engine = analyzer_engine # Set Tesseract configuration based on segmentation level if tesseract_config: self.tesseract_config = tesseract_config else: # Following function does not actually work correctly, so always use PSM 11 psm_value = TESSERACT_SEGMENTATION_LEVEL # _get_tesseract_psm(TESSERACT_SEGMENTATION_LEVEL) self.tesseract_config = f"--oem 3 --psm {psm_value}" # print( # f"Tesseract configured for {TESSERACT_SEGMENTATION_LEVEL}-level segmentation (PSM {psm_value})" # ) if not image_preprocessor: image_preprocessor = ContrastSegmentedImageEnhancer() self.image_preprocessor = image_preprocessor def _sanitize_filename( self, text: str, max_length: int = 20, fallback_prefix: str = "unknown_text" ) -> str: """ Sanitizes text for use in filenames by removing invalid characters and limiting length. :param text: The text to sanitize :param max_length: Maximum length of the sanitized text :param fallback_prefix: Prefix to use if sanitization fails :return: Sanitized text safe for filenames """ # Remove or replace invalid filename characters # Windows: < > : " | ? * \ / # Unix: / (forward slash) sanitized = safe_sanitize_text(text) # Remove leading/trailing underscores and spaces sanitized = sanitized.strip("_ ") # If empty after sanitization, use a default value if not sanitized: sanitized = fallback_prefix # Limit to max_length characters if len(sanitized) > max_length: sanitized = sanitized[:max_length] # Ensure we don't end with an underscore if we cut in the middle sanitized = sanitized.rstrip("_") # Final check: if still empty or too short, use fallback if not sanitized or len(sanitized) < 3: sanitized = fallback_prefix return sanitized def _create_safe_filename_with_confidence( self, original_text: str, new_text: str, conf: int, new_conf: int, ocr_type: str = "OCR", ) -> str: """ Creates a safe filename using confidence values when text sanitization fails. Args: original_text: Original text from Tesseract new_text: New text from VLM/PaddleOCR conf: Original confidence score new_conf: New confidence score ocr_type: Type of OCR used (VLM, Paddle, etc.) Returns: Safe filename string """ # Try to sanitize both texts safe_original = self._sanitize_filename( original_text, max_length=15, fallback_prefix=f"orig_conf_{conf}" ) safe_new = self._sanitize_filename( new_text, max_length=15, fallback_prefix=f"new_conf_{new_conf}" ) # If both sanitizations resulted in fallback names, create a confidence-based name if safe_original.startswith("unknown_text") and safe_new.startswith( "unknown_text" ): return f"{ocr_type}_conf_{conf}_to_conf_{new_conf}" return f"{safe_original}_conf_{conf}_to_{safe_new}_conf_{new_conf}" def _is_line_level_data(self, ocr_data: Dict[str, List]) -> bool: """ Determines if OCR data contains line-level results (multiple words per bounding box). Args: ocr_data: Dictionary with OCR data Returns: True if data appears to be line-level, False otherwise """ if not ocr_data or not ocr_data.get("text"): return False # Check if any text entries contain multiple words for text in ocr_data["text"]: if text.strip() and len(text.split()) > 1: return True return False def _convert_paddle_to_tesseract_format( self, paddle_results: List[Any], input_image_width: int = None, input_image_height: int = None, image_name: str = None, image: Image.Image = None, ) -> Dict[str, List]: """Converts PaddleOCR result format to Tesseract's dictionary format using relative coordinates. This function uses a safer approach: converts PaddleOCR coordinates to relative (0-1) coordinates based on whatever coordinate space PaddleOCR uses, then scales them to the input image dimensions. This avoids issues with PaddleOCR's internal image resizing. Args: paddle_results: List of PaddleOCR result dictionaries input_image_width: Width of the input image passed to PaddleOCR (target dimensions for scaling) input_image_height: Height of the input image passed to PaddleOCR (target dimensions for scaling) image_name: Name of the image image: Image object """ output = { "text": list(), "left": list(), "top": list(), "width": list(), "height": list(), "conf": list(), "model": list(), } # paddle_results is now a list of dictionaries with detailed information if not paddle_results: return output # Validate that we have target dimensions if input_image_width is None or input_image_height is None: print( "Warning: Input image dimensions not provided. PaddleOCR coordinates may be incorrectly scaled." ) # Fallback: we'll try to detect from coordinates, but this is less reliable use_relative_coords = False else: use_relative_coords = True for page_result in paddle_results: # Extract text recognition results from the new format rec_texts = page_result.get("rec_texts", list()) rec_scores = page_result.get("rec_scores", list()) rec_polys = page_result.get("rec_polys", list()) rec_models = page_result.get("rec_models", list()) # PaddleOCR may return image dimensions in the result - check for them # Some versions of PaddleOCR include this information result_image_width = page_result.get("image_width") result_image_height = page_result.get("image_height") # PaddleOCR typically returns coordinates in the input image space # However, it may internally resize images, so we need to check if coordinates # are in a different space by comparing with explicit metadata or detecting from coordinates # First pass: determine PaddleOCR's coordinate space by finding max coordinates # This tells us what coordinate space PaddleOCR is actually using max_x_coord = 0 max_y_coord = 0 for bounding_box in rec_polys: if hasattr(bounding_box, "tolist"): box = bounding_box.tolist() else: box = bounding_box if box and len(box) > 0: x_coords = [p[0] for p in box] y_coords = [p[1] for p in box] max_x_coord = max(max_x_coord, max(x_coords) if x_coords else 0) max_y_coord = max(max_y_coord, max(y_coords) if y_coords else 0) # Determine PaddleOCR's coordinate space dimensions # Priority: explicit result metadata > input dimensions (standard PaddleOCR behavior) # Note: PaddleOCR typically returns coordinates in the input image space. # We only use a different coordinate space if PaddleOCR provides explicit metadata. # Using max coordinates to detect coordinate space is unreliable because: # 1. Text might not extend to image edges # 2. There might be padding # 3. Max coordinates don't necessarily equal image dimensions if result_image_width is not None and result_image_height is not None: # Use explicit metadata from PaddleOCR if available (most reliable) paddle_coord_width = result_image_width paddle_coord_height = result_image_height # Only use relative conversion if coordinate space differs from input if ( paddle_coord_width != input_image_width or paddle_coord_height != input_image_height ): # print( # f"PaddleOCR metadata indicates coordinate space ({paddle_coord_width}x{paddle_coord_height}) " # f"differs from input ({input_image_width}x{input_image_height}). " # f"Using metadata for coordinate conversion." # ) pass elif input_image_width is not None and input_image_height is not None: # Default: assume coordinates are in input image space (standard PaddleOCR behavior) # This is the most common case and avoids incorrect scaling paddle_coord_width = input_image_width paddle_coord_height = input_image_height else: # Fallback: use max coordinates if we have no other information paddle_coord_width = max_x_coord if max_x_coord > 0 else 1 paddle_coord_height = max_y_coord if max_y_coord > 0 else 1 use_relative_coords = False print( f"Warning: No input dimensions provided. Using detected coordinate space ({paddle_coord_width}x{paddle_coord_height}) from max coordinates." ) # Validate coordinate space dimensions if paddle_coord_width is None or paddle_coord_height is None: paddle_coord_width = input_image_width or 1 paddle_coord_height = input_image_height or 1 use_relative_coords = False if paddle_coord_width <= 0 or paddle_coord_height <= 0: print( f"Warning: Invalid PaddleOCR coordinate space dimensions ({paddle_coord_width}x{paddle_coord_height}). Using input dimensions." ) paddle_coord_width = input_image_width or 1 paddle_coord_height = input_image_height or 1 use_relative_coords = False # If coordinate space matches input dimensions, coordinates are already in the correct space # Only use relative coordinate conversion if coordinate space differs from input if ( paddle_coord_width == input_image_width and paddle_coord_height == input_image_height and input_image_width is not None and input_image_height is not None ): # Coordinates are already in input space, no conversion needed use_relative_coords = False # print( # f"PaddleOCR coordinates are in input image space ({input_image_width}x{input_image_height}). " # f"Using coordinates directly without conversion." # ) # Second pass: convert coordinates using relative coordinate approach # Use default "Paddle" if rec_models is not available or doesn't match length if len(rec_models) != len(rec_texts): # print( # f"Warning: rec_models length ({len(rec_models)}) doesn't match rec_texts length ({len(rec_texts)}). Using default 'Paddle' for all." # ) rec_models = ["Paddle"] * len(rec_texts) # Update page_result to keep it consistent page_result["rec_models"] = rec_models else: # Ensure we're using the rec_models from page_result (which may have been modified) rec_models = page_result.get("rec_models", rec_models) # Debug: Print model distribution vlm_count = sum(1 for m in rec_models if m == "VLM") if vlm_count > 0: print( f"Found {vlm_count} VLM-labeled lines out of {len(rec_models)} total lines in page_result" ) for line_text, line_confidence, bounding_box, line_model in zip( rec_texts, rec_scores, rec_polys, rec_models ): # bounding_box is now a numpy array with shape (4, 2) # Convert to list of coordinates if it's a numpy array if hasattr(bounding_box, "tolist"): box = bounding_box.tolist() else: box = bounding_box if not box or len(box) == 0: continue # box is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] x_coords = [p[0] for p in box] y_coords = [p[1] for p in box] # Extract bounding box coordinates in PaddleOCR's coordinate space line_left_paddle = float(min(x_coords)) line_top_paddle = float(min(y_coords)) line_right_paddle = float(max(x_coords)) line_bottom_paddle = float(max(y_coords)) line_width_paddle = line_right_paddle - line_left_paddle line_height_paddle = line_bottom_paddle - line_top_paddle # Convert to relative coordinates (0-1) based on PaddleOCR's coordinate space # Then scale to input image dimensions if ( use_relative_coords and paddle_coord_width > 0 and paddle_coord_height > 0 ): # Normalize to relative coordinates [0-1] rel_left = line_left_paddle / paddle_coord_width rel_top = line_top_paddle / paddle_coord_height rel_width = line_width_paddle / paddle_coord_width rel_height = line_height_paddle / paddle_coord_height # Scale to input image dimensions line_left = rel_left * input_image_width line_top = rel_top * input_image_height line_width = rel_width * input_image_width line_height = rel_height * input_image_height else: # Fallback: use coordinates directly (may cause issues if coordinate spaces don't match) line_left = line_left_paddle line_top = line_top_paddle line_width = line_width_paddle line_height = line_height_paddle # if input_image_width and input_image_height: # print(f"Warning: Using PaddleOCR coordinates directly. This may cause scaling issues.") # Ensure coordinates are within valid bounds if input_image_width and input_image_height: line_left = max(0, min(line_left, input_image_width)) line_top = max(0, min(line_top, input_image_height)) line_width = max(0, min(line_width, input_image_width - line_left)) line_height = max( 0, min(line_height, input_image_height - line_top) ) # Add line-level data output["text"].append(line_text) output["left"].append(round(line_left, 2)) output["top"].append(round(line_top, 2)) output["width"].append(round(line_width, 2)) output["height"].append(round(line_height, 2)) output["conf"].append(int(line_confidence * 100)) output["model"].append(line_model if line_model else "Paddle") return output @staticmethod def _process_one_line_to_words( task: Tuple, output_folder: str, image_name: Optional[str], thread_local_segmenter: Optional[threading.local] = None, ) -> Tuple[int, Dict[str, List]]: """ Process a single line to word-level bounding boxes. Used by _convert_line_to_word_level for parallel execution. Args: task: (line_index, line_image, line_text, line_conf, line_model, line_left, line_top, line_width, line_height) output_folder: Passed to AdaptiveSegmenter image_name: Passed to segmenter.segment() Returns: (line_index, word_dict) with word_dict having keys text, left, top, width, height, conf, model (all lists). """ ( i, line_image, line_text, line_conf, line_model, line_left, line_top, line_width, line_height, ) = task word_dict = { "text": [], "left": [], "top": [], "width": [], "height": [], "conf": [], "model": [], # Preserve the originating line index so downstream can keep Paddle's # native line grouping even after word-level conversion. "line": [], } if thread_local_segmenter is not None: segmenter = getattr(thread_local_segmenter, "segmenter", None) if segmenter is None: segmenter = AdaptiveSegmenter(output_folder=output_folder) thread_local_segmenter.segmenter = segmenter else: segmenter = AdaptiveSegmenter(output_folder=output_folder) single_line_data = { "text": [line_text], "left": [0], "top": [0], "width": [line_width], "height": [line_height], "conf": [line_conf], "line": [i], } word_output, _ = segmenter.segment( single_line_data, line_image, image_name=image_name ) if not word_output or not word_output.get("text"): words = line_text.split() if words: num_chars = len("".join(words)) num_spaces = len(words) - 1 if num_chars > 0: char_space_ratio = 2.0 estimated_space_width = ( line_width / (num_chars * char_space_ratio + num_spaces) if (num_chars * char_space_ratio + num_spaces) > 0 else line_width / num_chars ) avg_char_width = estimated_space_width * char_space_ratio current_left = 0 for word in words: word_width = len(word) * avg_char_width clamped_left = max(0, min(current_left, line_width)) clamped_width = max( 0, min(word_width, line_width - clamped_left) ) word_dict["text"].append(word) word_dict["left"].append(line_left + clamped_left) word_dict["top"].append(line_top) word_dict["width"].append(clamped_width) word_dict["height"].append(line_height) word_dict["conf"].append(line_conf) word_dict["model"].append(line_model) word_dict["line"].append(i) current_left += word_width + estimated_space_width return (i, word_dict) for j in range(len(word_output["text"])): word_dict["text"].append(word_output["text"][j]) word_dict["left"].append(line_left + word_output["left"][j]) word_dict["top"].append(line_top + word_output["top"][j]) word_dict["width"].append(word_output["width"][j]) word_dict["height"].append(word_output["height"][j]) word_dict["conf"].append(word_output["conf"][j]) word_dict["model"].append(line_model) word_dict["line"].append(i) return (i, word_dict) def _convert_line_to_word_level( self, line_data: Dict[str, List], image_width: int, image_height: int, image: Image.Image, image_name: str = None, ) -> Dict[str, List]: """ Converts line-level OCR results to word-level using AdaptiveSegmenter.segment(). This method processes each line individually using the adaptive segmentation algorithm. Lines are processed in parallel with ThreadPoolExecutor when there is more than one. Args: line_data: Dictionary with keys "text", "left", "top", "width", "height", "conf" (all lists) image_width: Width of the full image image_height: Height of the full image image: PIL Image object of the full image image_name: Name of the image Returns: Dictionary with same keys as input, containing word-level bounding boxes """ output = { "text": list(), "left": list(), "top": list(), "width": list(), "height": list(), "conf": list(), "model": list(), "line": list(), } if not line_data or not line_data.get("text"): return output # Timing hooks removed (test-only). # Validate that image is not None before processing if image is None: print( "Warning: Image is None in _convert_line_to_word_level. Returning empty output." ) return output # Convert PIL Image to numpy array (BGR format for OpenCV) if hasattr(image, "size"): # PIL Image image_np = np.array(image) if len(image_np.shape) == 3: # Convert RGB to BGR for OpenCV image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) elif len(image_np.shape) == 2: # Grayscale - convert to BGR image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR) else: # Already numpy array image_np = image.copy() if len(image_np.shape) == 2: image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR) # Validate that image_np dimensions match the expected image_width and image_height # PIL Image.size returns (width, height), but numpy array shape is (height, width, channels) actual_height, actual_width = image_np.shape[:2] if actual_width != image_width or actual_height != image_height: print( f"Warning: Image dimension mismatch! Expected {image_width}x{image_height}, but got {actual_width}x{actual_height}" ) image_width = actual_width image_height = actual_height # Build list of tasks: one per valid line (crop and validate on main thread) _start_task_build = time.perf_counter() tasks = [] for i in range(len(line_data["text"])): line_text = line_data["text"][i] line_conf = line_data["conf"][i] if "model" in line_data and len(line_data["model"]) > i: line_model = line_data["model"][i] else: line_model = "Paddle" f_left = float(line_data["left"][i]) f_top = float(line_data["top"][i]) f_width = float(line_data["width"][i]) f_height = float(line_data["height"][i]) is_normalized = ( f_left <= 1.0 and f_top <= 1.0 and f_width <= 1.0 and f_height <= 1.0 ) if is_normalized: line_left = float(round(f_left * image_width)) line_top = float(round(f_top * image_height)) line_width = float(round(f_width * image_width)) line_height = float(round(f_height * image_height)) else: line_left = float(round(f_left)) line_top = float(round(f_top)) line_width = float(round(f_width)) line_height = float(round(f_height)) if not line_text.strip(): continue line_left = int(max(0, min(line_left, image_width - 1))) line_top = int(max(0, min(line_top, image_height - 1))) line_width = int(max(1, min(line_width, image_width - line_left))) line_height = int(max(1, min(line_height, image_height - line_top))) if line_left >= image_width or line_top >= image_height: continue if line_left + line_width > image_width: line_width = image_width - line_left if line_top + line_height > image_height: line_height = image_height - line_top if line_width <= 0 or line_height <= 0: continue try: line_image = image_np[ line_top : line_top + line_height, line_left : line_left + line_width, ].copy() except IndexError: continue if line_image.size == 0 or len(line_image.shape) < 2: continue tasks.append( ( i, line_image, line_text, line_conf, line_model, line_left, line_top, line_width, line_height, ) ) if not tasks: return output # Timing hooks removed (test-only). # Process lines in parallel. Dedicated worker cap is safer for this CPU-heavy path. max_workers = min(LINE_TO_WORD_SEGMENT_MAX_WORKERS, len(tasks)) # Timing hooks removed (test-only). process_one = partial( CustomImageAnalyzerEngine._process_one_line_to_words, output_folder=self.output_folder, image_name=image_name, thread_local_segmenter=threading.local(), ) if max_workers <= 1: results = [process_one(task) for task in tasks] else: with ThreadPoolExecutor(max_workers=max_workers) as executor: results = list(executor.map(process_one, tasks)) # Timing hooks removed (test-only). # Merge results in line order to preserve document order # Timing hooks removed (test-only). for _i, word_dict in sorted(results, key=lambda x: x[0]): for key in output: output[key].extend(word_dict[key]) # Timing hooks removed (test-only). return output def _visualize_tesseract_bounding_boxes( self, image: Image.Image, ocr_data: Dict[str, List], image_name: str = None, visualisation_folder: str = "tesseract_visualisations", ) -> None: """ Visualizes Tesseract OCR bounding boxes with confidence-based colors and a legend. Args: image: The PIL Image object ocr_data: Tesseract OCR data dictionary image_name: Optional name for the saved image file """ if not ocr_data or not ocr_data.get("text"): return # Convert PIL image to OpenCV format image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Get image dimensions height, width = image_cv.shape[:2] # Define confidence ranges and colors confidence_ranges = [ (80, 100, (0, 255, 0), "High (80-100%)"), # Green (50, 79, (0, 165, 255), "Medium (50-79%)"), # Orange (0, 49, (0, 0, 255), "Low (0-49%)"), # Red ] # Process each detected text element for i in range(len(ocr_data["text"])): text = ocr_data["text"][i] conf = int(ocr_data["conf"][i]) # Skip empty text or invalid confidence if not text.strip() or conf == -1: continue left = ocr_data["left"][i] top = ocr_data["top"][i] width_box = ocr_data["width"][i] height_box = ocr_data["height"][i] # Calculate bounding box coordinates x1 = int(left) y1 = int(top) x2 = int(left + width_box) y2 = int(top + height_box) # Ensure coordinates are within image bounds x1 = max(0, min(x1, width)) y1 = max(0, min(y1, height)) x2 = max(0, min(x2, width)) y2 = max(0, min(y2, height)) # Skip if bounding box is invalid if x2 <= x1 or y2 <= y1: continue # Determine color based on confidence score color = (0, 0, 255) # Default to red for min_conf, max_conf, conf_color, _ in confidence_ranges: if min_conf <= conf <= max_conf: color = conf_color break # Draw bounding box cv2.rectangle(image_cv, (x1, y1), (x2, y2), color, 1) # Add legend self._add_confidence_legend(image_cv, confidence_ranges) # Save the visualization tesseract_viz_folder = os.path.join(self.output_folder, visualisation_folder) # Double-check the constructed path is safe if not validate_folder_containment(tesseract_viz_folder, OUTPUT_FOLDER): raise ValueError( f"Unsafe tesseract visualisations folder path: {tesseract_viz_folder}" ) os.makedirs(tesseract_viz_folder, exist_ok=True) # Generate filename if image_name: # Remove file extension if present base_name = os.path.splitext(image_name)[0] filename = f"{base_name}_{visualisation_folder}.jpg" else: timestamp = int(time.time()) filename = f"{visualisation_folder}_{timestamp}.jpg" output_path = os.path.join(tesseract_viz_folder, filename) # Save the image max_filesize = 500 * 1024 # 500kb in bytes quality = 95 # Start high, OpenCV JPEG quality range is 0-100 # Try lowering JPEG quality until file is below size limit is_saved = False while quality >= 10: cv2.imwrite(output_path, image_cv, [int(cv2.IMWRITE_JPEG_QUALITY), quality]) if ( os.path.exists(output_path) and os.path.getsize(output_path) <= max_filesize ): is_saved = True break quality -= 5 if not is_saved: # Save as lowest acceptable quality if cannot get under 500kb, or raise warning cv2.imwrite(output_path, image_cv, [int(cv2.IMWRITE_JPEG_QUALITY), 10]) # Optionally log warning here that file could not be compressed below 500kb print(f"Tesseract visualization saved to: {output_path}") def _add_confidence_legend( self, image_cv: np.ndarray, confidence_ranges: List[Tuple] ) -> None: """ Adds a confidence legend to the visualization image. Args: image_cv: OpenCV image array confidence_ranges: List of tuples containing (min_conf, max_conf, color, label) """ height, width = image_cv.shape[:2] # Legend parameters legend_width = 200 legend_height = 100 legend_x = width - legend_width - 20 legend_y = 20 # Draw legend background cv2.rectangle( image_cv, (legend_x, legend_y), (legend_x + legend_width, legend_y + legend_height), (255, 255, 255), # White background -1, ) cv2.rectangle( image_cv, (legend_x, legend_y), (legend_x + legend_width, legend_y + legend_height), (0, 0, 0), # Black border 2, ) # Add title title_text = "Confidence Levels" font_scale = 0.6 font_thickness = 2 (title_width, title_height), _ = cv2.getTextSize( title_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness ) title_x = legend_x + (legend_width - title_width) // 2 title_y = legend_y + title_height + 10 cv2.putText( image_cv, title_text, (title_x, title_y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), # Black text font_thickness, ) # Add confidence range items item_spacing = 25 start_y = title_y + 25 for i, (min_conf, max_conf, color, label) in enumerate(confidence_ranges): item_y = start_y + i * item_spacing # Draw color box box_size = 15 box_x = legend_x + 10 box_y = item_y - box_size cv2.rectangle( image_cv, (box_x, box_y), (box_x + box_size, box_y + box_size), color, -1, ) cv2.rectangle( image_cv, (box_x, box_y), (box_x + box_size, box_y + box_size), (0, 0, 0), # Black border 1, ) # Add label text label_x = box_x + box_size + 10 label_y = item_y - 5 cv2.putText( image_cv, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), # Black text 1, ) # Calculate line-level bounding boxes and average confidence def _calculate_line_bbox(self, group): # Get the leftmost and rightmost positions left = group["left"].min() top = group["top"].min() right = (group["left"] + group["width"]).max() bottom = (group["top"] + group["height"]).max() # Calculate width and height width = right - left height = bottom - top # Calculate average confidence avg_conf = round(group["conf"].mean(), 0) return pd.Series( { "text": " ".join(group["text"].astype(str).tolist()), "left": left, "top": top, "width": width, "height": height, "conf": avg_conf, } ) def _perform_hybrid_ocr( self, image: Image.Image, confidence_threshold: int = HYBRID_OCR_CONFIDENCE_THRESHOLD, padding: int = HYBRID_OCR_PADDING, ocr: Optional[Any] = None, image_name: str = "unknown_image_name", ) -> Dict[str, list]: """ Performs hybrid OCR on an image using Tesseract for initial OCR and PaddleOCR/VLM to enhance results for low-confidence or uncertain words. Args: image (Image.Image): The input image (PIL format) to be processed. confidence_threshold (int, optional): Tesseract confidence threshold below which words are re-analyzed with secondary OCR (PaddleOCR/VLM). Defaults to HYBRID_OCR_CONFIDENCE_THRESHOLD. padding (int, optional): Pixel padding (in all directions) to add around each word box when cropping for secondary OCR. Defaults to HYBRID_OCR_PADDING. ocr (Optional[Any], optional): An instance of the PaddleOCR or VLM engine. If None, will use the instance's `paddle_ocr` attribute if available. Only necessary for PaddleOCR-based pipelines. image_name (str, optional): Optional name of the image, useful for debugging and visualization. Returns: Dict[str, list]: OCR results in the dictionary format of pytesseract.image_to_data (keys: 'text', 'left', 'top', 'width', 'height', 'conf', 'model', ...). """ # Determine if we're using VLM or PaddleOCR use_vlm = self.ocr_engine == "hybrid-vlm" if not use_vlm: if ocr is None: if hasattr(self, "paddle_ocr") and self.paddle_ocr is not None: ocr = self.paddle_ocr elif not SPACES_ZERO_GPU: raise ValueError( "No OCR object provided and 'paddle_ocr' is not initialized." ) # print("Starting hybrid OCR process...") # 1. Get initial word-level results from Tesseract tesseract_data = pytesseract.image_to_data( image, output_type=pytesseract.Output.DICT, config=self.tesseract_config, lang=self.tesseract_lang, ) if TESSERACT_WORD_LEVEL_OCR is False: ocr_df = pd.DataFrame(tesseract_data) # Filter out invalid entries (confidence == -1) ocr_df = ocr_df[ocr_df.conf != -1] # Group by line and aggregate text line_groups = ocr_df.groupby(["block_num", "par_num", "line_num"]) ocr_data = line_groups.apply(self._calculate_line_bbox).reset_index() # Overwrite tesseract_data with the aggregated data tesseract_data = { "text": ocr_data["text"].tolist(), "left": ocr_data["left"].astype(int).tolist(), "top": ocr_data["top"].astype(int).tolist(), "width": ocr_data["width"].astype(int).tolist(), "height": ocr_data["height"].astype(int).tolist(), "conf": ocr_data["conf"].tolist(), "model": ["Tesseract"] * len(ocr_data), # Add model field } final_data = { "text": list(), "left": list(), "top": list(), "width": list(), "height": list(), "conf": list(), "model": list(), # Track which model was used for each word } num_words = len(tesseract_data["text"]) # This handles the "no text on page" case. If num_words is 0, the loop is skipped # and an empty dictionary with empty lists is returned, which is the correct behavior. for i in range(num_words): text = tesseract_data["text"][i] conf = int(tesseract_data["conf"][i]) # Skip empty text boxes or non-word elements (like page/block markers) if not text.strip() or conf == -1: continue left = tesseract_data["left"][i] top = tesseract_data["top"][i] width = tesseract_data["width"][i] height = tesseract_data["height"][i] # line_number = tesseract_data['abs_line_id'][i] # Initialize model as Tesseract (default) model_used = "Tesseract" # If confidence is low, use PaddleOCR for a second opinion if conf <= confidence_threshold: img_width, img_height = image.size crop_left = max(0, left - padding) crop_top = max(0, top - padding) crop_right = min(img_width, left + width + padding) crop_bottom = min(img_height, top + height + padding) # Ensure crop dimensions are valid if crop_right <= crop_left or crop_bottom <= crop_top: continue # Skip invalid crops cropped_image = image.crop( (crop_left, crop_top, crop_right, crop_bottom) ) if use_vlm: # Use VLM for OCR vlm_result = _vlm_ocr_predict(cropped_image) rec_texts = vlm_result.get("rec_texts", []) rec_scores = vlm_result.get("rec_scores", []) else: # Use PaddleOCR cropped_image_np = np.array(cropped_image) if len(cropped_image_np.shape) == 2: cropped_image_np = np.stack([cropped_image_np] * 3, axis=-1) # paddle_results = ocr.predict(cropped_image_np) paddle_results = paddle_predict(cropped_image_np) if paddle_results and paddle_results[0]: rec_texts = paddle_results[0].get("rec_texts", []) rec_scores = paddle_results[0].get("rec_scores", []) else: rec_texts = [] rec_scores = [] if rec_texts and rec_scores: new_text = " ".join(rec_texts) new_conf = int(round(np.median(rec_scores) * 100, 0)) # Only replace if Paddle's/VLM's confidence is better if new_conf >= conf: ocr_type = "VLM" if use_vlm else "Paddle" message_output = f" Re-OCR'd word: '{text}' (conf: {conf}) -> '{new_text}' (conf: {new_conf:.0f}) [{ocr_type}]" print(message_output) if REPORT_VLM_OUTPUTS_TO_GUI: try: gr.Info(message_output, duration=2) except Exception: # gr.Info may not be available in worker process, ignore pass # For exporting example image comparisons, not used here safe_filename = self._create_safe_filename_with_confidence( text, new_text, conf, new_conf, ocr_type ) if SAVE_EXAMPLE_HYBRID_IMAGES: # Normalize and validate image_name to prevent path traversal attacks normalized_image_name = os.path.normpath( image_name + "_" + ocr_type ) # Ensure the image name doesn't contain path traversal characters if ( ".." in normalized_image_name or "/" in normalized_image_name or "\\" in normalized_image_name ): normalized_image_name = ( "safe_image" # Fallback to safe default ) hybrid_ocr_examples_folder = ( self.output_folder + f"/hybrid_ocr_examples/{normalized_image_name}" ) # Validate the constructed path is safe before creating directories if not validate_folder_containment( hybrid_ocr_examples_folder, OUTPUT_FOLDER ): raise ValueError( f"Unsafe hybrid_ocr_examples folder path: {hybrid_ocr_examples_folder}" ) if not os.path.exists(hybrid_ocr_examples_folder): os.makedirs(hybrid_ocr_examples_folder) output_image_path = ( hybrid_ocr_examples_folder + f"/{safe_filename}.png" ) print(f"Saving example image to {output_image_path}") _save_image_with_config_dpi( cropped_image, output_image_path ) text = new_text conf = new_conf model_used = ocr_type # Update model to VLM or Paddle else: ocr_type = "VLM" if use_vlm else "Paddle" print( f" '{text}' (conf: {conf}) -> {ocr_type} result '{new_text}' (conf: {new_conf:.0f}) was not better. Keeping original." ) else: # OCR ran but found nothing, discard original word ocr_type = "VLM" if use_vlm else "Paddle" print( f" '{text}' (conf: {conf}) -> No text found by {ocr_type}. Discarding." ) text = "" # Append the final result (either original, replaced, or skipped if empty) if text.strip(): final_data["text"].append( clean_unicode_text(text, preserve_international_scripts=True) ) final_data["left"].append(left) final_data["top"].append(top) final_data["width"].append(width) final_data["height"].append(height) final_data["conf"].append(int(conf)) final_data["model"].append(model_used) # final_data['line_number'].append(int(line_number)) return final_data def _perform_hybrid_paddle_vlm_ocr( self, image: Image.Image, ocr: Optional[Any] = None, paddle_results: List[Any] = None, confidence_threshold: int = HYBRID_OCR_CONFIDENCE_THRESHOLD, padding: int = HYBRID_OCR_PADDING, image_name: str = "unknown_image_name", input_image_width: int = None, input_image_height: int = None, ) -> List[Any]: """ Performs OCR using PaddleOCR at line level, then VLM for low-confidence lines. Returns modified paddle_results in the same format as PaddleOCR output. Args: image: PIL Image to process ocr: PaddleOCR instance (optional, uses self.paddle_ocr if not provided) paddle_results: PaddleOCR results in original format (List of dicts with rec_texts, rec_scores, rec_polys) confidence_threshold: Confidence threshold below which VLM is used padding: Padding to add around line crops image_name: Name of the image for logging/debugging input_image_width: Original image width (before preprocessing) input_image_height: Original image height (before preprocessing) Returns: Modified paddle_results with VLM replacements for low-confidence lines """ if ocr is None: if hasattr(self, "paddle_ocr") and self.paddle_ocr is not None: ocr = self.paddle_ocr elif not SPACES_ZERO_GPU: raise ValueError( "No OCR object provided and 'paddle_ocr' is not initialized." ) if paddle_results is None or not paddle_results: return paddle_results print("Starting hybrid PaddleOCR + VLM OCR process...") # Get image dimensions img_width, img_height = image.size # Use original dimensions if provided, otherwise use current image dimensions if input_image_width is None: input_image_width = img_width if input_image_height is None: input_image_height = img_height # Convert PaddleOCR result objects to plain dictionaries for pickling # The @spaces.GPU decorator requires picklable arguments, but PaddleOCR # result objects contain CopyableWeakMethod references that can't be pickled copied_paddle_results = [ _paddle_result_to_plain_dict(result) for result in paddle_results ] modified_paddle_results = _process_page_result_with_hybrid_vlm_ocr( copied_paddle_results, image, img_width, img_height, input_image_width, input_image_height, confidence_threshold, image_name, self.output_folder, padding, ) return modified_paddle_results def _perform_hybrid_paddle_inference_server_ocr( self, image: Image.Image, ocr: Optional[Any] = None, paddle_results: List[Any] = None, confidence_threshold: int = HYBRID_OCR_CONFIDENCE_THRESHOLD, padding: int = HYBRID_OCR_PADDING, image_name: str = "unknown_image_name", input_image_width: int = None, input_image_height: int = None, model_name: str = None, ) -> List[Any]: """ Performs OCR using PaddleOCR at line level, then inference-server API for low-confidence lines. Returns modified paddle_results in the same format as PaddleOCR output. Args: image: PIL Image to process ocr: PaddleOCR instance (optional, uses self.paddle_ocr if not provided) paddle_results: PaddleOCR results in original format (List of dicts with rec_texts, rec_scores, rec_polys) confidence_threshold: Confidence threshold below which inference-server is used padding: Padding to add around line crops image_name: Name of the image for logging/debugging input_image_width: Original image width (before preprocessing) input_image_height: Original image height (before preprocessing) model_name: Name of the inference-server model to use Returns: Modified paddle_results with inference-server replacements for low-confidence lines """ if ocr is None: if hasattr(self, "paddle_ocr") and self.paddle_ocr is not None: ocr = self.paddle_ocr elif not SPACES_ZERO_GPU: raise ValueError( "No OCR object provided and 'paddle_ocr' is not initialized." ) if paddle_results is None or not paddle_results: return paddle_results print("Starting hybrid PaddleOCR + Inference-server OCR process...") # Get image dimensions img_width, img_height = image.size # Use original dimensions if provided, otherwise use current image dimensions if input_image_width is None: input_image_width = img_width if input_image_height is None: input_image_height = img_height # Create a deep copy of paddle_results to modify copied_paddle_results = copy.deepcopy(paddle_results) def _normalize_paddle_result_lists(rec_texts, rec_scores, rec_polys): """ Normalizes PaddleOCR result lists to ensure they all have the same length. Pads missing entries with appropriate defaults: - rec_texts: empty string "" - rec_scores: 0.0 (low confidence) - rec_polys: empty list [] Args: rec_texts: List of recognized text strings rec_scores: List of confidence scores rec_polys: List of bounding box polygons Returns: Tuple of (normalized_rec_texts, normalized_rec_scores, normalized_rec_polys, max_length) """ len_texts = len(rec_texts) len_scores = len(rec_scores) len_polys = len(rec_polys) max_length = max(len_texts, len_scores, len_polys) # Only normalize if there's a mismatch if max_length > 0 and ( len_texts != max_length or len_scores != max_length or len_polys != max_length ): print( f"Warning: List length mismatch detected - rec_texts: {len_texts}, " f"rec_scores: {len_scores}, rec_polys: {len_polys}. " f"Padding to length {max_length}." ) # Pad rec_texts if len_texts < max_length: rec_texts = list(rec_texts) + [""] * (max_length - len_texts) # Pad rec_scores if len_scores < max_length: rec_scores = list(rec_scores) + [0.0] * (max_length - len_scores) # Pad rec_polys if len_polys < max_length: rec_polys = list(rec_polys) + [[]] * (max_length - len_polys) return rec_texts, rec_scores, rec_polys, max_length def _process_page_result_with_hybrid_inference_server_ocr( page_results: list, image: Image.Image, img_width: int, img_height: int, input_image_width: int, input_image_height: int, confidence_threshold: float, image_name: str, instance_self: object, padding: int = 0, ): """ Processes OCR page results using a hybrid system that combines PaddleOCR for initial recognition and an inference server for low-confidence lines. When PaddleOCR's recognition confidence for a detected line is below the specified threshold, the line is re-processed using a higher-quality (but slower) server model and the result is used to replace the low-confidence recognition. Results are kept in PaddleOCR's standard output format for downstream compatibility. Args: page_results (list): The list of page result dicts from PaddleOCR to process. Each dict should contain keys like 'rec_texts', 'rec_scores', 'rec_polys', and optionally 'image_width', 'image_height', and 'rec_models'. image (PIL.Image.Image): The PIL Image object of the full page to allow line cropping. img_width (int): The width of the (possibly preprocessed) image in pixels. img_height (int): The height of the (possibly preprocessed) image in pixels. input_image_width (int): The original image width (before any resizing/preprocessing). input_image_height (int): The original image height (before any resizing/preprocessing). confidence_threshold (float): Lines recognized by PaddleOCR with confidence lower than this threshold will be replaced using the inference server. image_name (str): The name of the source image, used for logging/debugging. instance_self (object): The enclosing class instance to access inference invocation. padding (int): Padding to add around line crops. Returns: None. Modifies page_results in place with higher-confidence text replacements when possible. """ # Process each page result in paddle_results for page_result in page_results: # Extract text recognition results from the paddle format rec_texts = page_result.get("rec_texts", list()) rec_scores = page_result.get("rec_scores", list()) rec_polys = page_result.get("rec_polys", list()) # Normalize lists to ensure they all have the same length rec_texts, rec_scores, rec_polys, num_lines = ( _normalize_paddle_result_lists(rec_texts, rec_scores, rec_polys) ) # Update page_result with normalized lists page_result["rec_texts"] = rec_texts page_result["rec_scores"] = rec_scores page_result["rec_polys"] = rec_polys # Initialize rec_models list with "Paddle" as default for all lines if ( "rec_models" not in page_result or len(page_result.get("rec_models", [])) != num_lines ): rec_models = ["Paddle"] * num_lines page_result["rec_models"] = rec_models else: rec_models = page_result["rec_models"] # Since we're using the exact image PaddleOCR processed, coordinates are directly in image space # No coordinate conversion needed - coordinates match the image dimensions exactly # Process each line for i in range(num_lines): line_text = rec_texts[i] line_conf = float(rec_scores[i]) * 100 # Convert to percentage bounding_box = rec_polys[i] # Skip if bounding box is empty (from padding) # Handle numpy arrays, lists, and None values safely if bounding_box is None: print( f"Current line {i + 1} of {num_lines}: Bounding box is None" ) continue # Convert to list first to handle numpy arrays safely if hasattr(bounding_box, "tolist"): box = bounding_box.tolist() else: box = bounding_box # Check if box is empty (handles both list and numpy array cases) if not box or (isinstance(box, list) and len(box) == 0): print(f"Current line {i + 1} of {num_lines}: Box is empty") continue # Skip empty lines if not line_text.strip(): print( f"Current line {i + 1} of {num_lines}: Line text is empty" ) continue # Convert polygon to bounding box x_coords = [p[0] for p in box] y_coords = [p[1] for p in box] line_left_paddle = float(min(x_coords)) line_top_paddle = float(min(y_coords)) line_right_paddle = float(max(x_coords)) line_bottom_paddle = float(max(y_coords)) line_width_paddle = line_right_paddle - line_left_paddle line_height_paddle = line_bottom_paddle - line_top_paddle # Since we're using the exact image PaddleOCR processed, coordinates are already in image space line_left = line_left_paddle line_top = line_top_paddle line_width = line_width_paddle line_height = line_height_paddle # Count words in PaddleOCR output paddle_words = line_text.split() paddle_word_count = len(paddle_words) # If confidence is low, use inference-server for a second opinion if line_conf <= confidence_threshold: # Ensure minimum line height for inference-server processing min_line_height = max( line_height, 20 ) # Minimum 20 pixels for text line # Calculate crop coordinates with padding # Convert floats to integers and apply padding, clamping to image bounds crop_left = max(0, int(round(line_left - padding))) crop_top = max(0, int(round(line_top - padding))) crop_right = min( img_width, int(round(line_left + line_width + padding)) ) crop_bottom = min( img_height, int(round(line_top + min_line_height + padding)) ) # Ensure crop dimensions are valid if crop_right <= crop_left or crop_bottom <= crop_top: # Invalid crop, keep original PaddleOCR result print( f"Current line {i + 1} of {num_lines}: Invalid crop, keeping original PaddleOCR result" ) continue # Crop the line image cropped_image = image.crop( (crop_left, crop_top, crop_right, crop_bottom) ) # Check if cropped image is too small for inference-server processing crop_width = crop_right - crop_left crop_height = crop_bottom - crop_top if crop_width < 10 or crop_height < 10: # Keep original PaddleOCR result for this line print( f"Current line {i + 1} of {num_lines}: Cropped image is too small, keeping original PaddleOCR result" ) continue # Ensure cropped image is in RGB mode before passing to inference-server if cropped_image.mode != "RGB": cropped_image = cropped_image.convert("RGB") # Match hybrid local VLM: resize/DPI budget then aspect-pad before API try: prepared_for_inference = _prepare_hybrid_line_crop_for_vlm( cropped_image ) except Exception as prep_err: print( f"Current line {i + 1} of {num_lines}: " f"Could not prepare image for inference server: {prep_err}" ) continue # Save the same pixels sent to the API when debugging if SAVE_VLM_INPUT_IMAGES: try: inference_server_debug_dir = os.path.join( self.output_folder, "hybrid_paddle_inference_server_visualisations/hybrid_analysis_input_images", ) os.makedirs(inference_server_debug_dir, exist_ok=True) line_text_safe = safe_sanitize_text(line_text) line_text_shortened = line_text_safe[:20] image_name_safe = safe_sanitize_text(image_name) image_name_shortened = image_name_safe[:20] filename = f"{image_name_shortened}_{line_text_shortened}_hybrid_analysis_input_image.png" filepath = os.path.join( inference_server_debug_dir, filename ) _save_image_with_config_dpi( prepared_for_inference, filepath ) except Exception as save_error: print( f"Warning: Could not save inference-server input image: {save_error}" ) # Use inference-server for OCR on this line with error handling inference_server_result = None inference_server_rec_texts = [] inference_server_rec_scores = [] print( f" Line {i + 1}/{num_lines}: Sending to inference server " f"(Paddle conf: {line_conf:.1f}%, words: {paddle_word_count})" ) try: inference_server_result = _inference_server_ocr_predict( prepared_for_inference, model_name=model_name, image_hybrid_line_prepared=True, ) inference_server_rec_texts = ( inference_server_result.get("rec_texts", []) if inference_server_result else [] ) inference_server_rec_scores = ( inference_server_result.get("rec_scores", []) if inference_server_result else [] ) except Exception as e: print( f"Current line {i + 1} of {num_lines}: Error in inference-server OCR: {e}" ) # Ensure we keep original PaddleOCR result on error inference_server_rec_texts = [] inference_server_rec_scores = [] if not ( inference_server_rec_texts and inference_server_rec_scores ): # Inference server returned empty or no results - keep Paddle print( f" Line {i + 1}/{num_lines}: Inference server returned no results " f"(Paddle conf: {line_conf:.1f}%, text: '{line_text[:40]}{'...' if len(line_text) > 40 else ''}'), keeping Paddle result." ) if inference_server_rec_texts and inference_server_rec_scores: # Combine inference-server words into a single text string inference_server_text = " ".join(inference_server_rec_texts) ### If text starts with "Cannot read", then skip this line if inference_server_text.startswith('""'): print( "Inference server text starts with '" "', skipping line {i + 1} of {num_lines}" ) continue inference_server_word_count = len( inference_server_rec_texts ) inference_server_conf = float( np.median(inference_server_rec_scores) ) # Keep as 0-1 range for paddle format # Only replace if word counts match word_count_allowed_difference = 7 if ( inference_server_word_count - paddle_word_count <= word_count_allowed_difference and inference_server_word_count - paddle_word_count >= -word_count_allowed_difference ): message_output = ( f" Re-OCR'd line: '{line_text}' (conf: {line_conf:.1f}, words: {paddle_word_count}) " f"-> '{inference_server_text}' (conf: {inference_server_conf*100:.1f}, words: {inference_server_word_count}) [Inference Server]" ) print(message_output) if REPORT_VLM_OUTPUTS_TO_GUI: try: gr.Info(message_output, duration=2) except Exception: # gr.Info may not be available in worker process, ignore pass # For exporting example image comparisons safe_filename = ( instance_self._create_safe_filename_with_confidence( line_text, inference_server_text, int(line_conf), int(inference_server_conf * 100), "Inference Server", ) ) if SAVE_EXAMPLE_HYBRID_IMAGES: # Normalize and validate image_name to prevent path traversal attacks normalized_image_name = os.path.normpath( image_name + "_hybrid_paddle_inference_server" ) if ( ".." in normalized_image_name or "/" in normalized_image_name or "\\" in normalized_image_name ): normalized_image_name = "safe_image" hybrid_ocr_examples_folder = ( instance_self.output_folder + f"/hybrid_ocr_examples/{normalized_image_name}" ) # Validate the constructed path is safe if not validate_folder_containment( hybrid_ocr_examples_folder, OUTPUT_FOLDER ): raise ValueError( f"Unsafe hybrid_ocr_examples folder path: {hybrid_ocr_examples_folder}" ) if not os.path.exists(hybrid_ocr_examples_folder): os.makedirs(hybrid_ocr_examples_folder) output_image_path = ( hybrid_ocr_examples_folder + f"/{safe_filename}.png" ) _save_image_with_config_dpi( prepared_for_inference, output_image_path ) # Replace with inference-server result in paddle_results format # Update rec_texts, rec_scores, and rec_models for this line rec_texts[i] = inference_server_text rec_scores[i] = inference_server_conf rec_models[i] = "Inference Server" # Ensure page_result is updated with the modified rec_models list page_result["rec_models"] = rec_models else: print( f" Line: '{line_text}' (conf: {line_conf:.1f}, words: {paddle_word_count}) -> " f"Inference-server result '{inference_server_text}' (conf: {inference_server_conf*100:.1f}, words: {inference_server_word_count}) " f"word count mismatch. Keeping PaddleOCR result." ) else: # Inference-server returned empty or no results - keep original PaddleOCR result if line_conf <= confidence_threshold: pass return page_results modified_paddle_results = _process_page_result_with_hybrid_inference_server_ocr( copied_paddle_results, image, img_width, img_height, input_image_width, input_image_height, confidence_threshold, image_name, self, padding, ) return modified_paddle_results def perform_ocr( self, image: Union[str, Image.Image, np.ndarray], ocr: Optional[Any] = None, bedrock_runtime=None, gemini_client=None, gemini_config=None, azure_openai_client=None, vlm_model_choice: str = None, inference_server_model_name: str = None, page_index_0: Optional[int] = None, ) -> Tuple[List[OCRResult], int, int, str]: """ Performs OCR on the given image using the configured engine. page_index_0: 0-based page index for VLM prompt/response log filenames when the basename does not encode the page (optional). """ if isinstance(image, str): image_path = image image_name = os.path.basename(image) image = Image.open(image) elif isinstance(image, np.ndarray): image = Image.fromarray(image) image_path = "" image_name = "unknown_image_name" # Pre-process image # Store original dimensions BEFORE preprocessing (needed for coordinate conversion) original_image_width = None original_image_height = None original_image_for_visualization = ( None # Store original image for visualization ) if PREPROCESS_LOCAL_OCR_IMAGES: # print("Pre-processing image...") # Get original dimensions before preprocessing original_image_width, original_image_height = image.size # Store original image for visualization (coordinates are in original space) original_image_for_visualization = image.copy() image, preprocessing_metadata = self.image_preprocessor.preprocess_image( image ) # Only export preprocessed images when they are actually used as OCR input. # Full-page VLM-style OCR paths use the original image for coordinate consistency. save_preprocessed_for_engine = self.ocr_engine not in ( "vlm", "inference-server", "bedrock-vlm", "gemini-vlm", "azure-openai-vlm", ) if SAVE_PREPROCESS_IMAGES and save_preprocessed_for_engine: # print("Saving pre-processed image...") image_basename = os.path.basename(image_name) output_path = os.path.join( self.output_folder, "preprocessed_images", image_basename + "_preprocessed_image.png", ) os.makedirs(os.path.dirname(output_path), exist_ok=True) _save_image_with_config_dpi(image, output_path) # print(f"Pre-processed image saved to {output_path}") else: preprocessing_metadata = dict() original_image_width, original_image_height = image.size # When preprocessing is disabled, the current image is the original original_image_for_visualization = image.copy() image_width, image_height = image.size # Store original image for line-to-word conversion when PaddleOCR processes original image original_image_for_cropping = None paddle_processed_original = False # Note: In testing I haven't seen that this necessarily improves results if self.ocr_engine == "hybrid-paddle": try: pass except Exception as e: raise ImportError( f"Error importing PaddleOCR: {e}. Please install it using 'pip install paddleocr paddlepaddle' in your python environment and retry." ) # Try hybrid with original image for cropping: ocr_data = self._perform_hybrid_ocr(image, image_name=image_name) elif self.ocr_engine == "hybrid-vlm": # Try hybrid VLM with original image for cropping: ocr_data = self._perform_hybrid_ocr(image, image_name=image_name) # Initialize VLM token tracking variables vlm_total_input_tokens = 0 vlm_total_output_tokens = 0 vlm_model_name = "" if self.ocr_engine == "vlm": # VLM page-level OCR - sends whole page to VLM and gets structured line-level results # Use original image (before preprocessing) for VLM since coordinates should be in original space vlm_image = ( original_image_for_visualization if original_image_for_visualization is not None else image ) ocr_data, vlm_input_tokens, vlm_output_tokens, vlm_model_name = ( _vlm_page_ocr_predict( vlm_image, image_name=image_name, output_folder=self.output_folder, page_index_0=page_index_0, ) ) vlm_total_input_tokens = vlm_input_tokens vlm_total_output_tokens = vlm_output_tokens # VLM returns data already in the expected format, so no conversion needed elif self.ocr_engine == "inference-server": # Inference-server page-level OCR - sends whole page to inference-server API and gets structured line-level results # Use original image (before preprocessing) for inference-server since coordinates should be in original space inference_server_image = ( original_image_for_visualization if original_image_for_visualization is not None else image ) ocr_data, vlm_input_tokens, vlm_output_tokens, vlm_model_name = ( _inference_server_page_ocr_predict( inference_server_image, image_name=image_name, normalised_coords_range=999, output_folder=self.output_folder, model_name=inference_server_model_name, page_index_0=page_index_0, ) ) vlm_total_input_tokens = vlm_input_tokens vlm_total_output_tokens = vlm_output_tokens # Inference-server returns data already in the expected format, so no conversion needed elif self.ocr_engine == "bedrock-vlm": # Bedrock page-level OCR - sends whole page to Bedrock API and gets structured line-level results # Use original image (before preprocessing) for Bedrock since coordinates should be in original space bedrock_image = ( original_image_for_visualization if original_image_for_visualization is not None else image ) # Get model choice from parameter or config from tools.config import CLOUD_VLM_MODEL_CHOICE model_choice = ( vlm_model_choice if vlm_model_choice else CLOUD_VLM_MODEL_CHOICE ) # Full-page VLM prompt instructs all models to use 0-999 coordinates; convert for any model_choice normalised_coords_range = 999 ocr_data, vlm_input_tokens, vlm_output_tokens, vlm_model_name = ( _bedrock_page_ocr_predict( bedrock_image, image_name=image_name, normalised_coords_range=normalised_coords_range, output_folder=self.output_folder, model_choice=model_choice, bedrock_runtime=bedrock_runtime, page_index_0=page_index_0, ) ) vlm_total_input_tokens = vlm_input_tokens vlm_total_output_tokens = vlm_output_tokens # Bedrock returns data already in the expected format, so no conversion needed elif self.ocr_engine == "gemini-vlm": # Gemini page-level OCR - sends whole page to Gemini API and gets structured line-level results # Use original image (before preprocessing) for Gemini since coordinates should be in original space gemini_image = ( original_image_for_visualization if original_image_for_visualization is not None else image ) # Get model choice from parameter or config from tools.config import CLOUD_VLM_MODEL_CHOICE model_choice = ( vlm_model_choice if vlm_model_choice else CLOUD_VLM_MODEL_CHOICE ) ocr_data, vlm_input_tokens, vlm_output_tokens, vlm_model_name = ( _gemini_page_ocr_predict( gemini_image, image_name=image_name, normalised_coords_range=999, # Full-page prompt uses 0-999 coordinates output_folder=self.output_folder, model_choice=model_choice, client=gemini_client, config=gemini_config, page_index_0=page_index_0, ) ) vlm_total_input_tokens = vlm_input_tokens vlm_total_output_tokens = vlm_output_tokens # Gemini returns data already in the expected format, so no conversion needed elif self.ocr_engine == "azure-openai-vlm": # Azure/OpenAI page-level OCR - sends whole page to Azure/OpenAI API and gets structured line-level results # Use original image (before preprocessing) for Azure/OpenAI since coordinates should be in original space azure_image = ( original_image_for_visualization if original_image_for_visualization is not None else image ) # Get model choice from parameter or config from tools.config import CLOUD_VLM_MODEL_CHOICE model_choice = ( vlm_model_choice if vlm_model_choice else CLOUD_VLM_MODEL_CHOICE ) ocr_data, vlm_input_tokens, vlm_output_tokens, vlm_model_name = ( _azure_openai_page_ocr_predict( azure_image, image_name=image_name, normalised_coords_range=999, # Full-page prompt uses 0-999 coordinates output_folder=self.output_folder, model_choice=model_choice, client=azure_openai_client, page_index_0=page_index_0, ) ) vlm_total_input_tokens = vlm_input_tokens vlm_total_output_tokens = vlm_output_tokens # Azure/OpenAI returns data already in the expected format, so no conversion needed elif self.ocr_engine == "tesseract": tesseract_cfg = _ensure_tessdata_available_in_env(self.tesseract_config) ocr_data = pytesseract.image_to_data( image, output_type=pytesseract.Output.DICT, config=tesseract_cfg, lang=self.tesseract_lang, # Ensure the Tesseract language data (e.g., fra.traineddata) is installed on your system. ) if TESSERACT_WORD_LEVEL_OCR is False: ocr_df = pd.DataFrame(ocr_data) # Filter out invalid entries (confidence == -1) ocr_df = ocr_df[ocr_df.conf != -1] # Group by line and aggregate text line_groups = ocr_df.groupby(["block_num", "par_num", "line_num"]) ocr_data = line_groups.apply(self._calculate_line_bbox).reset_index() # Convert DataFrame to dictionary of lists format expected by downstream code ocr_data = { "text": ocr_data["text"].tolist(), "left": ocr_data["left"].astype(int).tolist(), "top": ocr_data["top"].astype(int).tolist(), "width": ocr_data["width"].astype(int).tolist(), "height": ocr_data["height"].astype(int).tolist(), "conf": ocr_data["conf"].tolist(), "model": ["Tesseract"] * len(ocr_data), # Add model field } elif ( self.ocr_engine == "paddle" or self.ocr_engine == "hybrid-paddle-vlm" or self.ocr_engine == "hybrid-paddle-inference-server" ): if ocr is None: if hasattr(self, "paddle_ocr") and self.paddle_ocr is not None: ocr = self.paddle_ocr elif not SPACES_ZERO_GPU: raise ValueError( "No OCR object provided and 'paddle_ocr' is not initialised." ) try: pass except Exception as e: raise ImportError( f"Error importing PaddleOCR: {e}. Please install it using 'pip install paddleocr paddlepaddle' in your python environment and retry." ) prepare_page_for_hybrid_vlm = ( PREPARE_PAGE_FOR_HYBRID_VLM_BEFORE_PADDLE and self.ocr_engine in ["hybrid-paddle-vlm", "hybrid-paddle-inference-server"] ) paddle_prepared_width = None paddle_prepared_height = None paddle_input_image = image if prepare_page_for_hybrid_vlm: # Resize/pad the full page once so that line crops inherit the # VLM-ready pixel density (reduces per-crop VLM resizing work). if paddle_input_image.mode != "RGB": paddle_input_image = paddle_input_image.convert("RGB") paddle_input_image = _prepare_image_for_vlm( paddle_input_image, hybrid_vlm=True, ) paddle_prepared_width, paddle_prepared_height = paddle_input_image.size print( "Hybrid OCR: preparing PaddleOCR input page for VLM constraints " f"({paddle_prepared_width}x{paddle_prepared_height})." ) if not image_path: image_np = np.array( paddle_input_image ) # image_processed (possibly resized) # Check that sizes match the PaddleOCR input image we constructed. image_np_height, image_np_width = image_np.shape[:2] expected_w, expected_h = paddle_input_image.size if image_np_width != expected_w or image_np_height != expected_h: raise ValueError( f"Image size mismatch: {image_np_width}x{image_np_height} != {expected_w}x{expected_h}" ) # PaddleOCR may need an RGB image. Ensure it has 3 channels. if len(image_np.shape) == 2: image_np = np.stack([image_np] * 3, axis=-1) else: image_np = np.array(paddle_input_image) # paddle_results = ocr.predict(image_np) paddle_results = paddle_predict(image_np) # PaddleOCR processed the prepared image (not a file-path open) paddle_processed_original = False # Store the exact image that PaddleOCR processed (convert numpy array back to PIL Image) # This ensures we crop from the exact same image PaddleOCR analyzed. paddle_processed_image = Image.fromarray(image_np.astype(np.uint8)) else: # When using image path, load image to get dimensions temp_image = Image.open(image_path) # For file path, we still keep the original image for visualization # and for any downstream coordinate scaling. original_image_for_cropping = temp_image.copy() if prepare_page_for_hybrid_vlm: # Run PaddleOCR on the VLM-prepared page (to keep rec_polys # consistent with the line crops we send to VLM). paddle_input_image = temp_image if paddle_input_image.mode != "RGB": paddle_input_image = paddle_input_image.convert("RGB") paddle_input_image = _prepare_image_for_vlm( paddle_input_image, hybrid_vlm=True, ) paddle_prepared_width, paddle_prepared_height = ( paddle_input_image.size ) image_np = np.array(paddle_input_image) if len(image_np.shape) == 2: image_np = np.stack([image_np] * 3, axis=-1) # paddle_results = ocr.predict(image_np) paddle_results = paddle_predict(image_np) paddle_processed_original = False paddle_processed_image = paddle_input_image.copy() else: # Use PaddleOCR's file-path loading (original behaviour). try: # paddle_results = ocr.predict(image_path) paddle_results = paddle_predict(image_path) except Exception as ocr_path_exc: # PaddleOCR's file-path path can hit OpenCV decode issues on # specific pages. Retry using the already-loaded PIL image # to avoid the OpenCV "read from disk" path. print( f"WARNING: PaddleOCR failed reading image path via OpenCV " f"for {image_path}. Retrying with in-memory numpy image. " f"Error: {ocr_path_exc}" ) paddle_input_image = temp_image if paddle_input_image.mode != "RGB": paddle_input_image = paddle_input_image.convert("RGB") image_np = np.array(paddle_input_image) if len(image_np.shape) == 2: image_np = np.stack([image_np] * 3, axis=-1) paddle_results = paddle_predict(image_np) # PaddleOCR processed the original image from file path paddle_processed_original = True # Store the exact image that PaddleOCR processed (from file path) paddle_processed_image = temp_image.copy() # Save PaddleOCR visualization with bounding boxes if paddle_results and self.save_page_ocr_visualisations is True: for res in paddle_results: # self.output_folder is already validated and normalized at construction time paddle_viz_folder = os.path.join( self.output_folder, "paddle_visualisations" ) # Double-check the constructed path is safe if not validate_folder_containment( paddle_viz_folder, OUTPUT_FOLDER ): raise ValueError( f"Unsafe paddle visualisations folder path: {paddle_viz_folder}" ) os.makedirs(paddle_viz_folder, exist_ok=True) if hasattr(res, "save_to_img"): res.save_to_img(paddle_viz_folder) # If we prepared/resized the page before running PaddleOCR, ensure # each result dict reports the correct coordinate space size. # This lets _convert_paddle_to_tesseract_format scale bboxes back # into the original page pixel space reliably. if ( prepare_page_for_hybrid_vlm and paddle_prepared_width and paddle_prepared_height and isinstance(paddle_results, list) ): for res in paddle_results: try: if isinstance(res, dict): res["image_width"] = paddle_prepared_width res["image_height"] = paddle_prepared_height except Exception: pass if self.ocr_engine == "hybrid-paddle-vlm": modified_paddle_results = self._perform_hybrid_paddle_vlm_ocr( paddle_processed_image, # Use the exact image PaddleOCR processed ocr=ocr, paddle_results=copy.deepcopy(paddle_results), image_name=image_name, input_image_width=original_image_width, input_image_height=original_image_height, ) elif self.ocr_engine == "hybrid-paddle-inference-server": modified_paddle_results = self._perform_hybrid_paddle_inference_server_ocr( paddle_processed_image, # Use the exact image PaddleOCR processed ocr=ocr, paddle_results=copy.deepcopy(paddle_results), image_name=image_name, input_image_width=original_image_width, input_image_height=original_image_height, model_name=inference_server_model_name, ) else: modified_paddle_results = copy.deepcopy(paddle_results) ocr_data = self._convert_paddle_to_tesseract_format( modified_paddle_results, input_image_width=original_image_width, input_image_height=original_image_height, ) if self.save_page_ocr_visualisations is True: # Save output to image with identified bounding boxes # Use original image since coordinates are in original image space # Prefer original_image_for_cropping (when PaddleOCR processed from file path), # otherwise use original_image_for_visualization (stored before preprocessing) viz_image = ( original_image_for_cropping if original_image_for_cropping is not None else ( original_image_for_visualization if original_image_for_visualization is not None else image ) ) if isinstance(viz_image, Image.Image): # Convert PIL Image to numpy array in BGR format for OpenCV image_cv = cv2.cvtColor(np.array(viz_image), cv2.COLOR_RGB2BGR) else: image_cv = np.array(viz_image) if len(image_cv.shape) == 2: image_cv = cv2.cvtColor(image_cv, cv2.COLOR_GRAY2BGR) elif len(image_cv.shape) == 3 and image_cv.shape[2] == 3: # Assume RGB, convert to BGR image_cv = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR) # Draw all bounding boxes on the image for i in range(len(ocr_data["text"])): left = int(ocr_data["left"][i]) top = int(ocr_data["top"][i]) width = int(ocr_data["width"][i]) height = int(ocr_data["height"][i]) # Ensure coordinates are within image bounds left = max(0, min(left, image_cv.shape[1] - 1)) top = max(0, min(top, image_cv.shape[0] - 1)) right = max(left + 1, min(left + width, image_cv.shape[1])) bottom = max(top + 1, min(top + height, image_cv.shape[0])) cv2.rectangle( image_cv, (left, top), (right, bottom), (0, 255, 0), 2 ) # Save the visualization once with all boxes drawn paddle_viz_folder = os.path.join( self.output_folder, "paddle_visualisations" ) # Double-check the constructed path is safe if not validate_folder_containment(paddle_viz_folder, OUTPUT_FOLDER): raise ValueError( f"Unsafe paddle visualisations folder path: {paddle_viz_folder}" ) os.makedirs(paddle_viz_folder, exist_ok=True) # Generate safe filename if image_name: base_name = os.path.splitext(os.path.basename(image_name))[0] # Increment the number at the end of base_name # This converts zero-indexed input to one-indexed output incremented_base_name = base_name # Find the number pattern at the end # Matches patterns like: _0, _00, 0, 00, etc. pattern = r"(\d+)$" match = re.search(pattern, base_name) if match: number_str = match.group(1) number = int(number_str) incremented_number = number + 1 # Preserve the same number of digits (padding with zeros if needed) incremented_str = str(incremented_number).zfill(len(number_str)) incremented_base_name = re.sub( pattern, incremented_str, base_name ) # Sanitize filename to avoid issues with special characters incremented_base_name = safe_sanitize_text( incremented_base_name, max_length=50 ) filename = f"{incremented_base_name}_initial_bounding_boxes.jpg" else: timestamp = int(time.time()) filename = f"initial_bounding_boxes_{timestamp}.jpg" output_path = os.path.join(paddle_viz_folder, filename) max_filesize = 500 * 1024 # 500kb in bytes quality = 95 # Start high, OpenCV JPEG quality range is 0-100 # Try lowering JPEG quality until file is below size limit is_saved = False while quality >= 10: cv2.imwrite( output_path, image_cv, [int(cv2.IMWRITE_JPEG_QUALITY), quality] ) if ( os.path.exists(output_path) and os.path.getsize(output_path) <= max_filesize ): is_saved = True break quality -= 5 if not is_saved: # Save as lowest acceptable quality if cannot get under 500kb, or raise warning cv2.imwrite( output_path, image_cv, [int(cv2.IMWRITE_JPEG_QUALITY), 10] ) else: raise RuntimeError(f"Unsupported OCR engine: {self.ocr_engine}") # Always check for scale_factor, even if preprocessing_metadata is empty # This ensures rescaling happens correctly when preprocessing was applied scale_factor = ( preprocessing_metadata.get("scale_factor", 1.0) if preprocessing_metadata else 1.0 ) if scale_factor != 1.0: # Skip rescaling for PaddleOCR since _convert_paddle_to_tesseract_format # already scales coordinates directly to original image dimensions # hybrid-paddle-vlm also uses PaddleOCR and converts to original space # Skip rescaling for VLM since it returns coordinates in original image space if ( self.ocr_engine == "paddle" or self.ocr_engine == "hybrid-paddle-vlm" or self.ocr_engine == "hybrid-paddle-inference-server" or self.ocr_engine == "vlm" or self.ocr_engine == "inference-server" or self.ocr_engine == "bedrock-vlm" or self.ocr_engine == "gemini-vlm" or self.ocr_engine == "azure-openai-vlm" ): pass # print(f"Skipping rescale_ocr_data for PaddleOCR/VLM (already scaled to original dimensions)") else: # print("rescaling ocr_data with scale_factor: ", scale_factor) ocr_data = rescale_ocr_data(ocr_data, scale_factor) # print("Finished rescaling ocr_data") # Convert line-level results to word-level if configured and needed _paddle_line_engines = ( "paddle", "hybrid-paddle-vlm", "hybrid-paddle-inference-server", ) _skip_line_to_word = PADDLE_PRESERVE_LINE_BOXES and ( self.ocr_engine in _paddle_line_engines ) if ( CONVERT_LINE_TO_WORD_LEVEL and not _skip_line_to_word and self._is_line_level_data(ocr_data) ): # print("Converting line-level OCR results to word-level...") # Check if coordinates need to be scaled to match the image we're cropping from # For PaddleOCR: _convert_paddle_to_tesseract_format converts coordinates to original image space # - If PaddleOCR processed the original image (image_path provided), crop from original image (no scaling) # - If PaddleOCR processed the preprocessed image (no image_path), scale coordinates to preprocessed space and crop from preprocessed image # For Tesseract: OCR runs on preprocessed image # - If scale_factor != 1.0, rescale_ocr_data converted coordinates to original space, so crop from original image # - If scale_factor == 1.0, coordinates are still in preprocessed space, so crop from preprocessed image needs_scaling = False crop_image = image # Default to preprocessed image crop_image_width = image_width crop_image_height = image_height if ( PREPROCESS_LOCAL_OCR_IMAGES and original_image_width and original_image_height ): if ( self.ocr_engine == "paddle" or self.ocr_engine == "hybrid-paddle-vlm" or self.ocr_engine == "hybrid-paddle-inference-server" ): # PaddleOCR coordinates are converted to original space by _convert_paddle_to_tesseract_format # hybrid-paddle-vlm also uses PaddleOCR and converts to original space if paddle_processed_original: # PaddleOCR processed the original image, so crop from original image # No scaling needed - coordinates are already in original space crop_image = original_image_for_cropping crop_image_width = original_image_width crop_image_height = original_image_height needs_scaling = False else: # PaddleOCR processed the preprocessed image, so scale coordinates to preprocessed space needs_scaling = True elif ( self.ocr_engine == "vlm" or self.ocr_engine == "inference-server" or self.ocr_engine == "bedrock-vlm" or self.ocr_engine == "gemini-vlm" or self.ocr_engine == "azure-openai-vlm" ): # VLM/Cloud VLM returns coordinates in original image space (since we pass original image to VLM) # So we need to crop from the original image, not the preprocessed image if original_image_for_visualization is not None: # Coordinates are in original space, so crop from original image crop_image = original_image_for_visualization crop_image_width = original_image_width crop_image_height = original_image_height needs_scaling = False else: # Fallback to preprocessed image if original not available needs_scaling = False elif self.ocr_engine == "tesseract": # For Tesseract: if scale_factor != 1.0, rescale_ocr_data converted coordinates to original space # So we need to crop from the original image, not the preprocessed image if ( scale_factor != 1.0 and original_image_for_visualization is not None ): # Coordinates are in original space, so crop from original image crop_image = original_image_for_visualization crop_image_width = original_image_width crop_image_height = original_image_height needs_scaling = False else: # scale_factor == 1.0, so coordinates are still in preprocessed space # Crop from preprocessed image - no scaling needed needs_scaling = False if needs_scaling: # Calculate scale factors from original to preprocessed scale_x = image_width / original_image_width scale_y = image_height / original_image_height # Scale coordinates to preprocessed image space for cropping scaled_ocr_data = { "text": ocr_data["text"], "left": [x * scale_x for x in ocr_data["left"]], "top": [y * scale_y for y in ocr_data["top"]], "width": [w * scale_x for w in ocr_data["width"]], "height": [h * scale_y for h in ocr_data["height"]], "conf": ocr_data["conf"], "model": ocr_data["model"], } ocr_data = self._convert_line_to_word_level( scaled_ocr_data, crop_image_width, crop_image_height, crop_image, image_name=image_name, ) # Scale word-level results back to original image space scale_factor_x = original_image_width / image_width scale_factor_y = original_image_height / image_height for i in range(len(ocr_data["left"])): ocr_data["left"][i] = ocr_data["left"][i] * scale_factor_x ocr_data["top"][i] = ocr_data["top"][i] * scale_factor_y ocr_data["width"][i] = ocr_data["width"][i] * scale_factor_x ocr_data["height"][i] = ocr_data["height"][i] * scale_factor_y else: # No scaling needed - coordinates match the crop image space ocr_data = self._convert_line_to_word_level( ocr_data, crop_image_width, crop_image_height, crop_image, image_name=image_name, ) # print("Finished converting line level results to word level") # The rest of your processing pipeline now works for both engines ocr_result = ocr_data # Filter out empty strings and non-positive confidence. # Do not use int(conf): fractional confidences in (0, 1) from line→word conversion # (e.g. AdaptiveSegmenter) would become 0 and drop every word (empty OCR tables / UI). valid_indices = [] for i, text in enumerate(ocr_result["text"]): if not (text and str(text).strip()): continue try: c = float(ocr_result["conf"][i]) except (TypeError, ValueError, IndexError, KeyError): continue if c > 0: valid_indices.append(i) # Determine default model based on OCR engine if model field is not present if "model" in ocr_result and len(ocr_result["model"]) == len( ocr_result["text"] ): # Model field exists and has correct length - use it (preserves VLM/inference-server replacements) def get_model_name(idx): return ocr_result["model"][idx] else: # Model field not present or incorrect length - use default based on engine default_model = ( "Tesseract" if self.ocr_engine == "tesseract" else ( "Paddle" if self.ocr_engine == "paddle" else ( "Tesseract" if self.ocr_engine == "hybrid-paddle" else ( "Tesseract" if self.ocr_engine == "hybrid-vlm" else ( "Paddle" if self.ocr_engine == "hybrid-paddle-vlm" else ( "Paddle" if self.ocr_engine == "hybrid-paddle-inference-server" else ( "VLM" if self.ocr_engine == "vlm" else ( "Inference Server" if self.ocr_engine == "inference-server" else ( "Bedrock" if self.ocr_engine == "bedrock-vlm" else ( "Gemini" if self.ocr_engine == "gemini-vlm" else ( "Azure/OpenAI" if self.ocr_engine == "azure-openai-vlm" else None ) ) ) ) ) ) ) ) ) ) ) def get_model_name(idx): return default_model output = [ OCRResult( text=clean_unicode_text( ocr_result["text"][i], preserve_international_scripts=True ), left=ocr_result["left"][i], top=ocr_result["top"][i], width=ocr_result["width"][i], height=ocr_result["height"][i], conf=round(float(ocr_result["conf"][i]), 0), line=( ocr_result.get("line", [None] * len(ocr_result["text"]))[i] if isinstance(ocr_result, dict) else None ), model=get_model_name(i), ) for i in valid_indices ] return output, vlm_total_input_tokens, vlm_total_output_tokens, vlm_model_name def analyze_text( self, line_level_ocr_results: List[OCRResult], ocr_results_with_words: Dict[str, Dict], chosen_redact_comprehend_entities: List[str], pii_identification_method: str = LOCAL_PII_OPTION, comprehend_client="", custom_entities: List[str] = custom_entities, language: Optional[str] = DEFAULT_LANGUAGE, nlp_analyser: AnalyzerEngine = None, bedrock_runtime=None, model_choice: str = CLOUD_LLM_PII_MODEL_CHOICE, custom_llm_instructions: str = "", chosen_llm_entities: List[str] = None, file_name: Optional[str] = None, page_number: Optional[int] = None, **text_analyzer_kwargs, ) -> List[CustomImageRecognizerResult]: page_text = "" page_text_mapping = list() all_text_line_results = list() comprehend_query_number = 0 # Track LLM token usage llm_total_input_tokens = 0 llm_total_output_tokens = 0 llm_model_name = "" # Extract allow_list from text_analyzer_kwargs if provided # This allows allow_list terms to "overrule" LLM PII detection results allow_list = text_analyzer_kwargs.get("allow_list", []) if allow_list is None: allow_list = [] # Default chosen_llm_entities to chosen_redact_comprehend_entities if not provided if chosen_llm_entities is None: chosen_llm_entities = chosen_redact_comprehend_entities # Filter out CUSTOM_VLM_* entities (these are handled separately via VLM, not LLM) # and validate that we have either entities or custom instructions filtered_llm_entities = [ entity for entity in (chosen_llm_entities or []) if not entity.startswith("CUSTOM_VLM_") ] # If only CUSTOM_VLM_* entities (and no custom instructions), skip LLM analysis and return blank if not filtered_llm_entities and ( not custom_llm_instructions or not custom_llm_instructions.strip() ): if pii_identification_method == AWS_LLM_PII_OPTION: return (list(), 0, "", 0, 0) raise ValueError( "No standard entities selected for LLM PII detection and no custom instructions provided. " "Please select at least one entity type (excluding CUSTOM_VLM_* entities) or provide custom instructions." ) if not nlp_analyser: nlp_analyser = self.analyzer_engine # Collect all text and create mapping for i, line_level_ocr_result in enumerate(line_level_ocr_results): if page_text: page_text += " " start_pos = len(page_text) page_text += line_level_ocr_result.text # Note: We're not passing line_characters here since it's not needed for this use case page_text_mapping.append((start_pos, i, line_level_ocr_result, None)) # Determine language for downstream services aws_language = language or getattr(self, "language", None) or "en" valid_language_entities = nlp_analyser.registry.get_supported_entities( languages=[language] ) if "CUSTOM" not in valid_language_entities: valid_language_entities.append("CUSTOM") if "CUSTOM_FUZZY" not in valid_language_entities: valid_language_entities.append("CUSTOM_FUZZY") # Process using either Local or AWS Comprehend if pii_identification_method == LOCAL_PII_OPTION: language_supported_entities = filter_entities_for_language( custom_entities, valid_language_entities, language ) if language_supported_entities: text_analyzer_kwargs["entities"] = language_supported_entities else: out_message = f"No relevant entities supported for language: {language}" print(out_message) raise Warning(out_message) # Filter out LLM-specific parameters that Presidio AnalyzerEngine doesn't accept # Also exclude allow_list since we pass it explicitly presidio_kwargs = { k: v for k, v in text_analyzer_kwargs.items() if k not in [ "inference_method", "model_choice", "api_url", "local_model", "tokenizer", "assistant_model", "client", "client_config", "temperature", "max_tokens", "custom_instructions", "allow_list", ] } analyzer_result = nlp_analyser.analyze( text=page_text, language=language, allow_list=allow_list, **presidio_kwargs, ) all_text_line_results = map_back_entity_results( analyzer_result, page_text_mapping, all_text_line_results, allow_list=allow_list, ) elif pii_identification_method == AWS_PII_OPTION: # Run local detection for any custom entities (including CUSTOM/CUSTOM_FUZZY) local_custom_entities = [ entity for entity in (chosen_redact_comprehend_entities or []) if entity in (custom_entities or []) or entity in ("CUSTOM", "CUSTOM_FUZZY") ] if local_custom_entities: # Filter entities to only include those supported by the language language_supported_entities = filter_entities_for_language( local_custom_entities, valid_language_entities, language ) if language_supported_entities: text_analyzer_kwargs["entities"] = language_supported_entities # Filter out LLM-specific parameters that Presidio AnalyzerEngine doesn't accept presidio_kwargs = { k: v for k, v in text_analyzer_kwargs.items() if k not in [ "inference_method", "model_choice", "api_url", "local_model", "tokenizer", "assistant_model", "client", "client_config", "temperature", "max_tokens", "custom_instructions", "allow_list", ] } page_analyser_result = nlp_analyser.analyze( text=page_text, language=language, allow_list=allow_list, **presidio_kwargs, ) all_text_line_results = map_back_entity_results( page_analyser_result, page_text_mapping, all_text_line_results, allow_list=allow_list, ) # Guard: only call AWS Comprehend when at least one non-custom Comprehend entity is selected. aws_comprehend_entities = [ entity for entity in (chosen_redact_comprehend_entities or []) if entity in (FULL_COMPREHEND_ENTITY_LIST or []) and entity not in ("CUSTOM", "CUSTOM_FUZZY") ] # Process text in batches for AWS Comprehend current_batch = "" current_batch_mapping = list() batch_char_count = 0 batch_word_count = 0 for i, text_line in enumerate( line_level_ocr_results ): # Changed from line_level_text_results_list words = text_line.text.split() word_start_positions = list() current_pos = 0 for word in words: word_start_positions.append(current_pos) current_pos += len(word) + 1 word_idx = 0 while word_idx < len(words): word = words[word_idx] new_batch_char_count = len(current_batch) + len(word) + 1 # Check if we've hit the limit limit_reached = ( batch_word_count >= DEFAULT_NEW_BATCH_WORD_COUNT or new_batch_char_count >= DEFAULT_NEW_BATCH_CHAR_COUNT ) if limit_reached: # Add the current word to the batch first if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if ( not current_batch_mapping or current_batch_mapping[-1][1] != i ): current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, None, word_start_positions[word_idx], ) ) # Check if current word ends with phrase punctuation if ends_with_phrase_punctuation(word): # Process current batch all_text_line_results = do_aws_comprehend_call( current_batch, current_batch_mapping, comprehend_client, aws_language, text_analyzer_kwargs.get("allow_list", []), aws_comprehend_entities, all_text_line_results, ) if aws_comprehend_entities: comprehend_query_number += ( len(current_batch.strip()) + COMPREHEND_CHARACTERS_PER_UNIT - 1 ) // COMPREHEND_CHARACTERS_PER_UNIT # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx += 1 else: # Look ahead in current line for phrase-ending punctuation or end of line lookahead_idx = word_idx + 1 lookahead_batch = current_batch lookahead_char_count = batch_char_count lookahead_word_count = batch_word_count lookahead_mapping = list(current_batch_mapping) # Continue adding words until we find phrase-ending punctuation or end of line while lookahead_idx < len(words): lookahead_word = words[lookahead_idx] # Add the word to lookahead batch if lookahead_batch: lookahead_batch += " " lookahead_char_count += 1 lookahead_batch += lookahead_word lookahead_char_count += len(lookahead_word) lookahead_word_count += 1 if ( not lookahead_mapping or lookahead_mapping[-1][1] != i ): lookahead_mapping.append( ( lookahead_char_count - len(lookahead_word), i, text_line, None, word_start_positions[lookahead_idx], ) ) # Check if this word ends with phrase punctuation if ends_with_phrase_punctuation(lookahead_word): break lookahead_idx += 1 # Use the lookahead batch (either found phrase end or reached end of line) current_batch = lookahead_batch batch_char_count = lookahead_char_count batch_word_count = lookahead_word_count current_batch_mapping = lookahead_mapping # Process current batch all_text_line_results = do_aws_comprehend_call( current_batch, current_batch_mapping, comprehend_client, aws_language, text_analyzer_kwargs.get("allow_list", []), aws_comprehend_entities, all_text_line_results, ) if aws_comprehend_entities: comprehend_query_number += ( len(current_batch.strip()) + COMPREHEND_CHARACTERS_PER_UNIT - 1 ) // COMPREHEND_CHARACTERS_PER_UNIT # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx = lookahead_idx + 1 else: # Normal case: add word to batch if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if ( not current_batch_mapping or current_batch_mapping[-1][1] != i ): current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, None, word_start_positions[word_idx], ) ) word_idx += 1 # Process final batch if any if current_batch: all_text_line_results = do_aws_comprehend_call( current_batch, current_batch_mapping, comprehend_client, aws_language, text_analyzer_kwargs.get("allow_list", []), aws_comprehend_entities, all_text_line_results, ) if aws_comprehend_entities: comprehend_query_number += ( len(current_batch.strip()) + COMPREHEND_CHARACTERS_PER_UNIT - 1 ) // COMPREHEND_CHARACTERS_PER_UNIT elif pii_identification_method == AWS_LLM_PII_OPTION: # LLM-based entity detection using AWS Bedrock try: from tools.llm_entity_detection import do_llm_entity_detection_call except ImportError as e: print(f"Error importing LLM entity detection: {e}") raise ImportError( "LLM entity detection not available. Please ensure llm_funcs.py is accessible." ) if not bedrock_runtime: raise ValueError( "bedrock_runtime is required when using LLM-based PII detection" ) # Set inference method to aws-bedrock if not already set if text_analyzer_kwargs.get("inference_method") is None: text_analyzer_kwargs["inference_method"] = "aws-bedrock" # Update model_choice to use CLOUD_LLM_PII_MODEL_CHOICE for Bedrock, or value from text_analyzer_kwargs if set if text_analyzer_kwargs.get("model_choice") is None: model_choice = CLOUD_LLM_PII_MODEL_CHOICE else: model_choice = text_analyzer_kwargs.get( "model_choice", CLOUD_LLM_PII_MODEL_CHOICE ) # Set LLM model name for tracking (use custom-instructions model when applicable) custom_instructions_model = ( CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE.strip() if isinstance(CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE, str) and CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE.strip() else "" ) if ( (custom_llm_instructions or "").strip() and model_choice == CLOUD_LLM_PII_MODEL_CHOICE and custom_instructions_model ): llm_model_name = custom_instructions_model else: llm_model_name = model_choice or "" # Handle custom entities first (same as AWS Comprehend) # Include CUSTOM/CUSTOM_FUZZY (deny list) so deny-list words are redacted when CUSTOM is selected local_custom_entities = [ entity for entity in (chosen_llm_entities or []) if entity in (custom_entities or []) or entity in ("CUSTOM", "CUSTOM_FUZZY") ] if local_custom_entities: # Filter entities to only include those supported by the language language_supported_entities = filter_entities_for_language( local_custom_entities, valid_language_entities, language ) if language_supported_entities: text_analyzer_kwargs["entities"] = language_supported_entities # Filter out LLM-specific parameters that Presidio AnalyzerEngine doesn't accept presidio_kwargs = { k: v for k, v in text_analyzer_kwargs.items() if k not in [ "inference_method", "model_choice", "api_url", "local_model", "tokenizer", "assistant_model", "client", "client_config", "temperature", "max_tokens", "custom_instructions", "allow_list", ] } page_analyser_result = nlp_analyser.analyze( text=page_text, language=language, allow_list=allow_list, **presidio_kwargs, ) all_text_line_results = map_back_entity_results( page_analyser_result, page_text_mapping, all_text_line_results, allow_list=allow_list, ) # Process text in batches for LLM (same batching logic as AWS Comprehend) current_batch = "" current_batch_mapping = list() batch_char_count = 0 batch_word_count = 0 for i, text_line in enumerate(line_level_ocr_results): words = text_line.text.split() word_start_positions = list() current_pos = 0 for word in words: word_start_positions.append(current_pos) current_pos += len(word) + 1 word_idx = 0 while word_idx < len(words): word = words[word_idx] new_batch_char_count = len(current_batch) + len(word) + 1 # Check if we've hit the limit limit_reached = ( batch_word_count >= DEFAULT_NEW_BATCH_WORD_COUNT or new_batch_char_count >= DEFAULT_NEW_BATCH_CHAR_COUNT ) if limit_reached: # Add the current word to the batch first if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if ( not current_batch_mapping or current_batch_mapping[-1][1] != i ): current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, None, word_start_positions[word_idx], ) ) # Check if current word ends with phrase punctuation if ends_with_phrase_punctuation(word): # Remove 'CUSTOM' and 'CUSTOM_VLM_*' entities from the chosen_llm_entities list # CUSTOM_VLM_* entities are handled separately via VLM, not LLM llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" and not entity.startswith("CUSTOM_VLM_") ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=getattr(self, "output_folder", None), batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get( # "assistant_model" # ), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx += 1 else: # Look ahead in current line for phrase-ending punctuation or end of line lookahead_idx = word_idx + 1 lookahead_batch = current_batch lookahead_char_count = batch_char_count lookahead_word_count = batch_word_count lookahead_mapping = list(current_batch_mapping) # Continue adding words until we find phrase-ending punctuation or end of line while lookahead_idx < len(words): lookahead_word = words[lookahead_idx] # Add the word to lookahead batch if lookahead_batch: lookahead_batch += " " lookahead_char_count += 1 lookahead_batch += lookahead_word lookahead_char_count += len(lookahead_word) lookahead_word_count += 1 if ( not lookahead_mapping or lookahead_mapping[-1][1] != i ): lookahead_mapping.append( ( lookahead_char_count - len(lookahead_word), i, text_line, None, word_start_positions[lookahead_idx], ) ) # Check if this word ends with phrase punctuation if ends_with_phrase_punctuation(lookahead_word): break lookahead_idx += 1 # Use the lookahead batch (either found phrase end or reached end of line) current_batch = lookahead_batch batch_char_count = lookahead_char_count batch_word_count = lookahead_word_count current_batch_mapping = lookahead_mapping # Remove 'CUSTOM' and 'CUSTOM_VLM_*' entities from the chosen_llm_entities list # CUSTOM_VLM_* entities are handled separately via VLM, not LLM llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" and not entity.startswith("CUSTOM_VLM_") ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=getattr(self, "output_folder", None), batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get( # "assistant_model" # ), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx = lookahead_idx + 1 else: # Normal case: add word to batch if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if ( not current_batch_mapping or current_batch_mapping[-1][1] != i ): current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, None, word_start_positions[word_idx], ) ) word_idx += 1 # Process final batch if any if current_batch: # Remove 'CUSTOM' and 'CUSTOM_VLM_*' entities from the chosen_llm_entities list # CUSTOM_VLM_* entities are handled separately via VLM, not LLM llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" and not entity.startswith("CUSTOM_VLM_") ] all_text_line_results, batch_input_tokens, batch_output_tokens = ( do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=getattr(self, "output_folder", None), batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get("inference_method"), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens comprehend_query_number += 1 elif pii_identification_method == INFERENCE_SERVER_PII_OPTION: # LLM-based entity detection using inference server try: from tools.llm_entity_detection import do_llm_entity_detection_call except ImportError as e: print(f"Error importing LLM entity detection: {e}") raise ImportError( "LLM entity detection not available. Please ensure llm_funcs.py is accessible." ) # Set inference method to inference-server if not already set if text_analyzer_kwargs.get("inference_method") is None: text_analyzer_kwargs["inference_method"] = "inference-server" # Set API URL if not already set if text_analyzer_kwargs.get("api_url") is None: text_analyzer_kwargs["api_url"] = INFERENCE_SERVER_API_URL # Set model choice if not already set - use INFERENCE_SERVER_LLM_PII_MODEL_CHOICE if text_analyzer_kwargs.get("model_choice") is None: model_choice = INFERENCE_SERVER_LLM_PII_MODEL_CHOICE text_analyzer_kwargs["model_choice"] = model_choice else: model_choice = text_analyzer_kwargs.get("model_choice") # Set LLM model name for tracking llm_model_name = model_choice or "" # Update model_choice to use the value from text_analyzer_kwargs model_choice = text_analyzer_kwargs.get( "model_choice", INFERENCE_SERVER_LLM_PII_MODEL_CHOICE ) # Handle custom entities first (same as AWS Comprehend) # Include CUSTOM/CUSTOM_FUZZY (deny list) so deny-list words are redacted when CUSTOM is selected local_custom_entities = [ entity for entity in (chosen_llm_entities or []) if entity in (custom_entities or []) or entity in ("CUSTOM", "CUSTOM_FUZZY") ] if local_custom_entities: # Filter entities to only include those supported by the language language_supported_entities = filter_entities_for_language( local_custom_entities, valid_language_entities, language ) if language_supported_entities: text_analyzer_kwargs["entities"] = language_supported_entities # Filter out LLM-specific parameters that Presidio AnalyzerEngine doesn't accept presidio_kwargs = { k: v for k, v in text_analyzer_kwargs.items() if k not in [ "inference_method", "model_choice", "api_url", "local_model", "tokenizer", "assistant_model", "client", "client_config", "temperature", "max_tokens", "custom_instructions", "allow_list", ] } page_analyser_result = nlp_analyser.analyze( text=page_text, language=language, allow_list=allow_list, **presidio_kwargs, ) all_text_line_results = map_back_entity_results( page_analyser_result, page_text_mapping, all_text_line_results, allow_list=allow_list, ) # Process text in batches for LLM (same batching logic as AWS Comprehend) current_batch = "" current_batch_mapping = list() batch_char_count = 0 batch_word_count = 0 for i, text_line in enumerate(line_level_ocr_results): words = text_line.text.split() word_start_positions = list() current_pos = 0 for word in words: word_start_positions.append(current_pos) current_pos += len(word) + 1 word_idx = 0 while word_idx < len(words): word = words[word_idx] new_batch_char_count = len(current_batch) + len(word) + 1 # Check if we've hit the limit limit_reached = ( batch_word_count >= DEFAULT_NEW_BATCH_WORD_COUNT or new_batch_char_count >= DEFAULT_NEW_BATCH_CHAR_COUNT ) if limit_reached: # Add the current word to the batch first if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if ( not current_batch_mapping or current_batch_mapping[-1][1] != i ): current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, None, word_start_positions[word_idx], ) ) # Check if current word ends with phrase punctuation if ends_with_phrase_punctuation(word): # Remove 'CUSTOM' and 'CUSTOM_VLM_*' entities from the chosen_llm_entities list # CUSTOM_VLM_* entities are handled separately via VLM, not LLM llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" and not entity.startswith("CUSTOM_VLM_") ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=getattr(self, "output_folder", None), batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get( # "assistant_model" # ), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx += 1 else: # Look ahead in current line for phrase-ending punctuation or end of line lookahead_idx = word_idx + 1 lookahead_batch = current_batch lookahead_char_count = batch_char_count lookahead_word_count = batch_word_count lookahead_mapping = list(current_batch_mapping) # Continue adding words until we find phrase-ending punctuation or end of line while lookahead_idx < len(words): lookahead_word = words[lookahead_idx] # Add the word to lookahead batch if lookahead_batch: lookahead_batch += " " lookahead_char_count += 1 lookahead_batch += lookahead_word lookahead_char_count += len(lookahead_word) lookahead_word_count += 1 if ( not lookahead_mapping or lookahead_mapping[-1][1] != i ): lookahead_mapping.append( ( lookahead_char_count - len(lookahead_word), i, text_line, None, word_start_positions[lookahead_idx], ) ) # Check if this word ends with phrase punctuation if ends_with_phrase_punctuation(lookahead_word): break lookahead_idx += 1 # Use the lookahead batch (either found phrase end or reached end of line) current_batch = lookahead_batch batch_char_count = lookahead_char_count batch_word_count = lookahead_word_count current_batch_mapping = lookahead_mapping # Remove 'CUSTOM' and 'CUSTOM_VLM_*' entities from the chosen_llm_entities list # CUSTOM_VLM_* entities are handled separately via VLM, not LLM llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" and not entity.startswith("CUSTOM_VLM_") ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=getattr(self, "output_folder", None), batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get( # "assistant_model" # ), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx = lookahead_idx + 1 else: # Normal case: add word to batch if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if ( not current_batch_mapping or current_batch_mapping[-1][1] != i ): current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, None, word_start_positions[word_idx], ) ) word_idx += 1 # Process final batch if any if current_batch: # Remove 'CUSTOM' and 'CUSTOM_VLM_*' entities from the chosen_llm_entities list # CUSTOM_VLM_* entities are handled separately via VLM, not LLM llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" and not entity.startswith("CUSTOM_VLM_") ] all_text_line_results, batch_input_tokens, batch_output_tokens = ( do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=getattr(self, "output_folder", None), batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get("inference_method"), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens comprehend_query_number += 1 elif pii_identification_method == LOCAL_TRANSFORMERS_LLM_PII_OPTION: # LLM-based entity detection using local transformers models try: from tools.llm_entity_detection import do_llm_entity_detection_call except ImportError as e: print(f"Error importing LLM entity detection: {e}") raise ImportError( "LLM entity detection not available. Please ensure llm_funcs.py is accessible." ) # Set inference method to local if not already set if text_analyzer_kwargs.get("inference_method") is None: text_analyzer_kwargs["inference_method"] = "local" # Set model choice if not already set - use VLM model when USE_TRANSFORMERS_VLM_MODEL_AS_LLM else LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE if text_analyzer_kwargs.get("model_choice") is None: text_analyzer_kwargs["model_choice"] = ( SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL if USE_TRANSFORMERS_VLM_MODEL_AS_LLM else LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE ) # Update model_choice to use the value from text_analyzer_kwargs model_choice = text_analyzer_kwargs.get( "model_choice", ( SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL if USE_TRANSFORMERS_VLM_MODEL_AS_LLM else LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE ), ) # Handle custom entities first (same as AWS Comprehend) # Include CUSTOM/CUSTOM_FUZZY (deny list) so deny-list words are redacted when CUSTOM is selected local_custom_entities = [ entity for entity in (chosen_llm_entities or []) if entity in (custom_entities or []) or entity in ("CUSTOM", "CUSTOM_FUZZY") ] if local_custom_entities: # Filter entities to only include those supported by the language language_supported_entities = filter_entities_for_language( local_custom_entities, valid_language_entities, language ) if language_supported_entities: text_analyzer_kwargs["entities"] = language_supported_entities # Filter out LLM-specific parameters that Presidio AnalyzerEngine doesn't accept presidio_kwargs = { k: v for k, v in text_analyzer_kwargs.items() if k not in [ "inference_method", "model_choice", "api_url", "local_model", "tokenizer", "assistant_model", "client", "client_config", "temperature", "max_tokens", "custom_instructions", "allow_list", ] } page_analyser_result = nlp_analyser.analyze( text=page_text, language=language, allow_list=allow_list, **presidio_kwargs, ) all_text_line_results = map_back_entity_results( page_analyser_result, page_text_mapping, all_text_line_results, allow_list=allow_list, ) # Process text in batches for LLM (same batching logic as AWS Comprehend) current_batch = "" current_batch_mapping = list() batch_char_count = 0 batch_word_count = 0 for i, text_line in enumerate(line_level_ocr_results): words = text_line.text.split() word_start_positions = list() current_pos = 0 for word in words: word_start_positions.append(current_pos) current_pos += len(word) + 1 word_idx = 0 while word_idx < len(words): word = words[word_idx] new_batch_char_count = len(current_batch) + len(word) + 1 # Check if we've hit the limit limit_reached = ( batch_word_count >= DEFAULT_NEW_BATCH_WORD_COUNT or new_batch_char_count >= DEFAULT_NEW_BATCH_CHAR_COUNT ) if limit_reached: # Add the current word to the batch first if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if ( not current_batch_mapping or current_batch_mapping[-1][1] != i ): current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, None, word_start_positions[word_idx], ) ) # Check if current word ends with phrase punctuation if ends_with_phrase_punctuation(word): # Remove 'CUSTOM' and 'CUSTOM_VLM_*' entities from the chosen_llm_entities list # CUSTOM_VLM_* entities are handled separately via VLM, not LLM llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" and not entity.startswith("CUSTOM_VLM_") ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=getattr(self, "output_folder", None), batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get( # "assistant_model" # ), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx += 1 else: # Look ahead in current line for phrase-ending punctuation or end of line lookahead_idx = word_idx + 1 lookahead_batch = current_batch lookahead_char_count = batch_char_count lookahead_word_count = batch_word_count lookahead_mapping = list(current_batch_mapping) # Continue adding words until we find phrase-ending punctuation or end of line while lookahead_idx < len(words): lookahead_word = words[lookahead_idx] # Add the word to lookahead batch if lookahead_batch: lookahead_batch += " " lookahead_char_count += 1 lookahead_batch += lookahead_word lookahead_char_count += len(lookahead_word) lookahead_word_count += 1 if ( not lookahead_mapping or lookahead_mapping[-1][1] != i ): lookahead_mapping.append( ( lookahead_char_count - len(lookahead_word), i, text_line, None, word_start_positions[lookahead_idx], ) ) # Check if this word ends with phrase punctuation if ends_with_phrase_punctuation(lookahead_word): break lookahead_idx += 1 # Use the lookahead batch (either found phrase end or reached end of line) current_batch = lookahead_batch batch_char_count = lookahead_char_count batch_word_count = lookahead_word_count current_batch_mapping = lookahead_mapping # Remove 'CUSTOM' and 'CUSTOM_VLM_*' entities from the chosen_llm_entities list # CUSTOM_VLM_* entities are handled separately via VLM, not LLM llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" and not entity.startswith("CUSTOM_VLM_") ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=getattr(self, "output_folder", None), batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get( # "assistant_model" # ), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx = lookahead_idx + 1 else: # Normal case: add word to batch if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if ( not current_batch_mapping or current_batch_mapping[-1][1] != i ): current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, None, word_start_positions[word_idx], ) ) word_idx += 1 # Process final batch if any if current_batch: # Remove 'CUSTOM' and 'CUSTOM_VLM_*' entities from the chosen_llm_entities list # CUSTOM_VLM_* entities are handled separately via VLM, not LLM llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" and not entity.startswith("CUSTOM_VLM_") ] all_text_line_results, batch_input_tokens, batch_output_tokens = ( do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=getattr(self, "output_folder", None), batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get("inference_method"), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens comprehend_query_number += 1 # Process results and create bounding boxes combined_results = list() for i, text_line in enumerate(line_level_ocr_results): line_results = next( (results for idx, results in all_text_line_results if idx == i), [] ) if line_results and i < len(ocr_results_with_words): child_level_key = list(ocr_results_with_words.keys())[i] ocr_results_with_words_line_level = ocr_results_with_words[ child_level_key ] for result in line_results: bbox_results = self.map_analyzer_results_to_bounding_boxes( [result], [ OCRResult( text=text_line.text[result.start : result.end], left=text_line.left, top=text_line.top, width=text_line.width, height=text_line.height, conf=text_line.conf, ) ], text_line.text, text_analyzer_kwargs.get("allow_list", []), ocr_results_with_words_line_level, ) combined_results.extend(bbox_results) return ( combined_results, comprehend_query_number, llm_model_name, llm_total_input_tokens, llm_total_output_tokens, ) @staticmethod def _map_one_ocr_result_to_bboxes( redaction_relevant_ocr_result: OCRResult, text_analyzer_results: List[RecognizerResult], ocr_results_with_words_child_info: Dict[str, Dict], allow_list: List[str], ) -> List[CustomImageRecognizerResult]: """Map one OCR result to bounding boxes; safe to run in a thread.""" bboxes = [] line_text = ocr_results_with_words_child_info["text"] line_length = len(line_text) redaction_text = redaction_relevant_ocr_result.text for redaction_result in text_analyzer_results: if allow_list: allow_list_normalized = [ item.strip().lower() for item in allow_list if item ] redaction_text_normalized = redaction_text.strip().lower() is_in_allow_list = redaction_text_normalized in allow_list_normalized else: is_in_allow_list = False if not is_in_allow_list: start_in_line = max(0, redaction_result.start) end_in_line = min(line_length, redaction_result.end) matched_text = line_text[start_in_line:end_in_line] matched_text.split() matching_word_boxes = [] current_position = 0 for word_info in ocr_results_with_words_child_info.get("words", []): word_text = word_info["text"] word_length = len(word_text) word_start = current_position word_end = current_position + word_length current_position += word_length + 1 if word_start < end_in_line and word_end > start_in_line: matching_word_boxes.append(word_info["bounding_box"]) if matching_word_boxes: left = min(box[0] for box in matching_word_boxes) top = min(box[1] for box in matching_word_boxes) right = max(box[2] for box in matching_word_boxes) bottom = max(box[3] for box in matching_word_boxes) bboxes.append( CustomImageRecognizerResult( entity_type=redaction_result.entity_type, start=start_in_line, end=end_in_line, score=round(redaction_result.score, 2), left=left, top=top, width=right - left, height=bottom - top, text=matched_text, ) ) else: line_left = redaction_relevant_ocr_result.left line_top = redaction_relevant_ocr_result.top line_width = redaction_relevant_ocr_result.width line_height = redaction_relevant_ocr_result.height if line_length > 0: text_proportion = len(matched_text) / line_length char_width_estimate = line_width / line_length estimated_left_offset = start_in_line * char_width_estimate left = line_left + estimated_left_offset top = line_top width = text_proportion * line_width height = line_height else: left = line_left top = line_top width = line_width height = line_height bboxes.append( CustomImageRecognizerResult( entity_type=redaction_result.entity_type, start=start_in_line, end=end_in_line, score=round(redaction_result.score, 2), left=left, top=top, width=width, height=height, text=matched_text, ) ) return bboxes @staticmethod def map_analyzer_results_to_bounding_boxes( text_analyzer_results: List[RecognizerResult], redaction_relevant_ocr_results: List[OCRResult], full_text: str, allow_list: List[str], ocr_results_with_words_child_info: Dict[str, Dict], ) -> List[CustomImageRecognizerResult]: if not redaction_relevant_ocr_results: return [] n = len(redaction_relevant_ocr_results) max_workers = min(MAX_WORKERS, n) with ThreadPoolExecutor(max_workers=max_workers) as executor: results = list( executor.map( lambda ocr_result: CustomImageAnalyzerEngine._map_one_ocr_result_to_bboxes( ocr_result, text_analyzer_results, ocr_results_with_words_child_info, allow_list, ), redaction_relevant_ocr_results, ) ) redaction_bboxes = [bbox for bbox_list in results for bbox in bbox_list] return redaction_bboxes @staticmethod def remove_space_boxes(ocr_result: dict) -> dict: """Remove OCR bboxes that are for spaces. :param ocr_result: OCR results (raw or thresholded). :return: OCR results with empty words removed. """ # Get indices of items with no text idx = list() for i, text in enumerate(ocr_result["text"]): is_not_space = text.isspace() is False if text != "" and is_not_space: idx.append(i) # Only retain items with text filtered_ocr_result = {} for key in list(ocr_result.keys()): filtered_ocr_result[key] = [ocr_result[key][i] for i in idx] return filtered_ocr_result @staticmethod def _scale_bbox_results( ocr_result: Dict[str, List[Union[int, str]]], scale_factor: float ) -> Dict[str, float]: """Scale down the bounding box results based on a scale percentage. :param ocr_result: OCR results (raw). :param scale_percent: Scale percentage for resizing the bounding box. :return: OCR results (scaled). """ scaled_results = deepcopy(ocr_result) coordinate_keys = ["left", "top"] dimension_keys = ["width", "height"] for coord_key in coordinate_keys: scaled_results[coord_key] = [ int(np.ceil((x) / (scale_factor))) for x in scaled_results[coord_key] ] for dim_key in dimension_keys: scaled_results[dim_key] = [ max(1, int(np.ceil(x / (scale_factor)))) for x in scaled_results[dim_key] ] return scaled_results @staticmethod def estimate_x_offset(full_text: str, start: int) -> int: # Estimate the x-offset based on character position # This is a simple estimation and might need refinement for variable-width fonts return int(start / len(full_text) * len(full_text)) def estimate_width(self, ocr_result: OCRResult, start: int, end: int) -> int: # Extract the relevant text portion relevant_text = ocr_result.text[start:end] # If the relevant text is the same as the full text, return the full width if relevant_text == ocr_result.text: return ocr_result.width # Estimate width based on the proportion of the relevant text length to the total text length total_text_length = len(ocr_result.text) relevant_text_length = len(relevant_text) if total_text_length == 0: return 0 # Avoid division by zero # Proportion of the relevant text to the total text proportion = relevant_text_length / total_text_length # Estimate the width based on the proportion estimated_width = int(proportion * ocr_result.width) return estimated_width def bounding_boxes_overlap(box1: List, box2: List): """Check if two bounding boxes overlap.""" return ( box1[0] < box2[2] and box2[0] < box1[2] and box1[1] < box2[3] and box2[1] < box1[3] ) def map_back_entity_results( page_analyser_result: dict, page_text_mapping: dict, all_text_line_results: List[Tuple], allow_list: List[str] = None, ): """ Map Presidio analyzer results back to line-level results. Args: page_analyser_result: Results from Presidio analyzer page_text_mapping: Mapping of batch positions to line indices all_text_line_results: Existing line-level results to append to allow_list: List of allowed text values (to skip) - case-insensitive matching """ # Normalize allow_list for case-insensitive matching if allow_list: allow_list_normalized = [item.strip().lower() for item in allow_list if item] else: allow_list_normalized = [] for entity in page_analyser_result: entity_start = entity.start entity_end = entity.end # Track if the entity has been added to any line added_to_line = False for batch_start, line_idx, original_line, chars in page_text_mapping: batch_end = batch_start + len(original_line.text) # Check if the entity overlaps with the current line if ( batch_start < entity_end and batch_end > entity_start ): # Overlap condition relative_start = max( 0, entity_start - batch_start ) # Adjust start relative to the line relative_end = min( entity_end - batch_start, len(original_line.text) ) # Adjust end relative to the line # Get the text for this entity to check against allow_list result_text = original_line.text[relative_start:relative_end] # Check if result_text is in allow_list (case-insensitive) # If allow_list contains this text, skip adding it as a PII entity # This allows allow_list terms to "overrule" PII detection result_text_normalized = result_text.strip().lower() if result_text_normalized not in allow_list_normalized: # Create a new adjusted entity adjusted_entity = copy.deepcopy(entity) adjusted_entity.start = relative_start adjusted_entity.end = relative_end # Check if this line already has an entry existing_entry = next( ( entry for idx, entry in all_text_line_results if idx == line_idx ), None, ) if existing_entry is None: all_text_line_results.append((line_idx, [adjusted_entity])) else: existing_entry.append( adjusted_entity ) # Append to the existing list of entities added_to_line = True # If the entity spans multiple lines, you may want to handle that here if not added_to_line: # Handle cases where the entity does not fit in any line (optional) print(f"Entity '{entity}' does not fit in any line.") return all_text_line_results def map_back_comprehend_entity_results( response: object, current_batch_mapping: List[Tuple], allow_list: List[str], chosen_redact_comprehend_entities: List[str], all_text_line_results: List[Tuple], ): """ Map AWS Comprehend entity results back to line-level results. Args: response: AWS Comprehend response object current_batch_mapping: Mapping of batch positions to line indices allow_list: List of allowed text values (to skip) - case-insensitive matching chosen_redact_comprehend_entities: List of entity types to include all_text_line_results: Existing line-level results to append to """ if not response or "Entities" not in response: return all_text_line_results # Normalize allow_list for case-insensitive matching if allow_list: allow_list_normalized = [item.strip().lower() for item in allow_list if item] else: allow_list_normalized = [] for entity in response["Entities"]: if entity.get("Type") not in chosen_redact_comprehend_entities: continue entity_start = entity["BeginOffset"] entity_end = entity["EndOffset"] # Track if the entity has been added to any line added_to_line = False # Find the correct line and offset within that line for ( batch_start, line_idx, original_line, chars, line_offset, ) in current_batch_mapping: batch_end = batch_start + len(original_line.text[line_offset:]) # Check if the entity overlaps with the current line if ( batch_start < entity_end and batch_end > entity_start ): # Overlap condition # Calculate the absolute position within the line relative_start = max(0, entity_start - batch_start + line_offset) relative_end = min( entity_end - batch_start + line_offset, len(original_line.text) ) result_text = original_line.text[relative_start:relative_end] # Check if result_text is in allow_list (case-insensitive) # If allow_list contains this text, skip adding it as a PII entity # This allows allow_list terms to "overrule" AWS Comprehend PII detection result_text_normalized = result_text.strip().lower() if result_text_normalized not in allow_list_normalized: adjusted_entity = entity.copy() adjusted_entity["BeginOffset"] = ( relative_start # Now relative to the full line ) adjusted_entity["EndOffset"] = relative_end recogniser_entity = recognizer_result_from_dict(adjusted_entity) existing_entry = next( ( entry for idx, entry in all_text_line_results if idx == line_idx ), None, ) if existing_entry is None: all_text_line_results.append((line_idx, [recogniser_entity])) else: existing_entry.append( recogniser_entity ) # Append to the existing list of entities added_to_line = True # Optional: Handle cases where the entity does not fit in any line if not added_to_line: print(f"Entity '{entity}' does not fit in any line.") return all_text_line_results def do_aws_comprehend_call( current_batch: str, current_batch_mapping: List[Tuple], comprehend_client: botocore.client.BaseClient, language: str, allow_list: List[str], chosen_redact_comprehend_entities: List[str], all_text_line_results: List[Tuple], max_retries: int = 10, retry_delay: int = 1, ): """ Uses AWS Comprehend to detect PII entities in a text batch and maps the results back to the original lines for further processing. Args: current_batch (str): The concatenated text being analysed for PII. current_batch_mapping (List[Tuple]): Mapping from batch offsets back to individual line offsets and line indices for result mapping. comprehend_client (botocore.client.BaseClient): AWS Comprehend boto3 client for making API calls. language (str): The ISO language code for the text (e.g. "en"). allow_list (List[str]): List of lowercased phrases or words which, if detected, should not be flagged/redacted even if AWS returns them as PII. chosen_redact_comprehend_entities (List[str]): List of PII entity types (from AWS) enabled for detection/redaction. all_text_line_results (List[Tuple]): Existing recognition results by line; this will be updated in-place and returned. max_retries (int, optional): Maximum number of times to retry the AWS API in case of failure. Default is 10. retry_delay (int, optional): Number of seconds to wait between retries. Default is 1. Returns: List[Tuple]: Updated list of recognition results by text line, with AWS detected PII mapped back to their source. """ if not current_batch: return all_text_line_results # Guard: if no relevant AWS entity types are selected, skip AWS entirely. # (CUSTOM/CUSTOM_FUZZY and other local-only entities are handled via Presidio.) if not chosen_redact_comprehend_entities: return all_text_line_results for attempt in range(max_retries): try: response = comprehend_client.detect_pii_entities( Text=current_batch.strip(), LanguageCode=language ) all_text_line_results = map_back_comprehend_entity_results( response, current_batch_mapping, allow_list, chosen_redact_comprehend_entities, all_text_line_results, ) return all_text_line_results except Exception as e: if attempt == max_retries - 1: print("AWS Comprehend calls failed due to", e) raise time.sleep(retry_delay) def run_page_text_redaction( language: str, chosen_redact_entities: List[str], chosen_redact_comprehend_entities: List[str], line_level_text_results_list: List[str], line_characters: List, page_analyser_results: List = list(), page_analysed_bounding_boxes: List = list(), comprehend_client=None, allow_list: List[str] = None, pii_identification_method: str = LOCAL_PII_OPTION, nlp_analyser: AnalyzerEngine = None, score_threshold: float = 0.0, custom_entities: List[str] = None, comprehend_query_number: int = 0, bedrock_runtime=None, model_choice: str = CLOUD_LLM_PII_MODEL_CHOICE, custom_llm_instructions: str = "", chosen_llm_entities: List[str] = None, output_folder: str = None, file_name: Optional[str] = None, page_number: Optional[int] = None, **text_analyzer_kwargs, ): """ This function performs text redaction on a page based on the specified language and chosen entities. Args: language (str): The language code for the text being processed. chosen_redact_entities (List[str]): A list of entities to be redacted from the text. chosen_redact_comprehend_entities (List[str]): A list of entities identified by AWS Comprehend for redaction. line_level_text_results_list (List[str]): A list of text lines extracted from the page. line_characters (List): A list of character-level information for each line of text. page_analyser_results (List, optional): Results from previous page analysis. Defaults to an empty list. page_analysed_bounding_boxes (List, optional): Bounding boxes for the analysed page. Defaults to an empty list. comprehend_client: The AWS Comprehend client for making API calls. Defaults to None. allow_list (List[str], optional): A list of allowed entities that should not be redacted. Defaults to None. pii_identification_method (str, optional): The method used for PII identification. Defaults to LOCAL_PII_OPTION. nlp_analyser (AnalyzerEngine, optional): The NLP analyzer engine used for local analysis. Defaults to None. score_threshold (float, optional): The threshold score for entity detection. Defaults to 0.0. custom_entities (List[str], optional): A list of custom entities for redaction. Defaults to None. comprehend_query_number (int, optional): A counter for AWS Comprehend usage in units of 100 characters (1 unit = 100 characters, per AWS billing). Defaults to 0. bedrock_runtime: The AWS Bedrock runtime client for LLM-based entity detection. Defaults to None. model_choice (str, optional): The LLM model choice for entity detection. Defaults to CLOUD_LLM_PII_MODEL_CHOICE. custom_llm_instructions (str, optional): Custom instructions for LLM-based entity detection. Defaults to "". chosen_llm_entities (List[str], optional): A list of entities for LLM-based detection. Defaults to None. output_folder (str, optional): Output folder for saving LLM prompts and responses. Defaults to None. file_name (str, optional): File name (without extension) for saving LLM logs. Defaults to None. page_number (int, optional): Page number for saving LLM logs. Defaults to None. **text_analyzer_kwargs: Additional keyword arguments for text analysis. """ page_text = "" page_text_mapping = list() all_text_line_results = list() # Track Comprehend usage for this page in AWS billing units (1 unit = 100 chars). # IMPORTANT: do not reset/overwrite the function argument; callers aggregate per-page usage. comprehend_units_used = 0 # Track LLM token usage llm_total_input_tokens = 0 llm_total_output_tokens = 0 llm_model_name = "" # Default chosen_llm_entities to chosen_redact_comprehend_entities if not provided if chosen_llm_entities is None: chosen_llm_entities = chosen_redact_comprehend_entities # Collect all text from the page for i, text_line in enumerate(line_level_text_results_list): if page_text: page_text += " " start_pos = len(page_text) page_text += text_line.text page_text_mapping.append((start_pos, i, text_line, line_characters[i])) # Determine language for downstream services aws_language = language or "en" valid_language_entities = nlp_analyser.registry.get_supported_entities( languages=[language] ) if "CUSTOM" not in valid_language_entities: valid_language_entities.append("CUSTOM") if "CUSTOM_FUZZY" not in valid_language_entities: valid_language_entities.append("CUSTOM_FUZZY") # Process based on identification method if pii_identification_method == LOCAL_PII_OPTION: if not nlp_analyser: raise ValueError("nlp_analyser is required for Local identification method") language_supported_entities = filter_entities_for_language( chosen_redact_entities, valid_language_entities, language ) # When only CUSTOM_VLM_* entities are chosen, local PII has nothing to do; # allow progress so image/VLM analysis can run. only_custom_vlm = chosen_redact_entities and all( str(e).startswith("CUSTOM_VLM_") for e in (chosen_redact_entities or []) ) if language_supported_entities: text_analyzer_kwargs["entities"] = language_supported_entities elif only_custom_vlm: # Skip local PII; leave all_text_line_results empty so pipeline continues to VLM pass else: out_message = f"No relevant entities supported for language: {language}" print(out_message) raise Warning(out_message) if language_supported_entities: # Filter out LLM-specific parameters that Presidio AnalyzerEngine doesn't accept # Also exclude allow_list since we pass it explicitly presidio_kwargs = { k: v for k, v in text_analyzer_kwargs.items() if k not in [ "inference_method", "model_choice", "api_url", "local_model", "tokenizer", "assistant_model", "client", "client_config", "temperature", "max_tokens", "custom_instructions", "allow_list", ] } page_analyser_result = nlp_analyser.analyze( text=page_text, language=language, score_threshold=score_threshold, return_decision_process=True, allow_list=allow_list, **presidio_kwargs, ) all_text_line_results = map_back_entity_results( page_analyser_result, page_text_mapping, all_text_line_results, allow_list=allow_list, ) elif pii_identification_method == AWS_PII_OPTION: # Run local detection for any custom entities (including CUSTOM/CUSTOM_FUZZY) local_custom_entities = [ entity for entity in (chosen_redact_comprehend_entities or []) if entity in (custom_entities or []) or entity in ("CUSTOM", "CUSTOM_FUZZY") ] if local_custom_entities: # Filter entities to only include those supported by the language language_supported_entities = filter_entities_for_language( local_custom_entities, valid_language_entities, language ) if language_supported_entities: text_analyzer_kwargs["entities"] = language_supported_entities # Filter out LLM-specific parameters that Presidio AnalyzerEngine doesn't accept # Also exclude allow_list since we pass it explicitly presidio_kwargs = { k: v for k, v in text_analyzer_kwargs.items() if k not in [ "inference_method", "model_choice", "api_url", "local_model", "tokenizer", "assistant_model", "client", "client_config", "temperature", "max_tokens", "custom_instructions", "allow_list", ] } page_analyser_result = nlp_analyser.analyze( text=page_text, language=language, score_threshold=score_threshold, return_decision_process=True, allow_list=allow_list, **presidio_kwargs, ) all_text_line_results = map_back_entity_results( page_analyser_result, page_text_mapping, all_text_line_results, allow_list=allow_list, ) # Guard: only call AWS Comprehend when at least one non-custom Comprehend entity is selected. aws_comprehend_entities = [ entity for entity in (chosen_redact_comprehend_entities or []) if entity in (FULL_COMPREHEND_ENTITY_LIST or []) and entity not in ("CUSTOM", "CUSTOM_FUZZY") ] # Process text in batches for AWS Comprehend current_batch = "" current_batch_mapping = list() batch_char_count = 0 batch_word_count = 0 for i, text_line in enumerate(line_level_text_results_list): words = text_line.text.split() word_start_positions = list() current_pos = 0 for word in words: word_start_positions.append(current_pos) current_pos += len(word) + 1 word_idx = 0 while word_idx < len(words): word = words[word_idx] new_batch_char_count = len(current_batch) + len(word) + 1 # Check if we've hit the limit limit_reached = ( batch_word_count >= DEFAULT_NEW_BATCH_WORD_COUNT or new_batch_char_count >= DEFAULT_NEW_BATCH_CHAR_COUNT ) if limit_reached: # Add the current word to the batch first if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if not current_batch_mapping or current_batch_mapping[-1][1] != i: current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, line_characters[i], word_start_positions[word_idx], ) ) # Check if current word ends with phrase punctuation if ends_with_phrase_punctuation(word): # Process current batch all_text_line_results = do_aws_comprehend_call( current_batch, current_batch_mapping, comprehend_client, aws_language, text_analyzer_kwargs.get("allow_list", allow_list or []), aws_comprehend_entities, all_text_line_results, ) if aws_comprehend_entities: comprehend_units_used += ( len(current_batch.strip()) + COMPREHEND_CHARACTERS_PER_UNIT - 1 ) // COMPREHEND_CHARACTERS_PER_UNIT # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx += 1 else: # Look ahead in current line for phrase-ending punctuation or end of line lookahead_idx = word_idx + 1 lookahead_batch = current_batch lookahead_char_count = batch_char_count lookahead_word_count = batch_word_count lookahead_mapping = list(current_batch_mapping) # Continue adding words until we find phrase-ending punctuation or end of line while lookahead_idx < len(words): lookahead_word = words[lookahead_idx] # Add the word to lookahead batch if lookahead_batch: lookahead_batch += " " lookahead_char_count += 1 lookahead_batch += lookahead_word lookahead_char_count += len(lookahead_word) lookahead_word_count += 1 if not lookahead_mapping or lookahead_mapping[-1][1] != i: lookahead_mapping.append( ( lookahead_char_count - len(lookahead_word), i, text_line, line_characters[i], word_start_positions[lookahead_idx], ) ) # Check if this word ends with phrase punctuation if ends_with_phrase_punctuation(lookahead_word): break lookahead_idx += 1 # Use the lookahead batch (either found phrase end or reached end of line) current_batch = lookahead_batch batch_char_count = lookahead_char_count batch_word_count = lookahead_word_count current_batch_mapping = lookahead_mapping # Process current batch all_text_line_results = do_aws_comprehend_call( current_batch, current_batch_mapping, comprehend_client, aws_language, text_analyzer_kwargs.get("allow_list", allow_list or []), aws_comprehend_entities, all_text_line_results, ) if aws_comprehend_entities: comprehend_units_used += ( len(current_batch.strip()) + COMPREHEND_CHARACTERS_PER_UNIT - 1 ) // COMPREHEND_CHARACTERS_PER_UNIT # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx = lookahead_idx + 1 else: # Normal case: add word to batch if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if not current_batch_mapping or current_batch_mapping[-1][1] != i: current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, line_characters[i], word_start_positions[word_idx], ) ) word_idx += 1 # Process final batch if any if current_batch: all_text_line_results = do_aws_comprehend_call( current_batch, current_batch_mapping, comprehend_client, aws_language, text_analyzer_kwargs.get("allow_list", allow_list or []), aws_comprehend_entities, all_text_line_results, ) if aws_comprehend_entities: comprehend_units_used += ( len(current_batch.strip()) + COMPREHEND_CHARACTERS_PER_UNIT - 1 ) // COMPREHEND_CHARACTERS_PER_UNIT elif pii_identification_method == AWS_LLM_PII_OPTION: # LLM-based entity detection using AWS Bedrock try: from tools.llm_entity_detection import do_llm_entity_detection_call except ImportError as e: print(f"Error importing LLM entity detection: {e}") raise ImportError( "LLM entity detection not available. Please ensure llm_entity_detection.py is accessible." ) if not bedrock_runtime: raise ValueError( "bedrock_runtime is required when using LLM-based PII detection" ) # Set inference method to aws-bedrock if not already set if text_analyzer_kwargs.get("inference_method") is None: text_analyzer_kwargs["inference_method"] = "aws-bedrock" # Update model_choice to use CLOUD_LLM_PII_MODEL_CHOICE for Bedrock, or value from text_analyzer_kwargs if set if text_analyzer_kwargs.get("model_choice") is None: model_choice = CLOUD_LLM_PII_MODEL_CHOICE else: model_choice = text_analyzer_kwargs.get( "model_choice", CLOUD_LLM_PII_MODEL_CHOICE ) # Set LLM model name for tracking (use custom-instructions model when applicable) custom_instructions_model = ( CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE.strip() if isinstance(CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE, str) and CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE.strip() else "" ) if ( (custom_llm_instructions or "").strip() and model_choice == CLOUD_LLM_PII_MODEL_CHOICE and custom_instructions_model ): llm_model_name = custom_instructions_model else: llm_model_name = model_choice or "" # Handle custom entities first (same as AWS Comprehend) # Include CUSTOM/CUSTOM_FUZZY (deny list) so deny-list words are redacted when CUSTOM is selected local_custom_entities = [ entity for entity in (chosen_llm_entities or []) if entity in (custom_entities or []) or entity in ("CUSTOM", "CUSTOM_FUZZY") ] if local_custom_entities: # Filter entities to only include those supported by the language language_supported_entities = filter_entities_for_language( local_custom_entities, valid_language_entities, language ) if language_supported_entities: text_analyzer_kwargs["entities"] = language_supported_entities # Filter out LLM-specific parameters that Presidio AnalyzerEngine doesn't accept # Also exclude allow_list since we pass it explicitly presidio_kwargs = { k: v for k, v in text_analyzer_kwargs.items() if k not in [ "inference_method", "model_choice", "api_url", "local_model", "tokenizer", "assistant_model", "client", "client_config", "temperature", "max_tokens", "custom_instructions", "allow_list", ] } page_analyser_result = nlp_analyser.analyze( text=page_text, language=language, score_threshold=score_threshold, return_decision_process=True, allow_list=allow_list, **presidio_kwargs, ) all_text_line_results = map_back_entity_results( page_analyser_result, page_text_mapping, all_text_line_results ) # Process text in batches for LLM (same batching logic as AWS Comprehend) current_batch = "" current_batch_mapping = list() batch_char_count = 0 batch_word_count = 0 for i, text_line in enumerate(line_level_text_results_list): words = text_line.text.split() word_start_positions = list() current_pos = 0 for word in words: word_start_positions.append(current_pos) current_pos += len(word) + 1 word_idx = 0 while word_idx < len(words): word = words[word_idx] new_batch_char_count = len(current_batch) + len(word) + 1 # Check if we've hit the limit limit_reached = ( batch_word_count >= DEFAULT_NEW_BATCH_WORD_COUNT or new_batch_char_count >= DEFAULT_NEW_BATCH_CHAR_COUNT ) if limit_reached: # Add the current word to the batch first if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if not current_batch_mapping or current_batch_mapping[-1][1] != i: current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, line_characters[i], word_start_positions[word_idx], ) ) # Check if current word ends with phrase punctuation if ends_with_phrase_punctuation(word): # Remove 'CUSTOM' entities from the chosen_llm_entities list llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get( "allow_list", allow_list or [] ), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=output_folder, batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx += 1 else: # Look ahead in current line for phrase-ending punctuation or end of line lookahead_idx = word_idx + 1 lookahead_batch = current_batch lookahead_char_count = batch_char_count lookahead_word_count = batch_word_count lookahead_mapping = list(current_batch_mapping) # Continue adding words until we find phrase-ending punctuation or end of line while lookahead_idx < len(words): lookahead_word = words[lookahead_idx] # Add the word to lookahead batch if lookahead_batch: lookahead_batch += " " lookahead_char_count += 1 lookahead_batch += lookahead_word lookahead_char_count += len(lookahead_word) lookahead_word_count += 1 if not lookahead_mapping or lookahead_mapping[-1][1] != i: lookahead_mapping.append( ( lookahead_char_count - len(lookahead_word), i, text_line, line_characters[i], word_start_positions[lookahead_idx], ) ) # Check if this word ends with phrase punctuation if ends_with_phrase_punctuation(lookahead_word): break lookahead_idx += 1 # Use the lookahead batch (either found phrase end or reached end of line) current_batch = lookahead_batch batch_char_count = lookahead_char_count batch_word_count = lookahead_word_count current_batch_mapping = lookahead_mapping # Remove 'CUSTOM' entities from the chosen_llm_entities list llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get( "allow_list", allow_list or [] ), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=output_folder, batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx = lookahead_idx + 1 else: # Normal case: add word to batch if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if not current_batch_mapping or current_batch_mapping[-1][1] != i: current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, line_characters[i], word_start_positions[word_idx], ) ) word_idx += 1 # Process final batch if any if current_batch: # Remove 'CUSTOM' entities from the chosen_llm_entities list llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" ] all_text_line_results, batch_input_tokens, batch_output_tokens = ( do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", allow_list or []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=output_folder, batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get("inference_method"), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens comprehend_query_number += 1 comprehend_query_number += 1 elif pii_identification_method == INFERENCE_SERVER_PII_OPTION: # LLM-based entity detection using inference server try: from tools.llm_entity_detection import do_llm_entity_detection_call except ImportError as e: print(f"Error importing LLM entity detection: {e}") raise ImportError( "LLM entity detection not available. Please ensure llm_entity_detection.py is accessible." ) # Set inference method to inference-server if not already set if text_analyzer_kwargs.get("inference_method") is None: text_analyzer_kwargs["inference_method"] = "inference-server" # Set API URL if not already set if text_analyzer_kwargs.get("api_url") is None: text_analyzer_kwargs["api_url"] = INFERENCE_SERVER_API_URL # Set model choice if not already set - use INFERENCE_SERVER_LLM_PII_MODEL_CHOICE if text_analyzer_kwargs.get("model_choice") is None: text_analyzer_kwargs["model_choice"] = INFERENCE_SERVER_LLM_PII_MODEL_CHOICE # Update model_choice to use the value from text_analyzer_kwargs model_choice = text_analyzer_kwargs.get( "model_choice", INFERENCE_SERVER_LLM_PII_MODEL_CHOICE ) # Set LLM model name for tracking llm_model_name = model_choice or "" # Handle custom entities first (same as AWS Comprehend) # Include CUSTOM/CUSTOM_FUZZY (deny list) so deny-list words are redacted when CUSTOM is selected local_custom_entities = [ entity for entity in (chosen_llm_entities or []) if entity in (custom_entities or []) or entity in ("CUSTOM", "CUSTOM_FUZZY") ] if local_custom_entities: # Filter entities to only include those supported by the language language_supported_entities = filter_entities_for_language( local_custom_entities, valid_language_entities, language ) if language_supported_entities: text_analyzer_kwargs["entities"] = language_supported_entities # Filter out LLM-specific parameters that Presidio AnalyzerEngine doesn't accept # Also exclude allow_list since we pass it explicitly presidio_kwargs = { k: v for k, v in text_analyzer_kwargs.items() if k not in [ "inference_method", "model_choice", "api_url", "local_model", "tokenizer", "assistant_model", "client", "client_config", "temperature", "max_tokens", "custom_instructions", "allow_list", ] } page_analyser_result = nlp_analyser.analyze( text=page_text, language=language, score_threshold=score_threshold, return_decision_process=True, allow_list=allow_list, **presidio_kwargs, ) all_text_line_results = map_back_entity_results( page_analyser_result, page_text_mapping, all_text_line_results ) # Process text in batches for LLM (same batching logic as AWS Comprehend) current_batch = "" current_batch_mapping = list() batch_char_count = 0 batch_word_count = 0 for i, text_line in enumerate(line_level_text_results_list): words = text_line.text.split() word_start_positions = list() current_pos = 0 for word in words: word_start_positions.append(current_pos) current_pos += len(word) + 1 word_idx = 0 while word_idx < len(words): word = words[word_idx] new_batch_char_count = len(current_batch) + len(word) + 1 # Check if we've hit the limit limit_reached = ( batch_word_count >= DEFAULT_NEW_BATCH_WORD_COUNT or new_batch_char_count >= DEFAULT_NEW_BATCH_CHAR_COUNT ) if limit_reached: # Add the current word to the batch first if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if not current_batch_mapping or current_batch_mapping[-1][1] != i: current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, line_characters[i], word_start_positions[word_idx], ) ) # Check if current word ends with phrase punctuation if ends_with_phrase_punctuation(word): # Remove 'CUSTOM' entities from the chosen_llm_entities list llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get( "allow_list", allow_list or [] ), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=output_folder, batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx += 1 else: # Look ahead in current line for phrase-ending punctuation or end of line lookahead_idx = word_idx + 1 lookahead_batch = current_batch lookahead_char_count = batch_char_count lookahead_word_count = batch_word_count lookahead_mapping = list(current_batch_mapping) # Continue adding words until we find phrase-ending punctuation or end of line while lookahead_idx < len(words): lookahead_word = words[lookahead_idx] # Add the word to lookahead batch if lookahead_batch: lookahead_batch += " " lookahead_char_count += 1 lookahead_batch += lookahead_word lookahead_char_count += len(lookahead_word) lookahead_word_count += 1 if not lookahead_mapping or lookahead_mapping[-1][1] != i: lookahead_mapping.append( ( lookahead_char_count - len(lookahead_word), i, text_line, line_characters[i], word_start_positions[lookahead_idx], ) ) # Check if this word ends with phrase punctuation if ends_with_phrase_punctuation(lookahead_word): break lookahead_idx += 1 # Use the lookahead batch (either found phrase end or reached end of line) current_batch = lookahead_batch batch_char_count = lookahead_char_count batch_word_count = lookahead_word_count current_batch_mapping = lookahead_mapping # Remove 'CUSTOM' entities from the chosen_llm_entities list llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get( "allow_list", allow_list or [] ), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=output_folder, batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx = lookahead_idx + 1 else: # Normal case: add word to batch if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if not current_batch_mapping or current_batch_mapping[-1][1] != i: current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, line_characters[i], word_start_positions[word_idx], ) ) word_idx += 1 # Process final batch if any if current_batch: # Remove 'CUSTOM' entities from the chosen_llm_entities list llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" ] all_text_line_results, batch_input_tokens, batch_output_tokens = ( do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", allow_list or []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=output_folder, batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get("inference_method"), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens # LLM-based detection is metered separately; keep counter semantics consistent # (this function reports Comprehend units only). elif pii_identification_method == LOCAL_TRANSFORMERS_LLM_PII_OPTION: # LLM-based entity detection using local transformers models try: from tools.llm_entity_detection import do_llm_entity_detection_call except ImportError as e: print(f"Error importing LLM entity detection: {e}") raise ImportError( "LLM entity detection not available. Please ensure llm_entity_detection.py is accessible." ) # Set inference method to local if not already set if text_analyzer_kwargs.get("inference_method") is None: text_analyzer_kwargs["inference_method"] = "local" # Set model choice if not already set - use VLM model when USE_TRANSFORMERS_VLM_MODEL_AS_LLM else LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE if text_analyzer_kwargs.get("model_choice") is None: text_analyzer_kwargs["model_choice"] = ( SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL if USE_TRANSFORMERS_VLM_MODEL_AS_LLM else LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE ) # Update model_choice to use the value from text_analyzer_kwargs model_choice = text_analyzer_kwargs.get( "model_choice", ( SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL if USE_TRANSFORMERS_VLM_MODEL_AS_LLM else LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE ), ) # Handle custom entities first (same as AWS Comprehend) # Include CUSTOM/CUSTOM_FUZZY (deny list) so deny-list words are redacted when CUSTOM is selected local_custom_entities = [ entity for entity in (chosen_llm_entities or []) if entity in (custom_entities or []) or entity in ("CUSTOM", "CUSTOM_FUZZY") ] if local_custom_entities: # Filter entities to only include those supported by the language language_supported_entities = filter_entities_for_language( local_custom_entities, valid_language_entities, language ) if language_supported_entities: text_analyzer_kwargs["entities"] = language_supported_entities # Filter out LLM-specific parameters that Presidio AnalyzerEngine doesn't accept # Also exclude allow_list since we pass it explicitly presidio_kwargs = { k: v for k, v in text_analyzer_kwargs.items() if k not in [ "inference_method", "model_choice", "api_url", "local_model", "tokenizer", "assistant_model", "client", "client_config", "temperature", "max_tokens", "custom_instructions", "allow_list", ] } page_analyser_result = nlp_analyser.analyze( text=page_text, language=language, score_threshold=score_threshold, return_decision_process=True, allow_list=allow_list, **presidio_kwargs, ) all_text_line_results = map_back_entity_results( page_analyser_result, page_text_mapping, all_text_line_results ) # Process text in batches for LLM (same batching logic as AWS Comprehend) current_batch = "" current_batch_mapping = list() batch_char_count = 0 batch_word_count = 0 for i, text_line in enumerate(line_level_text_results_list): words = text_line.text.split() word_start_positions = list() current_pos = 0 for word in words: word_start_positions.append(current_pos) current_pos += len(word) + 1 word_idx = 0 while word_idx < len(words): word = words[word_idx] new_batch_char_count = len(current_batch) + len(word) + 1 # Check if we've hit the limit limit_reached = ( batch_word_count >= DEFAULT_NEW_BATCH_WORD_COUNT or new_batch_char_count >= DEFAULT_NEW_BATCH_CHAR_COUNT ) if limit_reached: # Add the current word to the batch first if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if not current_batch_mapping or current_batch_mapping[-1][1] != i: current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, line_characters[i], word_start_positions[word_idx], ) ) # Check if current word ends with phrase punctuation if ends_with_phrase_punctuation(word): # Remove 'CUSTOM' entities from the chosen_llm_entities list llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get( "allow_list", allow_list or [] ), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=output_folder, batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx += 1 else: # Look ahead in current line for phrase-ending punctuation or end of line lookahead_idx = word_idx + 1 lookahead_batch = current_batch lookahead_char_count = batch_char_count lookahead_word_count = batch_word_count lookahead_mapping = list(current_batch_mapping) # Continue adding words until we find phrase-ending punctuation or end of line while lookahead_idx < len(words): lookahead_word = words[lookahead_idx] # Add the word to lookahead batch if lookahead_batch: lookahead_batch += " " lookahead_char_count += 1 lookahead_batch += lookahead_word lookahead_char_count += len(lookahead_word) lookahead_word_count += 1 if not lookahead_mapping or lookahead_mapping[-1][1] != i: lookahead_mapping.append( ( lookahead_char_count - len(lookahead_word), i, text_line, line_characters[i], word_start_positions[lookahead_idx], ) ) # Check if this word ends with phrase punctuation if ends_with_phrase_punctuation(lookahead_word): break lookahead_idx += 1 # Use the lookahead batch (either found phrase end or reached end of line) current_batch = lookahead_batch batch_char_count = lookahead_char_count batch_word_count = lookahead_word_count current_batch_mapping = lookahead_mapping # Remove 'CUSTOM' entities from the chosen_llm_entities list llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" ] # Process current batch ( all_text_line_results, batch_input_tokens, batch_output_tokens, ) = do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get( "allow_list", allow_list or [] ), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=output_folder, batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get( "inference_method" ), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) comprehend_query_number += 1 # Reset batch current_batch = "" batch_word_count = 0 batch_char_count = 0 current_batch_mapping = list() word_idx = lookahead_idx + 1 else: # Normal case: add word to batch if current_batch: current_batch += " " batch_char_count += 1 current_batch += word batch_char_count += len(word) batch_word_count += 1 if not current_batch_mapping or current_batch_mapping[-1][1] != i: current_batch_mapping.append( ( batch_char_count - len(word), i, text_line, line_characters[i], word_start_positions[word_idx], ) ) word_idx += 1 # Process final batch if any if current_batch: # Remove 'CUSTOM' entities from the chosen_llm_entities list llm_chosen_redact_comprehend_entities = [ entity for entity in chosen_llm_entities if entity != "CUSTOM" ] all_text_line_results, batch_input_tokens, batch_output_tokens = ( do_llm_entity_detection_call( current_batch, current_batch_mapping, bedrock_runtime=bedrock_runtime, language=aws_language, allow_list=text_analyzer_kwargs.get("allow_list", allow_list or []), chosen_redact_llm_entities=llm_chosen_redact_comprehend_entities, all_text_line_results=all_text_line_results, model_choice=model_choice, temperature=text_analyzer_kwargs.get( "temperature", LLM_TEMPERATURE ), max_tokens=text_analyzer_kwargs.get( "max_tokens", LLM_MAX_NEW_TOKENS ), output_folder=output_folder, batch_number=comprehend_query_number + 1, custom_instructions=custom_llm_instructions, file_name=file_name, page_number=page_number, inference_method=text_analyzer_kwargs.get("inference_method"), # local_model=text_analyzer_kwargs.get("local_model"), # tokenizer=text_analyzer_kwargs.get("tokenizer"), # assistant_model=text_analyzer_kwargs.get("assistant_model"), client=text_analyzer_kwargs.get("client"), client_config=text_analyzer_kwargs.get("client_config"), api_url=text_analyzer_kwargs.get("api_url"), ) ) # Accumulate token usage llm_total_input_tokens += batch_input_tokens llm_total_output_tokens += batch_output_tokens comprehend_query_number += 1 # Process results for each line for i, text_line in enumerate(line_level_text_results_list): line_results = next( (results for idx, results in all_text_line_results if idx == i), [] ) if line_results: text_line_bounding_boxes = merge_text_bounding_boxes( line_results, line_characters[i] ) page_analyser_results.extend(line_results) page_analysed_bounding_boxes.extend(text_line_bounding_boxes) return ( page_analysed_bounding_boxes, comprehend_units_used, llm_model_name, llm_total_input_tokens, llm_total_output_tokens, ) def _char_bbox_and_text(char: Any) -> Tuple[Optional[List[float]], str]: """ Get bbox and text from a character object. Supports both pdfminer LTChar and PyMuPDF dict format {"text": ..., "bbox": [x0,y0,x1,y1], ...}. Returns (bbox_list or None, text_str). """ if isinstance(char, LTChar): return ( getattr(char, "bbox", None), getattr(char, "_text", None) or (char.get_text() if callable(getattr(char, "get_text", None)) else "") or "", ) if isinstance(char, dict) and "bbox" in char: bbox = char["bbox"] text = char.get("text", "") return ( bbox if isinstance(bbox, (list, tuple)) and len(bbox) >= 4 else None, text, ) return (None, "") def merge_text_bounding_boxes( analyser_results: dict, characters: List[Any], combine_pixel_dist: int = 20, vertical_padding: int = 0, ): """ Merge identified bounding boxes containing PII that are very close to one another. Supports both pdfminer LTChar objects and PyMuPDF-style dicts with "bbox" and "text" keys. """ analysed_bounding_boxes = list() original_bounding_boxes = list() # List to hold original bounding boxes if len(analyser_results) > 0 and len(characters) > 0: # Extract bounding box coordinates for sorting bounding_boxes = list() for result in analyser_results: char_boxes = [] char_text = [] for char in characters[result.start : result.end]: bbox, text = _char_bbox_and_text(char) if bbox is not None: char_boxes.append(bbox) char_text.append(text) if char_boxes: # Calculate the bounding box that encompasses all characters left = min(box[0] for box in char_boxes) bottom = min(box[1] for box in char_boxes) right = max(box[2] for box in char_boxes) top = max(box[3] for box in char_boxes) + vertical_padding bbox = [left, bottom, right, top] bounding_boxes.append( (bottom, left, result, bbox, char_text) ) # (y, x, result, bbox, text) # Store original bounding boxes original_bounding_boxes.append( { "text": "".join(char_text), # Keep both keys for compatibility across UI/table/render paths. # OCR word/line results use "bounding_box"; decision/review tables often use "boundingBox". "boundingBox": bbox, "bounding_box": bbox, "result": copy.deepcopy(result), } ) # Sort the results by y-coordinate and then by x-coordinate bounding_boxes.sort() if MERGE_BOUNDING_BOXES: merged_bounding_boxes = list() current_box = None current_y = None current_result = None current_text = list() for y, x, result, next_box, text in bounding_boxes: if current_y is None or current_box is None: # Initialize the first bounding box current_box = next_box current_y = next_box[1] current_result = result current_text = list(text) else: vertical_diff_bboxes = abs(next_box[1] - current_y) horizontal_diff_bboxes = abs(next_box[0] - current_box[2]) if ( vertical_diff_bboxes <= 5 and horizontal_diff_bboxes <= combine_pixel_dist ): # Merge bounding boxes # print("Merging boxes") merged_box = current_box.copy() merged_result = current_result merged_text = current_text.copy() merged_box[2] = next_box[2] # Extend horizontally merged_box[3] = max( current_box[3], next_box[3] ) # Adjust the top merged_result.end = max( current_result.end, result.end ) # Extend text range try: if current_result.entity_type != result.entity_type: merged_result.entity_type = ( current_result.entity_type + " - " + result.entity_type ) else: merged_result.entity_type = current_result.entity_type except Exception as e: print("Unable to combine result entity types:", e) if current_text: merged_text.append(" ") # Add space between texts merged_text.extend(text) merged_bounding_boxes.append( { "text": "".join(merged_text), "boundingBox": merged_box, "bounding_box": merged_box, "result": merged_result, } ) else: # Start a new bounding box current_box = next_box current_y = next_box[1] current_result = result current_text = list(text) # Combine original and merged bounding boxes analysed_bounding_boxes.extend(original_bounding_boxes) analysed_bounding_boxes.extend(merged_bounding_boxes) else: # Keep boxes without merging analysed_bounding_boxes.extend(original_bounding_boxes) # print("Analysed bounding boxes:", analysed_bounding_boxes) return analysed_bounding_boxes def recreate_page_line_level_ocr_results_with_page( page_line_level_ocr_results_with_words: dict, ): reconstructed_results = list() # Assume all lines belong to the same page, so we can just read it from one item # page = next(iter(page_line_level_ocr_results_with_words.values()))["page"] page = page_line_level_ocr_results_with_words["page"] for line_data in page_line_level_ocr_results_with_words["results"].values(): bbox = line_data.get("bounding_box") or line_data.get("boundingBox") if not bbox: continue text = line_data["text"] if line_data["line"]: line_number = line_data["line"] # Support both "confidence" (Textract) and "conf" (other OCR) if line_data["words"]: conf = sum( word.get("confidence", word.get("conf", 0.0)) for word in line_data["words"] ) / len(line_data["words"]) else: conf = line_data.get("confidence", line_data.get("conf", 0.0)) # Recreate the OCRResult model = line_data.get("model") line_result = OCRResult( text=text, left=bbox[0], top=bbox[1], width=bbox[2] - bbox[0], height=bbox[3] - bbox[1], line=line_number, conf=round(float(conf), 0), model=model, ) reconstructed_results.append(line_result) page_line_level_ocr_results_with_page = { "page": page, "results": reconstructed_results, } return page_line_level_ocr_results_with_page _PUNCTUATION_SPLIT_RE = re.compile(r"([(\[{]*)(.*?)_?([.,?!:;)\}\]]*)$") def split_words_and_punctuation_from_line( line_of_words: List[OCRResult], ) -> List[OCRResult]: """ Takes a list of OCRResult objects and splits words with trailing/leading punctuation. For a word like "example.", it creates two new OCRResult objects for "example" and "." and estimates their bounding boxes. Words with internal hyphens like "high-tech" are preserved. """ # Punctuation that will be split off. Hyphen is not included. new_word_list = list() for word_result in line_of_words: word_text = word_result.text # This regex finds a central "core" word, and captures leading and trailing punctuation. # Compiled once to avoid re-parsing the regex on every line. # Handles cases like "(word)." -> group1='(', group2='word', group3='.' match = _PUNCTUATION_SPLIT_RE.match(word_text) # Handle words with internal hyphens that might confuse the regex if "-" in word_text and not match.group(2): core_part_text = word_text leading_punc = "" trailing_punc = "" elif match: leading_punc, core_part_text, trailing_punc = match.groups() else: # Failsafe new_word_list.append(word_result) continue # If no split is needed, just add the original and continue if not leading_punc and not trailing_punc: new_word_list.append(word_result) continue # --- A split is required --- # Estimate new bounding boxes by proportionally allocating width original_width = word_result.width if not word_text or original_width == 0: continue # Failsafe avg_char_width = original_width / len(word_text) current_left = word_result.left # Add leading punctuation if it exists if leading_punc: punc_width = avg_char_width * len(leading_punc) new_word_list.append( OCRResult( text=leading_punc, left=current_left, top=word_result.top, width=punc_width, height=word_result.height, conf=word_result.conf, model=word_result.model, ) ) current_left += punc_width # Add the core part of the word if core_part_text: core_width = avg_char_width * len(core_part_text) new_word_list.append( OCRResult( text=core_part_text, left=current_left, top=word_result.top, width=core_width, height=word_result.height, conf=word_result.conf, model=word_result.model, ) ) current_left += core_width # Add trailing punctuation if it exists if trailing_punc: punc_width = avg_char_width * len(trailing_punc) new_word_list.append( OCRResult( text=trailing_punc, left=current_left, top=word_result.top, width=punc_width, height=word_result.height, conf=word_result.conf, model=word_result.model, ) ) return new_word_list def create_ocr_result_with_children( combined_results: dict, i: int, current_bbox: dict, current_line: list ): combined_results["text_line_" + str(i)] = { "line": i, "text": current_bbox.text, "bounding_box": ( current_bbox.left, current_bbox.top, current_bbox.left + current_bbox.width, current_bbox.top + current_bbox.height, ), "words": [ { "text": word.text, "bounding_box": ( word.left, word.top, word.left + word.width, word.top + word.height, ), "conf": word.conf, } for word in current_line ], "conf": current_bbox.conf, "model": getattr(current_bbox, "model", None), } return combined_results["text_line_" + str(i)] def combine_ocr_results( ocr_results: List[OCRResult], x_threshold: float = 50.0, y_threshold: float = 12.0, page: int = 1, preserve_line_boxes: bool = False, reading_order_mode: Optional[str] = None, ): """ Group OCR results into lines, splitting words from punctuation. When reading_order_mode is "column" (default), boxes are ordered for multi-column layouts before line numbers are assigned. Set preserve_line_boxes=True to keep each input box as its own line (Paddle line-level fast path). """ if not ocr_results: return {"page": page, "results": []}, {"page": page, "results": {}} mode = (reading_order_mode or LOCAL_OCR_READING_ORDER).strip().lower() if mode == "paddle_native": # Force Paddle's native textline boxes to be treated as final line groups. preserve_line_boxes = True lines, _, _ = build_line_groups( ocr_results, reading_order_mode=mode, preserve_line_boxes=preserve_line_boxes, y_threshold=y_threshold if mode == "legacy" else None, ) page_line_level_ocr_results = list() page_line_level_ocr_results_with_words = {} line_counter = 1 for line in lines: if not line: continue # Process the line to split punctuation from words processed_line = split_words_and_punctuation_from_line(line) # Re-calculate the line-level text and bounding box from the ORIGINAL words line_text = " ".join([word.text for word in line]) line_left = line[0].left line_top = min(word.top for word in line) line_right = max(word.left + word.width for word in line) line_bottom = max(word.top + word.height for word in line) line_conf = round( sum(word.conf for word in line) / len(line), 0 ) # This is mean confidence for the line final_line_bbox = OCRResult( text=line_text, left=line_left, top=line_top, width=line_right - line_left, height=line_bottom - line_top, line=line_counter, conf=line_conf, model=model_from_ocr_boxes(line), ) page_line_level_ocr_results.append(final_line_bbox) # Use the PROCESSED line to create the children. Creates a result within page_line_level_ocr_results_with_words page_line_level_ocr_results_with_words["text_line_" + str(line_counter)] = ( create_ocr_result_with_children( page_line_level_ocr_results_with_words, line_counter, final_line_bbox, processed_line, ) ) line_counter += 1 page_level_results_with_page = { "page": page, "results": page_line_level_ocr_results, } page_level_results_with_words = { "page": page, "results": page_line_level_ocr_results_with_words, } return page_level_results_with_page, page_level_results_with_words