kcc-agri / pipelines /eval_chatbot.py
hritikm15's picture
Day 9 β€” v4 merge deploy: kcc_core + advisors + Proof tab + pest heatmap
49818d2 verified
#!/usr/bin/env python3
"""Quality eval harness β€” runs the chat pipeline on a fixed test set and
scores answers on three rubrics that don't need a human:
1. Citation rate β€” answers that cite [1][2] when problem_type ∈ REQUIRES_CITATION
2. Banned-chemical rate β€” answers that mention any banned chemical (lower is better)
3. Latency p50/p95 β€” retrieval + generation
4. Top-1 score β€” average rerank score of #1 doc
Run BEFORE and AFTER any retrieval/LLM change to prove regressions.
Usage:
python -m pipelines.eval_chatbot --test eval/test_queries.json --out eval/results.json
"""
from __future__ import annotations
import argparse
import json
import statistics
import time
from pathlib import Path
from kcc_core import classify, citation_guard, config, prompt
from kcc_core import llm as llmmod
from kcc_core import retrieval as retr
from kcc_core.prompt import BANNED_CHEMICALS, REQUIRES_CITATION
def main():
p = argparse.ArgumentParser()
p.add_argument("--test", default="eval/test_queries.json")
p.add_argument("--out", default="eval/results.json")
args = p.parse_args()
tests = json.load(open(args.test, "r", encoding="utf-8"))
print(f"[eval] loaded {len(tests)} test queries")
retriever = retr.get_retriever()
golden = retr.get_golden_retriever()
rows = []
for t in tests:
q = t["query"]
crop = classify.detect_crop(q)
ptype = classify.classify_problem(q)
normalized = classify.normalize_query(q)
t0 = time.perf_counter()
docs = retr.multi_step_retrieve(retriever, golden, q, normalized,
crop, ptype, top_k=5,
state=t.get("state", ""),
district=t.get("district", ""),
run_hyde=True)
ret_ms = (time.perf_counter() - t0) * 1000
ctx = retr.KCCRetriever.format_context(docs)
prm = prompt.build_prompt(q, ctx, problem_type=ptype,
language=classify.detect_language(q),
detected_crop=crop)
t1 = time.perf_counter()
ans = llmmod.generate(prm, max_tokens=400, temperature=0.1)
gen_ms = (time.perf_counter() - t1) * 1000
ans, _w = citation_guard.review(ans, problem_type=ptype)
top_score = docs[0].rerank_score if docs and docs[0].rerank_score else (
docs[0].score if docs else 0.0)
cited = bool(citation_guard.has_citations(ans))
# Negation-aware: "do NOT use Endosulfan" is correct, not a leak.
banned = citation_guard.banned_chemical_check(ans or "")
rows.append({
"query": q,
"crop": crop, "problem_type": ptype,
"top_score": round(top_score, 3),
"ret_ms": round(ret_ms, 1),
"gen_ms": round(gen_ms, 1),
"cited": cited,
"needs_cite": ptype in REQUIRES_CITATION,
"banned": banned,
"answer": ans,
})
print(f" {len(rows)}/{len(tests)} {q[:60]}... "
f"top={top_score:.2f} ret={ret_ms:.0f}ms cited={cited}")
# Aggregate
needs_cite = [r for r in rows if r["needs_cite"]]
cite_rate = (sum(1 for r in needs_cite if r["cited"]) / len(needs_cite)) \
if needs_cite else None
banned_rate = sum(1 for r in rows if r["banned"]) / len(rows)
ret_ms = [r["ret_ms"] for r in rows]
gen_ms = [r["gen_ms"] for r in rows]
top_scores = [r["top_score"] for r in rows]
summary = {
"n": len(rows),
"citation_rate": round(cite_rate, 3) if cite_rate is not None else None,
"banned_rate": round(banned_rate, 3),
"ret_ms_p50": round(statistics.median(ret_ms), 1),
"ret_ms_p95": round(statistics.quantiles(ret_ms, n=20)[-1], 1)
if len(ret_ms) >= 20 else None,
"gen_ms_p50": round(statistics.median(gen_ms), 1),
"top_score_mean": round(statistics.mean(top_scores), 3),
}
out = {"summary": summary, "rows": rows}
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
with open(args.out, "w", encoding="utf-8") as f:
json.dump(out, f, ensure_ascii=False, indent=2)
print(f"\n[eval] summary: {json.dumps(summary, indent=2)}")
print(f"[eval] wrote {args.out}")
if __name__ == "__main__":
raise SystemExit(main())