| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import json |
| from collections import Counter, defaultdict |
| from dataclasses import dataclass |
| from datetime import UTC, datetime |
| from pathlib import Path |
|
|
| try: |
| from eval.runners.common_memory_client import get_memory_test_client |
| except ModuleNotFoundError: |
| from common_memory_client import get_memory_test_client |
|
|
|
|
| DEFAULT_INPUT = "output/golden_medical_qa.csv" |
| DEFAULT_JSON_OUT = "eval/dashboards/golden_memory_eval_summary.json" |
| DEFAULT_MD_OUT = "eval/dashboards/golden_memory_eval_summary.md" |
|
|
|
|
| @dataclass |
| class EvalRowResult: |
| row_id: str |
| audience: str |
| decision: str |
| expected_sources: set[str] |
| retrieved_sources: set[str] |
| label_required: bool |
| label_present: bool |
| audience_match: bool |
|
|
| @property |
| def source_recall(self) -> float: |
| if not self.expected_sources: |
| return 1.0 |
| return len(self.expected_sources & self.retrieved_sources) / len(self.expected_sources) |
|
|
| @property |
| def citation_precision(self) -> float: |
| if not self.retrieved_sources: |
| return 0.0 |
| return len(self.expected_sources & self.retrieved_sources) / len(self.retrieved_sources) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run golden evaluation against the local memory API.") |
| parser.add_argument("--input", default=DEFAULT_INPUT) |
| parser.add_argument("--json-out", default=DEFAULT_JSON_OUT) |
| parser.add_argument("--md-out", default=DEFAULT_MD_OUT) |
| parser.add_argument("--limit", type=int, default=0, help="Optional row limit for quicker local runs.") |
| return parser.parse_args() |
|
|
|
|
| def normalize_therapy(value: str) -> str: |
| lowered = value.lower() |
| if "nsclc" in lowered: |
| return "NSCLC" |
| return value |
|
|
|
|
| def normalize_geography(value: str) -> str: |
| if "eu" in value.lower(): |
| return "EU / EMA" |
| return value |
|
|
|
|
| def label_required(tags: str, notes_for_eval: str) -> bool: |
| lowered = f"{tags} {notes_for_eval}".lower() |
| return any(token in lowered for token in ["dose", "administration", "line-of-therapy", "approved eu boundaries"]) |
|
|
|
|
| def audience_match(audience: str, explanations: list[str]) -> bool: |
| text = " ".join(explanations).lower() |
| if audience.lower() == "patient": |
| return "internal-only" not in text |
| return True |
|
|
|
|
| def evaluate_rows(rows: list[dict]) -> dict: |
| client = get_memory_test_client() |
| results: list[EvalRowResult] = [] |
| by_audience_recall: dict[str, list[float]] = defaultdict(list) |
| by_audience_precision: dict[str, list[float]] = defaultdict(list) |
| decisions = Counter() |
| missed_anchor_rows: list[str] = [] |
|
|
| for row in rows: |
| payload = { |
| "question": row["question_text"], |
| "user_role": "Medical_Information_Specialist" if row["audience"] != "Internal" else "Internal_Medical_Reviewer", |
| "audience": row["audience"], |
| "geography": normalize_geography(row["geography"]), |
| "therapy_area": normalize_therapy(row["therapy_area"]), |
| "max_sources": 5, |
| "min_evidence_score": 0.0, |
| } |
| response = client.post("/memory/search", json=payload) |
| response.raise_for_status() |
| body = response.json() |
|
|
| expected_sources = set(filter(None, row["required_sources"].split(";"))) |
| retrieved_sources = {citation["source_id"] for citation in body["citations"]} |
| requires_label = label_required(row["evaluation_tags"], row["notes_for_eval"]) |
| label_present = any(source.startswith("LBL-") for source in retrieved_sources) |
|
|
| result = EvalRowResult( |
| row_id=row["id"], |
| audience=row["audience"], |
| decision=body["decision"], |
| expected_sources=expected_sources, |
| retrieved_sources=retrieved_sources, |
| label_required=requires_label, |
| label_present=label_present, |
| audience_match=audience_match(row["audience"], body["explanations"]), |
| ) |
| results.append(result) |
| by_audience_recall[result.audience].append(result.source_recall) |
| by_audience_precision[result.audience].append(result.citation_precision) |
| decisions[result.decision] += 1 |
| if requires_label and not label_present: |
| missed_anchor_rows.append(result.row_id) |
|
|
| total = len(results) or 1 |
| summary = { |
| "generated_at": datetime.now(UTC).isoformat(), |
| "dataset": "golden_medical_qa.csv", |
| "rows_evaluated": len(results), |
| "decision_counts": dict(decisions), |
| "overall": { |
| "source_recall_at_k": round(sum(item.source_recall for item in results) / total, 4), |
| "citation_precision": round(sum(item.citation_precision for item in results) / total, 4), |
| "audience_alignment_rate": round(sum(1 for item in results if item.audience_match) / total, 4), |
| "label_requirement_pass_rate": round( |
| sum(1 for item in results if (not item.label_required) or item.label_present) / total, |
| 4, |
| ), |
| }, |
| "by_audience": { |
| audience: { |
| "source_recall_at_k": round(sum(values) / len(values), 4), |
| "citation_precision": round(sum(by_audience_precision[audience]) / len(by_audience_precision[audience]), 4), |
| } |
| for audience, values in by_audience_recall.items() |
| }, |
| "risk_flags": { |
| "missed_label_anchor_rows": missed_anchor_rows[:50], |
| }, |
| "sample_failures": [ |
| { |
| "id": item.row_id, |
| "decision": item.decision, |
| "expected_sources": sorted(item.expected_sources), |
| "retrieved_sources": sorted(item.retrieved_sources), |
| "source_recall": round(item.source_recall, 4), |
| "citation_precision": round(item.citation_precision, 4), |
| } |
| for item in results |
| if item.source_recall < 0.5 or (item.label_required and not item.label_present) |
| ][:25], |
| } |
| return summary |
|
|
|
|
| def write_markdown(summary: dict, path: Path) -> None: |
| overall = summary["overall"] |
| lines = [ |
| "# Golden Memory Eval Summary", |
| "", |
| f"- Generated at: `{summary['generated_at']}`", |
| f"- Dataset: `{summary['dataset']}`", |
| f"- Rows evaluated: `{summary['rows_evaluated']}`", |
| "", |
| "## Overall", |
| "", |
| f"- Source recall@k: `{overall['source_recall_at_k']}`", |
| f"- Citation precision: `{overall['citation_precision']}`", |
| f"- Audience alignment rate: `{overall['audience_alignment_rate']}`", |
| f"- Label requirement pass rate: `{overall['label_requirement_pass_rate']}`", |
| "", |
| "## Decision Counts", |
| "", |
| ] |
| for key, value in summary["decision_counts"].items(): |
| lines.append(f"- `{key}`: `{value}`") |
| lines.extend(["", "## By Audience", ""]) |
| for audience, metrics in summary["by_audience"].items(): |
| lines.append(f"- `{audience}` recall@k: `{metrics['source_recall_at_k']}`, precision: `{metrics['citation_precision']}`") |
| lines.extend(["", "## Risk Flags", ""]) |
| lines.append(f"- Missed label anchor rows: `{len(summary['risk_flags']['missed_label_anchor_rows'])}`") |
| if summary["sample_failures"]: |
| lines.extend(["", "## Sample Failures", ""]) |
| for failure in summary["sample_failures"][:10]: |
| lines.append( |
| f"- `{failure['id']}` decision=`{failure['decision']}` recall=`{failure['source_recall']}` precision=`{failure['citation_precision']}`" |
| ) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text("\n".join(lines) + "\n", encoding="utf-8") |
|
|
|
|
| def load_rows(path: Path, limit: int) -> list[dict]: |
| rows: list[dict] = [] |
| with path.open(newline="", encoding="utf-8") as handle: |
| reader = csv.DictReader(handle) |
| for index, row in enumerate(reader, start=1): |
| rows.append(row) |
| if limit and index >= limit: |
| break |
| return rows |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| rows = load_rows(Path(args.input), args.limit) |
| summary = evaluate_rows(rows) |
|
|
| json_out = Path(args.json_out) |
| md_out = Path(args.md_out) |
| json_out.parent.mkdir(parents=True, exist_ok=True) |
| json_out.write_text(json.dumps(summary, indent=2), encoding="utf-8") |
| write_markdown(summary, md_out) |
|
|
| print(f"Wrote JSON summary to {json_out}") |
| print(f"Wrote Markdown summary to {md_out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|