from __future__ import annotations import argparse import csv import json from collections import Counter, defaultdict from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path try: from eval.runners.common_memory_client import get_memory_test_client except ModuleNotFoundError: from common_memory_client import get_memory_test_client DEFAULT_INPUT = "output/golden_medical_qa.csv" DEFAULT_JSON_OUT = "eval/dashboards/golden_memory_eval_summary.json" DEFAULT_MD_OUT = "eval/dashboards/golden_memory_eval_summary.md" @dataclass class EvalRowResult: row_id: str audience: str decision: str expected_sources: set[str] retrieved_sources: set[str] label_required: bool label_present: bool audience_match: bool @property def source_recall(self) -> float: if not self.expected_sources: return 1.0 return len(self.expected_sources & self.retrieved_sources) / len(self.expected_sources) @property def citation_precision(self) -> float: if not self.retrieved_sources: return 0.0 return len(self.expected_sources & self.retrieved_sources) / len(self.retrieved_sources) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run golden evaluation against the local memory API.") parser.add_argument("--input", default=DEFAULT_INPUT) parser.add_argument("--json-out", default=DEFAULT_JSON_OUT) parser.add_argument("--md-out", default=DEFAULT_MD_OUT) parser.add_argument("--limit", type=int, default=0, help="Optional row limit for quicker local runs.") return parser.parse_args() def normalize_therapy(value: str) -> str: lowered = value.lower() if "nsclc" in lowered: return "NSCLC" return value def normalize_geography(value: str) -> str: if "eu" in value.lower(): return "EU / EMA" return value def label_required(tags: str, notes_for_eval: str) -> bool: lowered = f"{tags} {notes_for_eval}".lower() return any(token in lowered for token in ["dose", "administration", "line-of-therapy", "approved eu boundaries"]) def audience_match(audience: str, explanations: list[str]) -> bool: text = " ".join(explanations).lower() if audience.lower() == "patient": return "internal-only" not in text return True def evaluate_rows(rows: list[dict]) -> dict: client = get_memory_test_client() results: list[EvalRowResult] = [] by_audience_recall: dict[str, list[float]] = defaultdict(list) by_audience_precision: dict[str, list[float]] = defaultdict(list) decisions = Counter() missed_anchor_rows: list[str] = [] for row in rows: payload = { "question": row["question_text"], "user_role": "Medical_Information_Specialist" if row["audience"] != "Internal" else "Internal_Medical_Reviewer", "audience": row["audience"], "geography": normalize_geography(row["geography"]), "therapy_area": normalize_therapy(row["therapy_area"]), "max_sources": 5, "min_evidence_score": 0.0, } response = client.post("/memory/search", json=payload) response.raise_for_status() body = response.json() expected_sources = set(filter(None, row["required_sources"].split(";"))) retrieved_sources = {citation["source_id"] for citation in body["citations"]} requires_label = label_required(row["evaluation_tags"], row["notes_for_eval"]) label_present = any(source.startswith("LBL-") for source in retrieved_sources) result = EvalRowResult( row_id=row["id"], audience=row["audience"], decision=body["decision"], expected_sources=expected_sources, retrieved_sources=retrieved_sources, label_required=requires_label, label_present=label_present, audience_match=audience_match(row["audience"], body["explanations"]), ) results.append(result) by_audience_recall[result.audience].append(result.source_recall) by_audience_precision[result.audience].append(result.citation_precision) decisions[result.decision] += 1 if requires_label and not label_present: missed_anchor_rows.append(result.row_id) total = len(results) or 1 summary = { "generated_at": datetime.now(UTC).isoformat(), "dataset": "golden_medical_qa.csv", "rows_evaluated": len(results), "decision_counts": dict(decisions), "overall": { "source_recall_at_k": round(sum(item.source_recall for item in results) / total, 4), "citation_precision": round(sum(item.citation_precision for item in results) / total, 4), "audience_alignment_rate": round(sum(1 for item in results if item.audience_match) / total, 4), "label_requirement_pass_rate": round( sum(1 for item in results if (not item.label_required) or item.label_present) / total, 4, ), }, "by_audience": { audience: { "source_recall_at_k": round(sum(values) / len(values), 4), "citation_precision": round(sum(by_audience_precision[audience]) / len(by_audience_precision[audience]), 4), } for audience, values in by_audience_recall.items() }, "risk_flags": { "missed_label_anchor_rows": missed_anchor_rows[:50], }, "sample_failures": [ { "id": item.row_id, "decision": item.decision, "expected_sources": sorted(item.expected_sources), "retrieved_sources": sorted(item.retrieved_sources), "source_recall": round(item.source_recall, 4), "citation_precision": round(item.citation_precision, 4), } for item in results if item.source_recall < 0.5 or (item.label_required and not item.label_present) ][:25], } return summary def write_markdown(summary: dict, path: Path) -> None: overall = summary["overall"] lines = [ "# Golden Memory Eval Summary", "", f"- Generated at: `{summary['generated_at']}`", f"- Dataset: `{summary['dataset']}`", f"- Rows evaluated: `{summary['rows_evaluated']}`", "", "## Overall", "", f"- Source recall@k: `{overall['source_recall_at_k']}`", f"- Citation precision: `{overall['citation_precision']}`", f"- Audience alignment rate: `{overall['audience_alignment_rate']}`", f"- Label requirement pass rate: `{overall['label_requirement_pass_rate']}`", "", "## Decision Counts", "", ] for key, value in summary["decision_counts"].items(): lines.append(f"- `{key}`: `{value}`") lines.extend(["", "## By Audience", ""]) for audience, metrics in summary["by_audience"].items(): lines.append(f"- `{audience}` recall@k: `{metrics['source_recall_at_k']}`, precision: `{metrics['citation_precision']}`") lines.extend(["", "## Risk Flags", ""]) lines.append(f"- Missed label anchor rows: `{len(summary['risk_flags']['missed_label_anchor_rows'])}`") if summary["sample_failures"]: lines.extend(["", "## Sample Failures", ""]) for failure in summary["sample_failures"][:10]: lines.append( f"- `{failure['id']}` decision=`{failure['decision']}` recall=`{failure['source_recall']}` precision=`{failure['citation_precision']}`" ) path.parent.mkdir(parents=True, exist_ok=True) path.write_text("\n".join(lines) + "\n", encoding="utf-8") def load_rows(path: Path, limit: int) -> list[dict]: rows: list[dict] = [] with path.open(newline="", encoding="utf-8") as handle: reader = csv.DictReader(handle) for index, row in enumerate(reader, start=1): rows.append(row) if limit and index >= limit: break return rows def main() -> None: args = parse_args() rows = load_rows(Path(args.input), args.limit) summary = evaluate_rows(rows) json_out = Path(args.json_out) md_out = Path(args.md_out) json_out.parent.mkdir(parents=True, exist_ok=True) json_out.write_text(json.dumps(summary, indent=2), encoding="utf-8") write_markdown(summary, md_out) print(f"Wrote JSON summary to {json_out}") print(f"Wrote Markdown summary to {md_out}") if __name__ == "__main__": main()