| """Retrieval engine: FAISS dense + BM25 sparse + cross-encoder rerank |
| + GoldenRetriever (TF-IDF) + HyDE for low-confidence queries. |
| |
| Wired together in `multi_step_retrieve`. |
| """ |
| from __future__ import annotations |
|
|
| import re |
| import sqlite3 |
| import time |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import faiss |
| import numpy as np |
| import pyarrow as pa |
| import pyarrow.parquet as pq |
| from sentence_transformers import SentenceTransformer |
|
|
| from . import config |
|
|
|
|
| @dataclass |
| class RetrievedDoc: |
| rank: int |
| score: float |
| rerank_score: float = 0.0 |
| query: str = "" |
| answer: str = "" |
| state: str = "" |
| crop: str = "" |
| year: int = 0 |
| source: str = "kcc" |
|
|
| @property |
| def score_pct(self) -> str: |
| return f"{self.score * 100:.1f}%" |
|
|
|
|
| |
| _GENERIC_SKIP = ( |
| "contact your kvk", "contact nearest kvk", "visit kvk", |
| "kvk se sampark karen", "nazdiki kvk", "apne kshetra ke", |
| "krishi vigyan kendra se", "district agriculture office", |
| "krishi vibhag se sampark", "contact agriculture department", |
| "not available in our system", "uplabdh nahi hai", |
| "jaankari nahi hai", "information not available", |
| "please contact", "please visit", "kindly contact", |
| "apne nazdiki bank", "bank se sampark karen", |
| ) |
|
|
|
|
| class KCCRetriever: |
| """FAISS + BM25 hybrid + cross-encoder reranker.""" |
|
|
| def __init__(self, |
| index_path: Path = config.FAISS_INDEX_FILE, |
| metadata_path: Path = config.METADATA_FILE, |
| model_name: str = config.EMBEDDING_MODEL, |
| top_k: int = config.TOP_K, |
| nprobe: int = 64, |
| device: Optional[str] = None): |
| self._top_k = top_k |
|
|
| if not index_path.exists(): |
| raise FileNotFoundError(f"FAISS index not found: {index_path}") |
| if not metadata_path.exists(): |
| raise FileNotFoundError(f"Metadata not found: {metadata_path}") |
|
|
| t0 = time.perf_counter() |
| self._index = faiss.read_index(str(index_path)) |
| self._index.nprobe = nprobe |
| print(f"[retr] FAISS loaded: {self._index.ntotal:,} vectors " |
| f"({time.perf_counter()-t0:.1f}s)", flush=True) |
|
|
| t0 = time.perf_counter() |
| self._meta = pq.read_table( |
| metadata_path, |
| columns=["QueryText", "KccAns", "StateName", "Crop", "year"], |
| memory_map=True, |
| ) |
| self._col_query = self._meta.column("QueryText") |
| self._col_ans = self._meta.column("KccAns") |
| self._col_state = self._meta.column("StateName") |
| self._col_crop = self._meta.column("Crop") |
| self._col_year = self._meta.column("year") |
| print(f"[retr] Metadata loaded: {self._meta.num_rows:,} rows " |
| f"({time.perf_counter()-t0:.1f}s)", flush=True) |
|
|
| assert self._index.ntotal == self._meta.num_rows, ( |
| f"MISMATCH: FAISS {self._index.ntotal} vs meta {self._meta.num_rows}. " |
| "Re-run pipelines/reindex_*.py.") |
|
|
| import torch |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| t0 = time.perf_counter() |
| self._model = SentenceTransformer(model_name, device=device) |
| self._model.encode(["warmup"], convert_to_numpy=True, normalize_embeddings=True) |
| print(f"[retr] Embedder ready on {device} ({time.perf_counter()-t0:.1f}s)", |
| flush=True) |
|
|
| self._reranker = None |
| if config.RERANKER_MODEL: |
| try: |
| from sentence_transformers import CrossEncoder |
| self._reranker = CrossEncoder(config.RERANKER_MODEL, device=device) |
| print(f"[retr] Reranker ready: {config.RERANKER_MODEL}", flush=True) |
| except Exception as e: |
| print(f"[retr] Reranker disabled: {e}", flush=True) |
|
|
| self._bm25_conn: Optional[sqlite3.Connection] = None |
| if config.BM25_INDEX_FILE.exists(): |
| try: |
| self._bm25_conn = sqlite3.connect( |
| f"file:{config.BM25_INDEX_FILE}?mode=ro", uri=True, |
| check_same_thread=False) |
| self._bm25_conn.execute("PRAGMA cache_size = -32768") |
| print("[retr] BM25 hybrid enabled", flush=True) |
| except Exception as e: |
| print(f"[retr] BM25 disabled: {e}", flush=True) |
|
|
| @property |
| def index_size(self) -> int: |
| return self._index.ntotal |
|
|
| def encode_query(self, query: str) -> np.ndarray: |
| return self._model.encode( |
| [query.strip()], |
| convert_to_numpy=True, |
| normalize_embeddings=True, |
| show_progress_bar=False, |
| ).astype(np.float32) |
|
|
| def search(self, |
| query: str, |
| top_k: Optional[int] = None, |
| crop_filter: Optional[str] = None, |
| state: str = "", |
| district: str = "", |
| min_score: float = 0.0) -> List[RetrievedDoc]: |
| k = top_k or self._top_k |
| q_vec = self.encode_query(query) |
|
|
| rerank_n = config.RERANKER_TOP_N if self._reranker else 0 |
| search_k = max(k * 3, rerank_n + k) |
| if crop_filter: |
| search_k = max(search_k, k * 20) |
|
|
| scores, indices = self._index.search(q_vec, search_k) |
| crop_lower = crop_filter.lower() if crop_filter else None |
| state_lower = state.strip().lower() if state else "" |
| district_lower = district.strip().lower() if district else "" |
|
|
| docs: List[RetrievedDoc] = [] |
| fallback: List[RetrievedDoc] = [] |
| seen: set = set() |
|
|
| for s, i in zip(scores[0], indices[0]): |
| if i < 0 or float(s) < min_score: |
| continue |
| if len(docs) >= k and len(fallback) >= k: |
| break |
|
|
| ans = str(self._col_ans[i].as_py() or "").strip() |
| if len(ans) < config.MIN_ANSWER_CHARS: |
| continue |
| ans_lower = ans.lower() |
| if any(p in ans_lower for p in _GENERIC_SKIP): |
| continue |
| if ans in seen: |
| continue |
| seen.add(ans) |
|
|
| crop_v = str(self._col_crop[i].as_py() or "").strip() |
| state_v = str(self._col_state[i].as_py() or "").strip() |
| year_v = int(self._col_year[i].as_py() or 0) |
| score = float(s) |
|
|
| |
| if year_v >= 2023: score += 0.15 |
| elif year_v >= 2021: score += 0.08 |
| elif year_v >= 2018: score += 0.03 |
| if state_lower and state_v.lower() == state_lower: |
| score += 0.12 |
| if district_lower and district_lower in str(self._col_query[i].as_py() or "").lower(): |
| score += 0.25 |
|
|
| doc = RetrievedDoc( |
| rank=0, score=score, query=str(self._col_query[i].as_py() or "").strip(), |
| answer=ans, state=state_v, crop=crop_v, year=year_v, source="kcc") |
|
|
| if crop_lower: |
| if crop_lower in crop_v.lower() and len(docs) < k: |
| docs.append(doc) |
| elif crop_lower not in crop_v.lower() and len(fallback) < k: |
| fallback.append(doc) |
| else: |
| if len(docs) < k: |
| docs.append(doc) |
|
|
| if crop_filter and len(docs) < k: |
| docs.extend(fallback[:k - len(docs)]) |
|
|
| |
| if self._bm25_conn is not None: |
| bm25 = self.bm25_search(query, top_n=max(k * 2, 20), crop_filter=crop_filter) |
| existing = {d.answer for d in docs} |
| for bd in bm25: |
| if bd.answer not in existing and len(docs) < k * 2: |
| existing.add(bd.answer) |
| docs.append(bd) |
|
|
| |
| if self._reranker and len(docs) > 1: |
| try: |
| pairs = [(query, d.answer) for d in docs] |
| scores = self._reranker.predict(pairs) |
| for d, rs in zip(docs, scores): |
| d.rerank_score = float(rs) |
| docs.sort(key=lambda d: d.rerank_score, reverse=True) |
| docs = docs[:k] |
| except Exception: |
| pass |
|
|
| for i, d in enumerate(docs): |
| d.rank = i + 1 |
| return docs |
|
|
| @staticmethod |
| def _sanitize_fts5(query: str) -> str: |
| cleaned = re.sub(r'["\(\)\*\:\^~]+', ' ', query) |
| return ' '.join(cleaned.split()) |
|
|
| def bm25_search(self, query: str, top_n: int = 20, |
| crop_filter: Optional[str] = None) -> List[RetrievedDoc]: |
| if self._bm25_conn is None: |
| return [] |
| safe_q = self._sanitize_fts5(query) |
| if not safe_q: |
| return [] |
|
|
| crop_clause = "" |
| params: list = [safe_q] |
| if crop_filter: |
| crop_clause = " AND crop = ?" |
| params.append(crop_filter) |
| params.append(top_n) |
| sql = f"""SELECT query_text, answer, state, crop, year, rank |
| FROM kcc_bm25 WHERE kcc_bm25 MATCH ?{crop_clause} |
| ORDER BY rank LIMIT ?""" |
| try: |
| rows = self._bm25_conn.cursor().execute(sql, params).fetchall() |
| except sqlite3.OperationalError: |
| return [] |
| if not rows: |
| return [] if crop_filter is None else self.bm25_search(query, top_n, None) |
|
|
| ranks = [r[5] for r in rows] |
| best, worst = min(ranks), max(ranks) |
| span = (worst - best) or 1.0 |
|
|
| out: List[RetrievedDoc] = [] |
| seen: set = set() |
| for i, (qt, ans, state, crop, year, rank) in enumerate(rows): |
| ans = str(ans or "").strip() |
| if len(ans) < config.MIN_ANSWER_CHARS or ans in seen: |
| continue |
| seen.add(ans) |
| score = 0.05 + 0.95 * (1.0 - (rank - best) / span) |
| out.append(RetrievedDoc( |
| rank=i + 1, score=round(score, 4), |
| query=str(qt or "").strip(), answer=ans, |
| state=str(state or ""), crop=str(crop or ""), |
| year=int(year or 0), source="kcc")) |
| return out |
|
|
| @staticmethod |
| def format_context(docs: List[RetrievedDoc], max_answer_chars: int = 600) -> str: |
| if not docs: |
| return "No relevant past Q&A found." |
| lines = ["RELEVANT PAST FARMER Q&A FROM KCC DATABASE:"] |
| for d in docs: |
| ans = d.answer[:max_answer_chars] + ("β¦" if len(d.answer) > max_answer_chars else "") |
| tag = f"[{d.rank}{('Β·G' if d.source == 'golden' else 'Β·I' if d.source == 'icar' else '')}]" |
| lines.append( |
| f"\n{tag} Similarity: {d.score_pct} | State: {d.state} | Crop: {d.crop} | Year: {d.year}" |
| f"\n Q: {d.query}" |
| f"\n A: {ans}" |
| ) |
| return "\n".join(lines) |
|
|
|
|
| |
|
|
| class KCCGoldenRetriever: |
| """Tiny TF-IDF retriever over hand-curated KCC pairs. |
| |
| Hits BEFORE FAISS β if a >= 0.85 cosine match exists, the verified answer |
| is prepended to the context with a [GOLDEN] tag. |
| """ |
| SIMILARITY_THRESHOLD = 0.85 |
|
|
| def __init__(self, golden_path: Path = config.GOLDEN_SET_FILE): |
| import json |
| self._entries: list = [] |
| self._vectorizer = None |
| self._tfidf = None |
| if not golden_path.exists(): |
| print(f"[golden] not found at {golden_path}", flush=True) |
| return |
| with open(golden_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| self._entries = data.get("entries", []) |
| if not self._entries: |
| return |
| try: |
| from sklearn.feature_extraction.text import TfidfVectorizer |
| from sklearn.metrics.pairwise import cosine_similarity |
| self._cos = cosine_similarity |
| queries = [e["query"] for e in self._entries] |
| self._vectorizer = TfidfVectorizer(ngram_range=(1, 2), |
| max_features=20000, |
| sublinear_tf=True, min_df=1) |
| self._tfidf = self._vectorizer.fit_transform(queries) |
| print(f"[golden] {len(self._entries)} entries indexed", flush=True) |
| except ImportError: |
| self._vectorizer = None |
|
|
| @property |
| def size(self) -> int: |
| return len(self._entries) |
|
|
| def lookup(self, query: str, top_k: int = 3) -> List[RetrievedDoc]: |
| if self._vectorizer is None or self._tfidf is None or not self._entries: |
| return [] |
| q_vec = self._vectorizer.transform([query]) |
| sims = self._cos(q_vec, self._tfidf).flatten() |
| top_idx = sims.argsort()[::-1][:top_k] |
| out: List[RetrievedDoc] = [] |
| for i, idx in enumerate(top_idx): |
| sc = float(sims[idx]) |
| if sc < self.SIMILARITY_THRESHOLD: |
| break |
| e = self._entries[idx] |
| out.append(RetrievedDoc( |
| rank=i + 1, score=sc, query=e.get("query", ""), |
| answer=e.get("answer", ""), state=e.get("state", ""), |
| crop=e.get("crop", ""), year=int(e.get("year", 0) or 0), |
| source="golden")) |
| return out |
|
|
|
|
| |
|
|
| def multi_step_retrieve(retriever: KCCRetriever, |
| golden: Optional[KCCGoldenRetriever], |
| query: str, |
| normalized_query: str, |
| crop: Optional[str], |
| problem_type: str, |
| *, |
| top_k: int = 5, |
| state: str = "", |
| district: str = "", |
| run_hyde: bool = True) -> List[RetrievedDoc]: |
| """Hybrid: golden lookup first, then FAISS+BM25+rerank, then HyDE if low-confidence. |
| |
| HyDE = generate hypothetical answer with the LLM, embed it, re-search. |
| Only fires when top score < HYDE_TRIGGER_THRESHOLD and run_hyde=True. |
| """ |
| docs: List[RetrievedDoc] = [] |
|
|
| if golden is not None and golden.size > 0: |
| gold = golden.lookup(query, top_k=2) |
| |
| docs.extend(gold) |
|
|
| rewritten = normalized_query |
| if crop and crop.lower() not in rewritten.lower(): |
| rewritten = f"{rewritten} {crop}" |
| if problem_type and problem_type != "general": |
| rewritten = f"{rewritten} {problem_type}" |
|
|
| base = retriever.search(rewritten, top_k=top_k, crop_filter=crop, |
| state=state, district=district) |
| seen = {d.answer for d in docs} |
| for d in base: |
| if d.answer not in seen: |
| d.rank = len(docs) + 1 |
| docs.append(d) |
| seen.add(d.answer) |
|
|
| top_score = docs[0].rerank_score if docs and docs[0].rerank_score else ( |
| docs[0].score if docs else 0.0) |
| needs_hyde = run_hyde and top_score < config.HYDE_TRIGGER_THRESHOLD |
|
|
| if needs_hyde: |
| from .hyde import expand_query |
| hyde_query = expand_query(query, crop=crop, problem_type=problem_type) |
| if hyde_query and hyde_query.strip() != query.strip(): |
| extra = retriever.search(hyde_query, top_k=top_k, crop_filter=crop, |
| state=state, district=district) |
| for d in extra: |
| if d.answer not in seen and len(docs) < top_k * 2: |
| d.rank = len(docs) + 1 |
| docs.append(d) |
| seen.add(d.answer) |
|
|
| return docs[:top_k] |
|
|
|
|
| |
|
|
| _retriever: Optional[KCCRetriever] = None |
| _golden: Optional[KCCGoldenRetriever] = None |
|
|
|
|
| def get_retriever(**kwargs) -> KCCRetriever: |
| global _retriever |
| if _retriever is None: |
| _retriever = KCCRetriever(**kwargs) |
| return _retriever |
|
|
|
|
| def get_golden_retriever() -> KCCGoldenRetriever: |
| global _golden |
| if _golden is None: |
| _golden = KCCGoldenRetriever() |
| return _golden |
|
|