pharmaspine-backend / scripts /load_eval_pack_to_db.py
ashish1265659565's picture
Upload folder using huggingface_hub
08fd094 verified
Raw
History Blame Contribute Delete
7.07 kB
"""Load all 10,000 evaluation cases from output/*.csv into Postgres.
This is YOUR eval pack — golden, adversarial, governance, retrieval, SME feedback.
These rows are benchmark *questions*, not documents to chunk. They belong in
``eval_cases`` so runners and dashboards can query them at full scale.
Usage:
python3 scripts/load_eval_pack_to_db.py
python3 scripts/load_eval_pack_to_db.py --verify-only
"""
from __future__ import annotations
import argparse
import csv
import json
import os
import sys
from datetime import UTC, datetime
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent.parent
OUTPUT_DIR = REPO_ROOT / "output"
_env_file = REPO_ROOT / ".env"
if _env_file.exists():
for raw_line in _env_file.read_text().splitlines():
line = raw_line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
os.environ.setdefault(key.strip(), value.strip())
import psycopg # noqa: E402
DSN = os.getenv(
"AKS_DATABASE_URL",
"postgresql+psycopg://mobcoderid-296@localhost/ai_knowledge_spine",
).replace("postgresql+psycopg://", "postgresql://", 1)
DATASETS = [
("golden", "golden_medical_qa.csv"),
("adversarial", "adversarial_medical_qa.csv"),
("governance", "governance_policy_cases.csv"),
("retrieval", "retrieval_stress_cases.csv"),
("sme", "smr_sme_feedback_examples.csv"),
]
CREATE_TABLE_SQL = (REPO_ROOT / "schemas" / "eval_cases.sql").read_text()
def _normalize_row(dataset: str, row: dict) -> tuple[str, str, str | None, str | None, str | None, dict]:
case_id = row.get("id", "").strip()
if not case_id:
raise ValueError(f"Row missing id in {dataset}: {row}")
therapy = row.get("therapy_area") or row.get("user_geography")
geography = row.get("geography") or row.get("user_geography")
audience = row.get("audience")
return case_id, dataset, therapy, geography, audience, dict(row)
def load(*, replace: bool = True) -> dict:
conn = psycopg.connect(DSN)
summary: dict = {"datasets": {}, "total": 0}
try:
with conn:
with conn.cursor() as cur:
for statement in CREATE_TABLE_SQL.split(";"):
stmt = statement.strip()
if stmt:
cur.execute(stmt)
if replace:
cur.execute("TRUNCATE eval_cases")
now = datetime.now(UTC)
for dataset, filename in DATASETS:
path = OUTPUT_DIR / filename
if not path.exists():
raise FileNotFoundError(f"Missing eval file: {path}")
count = 0
with path.open(newline="", encoding="utf-8") as fh:
for row in csv.DictReader(fh):
case_id, ds, therapy, geography, audience, payload = _normalize_row(
dataset, row
)
cur.execute(
"""
INSERT INTO eval_cases (
case_id, dataset, therapy_area, geography, audience,
payload, loaded_at
) VALUES (%s, %s, %s, %s, %s, %s::json, %s)
ON CONFLICT (case_id) DO UPDATE SET
dataset = EXCLUDED.dataset,
therapy_area = EXCLUDED.therapy_area,
geography = EXCLUDED.geography,
audience = EXCLUDED.audience,
payload = EXCLUDED.payload,
loaded_at = EXCLUDED.loaded_at
""",
(
case_id,
ds,
therapy,
geography,
audience,
json.dumps(payload),
now,
),
)
count += 1
summary["datasets"][dataset] = count
summary["total"] += count
finally:
conn.close()
return summary
def verify() -> None:
conn = psycopg.connect(DSN)
try:
with conn.cursor() as cur:
cur.execute(
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'eval_cases')"
)
if not cur.fetchone()[0]:
print("eval_cases table does not exist. Run: python3 scripts/load_eval_pack_to_db.py")
return
cur.execute("SELECT COUNT(*) FROM eval_cases")
total = cur.fetchone()[0]
print(f"eval_cases total: {total}")
cur.execute(
"SELECT dataset, COUNT(*) FROM eval_cases GROUP BY dataset ORDER BY dataset"
)
for dataset, count in cur.fetchall():
print(f" {dataset}: {count}")
finally:
conn.close()
# Source coverage check
import re
conn = psycopg.connect(DSN)
eval_sources: set[str] = set()
try:
with conn.cursor() as cur:
cur.execute("SELECT payload FROM eval_cases")
for (payload,) in cur.fetchall():
data = payload if isinstance(payload, dict) else json.loads(payload)
for key in ("required_sources", "expected_relevant_sources"):
raw = data.get(key, "")
if raw:
eval_sources.update(s.strip() for s in raw.split(";") if s.strip())
cur.execute("SELECT source_id FROM sources")
db_sources = {r[0] for r in cur.fetchall()}
finally:
conn.close()
missing = sorted(eval_sources - db_sources)
print(f"unique sources referenced in eval_cases: {len(eval_sources)}")
print(f"sources registered in DB: {len(db_sources)}")
if missing:
print(f"MISSING in sources table ({len(missing)}):", ", ".join(missing[:15]))
else:
print("All eval-referenced sources are registered in sources.")
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--verify-only", action="store_true")
args = parser.parse_args()
if args.verify_only:
verify()
return 0
summary = load()
print(json.dumps(summary, indent=2))
print("\n--- verification ---")
verify()
print("\nRun full golden eval (2500 rows, needs gateway/memory up):")
print(" python3 eval/runners/run_golden_memory_eval.py")
print("Or sample:")
print(" python3 eval/runners/run_golden_memory_eval.py --limit 100")
return 0
if __name__ == "__main__":
raise SystemExit(main())