| """ |
| step3_retrieval.py |
| ================== |
| Enterprise-grade retrieval engine for the KCC RAG chatbot. |
| |
| Architecture |
| ------------ |
| 1. Load FAISS index + metadata once at startup (cached in-process) |
| 2. encode_query() β embed user query β float32 unit vector |
| 3. search() β FAISS top-K β fetch matching metadata rows |
| 4. format_context()β assemble retrieved Q&A pairs as LLM context |
| |
| Design decisions |
| ---------------- |
| - Metadata loaded as PyArrow Table (memory-mapped, zero-copy slicing) |
| - FAISS IVF-PQ with nprobe=64 gives ~95% recall@5 |
| - Query embedding normalised β inner-product == cosine similarity |
| - Duplicate-answer dedup (same KccAns text β keep highest-scoring copy) |
| - Language-agnostic: model handles Hindi, Telugu, Kannada, Marathi, English |
| |
| Usage (standalone test) |
| ----------------------- |
| python step3_retrieval.py "What fertilizer for wheat in Uttar Pradesh?" |
| python step3_retrieval.py --top-k 10 "kharif crops pest control" |
| """ |
|
|
| import argparse |
| import re |
| import sqlite3 |
| import sys |
| import time |
| from pathlib import Path |
| from typing import List, Optional |
| from dataclasses import dataclass, field |
|
|
| import faiss |
| import numpy as np |
| import pandas as pd |
| import pyarrow as pa |
| import pyarrow.parquet as pq |
| from sentence_transformers import SentenceTransformer |
|
|
| |
| try: |
| from icar_retriever import get_icar_retriever as _get_icar_retriever |
| _ICAR_RETRIEVER = _get_icar_retriever() |
| _ICAR_AVAILABLE = True |
| except Exception as _e: |
| print(f"[ICAR] Retriever not available: {_e}") |
| _ICAR_RETRIEVER = None |
| _ICAR_AVAILABLE = False |
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
| import config |
|
|
| |
|
|
| @dataclass |
| class RetrievedDoc: |
| """One retrieved document returned by the retrieval engine.""" |
| rank: int |
| score: float |
| rerank_score: float = 0.0 |
| query: str = "" |
| answer: str = "" |
| state: str = "" |
| crop: str = "" |
| year: int = 0 |
|
|
| @property |
| def score_pct(self) -> str: |
| return f"{self.score * 100:.1f}%" |
|
|
|
|
| |
|
|
| class KCCRetriever: |
| """ |
| Stateful retrieval engine. Load once, query many times. |
| |
| Thread-safety: encode_query and search are read-only after __init__, |
| so the same instance can safely be called from multiple threads |
| (e.g., Streamlit's multi-user sessions). |
| """ |
|
|
| def __init__( |
| self, |
| index_path: Path = config.FAISS_INDEX_FILE, |
| metadata_path: Path = config.METADATA_FILE, |
| model_name: str = config.EMBEDDING_MODEL, |
| top_k: int = config.TOP_K, |
| nprobe: int = 64, |
| device: Optional[str] = None, |
| ): |
| self._top_k = top_k |
| self._loaded = False |
|
|
| |
| if not index_path.exists(): |
| raise FileNotFoundError( |
| f"FAISS index not found: {index_path}\n" |
| "Run step2_embeddings.py first." |
| ) |
| if not metadata_path.exists(): |
| raise FileNotFoundError( |
| f"Metadata not found: {metadata_path}\n" |
| "Run step2_embeddings.py first." |
| ) |
|
|
| |
| print(f"[KCCRetriever] Loading FAISS index β¦", flush=True) |
| t0 = time.perf_counter() |
| self._index = faiss.read_index(str(index_path)) |
| self._index.nprobe = nprobe |
| faiss_load_s = time.perf_counter() - t0 |
| print( |
| f"[KCCRetriever] FAISS loaded: {self._index.ntotal:,} vectors " |
| f"({faiss_load_s:.1f}s)", |
| flush=True, |
| ) |
|
|
| |
| print(f"[KCCRetriever] Loading metadata β¦", flush=True) |
| t0 = time.perf_counter() |
| |
| |
| self._meta: pa.Table = pq.read_table( |
| metadata_path, |
| columns=["QueryText", "KccAns", "StateName", "Crop", "year"], |
| memory_map=True, |
| ) |
| |
| self._col_query = self._meta.column("QueryText") |
| self._col_ans = self._meta.column("KccAns") |
| self._col_state = self._meta.column("StateName") |
| self._col_crop = self._meta.column("Crop") |
| self._col_year = self._meta.column("year") |
| meta_load_s = time.perf_counter() - t0 |
| print( |
| f"[KCCRetriever] Metadata loaded: {self._meta.num_rows:,} rows " |
| f"({meta_load_s:.1f}s)", |
| flush=True, |
| ) |
|
|
| |
| assert self._index.ntotal == self._meta.num_rows, ( |
| f"MISMATCH: FAISS {self._index.ntotal:,} vs metadata {self._meta.num_rows:,}. " |
| "Re-run step2_embeddings.py." |
| ) |
|
|
| |
| import torch |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"[KCCRetriever] Loading embedding model ({device}) β¦", flush=True) |
| t0 = time.perf_counter() |
| self._model = SentenceTransformer(model_name, device=device) |
| model_load_s = time.perf_counter() - t0 |
| |
| self._model.encode(["warmup"], convert_to_numpy=True, |
| normalize_embeddings=True) |
| print( |
| f"[KCCRetriever] Model ready ({model_load_s:.1f}s)", |
| flush=True, |
| ) |
|
|
| |
| self._reranker = None |
| if config.RERANKER_MODEL: |
| try: |
| from sentence_transformers import CrossEncoder as _CE |
| print(f"[KCCRetriever] Loading reranker: {config.RERANKER_MODEL} β¦", |
| flush=True) |
| self._reranker = _CE(config.RERANKER_MODEL) |
| print("[KCCRetriever] Reranker ready.", flush=True) |
| except Exception as e: |
| print(f"[KCCRetriever] Reranker unavailable: {e}", flush=True) |
|
|
| |
| self._bm25_conn: Optional[sqlite3.Connection] = None |
| bm25_path = getattr(config, "BM25_INDEX_FILE", None) |
| if bm25_path and Path(bm25_path).exists(): |
| try: |
| self._bm25_conn = sqlite3.connect( |
| f"file:{bm25_path}?mode=ro", uri=True, |
| check_same_thread=False, |
| ) |
| self._bm25_conn.execute("PRAGMA cache_size = -32768") |
| print(f"[KCCRetriever] BM25 index loaded (hybrid search enabled).", flush=True) |
| except Exception as e: |
| print(f"[KCCRetriever] BM25 unavailable: {e}", flush=True) |
| self._bm25_conn = None |
| else: |
| print( |
| "[KCCRetriever] BM25 index not found β " |
| "run step2b_bm25_index.py to enable hybrid search.", |
| flush=True, |
| ) |
|
|
| self._loaded = True |
| print("[KCCRetriever] Ready.", flush=True) |
|
|
| |
|
|
| @property |
| def index_size(self) -> int: |
| return self._index.ntotal |
|
|
| def encode_query(self, query: str) -> np.ndarray: |
| """ |
| Embed a single user query β float32 unit vector (1 Γ DIM). |
| Normalised so inner-product == cosine similarity in FAISS. |
| """ |
| return self._model.encode( |
| [query.strip()], |
| convert_to_numpy=True, |
| normalize_embeddings=True, |
| show_progress_bar=False, |
| ).astype(np.float32) |
|
|
| def search( |
| self, |
| query: str, |
| top_k: Optional[int] = None, |
| deduplicate: bool = True, |
| min_score: float = 0.0, |
| crop_filter: Optional[str] = None, |
| state: str = "", |
| district: str = "", |
| ) -> List[RetrievedDoc]: |
| """ |
| Retrieve the top-K most similar Q&A pairs for a user query. |
| |
| Parameters |
| ---------- |
| query : User's natural-language question (any supported language) |
| top_k : Override default TOP_K from config |
| deduplicate : Drop results with identical KccAns text (keep highest score) |
| min_score : Discard results below this cosine similarity threshold |
| crop_filter : If set, prefer results whose Crop column contains this |
| string (case-insensitive). Falls back to unfiltered |
| results when not enough crop-matched hits are found. |
| state : Optional farmer state (e.g. "Madhya Pradesh"). Boosts |
| retrieved docs from the same state by +0.12. |
| district : Optional farmer district (e.g. "Barwani"). Boosts |
| retrieved docs from the same district by +0.25. |
| |
| Returns |
| ------- |
| List of RetrievedDoc, ranked by similarity (best first). |
| """ |
| k = top_k or self._top_k |
|
|
| |
| t0 = time.perf_counter() |
| q_vec = self.encode_query(query) |
| embed_ms = (time.perf_counter() - t0) * 1000 |
|
|
| |
| rerank_n = config.RERANKER_TOP_N if self._reranker else 0 |
| search_k = max(k * 3, rerank_n + k) if deduplicate else k |
| if crop_filter: |
| search_k = max(search_k, k * 20) |
|
|
| t1 = time.perf_counter() |
| scores, indices = self._index.search(q_vec, search_k) |
| faiss_ms = (time.perf_counter() - t1) * 1000 |
|
|
| |
| _state_lower = state.strip().lower() if state else "" |
| _district_lower = district.strip().lower() if district else "" |
|
|
| valid_hits = [ |
| (float(s), int(i)) |
| for s, i in zip(scores[0], indices[0]) |
| if i >= 0 and float(s) >= min_score |
| ] |
|
|
| crop_lower = crop_filter.lower() if crop_filter else None |
|
|
| docs: List[RetrievedDoc] = [] |
| fallback_docs: List[RetrievedDoc] = [] |
| seen_answers: set = set() |
|
|
| for score, idx in valid_hits: |
| if len(docs) >= k and len(fallback_docs) >= k: |
| break |
|
|
| answer_raw = self._col_ans[idx].as_py() |
| answer = str(answer_raw).strip() if answer_raw is not None else "" |
|
|
| |
| if len(answer) < config.MIN_ANSWER_CHARS: |
| continue |
|
|
| |
| _GENERIC_SKIP = [ |
| "contact your kvk", "contact nearest kvk", "visit kvk", |
| "kvk se sampark karen", "nazdiki kvk", "apne kshetra ke", |
| "krishi vigyan kendra se", "district agriculture office", |
| "krishi vibhag se sampark", "contact agriculture department", |
| "not available in our system", "uplabdh nahi hai", |
| "jaankari nahi hai", "information not available", |
| "please contact", "please visit", "kindly contact", |
| "apne nazdiki bank", "bank se sampark karen", |
| ] |
| ans_lower = answer.lower() |
| if any(ph in ans_lower for ph in _GENERIC_SKIP): |
| continue |
|
|
| if deduplicate and answer in seen_answers: |
| continue |
| seen_answers.add(answer) |
|
|
| year_raw = self._col_year[idx].as_py() |
| crop_val = str(self._col_crop[idx].as_py() or "").strip() |
| |
| doc_state = str(self._col_state[idx].as_py() or "").strip() |
| doc_year = int(year_raw) if year_raw is not None else 0 |
|
|
| boosted_score = score |
| |
| if doc_year >= 2023: |
| boosted_score += 0.15 |
| elif doc_year >= 2021: |
| boosted_score += 0.08 |
| elif doc_year >= 2018: |
| boosted_score += 0.03 |
| |
| if _state_lower and doc_state.lower() == _state_lower: |
| boosted_score += 0.12 |
| |
| |
| if _district_lower and _district_lower in str(self._col_query[idx].as_py() or "").lower(): |
| boosted_score += 0.25 |
|
|
| doc = RetrievedDoc( |
| rank = 0, |
| score = boosted_score, |
| query = str(self._col_query[idx].as_py() or "").strip(), |
| answer = answer, |
| state = doc_state, |
| crop = crop_val, |
| year = doc_year, |
| ) |
|
|
| if crop_lower: |
| if crop_lower in crop_val.lower() and len(docs) < k: |
| docs.append(doc) |
| elif crop_lower not in crop_val.lower() and len(fallback_docs) < k: |
| fallback_docs.append(doc) |
| else: |
| if len(docs) < k: |
| docs.append(doc) |
|
|
| |
| if crop_filter and len(docs) < k: |
| docs.extend(fallback_docs[:k - len(docs)]) |
|
|
| |
| |
| |
| |
| if self._bm25_conn is not None: |
| bm25_docs = self.bm25_search( |
| query, |
| top_n = max(k * 2, 20), |
| crop_filter = crop_filter, |
| ) |
| |
| existing_answers = {d.answer for d in docs} |
| for bd in bm25_docs: |
| if bd.answer not in existing_answers and len(docs) < k * 2: |
| existing_answers.add(bd.answer) |
| docs.append(bd) |
|
|
| |
| |
| |
| if self._reranker and len(docs) > 1: |
| try: |
| pairs = [(query, doc.answer) for doc in docs] |
| scores = self._reranker.predict(pairs) |
| for doc, rs in zip(docs, scores): |
| doc.rerank_score = float(rs) |
| |
| docs.sort(key=lambda d: d.rerank_score, reverse=True) |
| docs = docs[:k] |
| except Exception: |
| pass |
|
|
| |
| for i, doc in enumerate(docs): |
| doc.rank = i + 1 |
|
|
| return docs |
|
|
| @staticmethod |
| def _sanitize_fts5(query: str) -> str: |
| """ |
| Sanitize a query string for safe use in SQLite FTS5 MATCH expressions. |
| Removes FTS5 special operators; keeps alphanumeric + Unicode (Hindi/Indian scripts). |
| """ |
| |
| cleaned = re.sub(r'["\(\)\*\:\^~]+', ' ', query) |
| |
| cleaned = ' '.join(cleaned.split()) |
| return cleaned |
|
|
| def bm25_search( |
| self, |
| query: str, |
| top_n: int = 20, |
| crop_filter: Optional[str] = None, |
| ) -> List[RetrievedDoc]: |
| """ |
| BM25 keyword search via SQLite FTS5. |
| Returns up to top_n results ranked by BM25 relevance. |
| |
| Falls back gracefully to empty list if BM25 index is unavailable |
| or the query contains no searchable tokens. |
| |
| BM25 complements FAISS by catching: |
| - Exact chemical/variety names (e.g. "Chlorpyrifos", "PB-1121") |
| - Specific dose queries ("2ml/liter") |
| - Uncommon tokens the embedding model distributes poorly |
| """ |
| if self._bm25_conn is None: |
| return [] |
|
|
| safe_q = self._sanitize_fts5(query) |
| if not safe_q: |
| return [] |
|
|
| |
| crop_clause = "" |
| params: list = [safe_q] |
| if crop_filter: |
| crop_clause = " AND crop = ?" |
| params.append(crop_filter) |
|
|
| sql = f""" |
| SELECT query_text, answer, state, crop, year, rank |
| FROM kcc_bm25 |
| WHERE kcc_bm25 MATCH ?{crop_clause} |
| ORDER BY rank |
| LIMIT ? |
| """ |
| params.append(top_n) |
|
|
| try: |
| cur = self._bm25_conn.cursor() |
| rows = cur.execute(sql, params).fetchall() |
| except sqlite3.OperationalError: |
| |
| return [] |
|
|
| if not rows: |
| |
| if crop_filter: |
| return self.bm25_search(query, top_n=top_n, crop_filter=None) |
| return [] |
|
|
| |
| |
| |
| raw_ranks = [r[5] for r in rows] |
| best = min(raw_ranks) |
| worst = max(raw_ranks, default=best) |
| span = (worst - best) or 1.0 |
|
|
| docs: List[RetrievedDoc] = [] |
| seen_answers: set = set() |
| for i, (qt, ans, state, crop, year, rank) in enumerate(rows): |
| ans = str(ans or "").strip() |
| if len(ans) < config.MIN_ANSWER_CHARS: |
| continue |
| if ans in seen_answers: |
| continue |
| seen_answers.add(ans) |
|
|
| |
| score = 0.05 + 0.95 * (1.0 - (rank - best) / span) |
| docs.append(RetrievedDoc( |
| rank = i + 1, |
| score = round(score, 4), |
| query = str(qt or "").strip(), |
| answer = ans, |
| state = str(state or "").strip(), |
| crop = str(crop or "").strip(), |
| year = int(year) if year else 0, |
| )) |
|
|
| return docs |
|
|
| def format_context(self, docs: List[RetrievedDoc]) -> str: |
| """ |
| Format retrieved docs as a concise context block for the LLM prompt. |
| Each retrieved Q&A pair is numbered and trimmed to keep the prompt short. |
| """ |
| if not docs: |
| return "No relevant past Q&A found." |
|
|
| MAX_ANSWER_CHARS = 600 |
|
|
| lines = ["RELEVANT PAST FARMER Q&A FROM KCC DATABASE:"] |
| for doc in docs: |
| ans = doc.answer[:MAX_ANSWER_CHARS] |
| if len(doc.answer) > MAX_ANSWER_CHARS: |
| ans += "β¦" |
| lines.append( |
| f"\n[{doc.rank}] Similarity: {doc.score_pct} | " |
| f"State: {doc.state} | Crop: {doc.crop} | Year: {doc.year}" |
| f"\n Q: {doc.query}" |
| f"\n A: {ans}" |
| ) |
| return "\n".join(lines) |
|
|
|
|
|
|
|
|
| |
|
|
| class KCCGoldenRetriever: |
| """ |
| Fast lookup for pre-verified golden Q&A pairs from the KCC corpus. |
| Checked BEFORE FAISS β if a golden match is found with high confidence, |
| it is prepended to the context with a [VERIFIED ANSWER] tag. |
| |
| Uses TF-IDF cosine similarity for matching (no GPU required, <5ms per query). |
| """ |
|
|
| SIMILARITY_THRESHOLD = 0.85 |
|
|
| def __init__(self, golden_path: str): |
| import json as _json |
| from pathlib import Path as _Path |
| self._entries: list = [] |
| self._vectorizer = None |
| self._tfidf_matrix = None |
|
|
| gp = _Path(golden_path) |
| if not gp.exists(): |
| print(f"[KCCGoldenRetriever] Golden set not found at {golden_path} β skipping", flush=True) |
| return |
|
|
| with open(gp, "r", encoding="utf-8") as f: |
| data = _json.load(f) |
| self._entries = data.get("entries", []) |
| if not self._entries: |
| print("[KCCGoldenRetriever] Empty golden set", flush=True) |
| return |
|
|
| |
| try: |
| from sklearn.feature_extraction.text import TfidfVectorizer |
| from sklearn.metrics.pairwise import cosine_similarity as _cos_sim |
| self._cos_sim = _cos_sim |
| queries = [e["query"] for e in self._entries] |
| self._vectorizer = TfidfVectorizer( |
| ngram_range=(1, 2), |
| max_features=20000, |
| sublinear_tf=True, |
| min_df=1, |
| ) |
| self._tfidf_matrix = self._vectorizer.fit_transform(queries) |
| print(f"[KCCGoldenRetriever] Loaded {len(self._entries)} golden entries", flush=True) |
| except ImportError: |
| print("[KCCGoldenRetriever] sklearn not available β golden retriever disabled", flush=True) |
| self._vectorizer = None |
|
|
| def lookup(self, query: str, top_k: int = 3) -> list: |
| """ |
| TF-IDF similarity search against golden set. |
| Returns list of dicts with keys: query, answer, crop, state, year, score |
| Only returns matches above SIMILARITY_THRESHOLD. |
| """ |
| if self._vectorizer is None or self._tfidf_matrix is None or not self._entries: |
| return [] |
| try: |
| q_vec = self._vectorizer.transform([query]) |
| sims = self._cos_sim(q_vec, self._tfidf_matrix).flatten() |
| top_indices = sims.argsort()[::-1][:top_k] |
| results = [] |
| for idx in top_indices: |
| score = float(sims[idx]) |
| if score >= self.SIMILARITY_THRESHOLD: |
| entry = dict(self._entries[idx]) |
| entry["score"] = score |
| results.append(entry) |
| return results |
| except Exception as e: |
| print(f"[KCCGoldenRetriever] lookup error: {e}", flush=True) |
| return [] |
|
|
| @property |
| def size(self) -> int: |
| return len(self._entries) |
|
|
|
|
| |
|
|
| _golden_retriever: Optional["KCCGoldenRetriever"] = None |
|
|
| def get_golden_retriever(golden_path: Optional[str] = None) -> "KCCGoldenRetriever": |
| """Return the module-level golden retriever singleton (create on first call).""" |
| global _golden_retriever |
| if _golden_retriever is None: |
| if golden_path is None: |
| from pathlib import Path as _P |
| golden_path = str(_P(__file__).parent / "kcc_golden_set.json") |
| _golden_retriever = KCCGoldenRetriever(golden_path) |
| return _golden_retriever |
|
|
| |
|
|
| _retriever: Optional[KCCRetriever] = None |
|
|
| def get_retriever(**kwargs) -> KCCRetriever: |
| """Return the module-level retriever singleton (create on first call).""" |
| global _retriever |
| if _retriever is None: |
| _retriever = KCCRetriever(**kwargs) |
| return _retriever |
|
|
|
|
| |
|
|
| def _standalone_demo(query: str, top_k: int) -> None: |
| print("=" * 70) |
| print("KCC RAG β Step 3: Retrieval Demo") |
| print("=" * 70) |
|
|
| retriever = get_retriever(top_k=top_k) |
|
|
| print(f"\nQuery: {query!r}") |
| print(f"Index size: {retriever.index_size:,} vectors\n") |
|
|
| t0 = time.perf_counter() |
| docs = retriever.search(query, top_k=top_k) |
| dt = (time.perf_counter() - t0) * 1000 |
|
|
| print(f"Retrieved {len(docs)} results in {dt:.1f} ms\n") |
| print(retriever.format_context(docs)) |
| print("=" * 70) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="KCC RAG β Step 3: Retrieval engine test" |
| ) |
| parser.add_argument("query", help="Query to search for") |
| parser.add_argument("--top-k", type=int, default=config.TOP_K) |
| args = parser.parse_args() |
| _standalone_demo(args.query, args.top_k) |
|
|