#!/usr/bin/env python3 from __future__ import annotations import argparse import json import os os.environ.setdefault("TRANSFORMERS_NO_TF", "1") os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1") os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1") os.environ["USE_TF"] = "0" os.environ["USE_FLAX"] = "0" os.environ["USE_TORCH"] = "1" from common import decode_span_matrix, load_onnx_session, run_onnx_span, sigmoid_np def replacement(label: str) -> str: return f"[PII:{label}]" def mask_text(text: str, spans: list[dict]) -> str: out = text for span in sorted(spans, key=lambda item: (item["start"], item["end"]), reverse=True): out = out[: span["start"]] + replacement(span["label"]) + out[span["end"] :] return out def predict(text: str, session, tokenizer, config, min_score: float): encoded = tokenizer(text, return_offsets_mapping=True, return_tensors="np", truncation=True) offsets = [tuple(item) for item in encoded["offset_mapping"][0].tolist()] span_logits = run_onnx_span(session, encoded) span_scores = sigmoid_np(span_logits[0]) spans = decode_span_matrix(text, offsets, span_scores, config, min_score) for span in spans: span["replacement"] = replacement(span["label"]) return spans def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--model", required=True) parser.add_argument("--text", required=True) parser.add_argument("--min-score", type=float, default=0.5) parser.add_argument("--json", action="store_true") args = parser.parse_args() session, tokenizer, config = load_onnx_session(args.model, onnx_file="model_quantized.onnx", onnx_subfolder="onnx") spans = predict(args.text, session, tokenizer, config, args.min_score) result = { "model": args.model, "backend": "onnx_global_pointer_q8", "min_score": args.min_score, "spans": spans, "masked_text": mask_text(args.text, spans), } if args.json: print(json.dumps(result, indent=2, ensure_ascii=False)) else: print(result["masked_text"]) if __name__ == "__main__": main()