"""Shared Gradio deployment helpers: session identity, FastAPI mount, logging."""
from __future__ import annotations
import logging
from datetime import datetime
from pathlib import Path
from typing import Any
import boto3
import gradio as gr
from botocore.exceptions import (
BotoCoreError,
ClientError,
NoCredentialsError,
PartialCredentialsError,
)
from fastapi import FastAPI, status
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from tools.auth import authenticate_user
from tools.aws_functions import upload_log_file_to_s3
from tools.config import (
ACCESS_LOG_DYNAMODB_TABLE_NAME,
ACCESS_LOGS_FOLDER,
ALLOWED_HOSTS,
ALLOWED_ORIGINS,
AWS_USER_POOL_ID,
COGNITO_AUTH,
CSV_ACCESS_LOG_HEADERS,
CSV_USAGE_LOG_HEADERS,
CUSTOM_HEADER,
CUSTOM_HEADER_VALUE,
DEFAULT_COST_CODE,
DISPLAY_FILE_NAMES_IN_LOGS,
DYNAMODB_ACCESS_LOG_HEADERS,
DYNAMODB_USAGE_LOG_HEADERS,
FASTAPI_ROOT_PATH,
GRADIO_SERVER_NAME,
GRADIO_SERVER_PORT,
HOST_NAME,
LOG_FILE_NAME,
ROOT_PATH,
RUN_FASTAPI,
S3_ACCESS_LOGS_FOLDER,
S3_OUTPUTS_FOLDER,
S3_USAGE_LOGS_FOLDER,
SAVE_LOGS_TO_CSV,
SAVE_LOGS_TO_DYNAMODB,
SAVE_OUTPUTS_TO_S3,
USAGE_LOG_DYNAMODB_TABLE_NAME,
USAGE_LOG_FILE_NAME,
USAGE_LOGS_FOLDER,
)
from tools.custom_csvlogger import CSVLogger_custom
logger = logging.getLogger(__name__)
def validate_custom_header(request: gr.Request) -> None:
"""Raise when CUSTOM_HEADER is configured but missing or wrong on the request."""
if not CUSTOM_HEADER or not CUSTOM_HEADER_VALUE:
return
headers = getattr(request, "headers", None) or {}
if CUSTOM_HEADER in headers:
supplied = headers[CUSTOM_HEADER]
if supplied == CUSTOM_HEADER_VALUE:
print("Custom header supplied and matches CUSTOM_HEADER_VALUE")
return
print("Custom header value does not match expected value.")
raise ValueError("Custom header value does not match expected value.")
print("Custom header value not found.")
raise ValueError("Custom header value not found.")
def resolve_session_identity(request: gr.Request) -> str:
"""
Resolve the session identifier from Gradio auth, Cognito/OIDC headers, or session hash.
"""
validate_custom_header(request)
headers = getattr(request, "headers", None) or {}
if request.username:
return request.username
if "x-cognito-id" in headers:
out_session_hash = headers["x-cognito-id"]
print("Cognito ID found:", out_session_hash)
return out_session_hash
if "x-amzn-oidc-identity" in headers:
out_session_hash = headers["x-amzn-oidc-identity"]
if AWS_USER_POOL_ID:
try:
cognito_client = boto3.client("cognito-idp")
response = cognito_client.admin_get_user(
UserPoolId=AWS_USER_POOL_ID,
Username=out_session_hash,
)
email = next(
attr["Value"]
for attr in response["UserAttributes"]
if attr["Name"] == "email"
)
print("Cognito email address found, will be used as session hash")
out_session_hash = email
except (
ClientError,
NoCredentialsError,
PartialCredentialsError,
BotoCoreError,
) as exc:
print(f"Error fetching Cognito user details: {exc}")
print("Falling back to using AWS ID as session hash")
except Exception as exc:
print(f"Unexpected error when fetching Cognito user details: {exc}")
print("Falling back to using AWS ID as session hash")
print("AWS ID found, will be used as username for session:", out_session_hash)
return out_session_hash
return request.session_hash
def build_s3_outputs_prefix(
session_hash: str,
base_folder: str = S3_OUTPUTS_FOLDER,
*,
session_scoped: bool = True,
) -> str:
"""Build the S3 key prefix for output uploads (optional session + date suffix)."""
s3_outputs_folder = base_folder or ""
if session_scoped and session_hash and s3_outputs_folder:
if SAVE_OUTPUTS_TO_S3:
s3_outputs_folder = s3_outputs_folder.rstrip("/") + "/" + session_hash + "/"
elif not session_scoped:
pass
elif not session_hash or not s3_outputs_folder:
s3_outputs_folder = base_folder or ""
if SAVE_OUTPUTS_TO_S3 and s3_outputs_folder:
today_suffix = datetime.now().strftime("%Y%m%d") + "/"
s3_outputs_folder = s3_outputs_folder.rstrip("/") + "/" + today_suffix
return s3_outputs_folder
def gradio_head_html(root_path: str = ROOT_PATH) -> str:
"""HTML head snippet with base href for reverse-proxy subpaths."""
clean_path = f"/{root_path.strip('/')}"
base_href = f"{clean_path}/" if clean_path != "/" else "/"
if root_path:
print(f"Setting HTML base href for Gradio to: '{base_href}'")
return (
f"\n\n"
''
)
def create_fastapi_app(*, root_path: str | None = None) -> FastAPI:
"""Create FastAPI app with lifespan, optional CORS/trusted-host middleware, and /health."""
from tools.helper_functions import lifespan
effective_root = root_path if root_path is not None else FASTAPI_ROOT_PATH
clean_root = (
f"/{effective_root.strip('/')}" if str(effective_root).strip("/") else ""
)
fastapi_app = FastAPI(lifespan=lifespan, root_path=clean_root)
if ALLOWED_ORIGINS:
print(f"CORS enabled. Allowing origins: {ALLOWED_ORIGINS}")
fastapi_app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if ALLOWED_HOSTS:
fastapi_app.add_middleware(TrustedHostMiddleware, allowed_hosts=ALLOWED_HOSTS)
@fastapi_app.get("/health", status_code=status.HTTP_200_OK)
def health_check():
return {"status": "ok"}
return fastapi_app
def _cognito_auth():
return authenticate_user if COGNITO_AUTH else None
class _LogField:
"""Minimal stand-in for Gradio components in CSVLogger_custom."""
def __init__(self, label: str) -> None:
self.label = label
def flag(self, sample: Any, flag_dir: Any = None) -> str:
return str(sample) if sample is not None else ""
class PlatformAccessLogger:
"""Access log writer using CSVLogger_custom."""
def __init__(self) -> None:
self._callback = CSVLogger_custom(dataset_file_name=LOG_FILE_NAME)
self._fields = [
_LogField("session_hash"),
_LogField("host_name"),
]
self._callback.setup(self._fields, ACCESS_LOGS_FOLDER)
def log(self, session_hash: str, host_name: str = HOST_NAME) -> None:
self._callback.flag(
[session_hash, host_name],
save_to_csv=SAVE_LOGS_TO_CSV,
save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB,
dynamodb_table_name=ACCESS_LOG_DYNAMODB_TABLE_NAME,
dynamodb_headers=DYNAMODB_ACCESS_LOG_HEADERS,
replacement_headers=CSV_ACCESS_LOG_HEADERS or None,
)
upload_log_file_to_s3(
ACCESS_LOGS_FOLDER + LOG_FILE_NAME,
S3_ACCESS_LOGS_FOLDER,
)
class PlatformAgentUsageLogger:
"""Agent orchestration usage log writer (main-app CSV / DynamoDB / S3 schema)."""
_MAIN_USAGE_FIELD_LABELS: tuple[str, ...] = (
"session_hash_textbox",
"doc_full_file_name_textbox",
"data_full_file_name_textbox",
"actual_time_taken_number",
"total_page_count",
"textract_query_number",
"pii_detection_method",
"comprehend_query_number",
"cost_code",
"textract_handwriting_signature",
"host_name_textbox",
"text_extraction_method",
"is_this_a_textract_api_call",
"task",
"vlm_model_name",
"vlm_total_input_tokens",
"vlm_total_output_tokens",
"llm_model_name",
"llm_total_input_tokens",
"llm_total_output_tokens",
)
def __init__(self) -> None:
self._callback = CSVLogger_custom(dataset_file_name=USAGE_LOG_FILE_NAME)
labels = (
list(CSV_USAGE_LOG_HEADERS)
if CSV_USAGE_LOG_HEADERS
else list(self._MAIN_USAGE_FIELD_LABELS)
)
self._fields = [_LogField(label) for label in labels]
self._callback.setup(self._fields, USAGE_LOGS_FOLDER)
def log_row(self, row: list[Any]) -> None:
self._callback.flag(
row,
save_to_csv=SAVE_LOGS_TO_CSV,
save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB,
dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME,
dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS,
replacement_headers=CSV_USAGE_LOG_HEADERS or None,
)
upload_log_file_to_s3(
USAGE_LOGS_FOLDER + USAGE_LOG_FILE_NAME,
S3_USAGE_LOGS_FOLDER,
)
def _doc_name_for_usage_log(document_name: str) -> str:
if DISPLAY_FILE_NAMES_IN_LOGS:
if not document_name:
return ""
return Path(document_name).stem
return "document" if document_name else ""
def build_agent_usage_log_row(
*,
session_hash: str,
duration_seconds: float | int | str = "",
document_name: str = "",
total_page_count: int | str = 0,
ocr_method: str = "",
pii_method: str = "",
llm_model_name: str = "",
vlm_model_name: str = "",
llm_input_tokens: int | str = 0,
llm_output_tokens: int | str = 0,
vlm_input_tokens: int | str = 0,
vlm_output_tokens: int | str = 0,
task: str = "agent",
) -> list[Any]:
"""Build a usage log row matching the main redaction app schema."""
return [
session_hash,
_doc_name_for_usage_log(document_name),
"",
duration_seconds,
total_page_count,
0,
pii_method,
0,
DEFAULT_COST_CODE,
False,
HOST_NAME,
ocr_method,
False,
task,
vlm_model_name,
vlm_input_tokens,
vlm_output_tokens,
llm_model_name,
llm_input_tokens,
llm_output_tokens,
]
# Module-level singletons for PI / lightweight callers
_access_logger: PlatformAccessLogger | None = None
_agent_usage_logger: PlatformAgentUsageLogger | None = None
def get_access_logger() -> PlatformAccessLogger:
global _access_logger
if _access_logger is None:
_access_logger = PlatformAccessLogger()
return _access_logger
def get_agent_usage_logger() -> PlatformAgentUsageLogger:
global _agent_usage_logger
if _agent_usage_logger is None:
_agent_usage_logger = PlatformAgentUsageLogger()
return _agent_usage_logger
def log_platform_access(session_hash: str, host_name: str = HOST_NAME) -> None:
if not SAVE_LOGS_TO_CSV and not SAVE_LOGS_TO_DYNAMODB:
return
try:
get_access_logger().log(session_hash, host_name)
except OSError as exc:
logger.warning(
"Access log write failed (%s); session UI continues. "
"On ECS/HF Pi images set ACCESS_LOGS_FOLDER=/tmp/pi-logs/ "
"(see agent-redact/pi/bootstrap_pi_config.py).",
exc,
)
def log_agent_usage_event(
*,
session_hash: str,
duration_seconds: float | int | str = "",
document_name: str = "",
total_page_count: int | str = 0,
ocr_method: str = "",
pii_method: str = "",
llm_model_name: str = "",
vlm_model_name: str = "",
llm_input_tokens: int | str = 0,
llm_output_tokens: int | str = 0,
vlm_input_tokens: int | str = 0,
vlm_output_tokens: int | str = 0,
task: str = "agent",
) -> None:
"""Log one Pi agent run to the main-app usage CSV / DynamoDB / S3 locations."""
if not SAVE_LOGS_TO_CSV and not SAVE_LOGS_TO_DYNAMODB:
return
row = build_agent_usage_log_row(
session_hash=session_hash,
duration_seconds=duration_seconds,
document_name=document_name,
total_page_count=total_page_count,
ocr_method=ocr_method,
pii_method=pii_method,
llm_model_name=llm_model_name,
vlm_model_name=vlm_model_name,
llm_input_tokens=llm_input_tokens,
llm_output_tokens=llm_output_tokens,
vlm_input_tokens=vlm_input_tokens,
vlm_output_tokens=vlm_output_tokens,
task=task,
)
get_agent_usage_logger().log_row(row)
def log_pi_usage_event(**kwargs: Any) -> None:
"""Back-compat alias — maps legacy Pi kwargs onto ``log_agent_usage_event``."""
event = kwargs.pop("event", "")
provider = kwargs.pop("provider", "")
model = kwargs.pop("model", "")
deployment_profile = kwargs.pop("deployment_profile", "")
llm_model_name = kwargs.pop("llm_model_name", "")
if not llm_model_name:
parts = [part for part in (provider, model, deployment_profile, event) if part]
llm_model_name = "/".join(parts) if parts else model
log_agent_usage_event(
llm_model_name=llm_model_name,
**kwargs,
)
def wire_pi_usage_logging(**kwargs: Any) -> None:
"""Log a Pi agent usage event (direct-call helper for Gradio handlers)."""
log_agent_usage_event(**kwargs)
def wire_access_logging(
session_hash_component: gr.Component,
host_name_component: gr.Component,
access_logs_state: gr.Component,
access_s3_logs_loc_state: gr.Component,
*,
flag_output: gr.Component | None = None,
) -> gr.events.EventListener:
"""
Wire access logging on session_hash change (main-app pattern).
Returns the EventListener so callers can chain further .success handlers.
"""
access_callback = CSVLogger_custom(dataset_file_name=LOG_FILE_NAME)
access_callback.setup(
[session_hash_component, host_name_component], ACCESS_LOGS_FOLDER
)
outputs = [flag_output] if flag_output is not None else []
return session_hash_component.change(
lambda *args: access_callback.flag(
list(args),
save_to_csv=SAVE_LOGS_TO_CSV,
save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB,
dynamodb_table_name=ACCESS_LOG_DYNAMODB_TABLE_NAME,
dynamodb_headers=DYNAMODB_ACCESS_LOG_HEADERS,
replacement_headers=CSV_ACCESS_LOG_HEADERS or None,
),
[session_hash_component, host_name_component],
outputs=outputs,
preprocess=False,
).success(
fn=upload_log_file_to_s3,
inputs=[access_logs_state, access_s3_logs_loc_state],
outputs=[],
)
def mount_or_launch(
demo: gr.Blocks,
*,
fastapi_app: FastAPI | None = None,
allowed_paths: list[str] | None = None,
css: str | None = None,
head_extra: str = "",
theme: gr.themes.Base | None = None,
server_name: str | None = None,
server_port: int | None = None,
show_error: bool = True,
queue_kwargs: dict[str, Any] | None = None,
root_path: str | None = None,
fastapi_root_path: str | None = None,
) -> FastAPI | None:
"""
Mount Gradio on FastAPI when RUN_FASTAPI else launch directly.
Returns the FastAPI app when mounted; None when launched in-process.
"""
if theme is None:
theme = gr.themes.Default(primary_hue="blue")
if server_name is None:
server_name = GRADIO_SERVER_NAME
if server_port is None:
server_port = GRADIO_SERVER_PORT
if queue_kwargs:
demo.queue(**queue_kwargs)
ui_root = root_path if root_path is not None else ROOT_PATH
head = gradio_head_html(ui_root) + (head_extra or "")
auth = _cognito_auth()
allowed = allowed_paths or []
mount_path = f"/{ui_root.strip('/')}" if str(ui_root).strip("/") else ""
if RUN_FASTAPI:
if fastapi_app is None:
fastapi_app = create_fastapi_app(root_path=fastapi_root_path)
return gr.mount_gradio_app(
fastapi_app,
demo,
path=mount_path,
head=head,
css=css,
theme=theme,
show_error=show_error,
auth=auth,
allowed_paths=allowed,
)
demo.launch(
theme=theme,
head=head,
css=css,
show_error=show_error,
server_name=server_name,
server_port=server_port,
root_path=ui_root,
auth=auth,
allowed_paths=allowed,
)
return None