import pytest
import sys
import os
import json
import logging
import asyncio
from datetime import datetime
from dotenv import load_dotenv
# Load .env from root
load_dotenv(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../.env')))
load_dotenv(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.env')))
load_dotenv() # Current dir
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
try:
from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent
except ImportError:
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent
from deepeval import assert_test
from deepeval.metrics import FaithfulnessMetric, AnswerRelevancyMetric, GEval
from deepeval.test_case import LLMTestCase, LLMTestCaseParams
from deepeval.models.base_model import DeepEvalBaseLLM
# Global to store judge prompts for reporting
JUDGE_PROMPTS = {} # key: metric_name, value: last_prompt
# --- JUDGE CONFIGURATIONS --- (Copied from test_medical_correctness.py)
class HuggingFaceJudge(DeepEvalBaseLLM):
def __init__(self, model_name="google/gemma-3-27b-it:featherless-ai"):
self.model_name = model_name
self.api_key = os.getenv("HF_TOKEN")
if not self.api_key:
raise ValueError("HF_TOKEN is required for HuggingFace Judge.")
from openai import OpenAI
self.client = OpenAI(
base_url="https://router.huggingface.co/v1",
api_key=self.api_key,
)
def load_model(self): return self.client
async def a_generate(self, prompt: str, schema=None, **kwargs) -> str:
# Use sync generate for simplicity in this wrapper
return self.generate(prompt, schema, **kwargs)
def generate(self, prompt: str, schema=None, **kwargs) -> str:
metric_name = kwargs.get('metric_name', 'Judge')
JUDGE_PROMPTS[metric_name] = prompt
# If schema is provided, we need to request JSON and parse it
system_msg = "You are a helpful assistant."
if schema:
system_msg = f"You are a helpful assistant that always responds in JSON format. Your response must follow this schema: {schema.schema() if hasattr(schema, 'schema') else 'JSON object'}"
try:
completion = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": prompt}
],
temperature=0.1,
max_tokens=2048,
)
raw_content = completion.choices[0].message.content
if not schema:
return raw_content
# Attempt to extract JSON from the response
import json
import re
# Find the first { and the last }
json_match = re.search(r'\{.*\}', raw_content, re.DOTALL)
if json_match:
json_str = json_match.group(0)
data = json.loads(json_str)
else:
data = json.loads(raw_content)
# print(f"DEBUG: Processed Judge Data for {metric_name}: {json.dumps(data)}")
print(f"DEBUG: Processed Judge Data for {metric_name} score: {data.get('score')}")
if hasattr(schema, 'model_validate'):
# Debug schema fields if something goes wrong
if not data.get("evaluation_steps") and "score" in data:
# Log the fields required by the schema
fields = schema.model_fields.keys() if hasattr(schema, 'model_fields') else []
logging.getLogger(__name__).error(f"Schema fields: {list(fields)}")
# Force populate evaluation_steps
data["evaluation_steps"] = [data.get("reason", "No specific steps provided.")]
if not data["evaluation_steps"] or data["evaluation_steps"] == [""]:
data["evaluation_steps"] = ["Clinical trajectory assessment."]
# Double check: DeepEval GEval strictly refuses empty lists
if "evaluation_steps" in data and not data["evaluation_steps"]:
data["evaluation_steps"] = ["General clinical audit."]
# Handle common DeepEval naming variations
if not data.get("evaluation_steps") and data.get("steps"):
data["evaluation_steps"] = data["steps"] if isinstance(data["steps"], list) else [data["steps"]]
# Final check for verdicts/truths/claims (Faithfulness/Relevancy)
for field in ["verdicts", "truths", "claims", "statements", "steps"]:
if field not in data:
data[field] = []
if "verdict" not in data:
data["verdict"] = "yes" if data.get("score", 0) > 0.5 else "no"
return schema.model_validate(data)
return schema(**data)
except Exception as e:
logging.error(f"Judge error ({metric_name}): {str(e)}")
# Fallback for metrics that expect a score if possible
if schema:
try:
# Minimum valid mock object to prevent crash
fallback = {
"score": 0.0,
"reason": f"Judge error: {str(e)}",
"verdict": "no",
"verdicts": [],
"truths": [],
"claims": [],
"statements": [],
"steps": ["Evaluation failed due to error"],
"evaluation_steps": ["Evaluation failed due to error"]
}
if hasattr(schema, 'model_validate'):
return schema.model_validate(fallback)
return schema(**fallback)
except Exception as ef:
logging.error(f"Fallback validation failed: {str(ef)}")
return f"Error: {str(e)}"
def get_model_name(self): return self.model_name
class GeminiJudge(DeepEvalBaseLLM):
def __init__(self, model_name="gemini-1.5-pro", api_key=None):
self.model_name = model_name
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
if not self.api_key:
raise ValueError("GOOGLE_API_KEY is required.")
import google.generativeai as genai
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel(model_name)
def load_model(self): return self.model
async def a_generate(self, prompt: str, schema=None, **kwargs):
JUDGE_PROMPTS[kwargs.get('metric_name', 'Gemini')] = prompt
try:
response = await asyncio.to_thread(self.model.generate_content, prompt)
return response.text
except Exception as e: return f"Error: {str(e)}"
def generate(self, prompt: str, schema=None, **kwargs) -> str:
return asyncio.run(self.a_generate(prompt, schema, **kwargs))
def get_model_name(self): return self.model_name
class MockJudge(DeepEvalBaseLLM):
def __init__(self, model_name="local-mock-judge"):
self.model_name = model_name
def load_model(self): return None
def generate(self, prompt: str, schema=None, **kwargs) -> str:
# Capture prompt
metric_key = kwargs.get('metric_name', 'Mock')
JUDGE_PROMPTS[metric_key] = prompt
# Simulate LLM response for metrics
if schema:
# Default positive response (using 1-10 scale as GEval often does)
data = {
"score": 10.0,
"reason": "The summary accurately reflects the patient data.",
"verdicts": [{"verdict": "yes", "reason": "Accurate clinical statement"}],
"truths": ["Patient data present"],
"claims": ["Statement matches data"],
"verdict": "yes",
"statements": ["The summary is correct"],
"steps": ["Step 1: Check facts", "Step 2: Verify trends"]
}
# DELIBERATE FAILURE LOGIC FOR MOCK MODE:
# If the prompt contains 'signs of recovery' but the context has 'AKI' or 'Cancer', fail it.
if "signs of recovery" in prompt.lower():
if any(x in prompt.upper() for x in ["AKI", "CANCER", "LUNG", "ALZHEIMER", "PALLIATIVE"]):
data["score"] = 1.0
data["reason"] = f"CRITICAL FAIL: General 'signs of recovery' claim detected in {metric_key} audit for unstable or chronic/terminal patient case."
data["verdict"] = "no"
data["verdicts"][0]["verdict"] = "no"
data["verdicts"][0]["reason"] = "Inaccurate clinical claim"
# Log for debugging
# print(f"DEBUG: MockJudge returning for {metric_key}: {data['score']}")
if hasattr(schema, 'model_validate'):
return schema.model_validate(data)
try:
return schema(**data)
except Exception:
# Fallback if schema is different
return data
return "Evaluated."
async def a_generate(self, prompt: str, schema=None, **kwargs) -> str:
return self.generate(prompt, schema, **kwargs)
def get_model_name(self): return self.model_name
# --- INITIALIZE JUDGE ---
eval_model = None
HAS_KEY = False
SKIP_REASON = ""
USE_MOCK = False
if os.getenv("HF_TOKEN"):
eval_model = HuggingFaceJudge()
HAS_KEY = True
USE_MOCK = False
elif os.getenv("GOOGLE_API_KEY"):
eval_model = GeminiJudge()
HAS_KEY = True
else:
print("WARNING: No API Key found. Using MockJudge for demonstration.")
eval_model = MockJudge()
HAS_KEY = True # Force True to run tests with Mock
USE_MOCK = True
# --- DATA LOADER ---
def load_test_data():
data_path = os.path.join(os.path.dirname(__file__), 'patient_test_data.json')
with open(data_path, 'r') as f:
return json.load(f)
# --- CONFIGURATION ---
USE_MOCK_AGENT = False # Set to True for instant testing of the DeepEval pipeline
@pytest.fixture(scope="module")
def agent():
if USE_MOCK_AGENT:
class MockAgent:
def generate_patient_summary(self, patient_data):
# Smarter Mock Agent: Generates variations based on data to test evaluation logic
res = patient_data.get("result", {})
name = res.get("patientname", "Patient")
encounters = res.get("encounters", [])
last_diag = encounters[-1].get("diagnosis", []) if encounters else []
# Default dangerous generic summary
summary = f"--- AI-GENERATED CLINICAL NARRATIVE ---\nThe patient {name} is showing signs of recovery. Stable vitals. Continue current medication.\n---"
# Slightly smarter logic for some cases
if any("AKI" in d or "Kidney" in d for d in last_diag):
summary = f"--- AI-GENERATED CLINICAL NARRATIVE ---\n{name} has Acute Kidney Injury. Creatinine is 2.4 (baseline 1.6). Monitoring fluid status.\n---"
elif "Oncology" in str(patient_data) or "Cancer" in str(patient_data):
summary = f"--- AI-GENERATED CLINICAL NARRATIVE ---\n{name} is undergoing chemo for Breast Cancer. Neutropenia noted (WBC 3.2). Chemo held.\n---"
return summary
return MockAgent()
ag = PatientSummarizerAgent()
# model_name = "microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf"
# Use a slightly better model for clinical summary if available locally
model_name = "microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf"
ag.configure_model(model_name)
# Fast config for testing
from ai_med_extract.utils import model_config
if hasattr(model_config, 'get_t4_generation_config'):
original_get_config = model_config.get_t4_generation_config
def fast_test_config(model_type):
config = original_get_config(model_type)
config['max_new_tokens'] = 512 # enough for a comprehensive summary
return config
model_config.get_t4_generation_config = fast_test_config
return ag
# --- CLINICAL REQUIREMENTS MAPPING ---
CLINICAL_REQUIREMENTS = {
"Acute Kidney Injury Scenario": ["creatinine", "baseline", "renal"],
"Oncology Treatment Cycle (Breast Cancer)": ["chemo", "neutropenia", "wbc", "held"],
"Palliative Care (Stage IV Lung Cancer - Symptom Management)": ["palliative", "hospice", "comfort", "cancer"],
"Hypertension & Diabetes Patient": ["glucose", "blood sugar", "metformin", "hypertension"],
"Neurological Management (Early-Stage Alzheimer's)": ["alzheimer", "memory", "cognitive", "donepezil"]
}
# --- HELPERS ---
def extract_narrative(report_text):
if "--- AI-GENERATED CLINICAL NARRATIVE ---" in report_text:
parts = report_text.split("--- AI-GENERATED CLINICAL NARRATIVE ---")
return parts[1].split("---")[0].strip()
return report_text
def get_context(data):
res = data.get("result", {})
context = [f"Patient: {res.get('patientname')}, PMH: {', '.join(res.get('past_medical_history', []))}"]
for enc in res.get("encounters", []):
context.append(f"Date: {enc['visit_date']}, Complaint: {enc['chief_complaint']}, Diagnosis: {', '.join(enc['diagnosis'])}, Notes: {enc['dr_notes']}")
return context
# --- RESULTS COLLECTOR (File-based) ---
RESULTS_FILE = os.path.join(os.path.dirname(__file__), 'test_results.json')
# Handle results file clearing - ensure it's fresh for each session
if os.path.exists(RESULTS_FILE):
try: os.remove(RESULTS_FILE)
except: pass
# --- TESTS ---
@pytest.mark.timeout(1200) # 20 minutes for all scenarios
@pytest.mark.parametrize("scenario", load_test_data())
@pytest.mark.skipif(not HAS_KEY, reason=SKIP_REASON)
def test_patient_summary_quality(agent, scenario):
scenario_name = scenario['name']
patient_data = scenario['data']
print(f"\n--- Testing Scenario: {scenario_name} ---")
print(f"Generating summary for {scenario_name}...")
# 0. Clear global prompts for this scenario
JUDGE_PROMPTS.clear()
# 1. Generate
full_report = agent.generate_patient_summary(patient_data)
ai_output = extract_narrative(full_report)
# 2. Define Test Case
test_case = LLMTestCase(
input="Generate a clinical summary for the patient.",
actual_output=ai_output,
retrieval_context=get_context(patient_data)
)
# 3. Metrics
faithfulness = FaithfulnessMetric(threshold=0.7, model=eval_model, truths_extraction_limit=3)
relevancy = AnswerRelevancyMetric(threshold=0.7, model=eval_model)
# NEW: Clinical Accuracy (GEval)
clinical_accuracy = GEval(
name="Clinical Accuracy",
model=eval_model,
criteria="Evaluate if the clinical summary accurately captures the patient's stability vs instability. A summary is ACCURATE if it correctly identifies worsening trends (like rising creatinine or falling WBC) and avoids false 'recovery' claims for terminal or acute cases.",
evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.RETRIEVAL_CONTEXT],
threshold=0.8
)
# 4. Measure
faithfulness.measure(test_case)
relevancy.measure(test_case)
clinical_accuracy.measure(test_case)
# 5. Assert & Collect
try:
assert_test(test_case, [faithfulness, relevancy, clinical_accuracy])
status = "PASSED"
except Exception as e:
# Clean up the error message for the report
err_msg = str(e).split('failed.')[0].strip() if 'failed.' in str(e) else str(e)
status = f"FAILED: {err_msg}"
# Capture results
res = {
"scenario": scenario_name,
"status": status,
"faithfulness_score": faithfulness.score if faithfulness.score is not None else 0.0,
"faithfulness_reason": faithfulness.reason,
"relevancy_score": relevancy.score if relevancy.score is not None else 0.0,
"relevancy_reason": relevancy.reason,
"clinical_accuracy_score": clinical_accuracy.score if clinical_accuracy.score is not None else 0.0,
"clinical_accuracy_reason": clinical_accuracy.reason,
"output_preview": ai_output,
"patient_json": json.dumps(patient_data, indent=2),
"prompts": JUDGE_PROMPTS.copy()
}
# Append to file
results = []
if os.path.exists(RESULTS_FILE):
with open(RESULTS_FILE, 'r') as f:
results = json.load(f)
results.append(res)
with open(RESULTS_FILE, 'w') as f:
json.dump(results, f)
# --- REPORT GENERATION ---
def finalize_report():
if not os.path.exists(RESULTS_FILE):
print("\n[WARNING] No results file found.")
return
with open(RESULTS_FILE, 'r') as f:
results = json.load(f)
report_path = os.path.join(os.path.dirname(__file__), 'deepeval_test_report.md')
with open(report_path, 'w', encoding='utf-8') as f:
f.write(f"# DeepEval Comprehensive Patient Data Test Report\n")
f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
# Explicit Model Info
if USE_MOCK_AGENT:
agent_model = "MockAgent (Clinical Logic Simulator)"
else:
# Try to get actual name from agent if possible
agent_model = "microsoft/Phi-3-mini-4k-instruct-gguf"
judge_model = eval_model.get_model_name() if eval_model else 'Default'
if USE_MOCK:
judge_model += " (Internal Clinical Audit Simulator)"
f.write(f"### Model Configuration\n")
f.write(f"- **Summarization Agent**: {agent_model}\n")
f.write(f"- **Evaluation Judge**: {judge_model}\n")
if USE_MOCK:
f.write(f"> [!WARNING]\n> **MOCK MODE ACTIVE**: No API keys found. Scores are simulated for pipeline verification and clinical logic testing.\n\n")
else:
f.write(f"\n")
f.write("| Scenario | Status | Faithfulness | Relevancy | Clinical Acc |\n")
f.write("| --- | --- | --- | --- | --- |\n")
for res in results:
f_score = res.get('faithfulness_score') or 0.0
r_score = res.get('relevancy_score') or 0.0
c_score = res.get('clinical_accuracy_score') or 0.0
f.write(f"| {res['scenario']} | {res['status']} | {f_score:.2f} | {r_score:.2f} | {c_score:.2f} |\n")
f.write("\n## Detailed Findings\n")
for res in results:
f.write(f"### {res['scenario']}\n")
f_score = res.get('faithfulness_score') or 0.0
r_score = res.get('relevancy_score') or 0.0
f.write(f"- **Faithfulness Score:** {f_score:.2f}\n")
f.write(f" - *Reason:* {res.get('faithfulness_reason', 'N/A')}\n")
f.write(f"- **Relevancy Score:** {r_score:.2f}\n")
f.write(f" - *Reason:* {res.get('relevancy_reason', 'N/A')}\n")
c_score = res.get('clinical_accuracy_score') or 0.0
f.write(f"- **Clinical Accuracy Score:** {c_score:.2f}\n")
f.write(f" - *Reason:* {res.get('clinical_accuracy_reason', 'N/A')}\n")
f.write(f"\n#### AI Summary Output\n")
f.write(f"```text\n{res['output_preview']}\n```\n")
f.write(f"\n\nPatient Input Data (JSON)
\n\n")
f.write(f"```json\n{res['patient_json']}\n```\n")
f.write(f" \n\n")
f.write(f"\nJudge Evaluation Prompts
\n\n")
prompts = res.get('prompts', {})
if prompts:
for m_name, p_text in prompts.items():
f.write(f"**{m_name} Metric Prompt:**\n")
f.write(f"```text\n{p_text}\n```\n\n")
else:
f.write("No prompt captured.\n")
f.write(f" \n\n---\n\n")
print(f"\n[SUCCESS] Comprehensive report generated: {report_path}")
# Final test to generate report
def test_generate_final_report():
finalize_report()