#!/usr/bin/env python3 from __future__ import annotations import json import tempfile from pathlib import Path from typing import Any import numpy as np from huggingface_hub import HfApi, hf_hub_download from transformers import AutoConfig, AutoTokenizer TOKENIZER_FILES = [ "tokenizer_config.json", "tokenizer.json", "special_tokens_map.json", "vocab.txt", "vocab.json", "merges.txt", "added_tokens.json", "sentencepiece.bpe.model", "spiece.model", ] DEFAULT_LABEL_MAX_SPAN_TOKENS = { "PPSN": 9, "POSTCODE": 8, "PHONE_NUMBER": 10, "PASSPORT_NUMBER": 8, "BANK_ROUTING_NUMBER": 6, "ACCOUNT_NUMBER": 19, "CREDIT_DEBIT_CARD": 12, "SWIFT_BIC": 8, "EMAIL": 15, "FIRST_NAME": 5, "LAST_NAME": 8, } DEFAULT_LABEL_MIN_NONSPACE_CHARS = { "PPSN": 8, "POSTCODE": 6, "PHONE_NUMBER": 7, "PASSPORT_NUMBER": 7, "BANK_ROUTING_NUMBER": 6, "ACCOUNT_NUMBER": 6, "CREDIT_DEBIT_CARD": 12, "SWIFT_BIC": 8, "EMAIL": 6, "FIRST_NAME": 2, "LAST_NAME": 2, } OUTPUT_PRIORITY = { "PPSN": 0, "PASSPORT_NUMBER": 1, "ACCOUNT_NUMBER": 2, "BANK_ROUTING_NUMBER": 3, "CREDIT_DEBIT_CARD": 4, "PHONE_NUMBER": 5, "SWIFT_BIC": 6, "POSTCODE": 7, "EMAIL": 8, "FIRST_NAME": 9, "LAST_NAME": 10, } def normalize_entity_name(label: str) -> str: label = (label or "").strip() if label.startswith("B-") or label.startswith("I-"): label = label[2:] return label.upper() def _sanitize_tokenizer_dir(tokenizer_path: Path) -> str: tokenizer_cfg_path = tokenizer_path / "tokenizer_config.json" if not tokenizer_cfg_path.exists(): return str(tokenizer_path) data = json.loads(tokenizer_cfg_path.read_text(encoding="utf-8")) if "fix_mistral_regex" not in data: return str(tokenizer_path) tmpdir = Path(tempfile.mkdtemp(prefix="openmed_span_tokenizer_")) keep = set(TOKENIZER_FILES) for child in tokenizer_path.iterdir(): if child.is_file() and child.name in keep: (tmpdir / child.name).write_bytes(child.read_bytes()) data.pop("fix_mistral_regex", None) (tmpdir / "tokenizer_config.json").write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") return str(tmpdir) def safe_auto_tokenizer(tokenizer_ref: str): tokenizer_path = Path(tokenizer_ref) if tokenizer_path.exists(): tokenizer_ref = _sanitize_tokenizer_dir(tokenizer_path) else: api = HfApi() files = set(api.list_repo_files(repo_id=tokenizer_ref, repo_type="model")) tmpdir = Path(tempfile.mkdtemp(prefix="openmed_remote_span_tokenizer_")) copied = False for name in TOKENIZER_FILES: if name not in files: continue src = hf_hub_download(repo_id=tokenizer_ref, filename=name, repo_type="model") (tmpdir / Path(name).name).write_bytes(Path(src).read_bytes()) copied = True if copied: tokenizer_ref = _sanitize_tokenizer_dir(tmpdir) try: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True) except Exception: pass try: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False) except TypeError: pass try: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True) except Exception: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=False) def label_names_from_config(config) -> list[str]: names = list(getattr(config, "span_label_names", [])) if not names: raise ValueError("Missing span_label_names in config") return [normalize_entity_name(name) for name in names] def label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]: raw = getattr(config, "span_label_thresholds", None) or {} out = {normalize_entity_name(key): float(value) for key, value in raw.items()} for label in label_names_from_config(config): out.setdefault(label, float(default_threshold)) return out def label_max_span_tokens_from_config(config) -> dict[str, int]: raw = getattr(config, "span_label_max_span_tokens", None) or {} out = {normalize_entity_name(key): int(value) for key, value in raw.items()} for label, value in DEFAULT_LABEL_MAX_SPAN_TOKENS.items(): out.setdefault(label, value) for label in label_names_from_config(config): out.setdefault(label, 8) return out def label_min_nonspace_chars_from_config(config) -> dict[str, int]: raw = getattr(config, "span_label_min_nonspace_chars", None) or {} out = {normalize_entity_name(key): int(value) for key, value in raw.items()} for label, value in DEFAULT_LABEL_MIN_NONSPACE_CHARS.items(): out.setdefault(label, value) for label in label_names_from_config(config): out.setdefault(label, 1) return out def overlaps(a: dict, b: dict) -> bool: return not (a["end"] <= b["start"] or b["end"] <= a["start"]) def dedupe_spans(spans: list[dict]) -> list[dict]: ordered = sorted( spans, key=lambda item: (-float(item.get("score", 0.0)), item["start"], item["end"], OUTPUT_PRIORITY.get(item["label"], 99)), ) kept = [] for span in ordered: if any(overlaps(span, other) for other in kept): continue kept.append(span) kept.sort(key=lambda item: (item["start"], item["end"], OUTPUT_PRIORITY.get(item["label"], 99))) return kept def valid_offset(offset: tuple[int, int]) -> bool: return bool(offset) and int(offset[1]) > int(offset[0]) def nonspace_length(text: str, start: int, end: int) -> int: return sum(0 if ch.isspace() else 1 for ch in text[int(start) : int(end)]) def alnum_upper(text: str) -> str: return "".join(ch for ch in text.upper() if ch.isalnum()) def is_reasonable_span_text(label: str, text: str, start: int, end: int) -> bool: value = text[int(start) : int(end)].strip() if not value: return False upper = alnum_upper(value) if label in {"FIRST_NAME", "LAST_NAME"}: if not any(ch.isalpha() for ch in value): return False if any(ch.isdigit() for ch in value): return False if start > 0 and text[int(start) - 1].isdigit(): return False return True if label == "EMAIL": if "@" not in value: return False local, _, domain = value.partition("@") return bool(local) and "." in domain if label == "PHONE_NUMBER": digits = "".join(ch for ch in value if ch.isdigit()) return len(digits) >= 7 if label == "PPSN": return bool(len(upper) in {8, 9} and upper[:7].isdigit() and upper[7:].isalpha()) if label == "POSTCODE": compact = value.replace(" ", "").replace("\u00A0", "").replace("\u202F", "") if any(not (ch.isalnum() or ch.isspace()) for ch in value): return False if len(compact) != 7: return False routing = compact[:3] unique = compact[3:] routing_ok = bool((routing[0].isalpha() and routing[1:].isdigit()) or routing == "D6W") unique_ok = bool(len(unique) == 4 and unique[0].isalpha() and unique[1:].isalnum()) return routing_ok and unique_ok if label == "PASSPORT_NUMBER": return 7 <= len(upper) <= 10 and upper.isalnum() if label == "BANK_ROUTING_NUMBER": digits = "".join(ch for ch in value if ch.isdigit()) return len(digits) == 6 if label == "SWIFT_BIC": return len(upper) in {8, 11} and upper.isalnum() if label == "CREDIT_DEBIT_CARD": digits = "".join(ch for ch in value if ch.isdigit()) return 12 <= len(digits) <= 19 if label == "ACCOUNT_NUMBER": if upper.startswith("IE"): return len(upper) == 22 digits = "".join(ch for ch in value if ch.isdigit()) return 6 <= len(digits) <= 34 return True def prefer_long_name_spans(spans: list[dict], thresholds: dict[str, float]) -> list[dict]: if not spans: return spans preferred: list[dict] = [] consumed: set[int] = set() for index, span in enumerate(spans): if index in consumed: continue label = span["label"] if label not in {"FIRST_NAME", "LAST_NAME"}: preferred.append(span) continue same_start = [ (other_index, other) for other_index, other in enumerate(spans) if other_index not in consumed and other["label"] == label and other["start"] == span["start"] ] if len(same_start) == 1: preferred.append(span) continue for other_index, _ in same_start: consumed.add(other_index) best_by_score = max(same_start, key=lambda item: float(item[1].get("score", 0.0)))[1] longest = max(same_start, key=lambda item: (item[1]["end"] - item[1]["start"], float(item[1].get("score", 0.0))))[1] threshold = float(thresholds.get(label, 0.5)) if float(longest.get("score", 0.0)) >= max(threshold + 0.15, float(best_by_score.get("score", 0.0)) * 0.7): preferred.append(longest) else: preferred.append(best_by_score) return preferred def decode_span_matrix( text: str, offsets: list[tuple[int, int]], span_scores: np.ndarray, config, min_score: float, ) -> list[dict]: label_names = label_names_from_config(config) thresholds = label_thresholds_from_config(config, min_score) max_span_tokens = label_max_span_tokens_from_config(config) min_nonspace_chars = label_min_nonspace_chars_from_config(config) if span_scores.ndim != 3: raise ValueError(f"Expected [num_labels, seq_len, seq_len] span scores, got shape {span_scores.shape}") num_labels, seq_len, _ = span_scores.shape spans: list[dict] = [] for label_index in range(min(num_labels, len(label_names))): label = label_names[label_index] threshold = thresholds.get(label, min_score) max_width = max(1, int(max_span_tokens.get(label, 8))) min_chars = max(1, int(min_nonspace_chars.get(label, 1))) for start_idx in range(seq_len): start_offset = offsets[start_idx] if not valid_offset(start_offset): continue max_end = min(seq_len, start_idx + max_width) for end_idx in range(start_idx, max_end): end_offset = offsets[end_idx] if not valid_offset(end_offset): continue score = float(span_scores[label_index, start_idx, end_idx]) if score < threshold: continue start_char = int(start_offset[0]) end_char = int(end_offset[1]) if end_char <= start_char: continue if nonspace_length(text, start_char, end_char) < min_chars: continue if not is_reasonable_span_text(label, text, start_char, end_char): continue spans.append({"start": start_char, "end": end_char, "label": label, "score": score}) return dedupe_spans(prefer_long_name_spans(spans, thresholds)) def sigmoid_np(values: np.ndarray) -> np.ndarray: clipped = np.clip(values, -60.0, 60.0) return 1.0 / (1.0 + np.exp(-clipped)) def load_onnx_session(model_ref: str, onnx_file: str = "model_quantized.onnx", onnx_subfolder: str = "onnx"): import onnxruntime as ort model_path = Path(model_ref) if model_path.exists(): candidates = [] if onnx_subfolder: candidates.append(model_path / onnx_subfolder / onnx_file) candidates.append(model_path / onnx_file) onnx_path = next((path for path in candidates if path.exists()), candidates[0]) config = AutoConfig.from_pretrained(model_ref) tokenizer = safe_auto_tokenizer(model_ref) else: remote_name = f"{onnx_subfolder}/{onnx_file}" if onnx_subfolder else onnx_file onnx_path = Path(hf_hub_download(repo_id=model_ref, filename=remote_name, repo_type="model")) config = AutoConfig.from_pretrained(model_ref) tokenizer = safe_auto_tokenizer(model_ref) session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) return session, tokenizer, config def run_onnx_span(session, encoded: dict[str, Any]) -> np.ndarray: feed = {} input_names = {item.name for item in session.get_inputs()} for key, value in encoded.items(): if key == "offset_mapping": continue if key in input_names: feed[key] = value outputs = session.run(None, feed) if not outputs: raise ValueError("ONNX session returned no outputs") return outputs[0]