pharmaspine-backend / eval /runners /run_golden_memory_eval.py
ashish1265659565's picture
Upload folder using huggingface_hub
08fd094 verified
Raw
History Blame
8.65 kB
from __future__ import annotations
import argparse
import csv
import json
from collections import Counter, defaultdict
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
try:
from eval.runners.common_memory_client import get_memory_test_client
except ModuleNotFoundError:
from common_memory_client import get_memory_test_client
DEFAULT_INPUT = "output/golden_medical_qa.csv"
DEFAULT_JSON_OUT = "eval/dashboards/golden_memory_eval_summary.json"
DEFAULT_MD_OUT = "eval/dashboards/golden_memory_eval_summary.md"
@dataclass
class EvalRowResult:
row_id: str
audience: str
decision: str
expected_sources: set[str]
retrieved_sources: set[str]
label_required: bool
label_present: bool
audience_match: bool
@property
def source_recall(self) -> float:
if not self.expected_sources:
return 1.0
return len(self.expected_sources & self.retrieved_sources) / len(self.expected_sources)
@property
def citation_precision(self) -> float:
if not self.retrieved_sources:
return 0.0
return len(self.expected_sources & self.retrieved_sources) / len(self.retrieved_sources)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run golden evaluation against the local memory API.")
parser.add_argument("--input", default=DEFAULT_INPUT)
parser.add_argument("--json-out", default=DEFAULT_JSON_OUT)
parser.add_argument("--md-out", default=DEFAULT_MD_OUT)
parser.add_argument("--limit", type=int, default=0, help="Optional row limit for quicker local runs.")
return parser.parse_args()
def normalize_therapy(value: str) -> str:
lowered = value.lower()
if "nsclc" in lowered:
return "NSCLC"
return value
def normalize_geography(value: str) -> str:
if "eu" in value.lower():
return "EU / EMA"
return value
def label_required(tags: str, notes_for_eval: str) -> bool:
lowered = f"{tags} {notes_for_eval}".lower()
return any(token in lowered for token in ["dose", "administration", "line-of-therapy", "approved eu boundaries"])
def audience_match(audience: str, explanations: list[str]) -> bool:
text = " ".join(explanations).lower()
if audience.lower() == "patient":
return "internal-only" not in text
return True
def evaluate_rows(rows: list[dict]) -> dict:
client = get_memory_test_client()
results: list[EvalRowResult] = []
by_audience_recall: dict[str, list[float]] = defaultdict(list)
by_audience_precision: dict[str, list[float]] = defaultdict(list)
decisions = Counter()
missed_anchor_rows: list[str] = []
for row in rows:
payload = {
"question": row["question_text"],
"user_role": "Medical_Information_Specialist" if row["audience"] != "Internal" else "Internal_Medical_Reviewer",
"audience": row["audience"],
"geography": normalize_geography(row["geography"]),
"therapy_area": normalize_therapy(row["therapy_area"]),
"max_sources": 5,
"min_evidence_score": 0.0,
}
response = client.post("/memory/search", json=payload)
response.raise_for_status()
body = response.json()
expected_sources = set(filter(None, row["required_sources"].split(";")))
retrieved_sources = {citation["source_id"] for citation in body["citations"]}
requires_label = label_required(row["evaluation_tags"], row["notes_for_eval"])
label_present = any(source.startswith("LBL-") for source in retrieved_sources)
result = EvalRowResult(
row_id=row["id"],
audience=row["audience"],
decision=body["decision"],
expected_sources=expected_sources,
retrieved_sources=retrieved_sources,
label_required=requires_label,
label_present=label_present,
audience_match=audience_match(row["audience"], body["explanations"]),
)
results.append(result)
by_audience_recall[result.audience].append(result.source_recall)
by_audience_precision[result.audience].append(result.citation_precision)
decisions[result.decision] += 1
if requires_label and not label_present:
missed_anchor_rows.append(result.row_id)
total = len(results) or 1
summary = {
"generated_at": datetime.now(UTC).isoformat(),
"dataset": "golden_medical_qa.csv",
"rows_evaluated": len(results),
"decision_counts": dict(decisions),
"overall": {
"source_recall_at_k": round(sum(item.source_recall for item in results) / total, 4),
"citation_precision": round(sum(item.citation_precision for item in results) / total, 4),
"audience_alignment_rate": round(sum(1 for item in results if item.audience_match) / total, 4),
"label_requirement_pass_rate": round(
sum(1 for item in results if (not item.label_required) or item.label_present) / total,
4,
),
},
"by_audience": {
audience: {
"source_recall_at_k": round(sum(values) / len(values), 4),
"citation_precision": round(sum(by_audience_precision[audience]) / len(by_audience_precision[audience]), 4),
}
for audience, values in by_audience_recall.items()
},
"risk_flags": {
"missed_label_anchor_rows": missed_anchor_rows[:50],
},
"sample_failures": [
{
"id": item.row_id,
"decision": item.decision,
"expected_sources": sorted(item.expected_sources),
"retrieved_sources": sorted(item.retrieved_sources),
"source_recall": round(item.source_recall, 4),
"citation_precision": round(item.citation_precision, 4),
}
for item in results
if item.source_recall < 0.5 or (item.label_required and not item.label_present)
][:25],
}
return summary
def write_markdown(summary: dict, path: Path) -> None:
overall = summary["overall"]
lines = [
"# Golden Memory Eval Summary",
"",
f"- Generated at: `{summary['generated_at']}`",
f"- Dataset: `{summary['dataset']}`",
f"- Rows evaluated: `{summary['rows_evaluated']}`",
"",
"## Overall",
"",
f"- Source recall@k: `{overall['source_recall_at_k']}`",
f"- Citation precision: `{overall['citation_precision']}`",
f"- Audience alignment rate: `{overall['audience_alignment_rate']}`",
f"- Label requirement pass rate: `{overall['label_requirement_pass_rate']}`",
"",
"## Decision Counts",
"",
]
for key, value in summary["decision_counts"].items():
lines.append(f"- `{key}`: `{value}`")
lines.extend(["", "## By Audience", ""])
for audience, metrics in summary["by_audience"].items():
lines.append(f"- `{audience}` recall@k: `{metrics['source_recall_at_k']}`, precision: `{metrics['citation_precision']}`")
lines.extend(["", "## Risk Flags", ""])
lines.append(f"- Missed label anchor rows: `{len(summary['risk_flags']['missed_label_anchor_rows'])}`")
if summary["sample_failures"]:
lines.extend(["", "## Sample Failures", ""])
for failure in summary["sample_failures"][:10]:
lines.append(
f"- `{failure['id']}` decision=`{failure['decision']}` recall=`{failure['source_recall']}` precision=`{failure['citation_precision']}`"
)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
def load_rows(path: Path, limit: int) -> list[dict]:
rows: list[dict] = []
with path.open(newline="", encoding="utf-8") as handle:
reader = csv.DictReader(handle)
for index, row in enumerate(reader, start=1):
rows.append(row)
if limit and index >= limit:
break
return rows
def main() -> None:
args = parse_args()
rows = load_rows(Path(args.input), args.limit)
summary = evaluate_rows(rows)
json_out = Path(args.json_out)
md_out = Path(args.md_out)
json_out.parent.mkdir(parents=True, exist_ok=True)
json_out.write_text(json.dumps(summary, indent=2), encoding="utf-8")
write_markdown(summary, md_out)
print(f"Wrote JSON summary to {json_out}")
print(f"Wrote Markdown summary to {md_out}")
if __name__ == "__main__":
main()