import pytest import sys import os import json import logging import asyncio from datetime import datetime from dotenv import load_dotenv # Load .env from root load_dotenv(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../.env'))) load_dotenv(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.env'))) load_dotenv() # Current dir sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) try: from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent except ImportError: sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))) from ai_med_extract.agents.patient_summary_agent import PatientSummarizerAgent from deepeval import assert_test from deepeval.metrics import FaithfulnessMetric, AnswerRelevancyMetric, GEval from deepeval.test_case import LLMTestCase, LLMTestCaseParams from deepeval.models.base_model import DeepEvalBaseLLM # Global to store judge prompts for reporting JUDGE_PROMPTS = {} # key: metric_name, value: last_prompt # --- JUDGE CONFIGURATIONS --- (Copied from test_medical_correctness.py) class HuggingFaceJudge(DeepEvalBaseLLM): def __init__(self, model_name="google/gemma-3-27b-it:featherless-ai"): self.model_name = model_name self.api_key = os.getenv("HF_TOKEN") if not self.api_key: raise ValueError("HF_TOKEN is required for HuggingFace Judge.") from openai import OpenAI self.client = OpenAI( base_url="https://router.huggingface.co/v1", api_key=self.api_key, ) def load_model(self): return self.client async def a_generate(self, prompt: str, schema=None, **kwargs) -> str: # Use sync generate for simplicity in this wrapper return self.generate(prompt, schema, **kwargs) def generate(self, prompt: str, schema=None, **kwargs) -> str: metric_name = kwargs.get('metric_name', 'Judge') JUDGE_PROMPTS[metric_name] = prompt # If schema is provided, we need to request JSON and parse it system_msg = "You are a helpful assistant." if schema: system_msg = f"You are a helpful assistant that always responds in JSON format. Your response must follow this schema: {schema.schema() if hasattr(schema, 'schema') else 'JSON object'}" try: completion = self.client.chat.completions.create( model=self.model_name, messages=[ {"role": "system", "content": system_msg}, {"role": "user", "content": prompt} ], temperature=0.1, max_tokens=2048, ) raw_content = completion.choices[0].message.content if not schema: return raw_content # Attempt to extract JSON from the response import json import re # Find the first { and the last } json_match = re.search(r'\{.*\}', raw_content, re.DOTALL) if json_match: json_str = json_match.group(0) data = json.loads(json_str) else: data = json.loads(raw_content) # print(f"DEBUG: Processed Judge Data for {metric_name}: {json.dumps(data)}") print(f"DEBUG: Processed Judge Data for {metric_name} score: {data.get('score')}") if hasattr(schema, 'model_validate'): # Debug schema fields if something goes wrong if not data.get("evaluation_steps") and "score" in data: # Log the fields required by the schema fields = schema.model_fields.keys() if hasattr(schema, 'model_fields') else [] logging.getLogger(__name__).error(f"Schema fields: {list(fields)}") # Force populate evaluation_steps data["evaluation_steps"] = [data.get("reason", "No specific steps provided.")] if not data["evaluation_steps"] or data["evaluation_steps"] == [""]: data["evaluation_steps"] = ["Clinical trajectory assessment."] # Double check: DeepEval GEval strictly refuses empty lists if "evaluation_steps" in data and not data["evaluation_steps"]: data["evaluation_steps"] = ["General clinical audit."] # Handle common DeepEval naming variations if not data.get("evaluation_steps") and data.get("steps"): data["evaluation_steps"] = data["steps"] if isinstance(data["steps"], list) else [data["steps"]] # Final check for verdicts/truths/claims (Faithfulness/Relevancy) for field in ["verdicts", "truths", "claims", "statements", "steps"]: if field not in data: data[field] = [] if "verdict" not in data: data["verdict"] = "yes" if data.get("score", 0) > 0.5 else "no" return schema.model_validate(data) return schema(**data) except Exception as e: logging.error(f"Judge error ({metric_name}): {str(e)}") # Fallback for metrics that expect a score if possible if schema: try: # Minimum valid mock object to prevent crash fallback = { "score": 0.0, "reason": f"Judge error: {str(e)}", "verdict": "no", "verdicts": [], "truths": [], "claims": [], "statements": [], "steps": ["Evaluation failed due to error"], "evaluation_steps": ["Evaluation failed due to error"] } if hasattr(schema, 'model_validate'): return schema.model_validate(fallback) return schema(**fallback) except Exception as ef: logging.error(f"Fallback validation failed: {str(ef)}") return f"Error: {str(e)}" def get_model_name(self): return self.model_name class GeminiJudge(DeepEvalBaseLLM): def __init__(self, model_name="gemini-1.5-pro", api_key=None): self.model_name = model_name self.api_key = api_key or os.getenv("GOOGLE_API_KEY") if not self.api_key: raise ValueError("GOOGLE_API_KEY is required.") import google.generativeai as genai genai.configure(api_key=self.api_key) self.model = genai.GenerativeModel(model_name) def load_model(self): return self.model async def a_generate(self, prompt: str, schema=None, **kwargs): JUDGE_PROMPTS[kwargs.get('metric_name', 'Gemini')] = prompt try: response = await asyncio.to_thread(self.model.generate_content, prompt) return response.text except Exception as e: return f"Error: {str(e)}" def generate(self, prompt: str, schema=None, **kwargs) -> str: return asyncio.run(self.a_generate(prompt, schema, **kwargs)) def get_model_name(self): return self.model_name class MockJudge(DeepEvalBaseLLM): def __init__(self, model_name="local-mock-judge"): self.model_name = model_name def load_model(self): return None def generate(self, prompt: str, schema=None, **kwargs) -> str: # Capture prompt metric_key = kwargs.get('metric_name', 'Mock') JUDGE_PROMPTS[metric_key] = prompt # Simulate LLM response for metrics if schema: # Default positive response (using 1-10 scale as GEval often does) data = { "score": 10.0, "reason": "The summary accurately reflects the patient data.", "verdicts": [{"verdict": "yes", "reason": "Accurate clinical statement"}], "truths": ["Patient data present"], "claims": ["Statement matches data"], "verdict": "yes", "statements": ["The summary is correct"], "steps": ["Step 1: Check facts", "Step 2: Verify trends"] } # DELIBERATE FAILURE LOGIC FOR MOCK MODE: # If the prompt contains 'signs of recovery' but the context has 'AKI' or 'Cancer', fail it. if "signs of recovery" in prompt.lower(): if any(x in prompt.upper() for x in ["AKI", "CANCER", "LUNG", "ALZHEIMER", "PALLIATIVE"]): data["score"] = 1.0 data["reason"] = f"CRITICAL FAIL: General 'signs of recovery' claim detected in {metric_key} audit for unstable or chronic/terminal patient case." data["verdict"] = "no" data["verdicts"][0]["verdict"] = "no" data["verdicts"][0]["reason"] = "Inaccurate clinical claim" # Log for debugging # print(f"DEBUG: MockJudge returning for {metric_key}: {data['score']}") if hasattr(schema, 'model_validate'): return schema.model_validate(data) try: return schema(**data) except Exception: # Fallback if schema is different return data return "Evaluated." async def a_generate(self, prompt: str, schema=None, **kwargs) -> str: return self.generate(prompt, schema, **kwargs) def get_model_name(self): return self.model_name # --- INITIALIZE JUDGE --- eval_model = None HAS_KEY = False SKIP_REASON = "" USE_MOCK = False if os.getenv("HF_TOKEN"): eval_model = HuggingFaceJudge() HAS_KEY = True USE_MOCK = False elif os.getenv("GOOGLE_API_KEY"): eval_model = GeminiJudge() HAS_KEY = True else: print("WARNING: No API Key found. Using MockJudge for demonstration.") eval_model = MockJudge() HAS_KEY = True # Force True to run tests with Mock USE_MOCK = True # --- DATA LOADER --- def load_test_data(): data_path = os.path.join(os.path.dirname(__file__), 'patient_test_data.json') with open(data_path, 'r') as f: return json.load(f) # --- CONFIGURATION --- USE_MOCK_AGENT = False # Set to True for instant testing of the DeepEval pipeline @pytest.fixture(scope="module") def agent(): if USE_MOCK_AGENT: class MockAgent: def generate_patient_summary(self, patient_data): # Smarter Mock Agent: Generates variations based on data to test evaluation logic res = patient_data.get("result", {}) name = res.get("patientname", "Patient") encounters = res.get("encounters", []) last_diag = encounters[-1].get("diagnosis", []) if encounters else [] # Default dangerous generic summary summary = f"--- AI-GENERATED CLINICAL NARRATIVE ---\nThe patient {name} is showing signs of recovery. Stable vitals. Continue current medication.\n---" # Slightly smarter logic for some cases if any("AKI" in d or "Kidney" in d for d in last_diag): summary = f"--- AI-GENERATED CLINICAL NARRATIVE ---\n{name} has Acute Kidney Injury. Creatinine is 2.4 (baseline 1.6). Monitoring fluid status.\n---" elif "Oncology" in str(patient_data) or "Cancer" in str(patient_data): summary = f"--- AI-GENERATED CLINICAL NARRATIVE ---\n{name} is undergoing chemo for Breast Cancer. Neutropenia noted (WBC 3.2). Chemo held.\n---" return summary return MockAgent() ag = PatientSummarizerAgent() # model_name = "microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf" # Use a slightly better model for clinical summary if available locally model_name = "microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf" ag.configure_model(model_name) # Fast config for testing from ai_med_extract.utils import model_config if hasattr(model_config, 'get_t4_generation_config'): original_get_config = model_config.get_t4_generation_config def fast_test_config(model_type): config = original_get_config(model_type) config['max_new_tokens'] = 512 # enough for a comprehensive summary return config model_config.get_t4_generation_config = fast_test_config return ag # --- CLINICAL REQUIREMENTS MAPPING --- CLINICAL_REQUIREMENTS = { "Acute Kidney Injury Scenario": ["creatinine", "baseline", "renal"], "Oncology Treatment Cycle (Breast Cancer)": ["chemo", "neutropenia", "wbc", "held"], "Palliative Care (Stage IV Lung Cancer - Symptom Management)": ["palliative", "hospice", "comfort", "cancer"], "Hypertension & Diabetes Patient": ["glucose", "blood sugar", "metformin", "hypertension"], "Neurological Management (Early-Stage Alzheimer's)": ["alzheimer", "memory", "cognitive", "donepezil"] } # --- HELPERS --- def extract_narrative(report_text): if "--- AI-GENERATED CLINICAL NARRATIVE ---" in report_text: parts = report_text.split("--- AI-GENERATED CLINICAL NARRATIVE ---") return parts[1].split("---")[0].strip() return report_text def get_context(data): res = data.get("result", {}) context = [f"Patient: {res.get('patientname')}, PMH: {', '.join(res.get('past_medical_history', []))}"] for enc in res.get("encounters", []): context.append(f"Date: {enc['visit_date']}, Complaint: {enc['chief_complaint']}, Diagnosis: {', '.join(enc['diagnosis'])}, Notes: {enc['dr_notes']}") return context # --- RESULTS COLLECTOR (File-based) --- RESULTS_FILE = os.path.join(os.path.dirname(__file__), 'test_results.json') # Handle results file clearing - ensure it's fresh for each session if os.path.exists(RESULTS_FILE): try: os.remove(RESULTS_FILE) except: pass # --- TESTS --- @pytest.mark.timeout(1200) # 20 minutes for all scenarios @pytest.mark.parametrize("scenario", load_test_data()) @pytest.mark.skipif(not HAS_KEY, reason=SKIP_REASON) def test_patient_summary_quality(agent, scenario): scenario_name = scenario['name'] patient_data = scenario['data'] print(f"\n--- Testing Scenario: {scenario_name} ---") print(f"Generating summary for {scenario_name}...") # 0. Clear global prompts for this scenario JUDGE_PROMPTS.clear() # 1. Generate full_report = agent.generate_patient_summary(patient_data) ai_output = extract_narrative(full_report) # 2. Define Test Case test_case = LLMTestCase( input="Generate a clinical summary for the patient.", actual_output=ai_output, retrieval_context=get_context(patient_data) ) # 3. Metrics faithfulness = FaithfulnessMetric(threshold=0.7, model=eval_model, truths_extraction_limit=3) relevancy = AnswerRelevancyMetric(threshold=0.7, model=eval_model) # NEW: Clinical Accuracy (GEval) clinical_accuracy = GEval( name="Clinical Accuracy", model=eval_model, criteria="Evaluate if the clinical summary accurately captures the patient's stability vs instability. A summary is ACCURATE if it correctly identifies worsening trends (like rising creatinine or falling WBC) and avoids false 'recovery' claims for terminal or acute cases.", evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.RETRIEVAL_CONTEXT], threshold=0.8 ) # 4. Measure faithfulness.measure(test_case) relevancy.measure(test_case) clinical_accuracy.measure(test_case) # 5. Assert & Collect try: assert_test(test_case, [faithfulness, relevancy, clinical_accuracy]) status = "PASSED" except Exception as e: # Clean up the error message for the report err_msg = str(e).split('failed.')[0].strip() if 'failed.' in str(e) else str(e) status = f"FAILED: {err_msg}" # Capture results res = { "scenario": scenario_name, "status": status, "faithfulness_score": faithfulness.score if faithfulness.score is not None else 0.0, "faithfulness_reason": faithfulness.reason, "relevancy_score": relevancy.score if relevancy.score is not None else 0.0, "relevancy_reason": relevancy.reason, "clinical_accuracy_score": clinical_accuracy.score if clinical_accuracy.score is not None else 0.0, "clinical_accuracy_reason": clinical_accuracy.reason, "output_preview": ai_output, "patient_json": json.dumps(patient_data, indent=2), "prompts": JUDGE_PROMPTS.copy() } # Append to file results = [] if os.path.exists(RESULTS_FILE): with open(RESULTS_FILE, 'r') as f: results = json.load(f) results.append(res) with open(RESULTS_FILE, 'w') as f: json.dump(results, f) # --- REPORT GENERATION --- def finalize_report(): if not os.path.exists(RESULTS_FILE): print("\n[WARNING] No results file found.") return with open(RESULTS_FILE, 'r') as f: results = json.load(f) report_path = os.path.join(os.path.dirname(__file__), 'deepeval_test_report.md') with open(report_path, 'w', encoding='utf-8') as f: f.write(f"# DeepEval Comprehensive Patient Data Test Report\n") f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") # Explicit Model Info if USE_MOCK_AGENT: agent_model = "MockAgent (Clinical Logic Simulator)" else: # Try to get actual name from agent if possible agent_model = "microsoft/Phi-3-mini-4k-instruct-gguf" judge_model = eval_model.get_model_name() if eval_model else 'Default' if USE_MOCK: judge_model += " (Internal Clinical Audit Simulator)" f.write(f"### Model Configuration\n") f.write(f"- **Summarization Agent**: {agent_model}\n") f.write(f"- **Evaluation Judge**: {judge_model}\n") if USE_MOCK: f.write(f"> [!WARNING]\n> **MOCK MODE ACTIVE**: No API keys found. Scores are simulated for pipeline verification and clinical logic testing.\n\n") else: f.write(f"\n") f.write("| Scenario | Status | Faithfulness | Relevancy | Clinical Acc |\n") f.write("| --- | --- | --- | --- | --- |\n") for res in results: f_score = res.get('faithfulness_score') or 0.0 r_score = res.get('relevancy_score') or 0.0 c_score = res.get('clinical_accuracy_score') or 0.0 f.write(f"| {res['scenario']} | {res['status']} | {f_score:.2f} | {r_score:.2f} | {c_score:.2f} |\n") f.write("\n## Detailed Findings\n") for res in results: f.write(f"### {res['scenario']}\n") f_score = res.get('faithfulness_score') or 0.0 r_score = res.get('relevancy_score') or 0.0 f.write(f"- **Faithfulness Score:** {f_score:.2f}\n") f.write(f" - *Reason:* {res.get('faithfulness_reason', 'N/A')}\n") f.write(f"- **Relevancy Score:** {r_score:.2f}\n") f.write(f" - *Reason:* {res.get('relevancy_reason', 'N/A')}\n") c_score = res.get('clinical_accuracy_score') or 0.0 f.write(f"- **Clinical Accuracy Score:** {c_score:.2f}\n") f.write(f" - *Reason:* {res.get('clinical_accuracy_reason', 'N/A')}\n") f.write(f"\n#### AI Summary Output\n") f.write(f"```text\n{res['output_preview']}\n```\n") f.write(f"\n
\nPatient Input Data (JSON)\n\n") f.write(f"```json\n{res['patient_json']}\n```\n") f.write(f"
\n\n") f.write(f"
\nJudge Evaluation Prompts\n\n") prompts = res.get('prompts', {}) if prompts: for m_name, p_text in prompts.items(): f.write(f"**{m_name} Metric Prompt:**\n") f.write(f"```text\n{p_text}\n```\n\n") else: f.write("No prompt captured.\n") f.write(f"
\n\n---\n\n") print(f"\n[SUCCESS] Comprehensive report generated: {report_path}") # Final test to generate report def test_generate_final_report(): finalize_report()