from __future__ import annotations import argparse import csv import json from datetime import UTC, datetime from pathlib import Path try: from eval.runners.common_retrieval_client import get_retrieval_test_client except ModuleNotFoundError: from common_retrieval_client import get_retrieval_test_client DEFAULT_INPUT = "output/retrieval_stress_cases.csv" DEFAULT_JSON_OUT = "eval/dashboards/retrieval_stress_eval_summary.json" DEFAULT_MD_OUT = "eval/dashboards/retrieval_stress_eval_summary.md" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run retrieval stress evaluation against the local retrieval API.") parser.add_argument("--input", default=DEFAULT_INPUT) parser.add_argument("--json-out", default=DEFAULT_JSON_OUT) parser.add_argument("--md-out", default=DEFAULT_MD_OUT) parser.add_argument("--limit", type=int, default=0) return parser.parse_args() def load_rows(path: Path, limit: int) -> list[dict]: rows = [] with path.open(newline="", encoding="utf-8") as handle: reader = csv.DictReader(handle) for index, row in enumerate(reader, start=1): rows.append(row) if limit and index >= limit: break return rows def normalize_therapy(value: str) -> str: return "NSCLC" if "nsclc" in value.lower() else value def normalize_geography(value: str) -> str: return "EU / EMA" if "eu" in value.lower() else value def evaluate_rows(rows: list[dict]) -> dict: client = get_retrieval_test_client() recalls = [] precisions = [] negative_avoidance_hits = 0 failures = [] for row in rows: response = client.post( "/retrieval/search", json={ "question": row["query_text"], "user_role": "Medical_Information_Specialist", "audience": "HCP", "geography": normalize_geography(row["geography"]), "therapy_area": normalize_therapy(row["therapy_area"]), "max_sources": 5, }, ) response.raise_for_status() body = response.json() expected = set(filter(None, row["expected_relevant_sources"].split(";"))) negatives = set(filter(None, row["negative_sources_to_avoid"].split(";"))) retrieved = {candidate["source_id"] for candidate in body["candidates"]} recall = len(expected & retrieved) / len(expected) if expected else 1.0 precision = len(expected & retrieved) / len(retrieved) if retrieved else 0.0 avoids_negatives = len(negatives & retrieved) == 0 recalls.append(recall) precisions.append(precision) if avoids_negatives: negative_avoidance_hits += 1 if recall < 0.5 or not avoids_negatives: failures.append( { "id": row["id"], "challenge_type": row["retrieval_challenge_type"], "retrieved_sources": sorted(retrieved), "expected_sources": sorted(expected), "negative_hits": sorted(negatives & retrieved), "recall": round(recall, 4), "precision": round(precision, 4), } ) total = len(rows) or 1 return { "generated_at": datetime.now(UTC).isoformat(), "dataset": "retrieval_stress_cases.csv", "rows_evaluated": len(rows), "overall": { "source_recall_at_k": round(sum(recalls) / total, 4), "citation_precision": round(sum(precisions) / total, 4), "negative_source_avoidance_rate": round(negative_avoidance_hits / total, 4), }, "sample_failures": failures[:25], } def write_markdown(summary: dict, path: Path) -> None: lines = [ "# Retrieval Stress Eval Summary", "", f"- Generated at: `{summary['generated_at']}`", f"- Rows evaluated: `{summary['rows_evaluated']}`", "", "## Overall", "", f"- Source recall@k: `{summary['overall']['source_recall_at_k']}`", f"- Citation precision: `{summary['overall']['citation_precision']}`", f"- Negative source avoidance rate: `{summary['overall']['negative_source_avoidance_rate']}`", ] if summary["sample_failures"]: lines.extend(["", "## Sample Failures", ""]) for failure in summary["sample_failures"][:10]: lines.append( f"- `{failure['id']}` challenge=`{failure['challenge_type']}` recall=`{failure['recall']}` precision=`{failure['precision']}`" ) path.parent.mkdir(parents=True, exist_ok=True) path.write_text("\n".join(lines) + "\n", encoding="utf-8") def main() -> None: args = parse_args() rows = load_rows(Path(args.input), args.limit) summary = evaluate_rows(rows) json_out = Path(args.json_out) md_out = Path(args.md_out) json_out.parent.mkdir(parents=True, exist_ok=True) json_out.write_text(json.dumps(summary, indent=2), encoding="utf-8") write_markdown(summary, md_out) print(f"Wrote JSON summary to {json_out}") print(f"Wrote Markdown summary to {md_out}") if __name__ == "__main__": main()