"""Retrieval engine: FAISS dense + BM25 sparse + cross-encoder rerank + GoldenRetriever (TF-IDF) + HyDE for low-confidence queries. Wired together in `multi_step_retrieve`. """ from __future__ import annotations import re import sqlite3 import time from dataclasses import dataclass from pathlib import Path from typing import List, Optional import faiss import numpy as np import pyarrow as pa import pyarrow.parquet as pq from sentence_transformers import SentenceTransformer from . import config @dataclass class RetrievedDoc: rank: int score: float rerank_score: float = 0.0 query: str = "" answer: str = "" state: str = "" crop: str = "" year: int = 0 source: str = "kcc" # "kcc" | "golden" | "icar" @property def score_pct(self) -> str: return f"{self.score * 100:.1f}%" # Generic boilerplate to drop at FAISS layer; matches the KCC corpus shape. _GENERIC_SKIP = ( "contact your kvk", "contact nearest kvk", "visit kvk", "kvk se sampark karen", "nazdiki kvk", "apne kshetra ke", "krishi vigyan kendra se", "district agriculture office", "krishi vibhag se sampark", "contact agriculture department", "not available in our system", "uplabdh nahi hai", "jaankari nahi hai", "information not available", "please contact", "please visit", "kindly contact", "apne nazdiki bank", "bank se sampark karen", ) class KCCRetriever: """FAISS + BM25 hybrid + cross-encoder reranker.""" def __init__(self, index_path: Path = config.FAISS_INDEX_FILE, metadata_path: Path = config.METADATA_FILE, model_name: str = config.EMBEDDING_MODEL, top_k: int = config.TOP_K, nprobe: int = 64, device: Optional[str] = None): self._top_k = top_k if not index_path.exists(): raise FileNotFoundError(f"FAISS index not found: {index_path}") if not metadata_path.exists(): raise FileNotFoundError(f"Metadata not found: {metadata_path}") t0 = time.perf_counter() self._index = faiss.read_index(str(index_path)) self._index.nprobe = nprobe print(f"[retr] FAISS loaded: {self._index.ntotal:,} vectors " f"({time.perf_counter()-t0:.1f}s)", flush=True) t0 = time.perf_counter() self._meta = pq.read_table( metadata_path, columns=["QueryText", "KccAns", "StateName", "Crop", "year"], memory_map=True, ) self._col_query = self._meta.column("QueryText") self._col_ans = self._meta.column("KccAns") self._col_state = self._meta.column("StateName") self._col_crop = self._meta.column("Crop") self._col_year = self._meta.column("year") print(f"[retr] Metadata loaded: {self._meta.num_rows:,} rows " f"({time.perf_counter()-t0:.1f}s)", flush=True) assert self._index.ntotal == self._meta.num_rows, ( f"MISMATCH: FAISS {self._index.ntotal} vs meta {self._meta.num_rows}. " "Re-run pipelines/reindex_*.py.") import torch if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" t0 = time.perf_counter() self._model = SentenceTransformer(model_name, device=device) self._model.encode(["warmup"], convert_to_numpy=True, normalize_embeddings=True) print(f"[retr] Embedder ready on {device} ({time.perf_counter()-t0:.1f}s)", flush=True) self._reranker = None if config.RERANKER_MODEL: try: from sentence_transformers import CrossEncoder self._reranker = CrossEncoder(config.RERANKER_MODEL, device=device) print(f"[retr] Reranker ready: {config.RERANKER_MODEL}", flush=True) except Exception as e: print(f"[retr] Reranker disabled: {e}", flush=True) self._bm25_conn: Optional[sqlite3.Connection] = None if config.BM25_INDEX_FILE.exists(): try: self._bm25_conn = sqlite3.connect( f"file:{config.BM25_INDEX_FILE}?mode=ro", uri=True, check_same_thread=False) self._bm25_conn.execute("PRAGMA cache_size = -32768") print("[retr] BM25 hybrid enabled", flush=True) except Exception as e: print(f"[retr] BM25 disabled: {e}", flush=True) @property def index_size(self) -> int: return self._index.ntotal def encode_query(self, query: str) -> np.ndarray: return self._model.encode( [query.strip()], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False, ).astype(np.float32) def search(self, query: str, top_k: Optional[int] = None, crop_filter: Optional[str] = None, state: str = "", district: str = "", min_score: float = 0.0) -> List[RetrievedDoc]: k = top_k or self._top_k q_vec = self.encode_query(query) rerank_n = config.RERANKER_TOP_N if self._reranker else 0 search_k = max(k * 3, rerank_n + k) if crop_filter: search_k = max(search_k, k * 20) scores, indices = self._index.search(q_vec, search_k) crop_lower = crop_filter.lower() if crop_filter else None state_lower = state.strip().lower() if state else "" district_lower = district.strip().lower() if district else "" docs: List[RetrievedDoc] = [] fallback: List[RetrievedDoc] = [] seen: set = set() for s, i in zip(scores[0], indices[0]): if i < 0 or float(s) < min_score: continue if len(docs) >= k and len(fallback) >= k: break ans = str(self._col_ans[i].as_py() or "").strip() if len(ans) < config.MIN_ANSWER_CHARS: continue ans_lower = ans.lower() if any(p in ans_lower for p in _GENERIC_SKIP): continue if ans in seen: continue seen.add(ans) crop_v = str(self._col_crop[i].as_py() or "").strip() state_v = str(self._col_state[i].as_py() or "").strip() year_v = int(self._col_year[i].as_py() or 0) score = float(s) # NOTE: boosts only affect *which docs reach the reranker*; reranker re-sorts. if year_v >= 2023: score += 0.15 elif year_v >= 2021: score += 0.08 elif year_v >= 2018: score += 0.03 if state_lower and state_v.lower() == state_lower: score += 0.12 if district_lower and district_lower in str(self._col_query[i].as_py() or "").lower(): score += 0.25 doc = RetrievedDoc( rank=0, score=score, query=str(self._col_query[i].as_py() or "").strip(), answer=ans, state=state_v, crop=crop_v, year=year_v, source="kcc") if crop_lower: if crop_lower in crop_v.lower() and len(docs) < k: docs.append(doc) elif crop_lower not in crop_v.lower() and len(fallback) < k: fallback.append(doc) else: if len(docs) < k: docs.append(doc) if crop_filter and len(docs) < k: docs.extend(fallback[:k - len(docs)]) # Hybrid merge with BM25 if self._bm25_conn is not None: bm25 = self.bm25_search(query, top_n=max(k * 2, 20), crop_filter=crop_filter) existing = {d.answer for d in docs} for bd in bm25: if bd.answer not in existing and len(docs) < k * 2: existing.add(bd.answer) docs.append(bd) # Cross-encoder rerank if self._reranker and len(docs) > 1: try: pairs = [(query, d.answer) for d in docs] scores = self._reranker.predict(pairs) for d, rs in zip(docs, scores): d.rerank_score = float(rs) docs.sort(key=lambda d: d.rerank_score, reverse=True) docs = docs[:k] except Exception: pass # cosine order for i, d in enumerate(docs): d.rank = i + 1 return docs @staticmethod def _sanitize_fts5(query: str) -> str: cleaned = re.sub(r'["\(\)\*\:\^~]+', ' ', query) return ' '.join(cleaned.split()) def bm25_search(self, query: str, top_n: int = 20, crop_filter: Optional[str] = None) -> List[RetrievedDoc]: if self._bm25_conn is None: return [] safe_q = self._sanitize_fts5(query) if not safe_q: return [] crop_clause = "" params: list = [safe_q] if crop_filter: crop_clause = " AND crop = ?" params.append(crop_filter) params.append(top_n) sql = f"""SELECT query_text, answer, state, crop, year, rank FROM kcc_bm25 WHERE kcc_bm25 MATCH ?{crop_clause} ORDER BY rank LIMIT ?""" try: rows = self._bm25_conn.cursor().execute(sql, params).fetchall() except sqlite3.OperationalError: return [] if not rows: return [] if crop_filter is None else self.bm25_search(query, top_n, None) ranks = [r[5] for r in rows] best, worst = min(ranks), max(ranks) span = (worst - best) or 1.0 out: List[RetrievedDoc] = [] seen: set = set() for i, (qt, ans, state, crop, year, rank) in enumerate(rows): ans = str(ans or "").strip() if len(ans) < config.MIN_ANSWER_CHARS or ans in seen: continue seen.add(ans) score = 0.05 + 0.95 * (1.0 - (rank - best) / span) out.append(RetrievedDoc( rank=i + 1, score=round(score, 4), query=str(qt or "").strip(), answer=ans, state=str(state or ""), crop=str(crop or ""), year=int(year or 0), source="kcc")) return out @staticmethod def format_context(docs: List[RetrievedDoc], max_answer_chars: int = 600) -> str: if not docs: return "No relevant past Q&A found." lines = ["RELEVANT PAST FARMER Q&A FROM KCC DATABASE:"] for d in docs: ans = d.answer[:max_answer_chars] + ("…" if len(d.answer) > max_answer_chars else "") tag = f"[{d.rank}{('·G' if d.source == 'golden' else '·I' if d.source == 'icar' else '')}]" lines.append( f"\n{tag} Similarity: {d.score_pct} | State: {d.state} | Crop: {d.crop} | Year: {d.year}" f"\n Q: {d.query}" f"\n A: {ans}" ) return "\n".join(lines) # ── Golden retriever (verified Q&A pairs, TF-IDF) ───────────────────────────── class KCCGoldenRetriever: """Tiny TF-IDF retriever over hand-curated KCC pairs. Hits BEFORE FAISS — if a >= 0.85 cosine match exists, the verified answer is prepended to the context with a [GOLDEN] tag. """ SIMILARITY_THRESHOLD = 0.85 def __init__(self, golden_path: Path = config.GOLDEN_SET_FILE): import json self._entries: list = [] self._vectorizer = None self._tfidf = None if not golden_path.exists(): print(f"[golden] not found at {golden_path}", flush=True) return with open(golden_path, "r", encoding="utf-8") as f: data = json.load(f) self._entries = data.get("entries", []) if not self._entries: return try: from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity self._cos = cosine_similarity queries = [e["query"] for e in self._entries] self._vectorizer = TfidfVectorizer(ngram_range=(1, 2), max_features=20000, sublinear_tf=True, min_df=1) self._tfidf = self._vectorizer.fit_transform(queries) print(f"[golden] {len(self._entries)} entries indexed", flush=True) except ImportError: self._vectorizer = None @property def size(self) -> int: return len(self._entries) def lookup(self, query: str, top_k: int = 3) -> List[RetrievedDoc]: if self._vectorizer is None or self._tfidf is None or not self._entries: return [] q_vec = self._vectorizer.transform([query]) sims = self._cos(q_vec, self._tfidf).flatten() top_idx = sims.argsort()[::-1][:top_k] out: List[RetrievedDoc] = [] for i, idx in enumerate(top_idx): sc = float(sims[idx]) if sc < self.SIMILARITY_THRESHOLD: break e = self._entries[idx] out.append(RetrievedDoc( rank=i + 1, score=sc, query=e.get("query", ""), answer=e.get("answer", ""), state=e.get("state", ""), crop=e.get("crop", ""), year=int(e.get("year", 0) or 0), source="golden")) return out # ── Multi-step orchestrator ─────────────────────────────────────────────────── def multi_step_retrieve(retriever: KCCRetriever, golden: Optional[KCCGoldenRetriever], query: str, normalized_query: str, crop: Optional[str], problem_type: str, *, top_k: int = 5, state: str = "", district: str = "", run_hyde: bool = True) -> List[RetrievedDoc]: """Hybrid: golden lookup first, then FAISS+BM25+rerank, then HyDE if low-confidence. HyDE = generate hypothetical answer with the LLM, embed it, re-search. Only fires when top score < HYDE_TRIGGER_THRESHOLD and run_hyde=True. """ docs: List[RetrievedDoc] = [] if golden is not None and golden.size > 0: gold = golden.lookup(query, top_k=2) # If we have a high-confidence golden hit, prepend it. docs.extend(gold) rewritten = normalized_query if crop and crop.lower() not in rewritten.lower(): rewritten = f"{rewritten} {crop}" if problem_type and problem_type != "general": rewritten = f"{rewritten} {problem_type}" base = retriever.search(rewritten, top_k=top_k, crop_filter=crop, state=state, district=district) seen = {d.answer for d in docs} for d in base: if d.answer not in seen: d.rank = len(docs) + 1 docs.append(d) seen.add(d.answer) top_score = docs[0].rerank_score if docs and docs[0].rerank_score else ( docs[0].score if docs else 0.0) needs_hyde = run_hyde and top_score < config.HYDE_TRIGGER_THRESHOLD if needs_hyde: from .hyde import expand_query hyde_query = expand_query(query, crop=crop, problem_type=problem_type) if hyde_query and hyde_query.strip() != query.strip(): extra = retriever.search(hyde_query, top_k=top_k, crop_filter=crop, state=state, district=district) for d in extra: if d.answer not in seen and len(docs) < top_k * 2: d.rank = len(docs) + 1 docs.append(d) seen.add(d.answer) return docs[:top_k] # ── Module singletons ───────────────────────────────────────────────────────── _retriever: Optional[KCCRetriever] = None _golden: Optional[KCCGoldenRetriever] = None def get_retriever(**kwargs) -> KCCRetriever: global _retriever if _retriever is None: _retriever = KCCRetriever(**kwargs) return _retriever def get_golden_retriever() -> KCCGoldenRetriever: global _golden if _golden is None: _golden = KCCGoldenRetriever() return _golden