| """ |
| LLM-based entity detection using AWS Bedrock. |
| This module provides functions to detect PII entities using LLMs instead of AWS llm. |
| """ |
|
|
| import json |
| import os |
| import re |
| from datetime import datetime |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import boto3 |
| from gradio import Progress |
|
|
| from tools.config import ( |
| CHOSEN_LLM_PII_INFERENCE_METHOD, |
| CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE, |
| CLOUD_LLM_PII_MODEL_CHOICE, |
| INFERENCE_SERVER_API_URL, |
| LLM_MAX_NEW_TOKENS, |
| LLM_TEMPERATURE, |
| model_name_map, |
| ) |
| from tools.llm_entity_detection_prompts import ( |
| create_entity_detection_prompt, |
| create_entity_detection_system_prompt, |
| ) |
|
|
| |
| LLM_LOG_TABULAR_NAME_MAX_LEN = 25 |
|
|
| |
| try: |
| |
|
|
| from tools.llm_funcs import ( |
| send_request, |
| ) |
| except ImportError as e: |
| print(f"Warning: Could not import LLM functions: {e}") |
| print("LLM-based entity detection will not be available.") |
| print("Please ensure llm_funcs.py is in the tools folder.") |
| call_aws_bedrock = None |
| construct_azure_client = None |
| ResponseObject = None |
|
|
|
|
| def _find_text_in_passage( |
| search_text: str, |
| original_text: str, |
| reported_offset: Optional[int] = None, |
| start_from: int = 0, |
| ) -> Optional[Tuple[int, int]]: |
| """ |
| Find the position of search_text in original_text and return (begin, end) offsets. |
| |
| Only considers occurrences at or after start_from. This allows a "first pass" where |
| each entity is matched starting after the previous entity's end, so repeated phrases |
| (e.g. "University of Notre Dame" vs "University" + "of Notre Dame") map to the |
| correct occurrence. |
| |
| Args: |
| search_text: The text to search for |
| original_text: The text to search in |
| reported_offset: Optional offset reported by LLM (used to disambiguate multiple matches) |
| start_from: Only consider matches at or after this position (default 0). |
| |
| Returns: |
| Tuple of (begin_offset, end_offset) if found, None otherwise |
| """ |
| if not search_text: |
| return None |
|
|
| def first_or_closest( |
| positions: List[int], length: int |
| ) -> Optional[Tuple[int, int]]: |
| candidates = [p for p in positions if p >= start_from] |
| if not candidates: |
| return None |
| if reported_offset is not None: |
| closest_pos = min(candidates, key=lambda p: abs(p - reported_offset)) |
| else: |
| closest_pos = min(candidates) |
| return (closest_pos, closest_pos + length) |
|
|
| |
| search_text_clean = search_text.rstrip("...").strip() |
|
|
| |
| all_positions = [] |
| start = 0 |
| while True: |
| pos = original_text.find(search_text, start) |
| if pos == -1: |
| break |
| all_positions.append(pos) |
| start = pos + 1 |
|
|
| if all_positions: |
| result = first_or_closest(all_positions, len(search_text)) |
| if result is not None: |
| return result |
|
|
| |
| if search_text_clean != search_text: |
| all_positions_clean = [] |
| start = 0 |
| while True: |
| pos = original_text.find(search_text_clean, start) |
| if pos == -1: |
| break |
| all_positions_clean.append(pos) |
| start = pos + 1 |
|
|
| if all_positions_clean: |
| result = first_or_closest(all_positions_clean, len(search_text_clean)) |
| if result is not None: |
| return result |
|
|
| |
| search_text_lower = search_text.lower() |
| original_text_lower = original_text.lower() |
| all_positions_lower = [] |
| start = 0 |
| while True: |
| pos = original_text_lower.find(search_text_lower, start) |
| if pos == -1: |
| break |
| all_positions_lower.append(pos) |
| start = pos + 1 |
|
|
| if all_positions_lower: |
| result = first_or_closest(all_positions_lower, len(search_text)) |
| if result is not None: |
| return result |
|
|
| |
| if search_text_clean != search_text: |
| search_text_clean_lower = search_text_clean.lower() |
| all_positions_clean_lower = [] |
| start = 0 |
| while True: |
| pos = original_text_lower.find(search_text_clean_lower, start) |
| if pos == -1: |
| break |
| all_positions_clean_lower.append(pos) |
| start = pos + 1 |
|
|
| if all_positions_clean_lower: |
| result = first_or_closest(all_positions_clean_lower, len(search_text_clean)) |
| if result is not None: |
| return result |
|
|
| return None |
|
|
|
|
| def _find_all_text_in_passage( |
| search_text: str, original_text: str |
| ) -> List[Tuple[int, int]]: |
| """ |
| Find all positions of search_text in original_text and return a list of (begin, end) offsets. |
| Uses the same search strategy as _find_text_in_passage (exact, then cleaned, then case-insensitive). |
| LLM offset values are never used; positions come only from search. |
| |
| Returns: |
| List of (begin_offset, end_offset) tuples, sorted by begin_offset (ascending). |
| """ |
| if not search_text: |
| return [] |
|
|
| search_text_clean = search_text.rstrip("...").strip() |
|
|
| def find_all_exact(needle: str, haystack: str) -> List[Tuple[int, int]]: |
| result = [] |
| start = 0 |
| while True: |
| pos = haystack.find(needle, start) |
| if pos == -1: |
| break |
| result.append((pos, pos + len(needle))) |
| start = pos + 1 |
| return result |
|
|
| positions = find_all_exact(search_text, original_text) |
| if positions: |
| return sorted(positions, key=lambda p: p[0]) |
|
|
| if search_text_clean != search_text: |
| positions = find_all_exact(search_text_clean, original_text) |
| if positions: |
| return sorted(positions, key=lambda p: p[0]) |
|
|
| |
| needle_lower = search_text.lower() |
| haystack_lower = original_text.lower() |
| positions = find_all_exact(needle_lower, haystack_lower) |
| if positions: |
| |
| return sorted( |
| [(p[0], p[0] + len(search_text)) for p in positions], key=lambda p: p[0] |
| ) |
|
|
| if search_text_clean != search_text: |
| needle_clean_lower = search_text_clean.lower() |
| positions = find_all_exact(needle_clean_lower, haystack_lower) |
| if positions: |
| return sorted( |
| [(p[0], p[0] + len(search_text_clean)) for p in positions], |
| key=lambda p: p[0], |
| ) |
|
|
| return [] |
|
|
|
|
| def _entity_get(obj: Dict[str, Any], key: str, default: Any = None) -> Any: |
| """Get value from entity dict with case-insensitive key lookup (e.g. BeginOffset vs beginOffset).""" |
| key_lower = key.lower() |
| for k, v in obj.items(): |
| if k.lower() == key_lower: |
| return v |
| return default |
|
|
|
|
| def parse_llm_entity_response( |
| response_text: str, |
| original_text: str, |
| ) -> List[Dict[str, Any]]: |
| """ |
| Parse LLM response and extract entity information. |
| LLM BeginOffset/EndOffset are used only to define order. Positions are |
| resolved by a first-pass text search: for each entity (in reported order), |
| search for the entity's Text in the passage starting from the end of the |
| preceding entity's resolved span. If not found there, search from the |
| start of the passage. This ensures repeated phrases (e.g. "University of |
| Notre Dame" once, then "University" and "of Notre Dame" separately) map |
| to the correct occurrence and avoid duplicate redaction boxes. |
| |
| Args: |
| response_text: The LLM response text (should contain JSON) |
| original_text: The original text that was analyzed (for validation) |
| |
| Returns: |
| List of entity dictionaries with keys: Type, BeginOffset, EndOffset, Score, Text |
| """ |
| entities_out: List[Dict[str, Any]] = [] |
|
|
| |
| |
| response_text = re.sub( |
| r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE |
| ) |
| response_text = re.sub( |
| r"<thinking>.*?</thinking>", "", response_text, flags=re.DOTALL | re.IGNORECASE |
| ) |
|
|
| |
| |
| json_str = None |
| if "```json" in response_text or "```" in response_text: |
| code_block = re.search( |
| r"```(?:json)?\s*\n?(.*?)(?:\n?```|$)", response_text, re.DOTALL |
| ) |
| if code_block: |
| candidate = code_block.group(1).strip() |
| |
| candidate = re.sub(r"<end_of_turn>\s*$", "", candidate, flags=re.IGNORECASE) |
| candidate = candidate.rstrip() |
| |
| start = candidate.find("{") |
| if start >= 0: |
| depth = 0 |
| for i in range(start, len(candidate)): |
| if candidate[i] == "{": |
| depth += 1 |
| elif candidate[i] == "}": |
| depth -= 1 |
| if depth == 0: |
| json_str = candidate[start : i + 1] |
| break |
| if json_str is None: |
| json_str = candidate[start:] |
|
|
| |
| if json_str is None: |
| json_match = re.search( |
| r'\{[^{}]*"entities"[^{}]*\[.*?\].*?\}', response_text, re.DOTALL |
| ) |
| if not json_match: |
| json_match = re.search(r'\{.*?"entities".*?\}', response_text, re.DOTALL) |
| if json_match: |
| json_str = json_match.group(0) |
|
|
| if json_str: |
| try: |
| |
| json_str = json_str.strip() |
| |
| json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE) |
| json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE) |
| |
| json_str = re.sub(r"<end_of_turn>\s*$", "", json_str, flags=re.IGNORECASE) |
| json_str = json_str.strip() |
| |
| start = json_str.find("{") |
| if start >= 0: |
| depth = 0 |
| for i in range(start, len(json_str)): |
| if json_str[i] == "{": |
| depth += 1 |
| elif json_str[i] == "}": |
| depth -= 1 |
| if depth == 0: |
| json_str = json_str[start : i + 1] |
| break |
|
|
| |
| |
| json_str = re.sub(r",\s*}", "}", json_str) |
| json_str = re.sub(r",\s*]", "]", json_str) |
|
|
| |
| |
| |
| def fix_unquoted_value(match): |
| key_part = match.group(1) |
| value = match.group(2) |
| separator = match.group(3) |
| |
| if re.match( |
| r"^[A-Za-z_][A-Za-z0-9_]*$", value |
| ) and value.lower() not in ["true", "false", "null"]: |
| return f'{key_part}: "{value}"{separator}' |
| return match.group(0) |
|
|
| |
| |
| |
| json_str = re.sub( |
| r'("[\w]+")\s*:\s*([A-Za-z_][A-Za-z0-9_]*)\s*([,}\]])', |
| fix_unquoted_value, |
| json_str, |
| ) |
|
|
| |
| json_str = re.sub( |
| r'("[\w]+")\s*:\s*([A-Za-z_][A-Za-z0-9_]*)\s*(\n)', |
| r'\1: "\2"\3', |
| json_str, |
| ) |
|
|
| |
| |
| json_str = json_str.rstrip().rstrip("\r\t") |
| json_str = re.sub(r"[ \t\r\n]+$", "", json_str) |
| json_str = re.sub(r"`+$", "", json_str) |
| json_str = re.sub(r"<end_of_turn>\s*$", "", json_str, flags=re.IGNORECASE) |
| json_str = json_str.rstrip() |
| start = json_str.find("{") |
| if start >= 0: |
| depth = 0 |
| for i in range(start, len(json_str)): |
| if json_str[i] == "{": |
| depth += 1 |
| elif json_str[i] == "}": |
| depth -= 1 |
| if depth == 0: |
| json_str = json_str[start : i + 1] |
| break |
|
|
| |
| try: |
| data = json.loads(json_str) |
| except json.JSONDecodeError as e: |
| |
| |
| print( |
| f"Initial JSON parse failed: {e}. Attempting more aggressive fixes..." |
| ) |
|
|
| |
| |
| def quote_unquoted_identifier(match): |
| prefix = match.group(1) |
| value = match.group(2) |
| suffix = match.group(3) |
| |
| if re.match( |
| r"^[A-Za-z_][A-Za-z0-9_]*$", value |
| ) and value.lower() not in ["true", "false", "null"]: |
| return f'{prefix}: "{value}"{suffix}' |
| return match.group(0) |
|
|
| |
| json_str = re.sub( |
| r"(:\s*)([A-Za-z_][A-Za-z0-9_]*)(\s*[,}\]])", |
| quote_unquoted_identifier, |
| json_str, |
| ) |
|
|
| |
| try: |
| data = json.loads(json_str) |
| except json.JSONDecodeError as e2: |
| print(f"JSON parsing failed after fixes: {e2}") |
| print(f"Cleaned JSON string (first 1000 chars): {json_str[:1000]}") |
| raise e2 |
|
|
| if "entities" in data and isinstance(data["entities"], list): |
| |
| raw_entities: List[Dict[str, Any]] = [] |
| for entity in data["entities"]: |
| entity_type_val = _entity_get(entity, "Type") |
| if entity_type_val is None: |
| print(f"Warning: Entity missing Type field: {entity}") |
| continue |
| entity_text = _entity_get(entity, "Text", "") |
| reported_begin = _entity_get(entity, "BeginOffset") |
| if reported_begin is not None: |
| try: |
| reported_begin = int(reported_begin) |
| except (ValueError, TypeError): |
| reported_begin = None |
| reported_end = _entity_get(entity, "EndOffset") |
| if reported_end is not None: |
| try: |
| reported_end = int(reported_end) |
| except (ValueError, TypeError): |
| reported_end = None |
| |
| if ( |
| not entity_text |
| and reported_begin is not None |
| and reported_end is not None |
| and 0 <= reported_begin < reported_end <= len(original_text) |
| ): |
| entity_text = original_text[reported_begin:reported_end] |
| if not entity_text: |
| print( |
| f"Warning: Entity of type '{entity_type_val}' has no Text value and invalid offsets" |
| ) |
| continue |
| raw_entities.append( |
| { |
| "Type": str(entity_type_val), |
| "Text": entity_text, |
| "Score": float(_entity_get(entity, "Score", 0.8)), |
| "reported_begin": reported_begin, |
| } |
| ) |
|
|
| |
| |
| |
| ordered = sorted( |
| raw_entities, |
| key=lambda r: ( |
| r["reported_begin"] is None, |
| r["reported_begin"] or 0, |
| ), |
| ) |
| search_start = 0 |
| for rec in ordered: |
| search_text = rec["Text"] |
| result = _find_text_in_passage( |
| search_text, |
| original_text, |
| reported_offset=rec["reported_begin"], |
| start_from=search_start, |
| ) |
| if result is None: |
| result = _find_text_in_passage( |
| search_text, |
| original_text, |
| reported_offset=rec["reported_begin"], |
| start_from=0, |
| ) |
| if result is None: |
| print( |
| f"Warning: Could not find text '{search_text[:50]}...' in original passage" |
| ) |
| continue |
| start, end = result |
| entities_out.append( |
| { |
| "Type": rec["Type"], |
| "BeginOffset": start, |
| "EndOffset": end, |
| "Score": rec["Score"], |
| "Text": original_text[start:end], |
| } |
| ) |
| search_start = end |
| except json.JSONDecodeError as e: |
| print(f"Error parsing JSON from LLM response: {e}") |
| print(f"Response text: {response_text[:500]}") |
| except (ValueError, KeyError) as e: |
| print(f"Error processing entity data: {e}") |
| else: |
| print("Warning: Could not find JSON in LLM response") |
| print(f"Response text: {response_text[:500]}") |
|
|
| return entities_out |
|
|
|
|
| def _sanitize_for_filename(s: str, max_len: Optional[int] = None) -> str: |
| """Sanitize a string for use in a filename (alphanumeric, spaces to underscores).""" |
| out = ( |
| "".join(c for c in (s or "") if c.isalnum() or c in (" ", "-", "_")) |
| .strip() |
| .replace(" ", "_") |
| ) |
| if max_len is not None and len(out) > max_len: |
| out = out[:max_len] |
| return out or "unknown" |
|
|
|
|
| def save_llm_prompt_response( |
| system_prompt: str, |
| user_prompt: str, |
| response_text: str, |
| output_folder: str, |
| batch_number: int, |
| model_choice: str, |
| entities_to_detect: List[str], |
| language: str, |
| temperature: float, |
| max_tokens: int, |
| file_name: Optional[str] = None, |
| page_number: Optional[int] = None, |
| sheet_name: Optional[str] = None, |
| column_name: Optional[str] = None, |
| row_number: Optional[int] = None, |
| input_tokens: Optional[int] = None, |
| output_tokens: Optional[int] = None, |
| ) -> str: |
| """ |
| Save LLM prompt and response to a text file for traceability. |
| |
| Writes the exact system prompt and user prompt that were sent to the model |
| (e.g. for local transformers, inference-server, AWS, etc.). Each section is |
| clearly delimited so the log never duplicates or conflates system vs user. |
| |
| Args: |
| system_prompt: System prompt sent to LLM (exactly as passed to the model). |
| user_prompt: User prompt sent to LLM (exactly as passed to the model). |
| response_text: Response text from LLM |
| output_folder: Output folder path |
| batch_number: Batch number for this call |
| model_choice: Model used |
| entities_to_detect: List of entities being detected |
| language: Language code |
| temperature: Temperature used |
| max_tokens: Max tokens used |
| file_name: Optional file name (without extension) for the filename / log header |
| page_number: Optional page number (0-based) for the filename; displayed in log as 1-based. |
| sheet_name: Optional Excel sheet name (tabular data); included in log and filename if present. |
| column_name: Optional column name (tabular data); included in log and filename (shortened if long). |
| row_number: Optional row number (1-based for display; tabular data); included in log and filename. |
| input_tokens: Optional input token count from the LLM call |
| output_tokens: Optional output token count from the LLM call |
| |
| Returns: |
| Path to the saved file |
| """ |
| |
| system_prompt_str = (system_prompt if system_prompt is not None else "").strip() |
| user_prompt_str = (user_prompt if user_prompt is not None else "").strip() |
|
|
| |
| llm_logs_folder = os.path.join(output_folder, "llm_prompts_responses") |
| os.makedirs(llm_logs_folder, exist_ok=True) |
|
|
| |
| is_tabular = ( |
| column_name is not None or sheet_name is not None or row_number is not None |
| ) |
| if is_tabular: |
| parts = ["llm"] |
| if sheet_name: |
| parts.append( |
| _sanitize_for_filename(sheet_name, LLM_LOG_TABULAR_NAME_MAX_LEN) |
| ) |
| if column_name: |
| parts.append( |
| _sanitize_for_filename(column_name, LLM_LOG_TABULAR_NAME_MAX_LEN) |
| ) |
| if row_number is not None: |
| parts.append(f"row{row_number:05d}") |
| parts.append(f"batch_{batch_number:04d}") |
| filename = "_".join(parts) + ".txt" |
| elif file_name and page_number is not None: |
| |
| safe_file_name = _sanitize_for_filename(file_name) |
| filename = ( |
| f"llm_{safe_file_name}_page_{page_number:04d}_batch_{batch_number:04d}.txt" |
| ) |
| else: |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| filename = f"llm_batch_{batch_number:04d}_{timestamp}.txt" |
| filepath = os.path.join(llm_logs_folder, filename) |
|
|
| |
| |
| with open(filepath, "w", encoding="utf-8") as f: |
| f.write("=" * 80 + "\n") |
| f.write("LLM ENTITY DETECTION - PROMPT AND RESPONSE LOG\n") |
| f.write("=" * 80 + "\n\n") |
|
|
| f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") |
| if file_name: |
| f.write(f"File: {file_name}\n") |
| if sheet_name: |
| f.write(f"Sheet: {sheet_name}\n") |
| if column_name is not None: |
| f.write(f"Column: {column_name}\n") |
| if row_number is not None: |
| f.write(f"Row: {row_number}\n") |
| if page_number is not None: |
| f.write(f"Page: {page_number + 1}\n") |
| if input_tokens is not None: |
| f.write(f"Input tokens: {input_tokens}\n") |
| if output_tokens is not None: |
| f.write(f"Output tokens: {output_tokens}\n") |
| f.write(f"Batch Number: {batch_number}\n") |
| f.write(f"Model: {model_choice}\n") |
| f.write(f"Language: {language}\n") |
| f.write(f"Temperature: {temperature}\n") |
| f.write(f"Max Tokens: {max_tokens}\n") |
| f.write(f"Entities to Detect: {', '.join(entities_to_detect)}\n") |
|
|
| f.write("\n" + "=" * 80 + "\n") |
| f.write("SYSTEM PROMPT (sent as system role)\n") |
| f.write("=" * 80 + "\n") |
| f.write("--- BEGIN SYSTEM PROMPT ---\n") |
| f.write(system_prompt_str) |
| f.write("\n--- END SYSTEM PROMPT ---\n") |
|
|
| f.write("\n" + "=" * 80 + "\n") |
| f.write("USER PROMPT (sent as user role)\n") |
| f.write("=" * 80 + "\n") |
| if ( |
| system_prompt_str |
| and user_prompt_str |
| and system_prompt_str == user_prompt_str |
| ): |
| f.write( |
| "[NOTE: System and user prompt content were identical - check caller.]\n" |
| ) |
| f.write("--- BEGIN USER PROMPT ---\n") |
| f.write(user_prompt_str) |
| f.write("\n--- END USER PROMPT ---\n") |
|
|
| f.write("\n\n" + "=" * 80 + "\n") |
| f.write("LLM RESPONSE\n") |
| f.write("=" * 80 + "\n\n") |
| f.write(response_text) |
| f.write("\n\n" + "=" * 80 + "\n") |
| f.write("END OF LOG\n") |
| f.write("=" * 80 + "\n") |
|
|
| return filepath |
|
|
|
|
| def call_llm_for_entity_detection( |
| text: str, |
| entities_to_detect: List[str], |
| language: str, |
| bedrock_runtime: Optional[boto3.Session.client] = None, |
| model_choice: str = CLOUD_LLM_PII_MODEL_CHOICE, |
| temperature: float = LLM_TEMPERATURE, |
| max_tokens: int = LLM_MAX_NEW_TOKENS, |
| max_retries: int = 10, |
| retry_delay: int = 3, |
| output_folder: Optional[str] = None, |
| batch_number: int = 0, |
| custom_instructions: str = "", |
| file_name: Optional[str] = None, |
| page_number: Optional[int] = None, |
| sheet_name: Optional[str] = None, |
| column_name: Optional[str] = None, |
| row_number: Optional[int] = None, |
| inference_method: Optional[str] = None, |
| local_model=None, |
| tokenizer=None, |
| assistant_model=None, |
| client=None, |
| client_config=None, |
| api_url: Optional[str] = None, |
| ) -> List[Dict[str, Any]]: |
| """ |
| Call LLM to detect entities in text using various inference methods. |
| |
| Args: |
| text: Text to analyze |
| entities_to_detect: List of entity types to detect |
| language: Language code |
| bedrock_runtime: AWS Bedrock runtime client (required for AWS method) |
| model_choice: Model identifier (varies by inference method) |
| temperature: Temperature for LLM generation (lower = more deterministic) |
| max_tokens: Maximum tokens in response |
| max_retries: Maximum retry attempts |
| retry_delay: Delay between retries (seconds) |
| output_folder: Optional folder to save prompt/response logs |
| batch_number: Batch number for logging |
| custom_instructions: Optional custom instructions to include in the prompt |
| file_name: Optional file name (without extension) for saving logs |
| page_number: Optional page number for saving logs (document flow) |
| sheet_name: Optional Excel sheet name for tabular logs |
| column_name: Optional column name for tabular logs |
| row_number: Optional row number (1-based) for tabular logs |
| inference_method: Inference method to use ("aws-bedrock", "local", "inference-server", "azure-openai", "gemini") |
| If None, uses CHOSEN_LLM_PII_INFERENCE_METHOD from config |
| local_model: Local model instance (required for "local" method) |
| tokenizer: Tokenizer instance (required for "local" method with transformers) |
| assistant_model: Assistant model for speculative decoding (optional) |
| client: API client (required for "azure-openai" or "gemini" methods) |
| client_config: Client config (required for "gemini" method) |
| api_url: API URL for inference-server (required for "inference-server" method) |
| |
| Returns: |
| List of entity dictionaries |
| """ |
| |
| |
| if not isinstance(custom_instructions, str): |
| custom_instructions = ( |
| "" |
| if custom_instructions is True or not custom_instructions |
| else str(custom_instructions) |
| ) |
| if ( |
| isinstance(custom_instructions, str) |
| and custom_instructions.strip().lower() == "true" |
| ): |
| custom_instructions = "" |
|
|
| |
| if inference_method is None: |
| inference_method = CHOSEN_LLM_PII_INFERENCE_METHOD |
|
|
| |
| custom_instructions_model = ( |
| CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE.strip() |
| if isinstance(CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE, str) |
| and CLOUD_LLM_PII_CUSTOM_INSTRUCTIONS_MODEL_CHOICE.strip() |
| else "" |
| ) |
| if ( |
| custom_instructions.strip() |
| and model_choice == CLOUD_LLM_PII_MODEL_CHOICE |
| and custom_instructions_model |
| ): |
| model_choice = custom_instructions_model |
|
|
| |
| filtered_entities = [ |
| entity for entity in entities_to_detect if not entity.startswith("CUSTOM_VLM_") |
| ] |
|
|
| |
| if not filtered_entities and ( |
| not custom_instructions or not custom_instructions.strip() |
| ): |
| |
| if not entities_to_detect: |
| raise ValueError( |
| "No standard entities selected and no custom instructions provided. " |
| "Please select at least one entity type (excluding CUSTOM_VLM_* entities) or provide custom instructions for LLM-based PII detection." |
| ) |
| |
| return [] |
|
|
| |
| model_source = None |
| if model_choice and model_name_map and model_choice in model_name_map: |
| model_source = model_name_map[model_choice].get("source", "AWS") |
| |
| if model_source == "Local": |
| inference_method = "local" |
| elif model_source == "inference-server": |
| inference_method = "inference-server" |
| elif model_source == "Azure/OpenAI": |
| inference_method = "azure-openai" |
| elif model_source == "Gemini": |
| inference_method = "gemini" |
| elif model_source == "AWS": |
| inference_method = "aws-bedrock" |
|
|
| system_prompt = create_entity_detection_system_prompt( |
| filtered_entities, language, custom_instructions |
| ) |
| user_prompt = create_entity_detection_prompt( |
| text, filtered_entities, language, custom_instructions |
| ) |
|
|
| |
| model_source_map = { |
| "aws-bedrock": "AWS", |
| "local": "Local", |
| "inference-server": "inference-server", |
| "azure-openai": "Azure/OpenAI", |
| "gemini": "Gemini", |
| } |
|
|
| model_source = model_source_map.get(inference_method, "AWS") |
|
|
| |
| if inference_method == "gemini" and (client is None or client_config is None): |
| from tools.llm_funcs import construct_gemini_generative_model |
|
|
| try: |
| client, client_config = construct_gemini_generative_model( |
| in_api_key="", |
| temperature=temperature, |
| model_choice=model_choice, |
| system_prompt=system_prompt, |
| max_tokens=max_tokens, |
| ) |
| except Exception as e: |
| raise ValueError( |
| f"Failed to construct Gemini client: {e}. " |
| f"Ensure GEMINI_API_KEY is set or pass client and client_config." |
| ) |
|
|
| |
| if inference_method == "azure-openai" and client is None: |
| from tools.llm_funcs import construct_azure_client |
|
|
| try: |
| client, _ = construct_azure_client( |
| in_api_key="", |
| endpoint="", |
| ) |
| except Exception as e: |
| raise ValueError( |
| f"Failed to construct Azure/OpenAI client: {e}. " |
| f"Ensure AZURE_OPENAI_API_KEY is set or pass client." |
| ) |
|
|
| |
| if inference_method == "inference-server" and api_url is None: |
| api_url = INFERENCE_SERVER_API_URL |
| if not api_url: |
| raise ValueError( |
| "api_url is required when using inference-server method. " |
| "Set INFERENCE_SERVER_API_URL in config or pass api_url parameter." |
| ) |
|
|
| try: |
| |
| |
| ( |
| response, |
| conversation_history, |
| response_text, |
| num_transformer_input_tokens, |
| num_transformer_generated_tokens, |
| ) = send_request( |
| prompt=user_prompt, |
| conversation_history=[], |
| client=client, |
| config=client_config, |
| model_choice=model_choice, |
| system_prompt=system_prompt, |
| temperature=temperature, |
| bedrock_runtime=bedrock_runtime, |
| model_source=model_source, |
| |
| |
| |
| |
| |
| progress=Progress( |
| track_tqdm=False |
| ), |
| api_url=api_url, |
| ) |
| except Exception as e: |
| print(f"LLM entity detection failed: {e}") |
| raise |
|
|
| |
| input_tokens = 0 |
| output_tokens = 0 |
| try: |
| if isinstance(response, dict) and "usage" in response: |
| |
| input_tokens = response["usage"].get("prompt_tokens", 0) |
| output_tokens = response["usage"].get("completion_tokens", 0) |
| elif hasattr(response, "usage_metadata"): |
| |
| if isinstance(response.usage_metadata, dict): |
| input_tokens = response.usage_metadata.get("inputTokens", 0) |
| output_tokens = response.usage_metadata.get("outputTokens", 0) |
| |
| elif hasattr(response.usage_metadata, "prompt_token_count"): |
| input_tokens = response.usage_metadata.prompt_token_count |
| output_tokens = response.usage_metadata.candidates_token_count |
| except (KeyError, AttributeError) as e: |
| print(f"Warning: Could not extract token usage from response: {e}") |
|
|
| |
| if num_transformer_input_tokens and num_transformer_input_tokens > 0: |
| input_tokens = num_transformer_input_tokens |
| if num_transformer_generated_tokens and num_transformer_generated_tokens > 0: |
| output_tokens = num_transformer_generated_tokens |
|
|
| |
| |
| |
| if output_folder and response_text: |
| try: |
| saved_file = save_llm_prompt_response( |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| response_text=response_text, |
| output_folder=output_folder, |
| batch_number=batch_number, |
| model_choice=model_choice, |
| entities_to_detect=entities_to_detect, |
| language=language, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| file_name=file_name, |
| page_number=page_number, |
| sheet_name=sheet_name, |
| column_name=column_name, |
| row_number=row_number, |
| input_tokens=input_tokens, |
| output_tokens=output_tokens, |
| ) |
| if 0 == 1: |
| print(f"Saved LLM prompt/response to: {saved_file}") |
| except Exception as e: |
| print(f"Warning: Could not save LLM prompt/response: {e}") |
|
|
| |
| entities = parse_llm_entity_response(response_text, text) |
|
|
| return entities, input_tokens, output_tokens |
|
|
|
|
| def map_back_llm_entity_results( |
| entities: List[Dict[str, Any]], |
| current_batch_mapping: List[Tuple], |
| allow_list: List[str], |
| chosen_redact_llm_entities: List[str], |
| all_text_line_results: List[Tuple], |
| ) -> List[Tuple]: |
| """ |
| Map LLM-detected entities back to line-level results. |
| Similar to map_back_llm_entity_results but for LLM responses. |
| |
| Args: |
| entities: List of entity dictionaries from LLM |
| current_batch_mapping: Mapping of batch positions to line indices |
| allow_list: List of allowed text values (to skip) - case-insensitive matching |
| chosen_redact_llm_entities: List of entity types to include |
| all_text_line_results: Existing line-level results to append to |
| |
| Returns: |
| Updated all_text_line_results |
| """ |
| if not entities: |
| return all_text_line_results |
|
|
| |
| if allow_list: |
| allow_list_normalized = [item.strip().lower() for item in allow_list if item] |
| else: |
| allow_list_normalized = [] |
|
|
| for entity in entities: |
| entity_type = entity.get("Type") |
| |
| |
| |
| |
| |
| |
| |
|
|
| entity_start = entity["BeginOffset"] |
| entity_end = entity["EndOffset"] |
|
|
| |
| added_to_line = False |
|
|
| |
| for ( |
| batch_start, |
| line_idx, |
| original_line, |
| chars, |
| line_offset, |
| ) in current_batch_mapping: |
| |
| if line_offset is not None: |
| |
| line_text_length = len(original_line.text[line_offset:]) |
| else: |
| line_text_length = len(original_line.text) |
|
|
| batch_end = batch_start + line_text_length |
|
|
| |
| if batch_start < entity_end and batch_end > entity_start: |
| |
| if line_offset is not None: |
| relative_start = max(0, entity_start - batch_start + line_offset) |
| relative_end = min( |
| entity_end - batch_start + line_offset, len(original_line.text) |
| ) |
| else: |
| relative_start = max(0, entity_start - batch_start) |
| relative_end = min( |
| entity_end - batch_start, len(original_line.text) |
| ) |
|
|
| result_text = original_line.text[relative_start:relative_end] |
|
|
| |
| |
| |
| result_text_normalized = result_text.strip().lower() |
| if result_text_normalized not in allow_list_normalized: |
| |
| adjusted_entity = { |
| "Type": entity_type, |
| "BeginOffset": relative_start, |
| "EndOffset": relative_end, |
| "Score": entity.get("Score", 0.8), |
| } |
|
|
| |
| from tools.presidio_analyzer_custom import ( |
| recognizer_result_from_dict, |
| ) |
|
|
| recogniser_entity = recognizer_result_from_dict(adjusted_entity) |
|
|
| |
| existing_entry = next( |
| ( |
| entry |
| for idx, entry in all_text_line_results |
| if idx == line_idx |
| ), |
| None, |
| ) |
| if existing_entry is None: |
| all_text_line_results.append((line_idx, [recogniser_entity])) |
| else: |
| existing_entry.append(recogniser_entity) |
|
|
| added_to_line = True |
|
|
| |
| if not added_to_line: |
| print( |
| f"Entity '{entity_type}' at position {entity_start}-{entity_end} does not fit in any line." |
| ) |
|
|
| return all_text_line_results |
|
|
|
|
| def do_llm_entity_detection_call( |
| current_batch: str, |
| current_batch_mapping: List[Tuple], |
| bedrock_runtime: Optional[boto3.Session.client] = None, |
| language: str = "en", |
| allow_list: List[str] = None, |
| chosen_redact_llm_entities: List[str] = None, |
| all_text_line_results: List[Tuple] = None, |
| model_choice: str = CLOUD_LLM_PII_MODEL_CHOICE, |
| temperature: float = LLM_TEMPERATURE, |
| max_tokens: int = LLM_MAX_NEW_TOKENS, |
| output_folder: Optional[str] = None, |
| batch_number: int = 0, |
| custom_instructions: str = "", |
| file_name: Optional[str] = None, |
| page_number: Optional[int] = None, |
| inference_method: Optional[str] = None, |
| local_model=None, |
| tokenizer=None, |
| assistant_model=None, |
| client=None, |
| client_config=None, |
| api_url: Optional[str] = None, |
| ) -> Tuple[List[Tuple], int, int]: |
| """ |
| Call LLM for entity detection on a batch of text. |
| Similar interface to do_aws_llm_call. |
| |
| Args: |
| current_batch: Text batch to analyze |
| current_batch_mapping: Mapping of batch positions to line indices |
| bedrock_runtime: AWS Bedrock runtime client (required for AWS method) |
| language: Language code |
| allow_list: List of allowed text values |
| chosen_redact_llm_entities: List of entity types to detect |
| all_text_line_results: Existing line-level results |
| model_choice: Model identifier (varies by inference method) |
| temperature: Temperature for LLM generation |
| max_tokens: Maximum tokens in response |
| output_folder: Optional folder to save prompt/response logs |
| batch_number: Batch number for logging |
| custom_instructions: Optional custom instructions to include in the prompt |
| file_name: Optional file name (without extension) for saving logs |
| page_number: Optional page number for saving logs |
| inference_method: Inference method to use (if None, uses config default) |
| local_model: Local model instance (required for "local" method) |
| tokenizer: Tokenizer instance (required for "local" method with transformers) |
| assistant_model: Assistant model for speculative decoding (optional) |
| client: API client (required for "azure-openai" or "gemini" methods) |
| client_config: Client config (required for "gemini" method) |
| api_url: API URL for inference-server (required for "inference-server" method) |
| |
| Returns: |
| Tuple of (updated all_text_line_results, input_tokens, output_tokens) |
| """ |
| if not current_batch: |
| return (all_text_line_results or [], 0, 0) |
|
|
| if allow_list is None: |
| allow_list = [] |
| if chosen_redact_llm_entities is None: |
| chosen_redact_llm_entities = [] |
| if all_text_line_results is None: |
| all_text_line_results = [] |
|
|
| try: |
| entities, input_tokens, output_tokens = call_llm_for_entity_detection( |
| text=current_batch.strip(), |
| entities_to_detect=chosen_redact_llm_entities, |
| language=language, |
| bedrock_runtime=bedrock_runtime, |
| model_choice=model_choice, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| output_folder=output_folder, |
| batch_number=batch_number, |
| custom_instructions=custom_instructions, |
| file_name=file_name, |
| page_number=page_number, |
| inference_method=inference_method, |
| local_model=local_model, |
| tokenizer=tokenizer, |
| assistant_model=assistant_model, |
| client=client, |
| client_config=client_config, |
| api_url=api_url, |
| ) |
|
|
| all_text_line_results = map_back_llm_entity_results( |
| entities, |
| current_batch_mapping, |
| allow_list, |
| chosen_redact_llm_entities, |
| all_text_line_results, |
| ) |
|
|
| return all_text_line_results, input_tokens, output_tokens |
|
|
| except Exception as e: |
| print(f"LLM entity detection call failed: {e}") |
| raise |
|
|