kcc-agri / kcc_core /retrieval.py
hritikm15's picture
Day 9 β€” v4 merge deploy: kcc_core + advisors + Proof tab + pest heatmap
49818d2 verified
"""Retrieval engine: FAISS dense + BM25 sparse + cross-encoder rerank
+ GoldenRetriever (TF-IDF) + HyDE for low-confidence queries.
Wired together in `multi_step_retrieve`.
"""
from __future__ import annotations
import re
import sqlite3
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
import faiss
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from sentence_transformers import SentenceTransformer
from . import config
@dataclass
class RetrievedDoc:
rank: int
score: float
rerank_score: float = 0.0
query: str = ""
answer: str = ""
state: str = ""
crop: str = ""
year: int = 0
source: str = "kcc" # "kcc" | "golden" | "icar"
@property
def score_pct(self) -> str:
return f"{self.score * 100:.1f}%"
# Generic boilerplate to drop at FAISS layer; matches the KCC corpus shape.
_GENERIC_SKIP = (
"contact your kvk", "contact nearest kvk", "visit kvk",
"kvk se sampark karen", "nazdiki kvk", "apne kshetra ke",
"krishi vigyan kendra se", "district agriculture office",
"krishi vibhag se sampark", "contact agriculture department",
"not available in our system", "uplabdh nahi hai",
"jaankari nahi hai", "information not available",
"please contact", "please visit", "kindly contact",
"apne nazdiki bank", "bank se sampark karen",
)
class KCCRetriever:
"""FAISS + BM25 hybrid + cross-encoder reranker."""
def __init__(self,
index_path: Path = config.FAISS_INDEX_FILE,
metadata_path: Path = config.METADATA_FILE,
model_name: str = config.EMBEDDING_MODEL,
top_k: int = config.TOP_K,
nprobe: int = 64,
device: Optional[str] = None):
self._top_k = top_k
if not index_path.exists():
raise FileNotFoundError(f"FAISS index not found: {index_path}")
if not metadata_path.exists():
raise FileNotFoundError(f"Metadata not found: {metadata_path}")
t0 = time.perf_counter()
self._index = faiss.read_index(str(index_path))
self._index.nprobe = nprobe
print(f"[retr] FAISS loaded: {self._index.ntotal:,} vectors "
f"({time.perf_counter()-t0:.1f}s)", flush=True)
t0 = time.perf_counter()
self._meta = pq.read_table(
metadata_path,
columns=["QueryText", "KccAns", "StateName", "Crop", "year"],
memory_map=True,
)
self._col_query = self._meta.column("QueryText")
self._col_ans = self._meta.column("KccAns")
self._col_state = self._meta.column("StateName")
self._col_crop = self._meta.column("Crop")
self._col_year = self._meta.column("year")
print(f"[retr] Metadata loaded: {self._meta.num_rows:,} rows "
f"({time.perf_counter()-t0:.1f}s)", flush=True)
assert self._index.ntotal == self._meta.num_rows, (
f"MISMATCH: FAISS {self._index.ntotal} vs meta {self._meta.num_rows}. "
"Re-run pipelines/reindex_*.py.")
import torch
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
t0 = time.perf_counter()
self._model = SentenceTransformer(model_name, device=device)
self._model.encode(["warmup"], convert_to_numpy=True, normalize_embeddings=True)
print(f"[retr] Embedder ready on {device} ({time.perf_counter()-t0:.1f}s)",
flush=True)
self._reranker = None
if config.RERANKER_MODEL:
try:
from sentence_transformers import CrossEncoder
self._reranker = CrossEncoder(config.RERANKER_MODEL, device=device)
print(f"[retr] Reranker ready: {config.RERANKER_MODEL}", flush=True)
except Exception as e:
print(f"[retr] Reranker disabled: {e}", flush=True)
self._bm25_conn: Optional[sqlite3.Connection] = None
if config.BM25_INDEX_FILE.exists():
try:
self._bm25_conn = sqlite3.connect(
f"file:{config.BM25_INDEX_FILE}?mode=ro", uri=True,
check_same_thread=False)
self._bm25_conn.execute("PRAGMA cache_size = -32768")
print("[retr] BM25 hybrid enabled", flush=True)
except Exception as e:
print(f"[retr] BM25 disabled: {e}", flush=True)
@property
def index_size(self) -> int:
return self._index.ntotal
def encode_query(self, query: str) -> np.ndarray:
return self._model.encode(
[query.strip()],
convert_to_numpy=True,
normalize_embeddings=True,
show_progress_bar=False,
).astype(np.float32)
def search(self,
query: str,
top_k: Optional[int] = None,
crop_filter: Optional[str] = None,
state: str = "",
district: str = "",
min_score: float = 0.0) -> List[RetrievedDoc]:
k = top_k or self._top_k
q_vec = self.encode_query(query)
rerank_n = config.RERANKER_TOP_N if self._reranker else 0
search_k = max(k * 3, rerank_n + k)
if crop_filter:
search_k = max(search_k, k * 20)
scores, indices = self._index.search(q_vec, search_k)
crop_lower = crop_filter.lower() if crop_filter else None
state_lower = state.strip().lower() if state else ""
district_lower = district.strip().lower() if district else ""
docs: List[RetrievedDoc] = []
fallback: List[RetrievedDoc] = []
seen: set = set()
for s, i in zip(scores[0], indices[0]):
if i < 0 or float(s) < min_score:
continue
if len(docs) >= k and len(fallback) >= k:
break
ans = str(self._col_ans[i].as_py() or "").strip()
if len(ans) < config.MIN_ANSWER_CHARS:
continue
ans_lower = ans.lower()
if any(p in ans_lower for p in _GENERIC_SKIP):
continue
if ans in seen:
continue
seen.add(ans)
crop_v = str(self._col_crop[i].as_py() or "").strip()
state_v = str(self._col_state[i].as_py() or "").strip()
year_v = int(self._col_year[i].as_py() or 0)
score = float(s)
# NOTE: boosts only affect *which docs reach the reranker*; reranker re-sorts.
if year_v >= 2023: score += 0.15
elif year_v >= 2021: score += 0.08
elif year_v >= 2018: score += 0.03
if state_lower and state_v.lower() == state_lower:
score += 0.12
if district_lower and district_lower in str(self._col_query[i].as_py() or "").lower():
score += 0.25
doc = RetrievedDoc(
rank=0, score=score, query=str(self._col_query[i].as_py() or "").strip(),
answer=ans, state=state_v, crop=crop_v, year=year_v, source="kcc")
if crop_lower:
if crop_lower in crop_v.lower() and len(docs) < k:
docs.append(doc)
elif crop_lower not in crop_v.lower() and len(fallback) < k:
fallback.append(doc)
else:
if len(docs) < k:
docs.append(doc)
if crop_filter and len(docs) < k:
docs.extend(fallback[:k - len(docs)])
# Hybrid merge with BM25
if self._bm25_conn is not None:
bm25 = self.bm25_search(query, top_n=max(k * 2, 20), crop_filter=crop_filter)
existing = {d.answer for d in docs}
for bd in bm25:
if bd.answer not in existing and len(docs) < k * 2:
existing.add(bd.answer)
docs.append(bd)
# Cross-encoder rerank
if self._reranker and len(docs) > 1:
try:
pairs = [(query, d.answer) for d in docs]
scores = self._reranker.predict(pairs)
for d, rs in zip(docs, scores):
d.rerank_score = float(rs)
docs.sort(key=lambda d: d.rerank_score, reverse=True)
docs = docs[:k]
except Exception:
pass # cosine order
for i, d in enumerate(docs):
d.rank = i + 1
return docs
@staticmethod
def _sanitize_fts5(query: str) -> str:
cleaned = re.sub(r'["\(\)\*\:\^~]+', ' ', query)
return ' '.join(cleaned.split())
def bm25_search(self, query: str, top_n: int = 20,
crop_filter: Optional[str] = None) -> List[RetrievedDoc]:
if self._bm25_conn is None:
return []
safe_q = self._sanitize_fts5(query)
if not safe_q:
return []
crop_clause = ""
params: list = [safe_q]
if crop_filter:
crop_clause = " AND crop = ?"
params.append(crop_filter)
params.append(top_n)
sql = f"""SELECT query_text, answer, state, crop, year, rank
FROM kcc_bm25 WHERE kcc_bm25 MATCH ?{crop_clause}
ORDER BY rank LIMIT ?"""
try:
rows = self._bm25_conn.cursor().execute(sql, params).fetchall()
except sqlite3.OperationalError:
return []
if not rows:
return [] if crop_filter is None else self.bm25_search(query, top_n, None)
ranks = [r[5] for r in rows]
best, worst = min(ranks), max(ranks)
span = (worst - best) or 1.0
out: List[RetrievedDoc] = []
seen: set = set()
for i, (qt, ans, state, crop, year, rank) in enumerate(rows):
ans = str(ans or "").strip()
if len(ans) < config.MIN_ANSWER_CHARS or ans in seen:
continue
seen.add(ans)
score = 0.05 + 0.95 * (1.0 - (rank - best) / span)
out.append(RetrievedDoc(
rank=i + 1, score=round(score, 4),
query=str(qt or "").strip(), answer=ans,
state=str(state or ""), crop=str(crop or ""),
year=int(year or 0), source="kcc"))
return out
@staticmethod
def format_context(docs: List[RetrievedDoc], max_answer_chars: int = 600) -> str:
if not docs:
return "No relevant past Q&A found."
lines = ["RELEVANT PAST FARMER Q&A FROM KCC DATABASE:"]
for d in docs:
ans = d.answer[:max_answer_chars] + ("…" if len(d.answer) > max_answer_chars else "")
tag = f"[{d.rank}{('Β·G' if d.source == 'golden' else 'Β·I' if d.source == 'icar' else '')}]"
lines.append(
f"\n{tag} Similarity: {d.score_pct} | State: {d.state} | Crop: {d.crop} | Year: {d.year}"
f"\n Q: {d.query}"
f"\n A: {ans}"
)
return "\n".join(lines)
# ── Golden retriever (verified Q&A pairs, TF-IDF) ─────────────────────────────
class KCCGoldenRetriever:
"""Tiny TF-IDF retriever over hand-curated KCC pairs.
Hits BEFORE FAISS β€” if a >= 0.85 cosine match exists, the verified answer
is prepended to the context with a [GOLDEN] tag.
"""
SIMILARITY_THRESHOLD = 0.85
def __init__(self, golden_path: Path = config.GOLDEN_SET_FILE):
import json
self._entries: list = []
self._vectorizer = None
self._tfidf = None
if not golden_path.exists():
print(f"[golden] not found at {golden_path}", flush=True)
return
with open(golden_path, "r", encoding="utf-8") as f:
data = json.load(f)
self._entries = data.get("entries", [])
if not self._entries:
return
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
self._cos = cosine_similarity
queries = [e["query"] for e in self._entries]
self._vectorizer = TfidfVectorizer(ngram_range=(1, 2),
max_features=20000,
sublinear_tf=True, min_df=1)
self._tfidf = self._vectorizer.fit_transform(queries)
print(f"[golden] {len(self._entries)} entries indexed", flush=True)
except ImportError:
self._vectorizer = None
@property
def size(self) -> int:
return len(self._entries)
def lookup(self, query: str, top_k: int = 3) -> List[RetrievedDoc]:
if self._vectorizer is None or self._tfidf is None or not self._entries:
return []
q_vec = self._vectorizer.transform([query])
sims = self._cos(q_vec, self._tfidf).flatten()
top_idx = sims.argsort()[::-1][:top_k]
out: List[RetrievedDoc] = []
for i, idx in enumerate(top_idx):
sc = float(sims[idx])
if sc < self.SIMILARITY_THRESHOLD:
break
e = self._entries[idx]
out.append(RetrievedDoc(
rank=i + 1, score=sc, query=e.get("query", ""),
answer=e.get("answer", ""), state=e.get("state", ""),
crop=e.get("crop", ""), year=int(e.get("year", 0) or 0),
source="golden"))
return out
# ── Multi-step orchestrator ───────────────────────────────────────────────────
def multi_step_retrieve(retriever: KCCRetriever,
golden: Optional[KCCGoldenRetriever],
query: str,
normalized_query: str,
crop: Optional[str],
problem_type: str,
*,
top_k: int = 5,
state: str = "",
district: str = "",
run_hyde: bool = True) -> List[RetrievedDoc]:
"""Hybrid: golden lookup first, then FAISS+BM25+rerank, then HyDE if low-confidence.
HyDE = generate hypothetical answer with the LLM, embed it, re-search.
Only fires when top score < HYDE_TRIGGER_THRESHOLD and run_hyde=True.
"""
docs: List[RetrievedDoc] = []
if golden is not None and golden.size > 0:
gold = golden.lookup(query, top_k=2)
# If we have a high-confidence golden hit, prepend it.
docs.extend(gold)
rewritten = normalized_query
if crop and crop.lower() not in rewritten.lower():
rewritten = f"{rewritten} {crop}"
if problem_type and problem_type != "general":
rewritten = f"{rewritten} {problem_type}"
base = retriever.search(rewritten, top_k=top_k, crop_filter=crop,
state=state, district=district)
seen = {d.answer for d in docs}
for d in base:
if d.answer not in seen:
d.rank = len(docs) + 1
docs.append(d)
seen.add(d.answer)
top_score = docs[0].rerank_score if docs and docs[0].rerank_score else (
docs[0].score if docs else 0.0)
needs_hyde = run_hyde and top_score < config.HYDE_TRIGGER_THRESHOLD
if needs_hyde:
from .hyde import expand_query
hyde_query = expand_query(query, crop=crop, problem_type=problem_type)
if hyde_query and hyde_query.strip() != query.strip():
extra = retriever.search(hyde_query, top_k=top_k, crop_filter=crop,
state=state, district=district)
for d in extra:
if d.answer not in seen and len(docs) < top_k * 2:
d.rank = len(docs) + 1
docs.append(d)
seen.add(d.answer)
return docs[:top_k]
# ── Module singletons ─────────────────────────────────────────────────────────
_retriever: Optional[KCCRetriever] = None
_golden: Optional[KCCGoldenRetriever] = None
def get_retriever(**kwargs) -> KCCRetriever:
global _retriever
if _retriever is None:
_retriever = KCCRetriever(**kwargs)
return _retriever
def get_golden_retriever() -> KCCGoldenRetriever:
global _golden
if _golden is None:
_golden = KCCGoldenRetriever()
return _golden