| """ |
| api.py |
| ====== |
| FastAPI REST wrapper around the KCC RAG pipeline. |
| |
| Endpoints |
| --------- |
| POST /query — main chat endpoint (JSON in, JSON out) |
| POST /query/stream — streaming chat (Server-Sent Events) |
| POST /pest-risk — district-level pest risk prediction (NEW v2) |
| POST /diagnose — image disease diagnosis (upload image file) |
| GET /health — liveness probe |
| |
| Usage (development) |
| ------------------- |
| uvicorn api:app --host 0.0.0.0 --port 8000 --reload |
| |
| Usage (production) |
| ------------------ |
| uvicorn api:app --host 0.0.0.0 --port 8000 --workers 2 |
| """ |
|
|
| import os as _os, glob as _glob |
| from pathlib import Path as _Path |
|
|
| |
| def _ensure_large_files(): |
| _BASE = _Path(_os.path.abspath(__file__)).parent |
| _TOKEN = _os.environ.get("HF_TOKEN") |
| _DATASET = "hritikm15/kcc-data" |
|
|
| _LARGE_FILES = [ |
| "index/faiss.index", |
| "index/metadata.parquet", |
| "index/bm25_chunk_aa", |
| "index/bm25_chunk_ab", |
| "index/bm25_chunk_ac", |
| "index/bm25_chunk_ad", |
| "mandi_advisor/presow_v4_model.pkl", |
| "mandi_advisor/presow_v4_meta.pkl", |
| "mandi_advisor/feature_store.parquet", |
| "pest_model/pest_risk_model_district_v2.pkl", |
| ] |
|
|
| _missing = [f for f in _LARGE_FILES if not (_BASE / f).exists() or (_BASE / f).stat().st_size < 1000] |
|
|
| if _missing: |
| print(f"[startup] {len(_missing)} large files missing — downloading from HF Dataset...", flush=True) |
| from huggingface_hub import hf_hub_download |
| for _repo_file in _missing: |
| _local = _BASE / _repo_file |
| _local.parent.mkdir(parents=True, exist_ok=True) |
| print(f"[startup] Downloading {_repo_file}...", flush=True) |
| hf_hub_download( |
| repo_id=_DATASET, |
| filename=_repo_file, |
| repo_type="dataset", |
| token=_TOKEN, |
| local_dir=str(_BASE), |
| ) |
| print(f"[startup] {_repo_file} done ({_local.stat().st_size // 1_000_000}MB)", flush=True) |
|
|
| |
| _BM25 = _BASE / "index" / "bm25.db" |
| _CHUNKS = sorted((_BASE / "index").glob("bm25_chunk_*")) |
| if _CHUNKS and (not _BM25.exists() or _BM25.stat().st_size < 1_000_000): |
| print(f"[startup] Reassembling bm25.db from {len(_CHUNKS)} chunks...", flush=True) |
| with open(_BM25, "wb") as _out: |
| for _chunk in _CHUNKS: |
| with open(_chunk, "rb") as _f: |
| while True: |
| _buf = _f.read(64 * 1024 * 1024) |
| if not _buf: break |
| _out.write(_buf) |
| print(f"[startup] bm25.db ready ({_BM25.stat().st_size // 1_000_000}MB)", flush=True) |
|
|
| _ensure_large_files() |
| |
|
|
| import io |
| import re |
| import sys |
| import time |
| from pathlib import Path |
| from typing import AsyncIterator, List, Optional |
|
|
| from fastapi import FastAPI, File, HTTPException, Response, UploadFile |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from pydantic import BaseModel, Field |
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| import config |
| from step3_retrieval import KCCRetriever, RetrievedDoc, get_retriever |
| from step4_app import ( |
| SAFETY_GUARDRAILS, |
| LOW_CONF_THRESHOLD, |
| classify_problem, |
| detect_crop, |
| multi_step_retrieve, |
| normalize_query, |
| _build_prompt, |
| detect_language, |
| _diagnose_image_gemini, |
| _get_gemini_client, |
| ) |
|
|
| |
| app = FastAPI( |
| title="KCC Agricultural Chatbot API", |
| description=( |
| "Enterprise REST API over 16.5 M Kisan Call Center Q&A pairs. " |
| "Pest predictions: district-level stacking model (AUC 0.936), 1-month lead time. " |
| "Supports Hindi, English, Telugu, Kannada, Marathi." |
| ), |
| version="2.0.0", |
| ) |
|
|
| |
| import config as _cfg |
| import os |
|
|
| |
| import hashlib, secrets as _secrets |
| from datetime import timedelta as _timedelta, datetime as _datetime, timezone as _timezone |
| from typing import Optional as _Opt |
|
|
| try: |
| from jose import JWTError, jwt as _jwt |
| _JWT_AVAILABLE = True |
| except ImportError: |
| _JWT_AVAILABLE = False |
|
|
| _JWT_SECRET = os.environ.get("JWT_SECRET", _secrets.token_hex(32)) |
| _JWT_ALGO = "HS256" |
| _JWT_EXP_HOURS = 24 |
|
|
| _B2B_USERS = { |
| "enterprise": os.environ.get("B2B_DEMO_PASSWORD", "KCC@Demo2025"), |
| "admin": os.environ.get("ADMIN_PASSWORD", "Admin@KCC2025"), |
| } |
|
|
| class _TokenRequest(BaseModel): |
| username: str |
| password: str |
|
|
| class _TokenResponse(BaseModel): |
| access_token: str |
| token_type: str = "bearer" |
| expires_in: int = _JWT_EXP_HOURS * 3600 |
|
|
| def _create_token(username: str) -> str: |
| if not _JWT_AVAILABLE: |
| return f"mock-token-{username}" |
| exp = _datetime.now(_timezone.utc) + _timedelta(hours=_JWT_EXP_HOURS) |
| return _jwt.encode({"sub": username, "exp": exp}, _JWT_SECRET, algorithm=_JWT_ALGO) |
|
|
| def _verify_token(token: str) -> _Opt[str]: |
| if not _JWT_AVAILABLE: |
| return token.replace("mock-token-", "") if token.startswith("mock-token-") else None |
| try: |
| payload = _jwt.decode(token, _JWT_SECRET, algorithms=[_JWT_ALGO]) |
| return payload.get("sub") |
| except Exception: |
| return None |
|
|
| from fastapi import Depends, Header |
| from fastapi.security import HTTPBearer as _HTTPBearer, HTTPAuthorizationCredentials as _Creds |
| _security = _HTTPBearer(auto_error=False) |
|
|
| def _get_current_user(creds: _Opt[_Creds] = Depends(_security)) -> _Opt[str]: |
| if creds is None: |
| return None |
| return _verify_token(creds.credentials) |
|
|
| def _require_b2b(user: _Opt[str] = Depends(_get_current_user)): |
| if user is None: |
| raise HTTPException(status_code=401, |
| detail="B2B endpoint requires authentication. POST /auth/token to get a token.") |
| return user |
|
|
|
|
| _ALLOWED_ORIGINS = _cfg.ALLOWED_ORIGINS or ["http://localhost:8501", "http://localhost:8000"] |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=_ALLOWED_ORIGINS, |
| allow_credentials=True, |
| allow_methods=["GET", "POST"], |
| allow_headers=["Content-Type", "Authorization", "X-API-Key", "X-Request-ID"], |
| expose_headers=["X-Request-ID"], |
| ) |
|
|
|
|
| |
| |
| try: |
| from slowapi import Limiter |
| from slowapi.util import get_remote_address |
| from slowapi.errors import RateLimitExceeded |
| from fastapi import Request as _SlowReq |
| _RATE_LIMIT_OK = True |
| except ImportError: |
| _RATE_LIMIT_OK = False |
|
|
| if _RATE_LIMIT_OK: |
| _limiter = Limiter( |
| key_func=get_remote_address, |
| default_limits=[ |
| os.environ.get("RATE_LIMIT_DEFAULT", "120/minute"), |
| os.environ.get("RATE_LIMIT_HOUR", "2000/hour"), |
| ], |
| ) |
| app.state.limiter = _limiter |
|
|
| @app.exception_handler(RateLimitExceeded) |
| async def _rate_limit_handler(request: _SlowReq, exc: RateLimitExceeded): |
| return JSONResponse(status_code=429, |
| content={"detail": f"Rate limit exceeded: {exc.detail}"}) |
|
|
|
|
| |
| |
| |
| import uuid as _uuid |
| from fastapi import Request as _ReqIDReq |
|
|
| @app.middleware("http") |
| async def _request_id_middleware(request: _ReqIDReq, call_next): |
| request_id = request.headers.get("X-Request-ID") or str(_uuid.uuid4()) |
| request.state.request_id = request_id |
| response = await call_next(request) |
| response.headers["X-Request-ID"] = request_id |
| return response |
|
|
| |
| _retriever: Optional[KCCRetriever] = None |
|
|
| def _get_retriever() -> KCCRetriever: |
| global _retriever |
| if _retriever is None: |
| _retriever = get_retriever( |
| index_path = config.FAISS_INDEX_FILE, |
| metadata_path = config.METADATA_FILE, |
| ) |
| return _retriever |
|
|
|
|
| |
| |
| |
|
|
| _AGRI_SIGNALS = { |
| |
| "crop","crops","plant","plants","planting","seed","seeds","soil","fertilizer", |
| "pesticide","pest","disease","farm","farmer","farming","kisan","kheti","fasal", |
| "gehu","gehun","dhan","kapas","tamatar","aloo","pyaz","mandi","bhav","rate", |
| "harvest","harvesting","sow","sown","sowing","grow","grows","growing","grown", |
| "variety","varieties","cultivar","cultivars","season","seasonal","kharif","rabi","zaid", |
| "irrigation","spray","insecticide","fungicide","organic","yield","yields","blight", |
| "wheat","rice","paddy","cotton","maize","sugarcane","soybean","mustard", |
| "chilli","brinjal","onion","tomato","potato","groundnut","gram","pulses","pulse", |
| "barley","jowar","bajra","ragi","arhar","tur","moong","urad","masoor","chana", |
| |
| "aphid","borer","mildew","rust","wilt","thrips","mite","whitefly", |
| "caterpillar","jassid","leaf spot","mosaic","virus","rot","fungus", |
| "blast","armyworm","bollworm","helicoverpa","spodoptera","girdle", |
| |
| "khaad","dawai","beej","sinchai","pattiya","pattiyan","keeda","bimari", |
| "rog","upchar","khet","paidavar","safed makhi","tela","mahu","tikda", |
| |
| "pala","thand","frost","baarish","drought","sukha","heat stress", |
| "ola","flood","andhi", |
| |
| "kvk","icar","pm kisan","fasal bima","kcc","kisan call","advisory", |
| |
| "endosulfan","monocrotophos","chlorpyrifos","imidacloprid","emamectin", |
| "mancozeb","propiconazole","thiamethoxam","rhizobium","urea","dap", |
| |
| "फसल","खेत","किसान","खाद","बीज","धान","गेहूं","कपास", |
| "मक्का","सरसों","चना","अरहर","मूंग","प्याज","आलू", |
| "कीड़ा","बीमारी","दवाइ","स्प्रे","उपचार","सिंचाई", |
| } |
|
|
| _NON_AGRI_RE = re.compile( |
| r"\b(stock market|share market|bitcoin|crypto|politics|election|" |
| r"movie|cricket|football|recipe|cooking|exam|bank account|insurance claim|" |
| r"marriage|divorce|job|salary|relationship|celebrity|news|tv show|" |
| r"web series|ipl|bollywood|actor|actress|pakistan|china|war|army|" |
| r"love|dating|girlfriend|boyfriend|password|hack|code)\b", |
| re.IGNORECASE, |
| ) |
|
|
| _SHORT_FOLLOWUP_RE = re.compile(r"^(ha|haan|yes|ok|okay|nahi|no|theek|accha|sahi|nahi ji|ji|acha)\b", re.IGNORECASE) |
|
|
|
|
| def is_agriculture_query(query: str) -> bool: |
| """ |
| Returns True if query is agriculture-related. |
| Short follow-ups (yes/no/ha) always pass (they're contextual replies). |
| """ |
| stripped = query.strip() |
| q_lower_short = stripped.lower() |
|
|
| |
| |
| if len(stripped.split()) <= 3 and not _NON_AGRI_RE.search(q_lower_short): |
| return True |
|
|
| |
| if any('ऀ' <= c <= 'ॿ' for c in stripped): |
| if _NON_AGRI_RE.search(stripped.lower()): |
| return False |
| return True |
|
|
| q_lower = stripped.lower() |
|
|
| |
| if _NON_AGRI_RE.search(q_lower): |
| words = set(re.findall(r"\b\w+\b", q_lower)) |
| if words.intersection(_AGRI_SIGNALS): |
| return True |
| return False |
|
|
| |
| words = set(re.findall(r"\b\w+\b", q_lower)) |
| return bool(words.intersection(_AGRI_SIGNALS)) |
|
|
|
|
| _OFF_TOPIC_MSG = ( |
| "I can only help with agriculture-related questions — crops, pests, diseases, " |
| "fertilizers, mandi prices, irrigation, seeds, and farming advice. " |
| "Please ask a farming question and I'll be happy to help! 🌾\n\n" |
| "Main sirf kheti-baadi, fasal, keede-makode, bimari, khaad, mandi bhav, " |
| "aur kisan-sambandhit sawalon ka jawab de sakta hoon." |
| ) |
|
|
|
|
| |
| class QueryRequest(BaseModel): |
| query: str = Field(..., min_length=2, description="Farmer's question (any language)") |
| top_k: int = Field(5, ge=1, le=20) |
| min_score: float = Field(0.0, ge=0.0, le=1.0) |
| deduplicate: bool = Field(True) |
| history: List[dict] = Field(default_factory=list) |
| state: str = Field("", description="Farmer's state (e.g. Madhya Pradesh) — boosts regional results") |
| district: str = Field("", description="Farmer's district (e.g. Barwani) — boosts hyper-local results") |
|
|
|
|
| class SourceDoc(BaseModel): |
| rank: int |
| score: float |
| query: str |
| answer: str |
| crop: Optional[str] |
| state: Optional[str] |
| year: Optional[int] |
|
|
|
|
| class QueryResponse(BaseModel): |
| answer: str |
| sources: List[SourceDoc] |
| detected_crop: Optional[str] |
| problem_type: str |
| low_confidence: bool |
| off_topic: bool |
| retrieval_ms: float |
| generation_ms: float |
| model_backend: str |
| safety_warnings: List[str] = Field(default_factory=list) |
|
|
|
|
| class DiagnoseResponse(BaseModel): |
| crop: str |
| condition: str |
| problem_type: str |
| confidence: str |
| visible_symptoms: Optional[str] |
| treatment_advice: Optional[str] |
| sources: List[SourceDoc] |
|
|
|
|
| |
| class PestRiskRequest(BaseModel): |
| state: str = Field(..., description="Indian state name (English)") |
| district: str = Field("", description="District name (optional — uses spatial fallback)") |
| crop: str = Field(..., description="Crop name (Hindi or English)") |
| month: Optional[int] = Field(None, ge=1, le=12, description="Month (1-12). Default: current month.") |
| year: Optional[int] = Field(None, description="Year. Default: current year.") |
|
|
|
|
| class PestRiskResult(BaseModel): |
| pest: str |
| pest_cat: str |
| risk_score: int |
| risk_level: str |
| confidence: str |
| confidence_tier: Optional[dict] = None |
| model_auc: float |
| model_score: Optional[float] = None |
| model_version: str |
| recommended_action: str |
| spray: str |
| weather_driver: str |
| history_note: str |
| feature_drivers: List[str] |
| growth_stage: str |
| ndvi_index: int |
|
|
|
|
| class PestRiskResponse(BaseModel): |
| state: str |
| district: str |
| crop: str |
| month: int |
| predictions: List[PestRiskResult] |
| model_version: str |
| lead_time_note: str |
| retrieval_ms: float |
|
|
|
|
| |
| def _request_to_settings(req: QueryRequest) -> dict: |
| return {"top_k": req.top_k, "min_score": req.min_score, "deduplicate": req.deduplicate} |
|
|
|
|
| def _docs_to_schema(docs: List[RetrievedDoc]) -> List[SourceDoc]: |
| return [ |
| SourceDoc(rank=d.rank, score=round(d.score, 4), |
| query=d.query, answer=d.answer, |
| crop=d.crop, state=d.state, year=d.year) |
| for d in docs |
| ] |
|
|
|
|
| def _get_llm_backend() -> str: |
| """Return name of active LLM backend — always Groq for API mode.""" |
| return "groq" |
|
|
|
|
| def _trim_prompt_for_groq(prompt: str, max_chars: int = 22000) -> str: |
| """Hard-cap prompt size so it fits in Groq free-tier per-minute token budget. |
| Groq free tier limits: llama-4-scout 30K TPM, llama-3.3-70b 12K TPM, |
| llama-3.1-8b 6K TPM. ~4 chars/token → 22000 chars ≈ 5.5K tokens, safe for all. |
| Truncates the MIDDLE (retrieved context) and keeps the question intact.""" |
| if len(prompt) <= max_chars: |
| return prompt |
| q_marker = "FARMER'S QUESTION:" |
| q_idx = prompt.rfind(q_marker) |
| if q_idx == -1: |
| |
| return prompt[:max_chars] |
| head_budget = max_chars - (len(prompt) - q_idx) - 200 |
| if head_budget < 1000: |
| head_budget = 1000 |
| head = prompt[:head_budget] |
| tail = prompt[q_idx:] |
| print(f"[_trim_prompt] truncated {len(prompt)}→{len(head)+len(tail)} chars", flush=True) |
| return head + "\n\n[...context truncated to fit Groq free-tier budget...]\n\n" + tail |
|
|
|
|
| def _generate_full(prompt: str) -> str: |
| """Generate full response using Groq cascade. |
| llama-4-scout -> llama-3.3-70b -> llama-3.1-8b-instant -> Gemini -> kcc_llm local |
| """ |
| import groq as _groq |
| import time as _time |
|
|
| prompt = _trim_prompt_for_groq(prompt) |
|
|
| _GROQ_MODELS = [ |
| getattr(config, "GROQ_MODEL_PRIMARY", |
| "meta-llama/llama-4-scout-17b-16e-instruct"), |
| "llama-3.3-70b-versatile", |
| "llama-3.1-8b-instant", |
| ] |
| gc = _groq.Groq(api_key=config.GROQ_API_KEY) |
| for model in _GROQ_MODELS: |
| try: |
| resp = gc.chat.completions.create( |
| model=model, |
| messages=[{"role": "user", "content": prompt}], |
| max_tokens=600, |
| temperature=0.1, |
| ) |
| txt = resp.choices[0].message.content or "" |
| if txt.strip(): |
| return txt |
| except Exception as exc: |
| print(f"[_generate_full] groq {model} failed: {type(exc).__name__}: {str(exc)[:200]}", flush=True) |
| if "429" in str(exc) or "rate" in str(exc).lower(): |
| _time.sleep(2) |
| continue |
|
|
| |
| try: |
| import google.generativeai as genai |
| genai.configure(api_key=config.GEMINI_API_KEY) |
| for gm in ["gemini-2.0-flash", "gemini-1.5-flash"]: |
| try: |
| resp = genai.GenerativeModel(gm).generate_content( |
| prompt, |
| generation_config={"max_output_tokens": 600, "temperature": 0.1}, |
| ) |
| return resp.text |
| except Exception: |
| continue |
| except Exception: |
| pass |
|
|
| |
| try: |
| from kcc_llm import generate |
| return generate(prompt) |
| except Exception: |
| pass |
|
|
| return "Service temporarily unavailable. Please try again." |
|
|
|
|
| @app.post("/auth/token", response_model=_TokenResponse, tags=["Auth"]) |
| def login(req: _TokenRequest): |
| """Get JWT token for B2B dashboard and enterprise API access.""" |
| pwd = _B2B_USERS.get(req.username) |
| if not pwd or pwd != req.password: |
| raise HTTPException(status_code=401, detail="Invalid credentials") |
| token = _create_token(req.username) |
| return _TokenResponse(access_token=token, expires_in=_JWT_EXP_HOURS * 3600) |
|
|
|
|
| @app.get("/auth/me", tags=["Auth"]) |
| def me(user: str = Depends(_require_b2b)): |
| """Get current authenticated user info.""" |
| return { |
| "username": user, |
| "role": "b2b_enterprise", |
| "access": ["pest-risk", "price-forecast", "advisory"], |
| } |
|
|
|
|
| @app.get("/health", tags=["System"]) |
| def health(response: Response): |
| """Liveness + readiness probe. Returns 503 until retriever is warmed up.""" |
| backend = _get_llm_backend() |
| retriever_ready = False |
| try: |
| retriever_ready = _get_retriever() is not None |
| except Exception: |
| pass |
| payload = { |
| "status": "ok" if retriever_ready else "warming_up", |
| "retriever_ready": retriever_ready, |
| "llm_backend": backend, |
| "groq_model": config.GROQ_MODEL_PRIMARY, |
| "gemini_model": config.GEMINI_MODEL, |
| "api_keys_set": bool(config.GROQ_API_KEY and config.GEMINI_API_KEY), |
| "version": "2.0.0", |
| "environment": config.ENVIRONMENT, |
| "disclaimer": "AI advisory only — always consult a qualified agronomist before acting.", |
| } |
| if not retriever_ready: |
| return JSONResponse(content=payload, status_code=503) |
| return payload |
|
|
|
|
| @app.post("/query", response_model=QueryResponse, tags=["Chat"]) |
| def query(req: QueryRequest): |
| """ |
| Main chat endpoint. |
| - Topic guard: rejects non-agriculture questions immediately (no LLM cost) |
| - Multi-step FAISS retrieval over 16.5M KCC Q&A pairs |
| - Fine-tuned Llama-3.2-3B-Instruct (KCC domain expert) or Gemini fallback |
| """ |
| |
| if not is_agriculture_query(req.query): |
| return QueryResponse( |
| answer = _OFF_TOPIC_MSG, |
| sources = [], |
| detected_crop = None, |
| problem_type = "off_topic", |
| low_confidence= False, |
| off_topic = True, |
| retrieval_ms = 0.0, |
| generation_ms = 0.0, |
| model_backend = "guard", |
| ) |
|
|
| retriever = _get_retriever() |
| settings = _request_to_settings(req) |
|
|
| detected_crop = detect_crop(req.query) |
| problem_type = classify_problem(req.query) |
| normalized_q = normalize_query(req.query) |
|
|
| t0 = time.perf_counter() |
| docs = multi_step_retrieve(retriever, req.query, normalized_q, |
| detected_crop, problem_type, settings, |
| state=req.state, district=req.district) |
| ret_ms = (time.perf_counter() - t0) * 1000 |
|
|
| top_score = docs[0].score if docs else 0.0 |
| low_confidence = top_score < LOW_CONF_THRESHOLD |
|
|
| context = retriever.format_context(docs) |
| meta_lines = [] |
| if detected_crop: |
| meta_lines.append(f"DETECTED CROP: {detected_crop}") |
| if problem_type != "general": |
| meta_lines.append(f"PROBLEM TYPE: {problem_type.upper()}") |
| safety_note = SAFETY_GUARDRAILS.get(problem_type, "") |
| if safety_note: |
| meta_lines.append(safety_note) |
| if low_confidence: |
| meta_lines.append( |
| f"LOW CONFIDENCE (top score: {top_score*100:.0f}%) — " |
| "ask the farmer 1-2 clarifying questions instead of guessing." |
| ) |
| if meta_lines: |
| context = "\n".join(meta_lines) + "\n\n" + context |
|
|
| _lang = detect_language(req.query) |
| prompt = _build_prompt(req.query, context, req.history, problem_type, _lang) |
| prompt = _trim_prompt_for_groq(prompt) |
|
|
| t1 = time.perf_counter() |
| |
| |
| |
| try: |
| from kcc_core.llm import generate_with_meta |
| from kcc_core.citation_guard import review as _cite_review |
| _out = generate_with_meta(prompt, max_tokens=600, temperature=0.1) |
| answer = _out["text"] |
| backend = _out["backend"] |
| if not answer: |
| |
| answer = _generate_full(prompt) |
| backend = _get_llm_backend() |
| except Exception as exc: |
| |
| try: |
| answer = _generate_full(prompt) |
| backend = _get_llm_backend() |
| except Exception as exc2: |
| raise HTTPException(status_code=502, detail=f"LLM error: {exc2}") from exc2 |
| gen_ms = (time.perf_counter() - t1) * 1000 |
|
|
| |
| |
| try: |
| from kcc_core.citation_guard import review as _cite_review |
| answer, safety_warnings = _cite_review(answer, problem_type=problem_type) |
| except Exception: |
| safety_warnings = [] |
|
|
| return QueryResponse( |
| answer = answer, |
| sources = _docs_to_schema(docs), |
| detected_crop = detected_crop, |
| problem_type = problem_type, |
| low_confidence= low_confidence, |
| off_topic = False, |
| retrieval_ms = round(ret_ms, 1), |
| generation_ms = round(gen_ms, 1), |
| model_backend = backend, |
| safety_warnings = safety_warnings, |
| ) |
|
|
|
|
| @app.post("/query/stream", tags=["Chat"]) |
| async def query_stream(req: QueryRequest): |
| """Streaming chat (Server-Sent Events). Topic-guarded like /query.""" |
|
|
| |
| if not is_agriculture_query(req.query): |
| async def _off_topic(): |
| yield 'data: ' + _OFF_TOPIC_MSG.replace(chr(10), chr(92)+'n') + chr(10)+chr(10) |
| yield "data: [DONE]\n\n" |
| return StreamingResponse(_off_topic(), media_type="text/event-stream") |
|
|
| retriever = _get_retriever() |
| settings = _request_to_settings(req) |
|
|
| detected_crop = detect_crop(req.query) |
| problem_type = classify_problem(req.query) |
| normalized_q = normalize_query(req.query) |
|
|
| docs = multi_step_retrieve(retriever, req.query, normalized_q, |
| detected_crop, problem_type, settings, |
| state=req.state, district=req.district) |
|
|
| top_score = docs[0].score if docs else 0.0 |
| low_confidence = top_score < LOW_CONF_THRESHOLD |
|
|
| context = retriever.format_context(docs) |
| meta_lines = [] |
| if detected_crop: |
| meta_lines.append(f"DETECTED CROP: {detected_crop}") |
| if problem_type != "general": |
| meta_lines.append(f"PROBLEM TYPE: {problem_type.upper()}") |
| safety_note = SAFETY_GUARDRAILS.get(problem_type, "") |
| if safety_note: |
| meta_lines.append(safety_note) |
| if low_confidence: |
| meta_lines.append( |
| f"LOW CONFIDENCE (top score: {top_score*100:.0f}%) — ask clarifying questions." |
| ) |
| if meta_lines: |
| context = "\n".join(meta_lines) + "\n\n" + context |
|
|
| _lang = detect_language(req.query) |
| prompt = _build_prompt(req.query, context, req.history, problem_type, _lang) |
| prompt = _trim_prompt_for_groq(prompt) |
|
|
| async def _event_stream() -> AsyncIterator[str]: |
| |
| |
| |
| _full_answer: list[str] = [] |
| _delivered_done = False |
|
|
| try: |
| |
| import groq as _groq |
| _GROQ_STREAM_MODELS = [ |
| getattr(config, "GROQ_MODEL_PRIMARY", "meta-llama/llama-4-scout-17b-16e-instruct"), |
| "llama-3.3-70b-versatile", |
| "llama-3.1-8b-instant", |
| ] |
| gc = _groq.Groq(api_key=config.GROQ_API_KEY) |
| for _gm in _GROQ_STREAM_MODELS: |
| try: |
| stream = gc.chat.completions.create( |
| model=_gm, |
| messages=[{"role": "user", "content": prompt}], |
| max_tokens=600, |
| stream=True, |
| ) |
| for chunk in stream: |
| token = (chunk.choices[0].delta.content or "") if chunk.choices else "" |
| if token: |
| _full_answer.append(token) |
| _tok = token.replace(chr(10), chr(92)+"n") |
| yield "data: " + _tok + chr(10)+chr(10) |
| |
| _full = "".join(_full_answer) |
| try: |
| from kcc_core.citation_guard import review as _cite_review |
| import json as _json |
| _filtered, _warnings = _cite_review(_full, problem_type=problem_type) |
| if _warnings: |
| yield "data: SAFETY:" + _json.dumps(_warnings) + chr(10)+chr(10) |
| except Exception: |
| pass |
| yield "data: [DONE]\n\n" |
| _delivered_done = True |
| return |
| except Exception as _e: |
| print(f"[query/stream] groq stream {_gm} failed: {_e!r}", flush=True) |
| continue |
| except Exception as _e: |
| print(f"[query/stream] groq init failed: {_e!r}", flush=True) |
|
|
| |
| |
| |
| if not _full_answer: |
| try: |
| from kcc_core.llm import generate_with_meta as _kcc_gen |
| _result = _kcc_gen(prompt, max_tokens=600) |
| _text = (_result or {}).get("text", "") |
| if _text: |
| _full_answer.append(_text) |
| |
| _i = 0 |
| while _i < len(_text): |
| _chunk = _text[_i:_i+48] |
| _i += 48 |
| _ctxt = _chunk.replace(chr(10), chr(92)+'n') |
| yield 'data: ' + _ctxt + chr(10)+chr(10) |
| |
| try: |
| from kcc_core.citation_guard import review as _cite_review |
| import json as _json |
| _filtered, _warnings = _cite_review(_text, problem_type=problem_type) |
| if _warnings: |
| yield "data: SAFETY:" + _json.dumps(_warnings) + chr(10)+chr(10) |
| except Exception: |
| pass |
| yield "data: [DONE]\n\n" |
| _delivered_done = True |
| return |
| except Exception as _e: |
| print(f"[query/stream] groq non-stream fallback failed: {_e!r}", flush=True) |
|
|
| |
| |
| if not _full_answer: |
| try: |
| _text = _generate_full(prompt) |
| if _text and _text.strip() and "Service temporarily" not in _text: |
| _full_answer.append(_text) |
| _i = 0 |
| while _i < len(_text): |
| _chunk = _text[_i:_i+48]; _i += 48 |
| _ctxt = _chunk.replace(chr(10), chr(92)+'n') |
| yield 'data: ' + _ctxt + chr(10)+chr(10) |
| try: |
| from kcc_core.citation_guard import review as _cite_review |
| import json as _json |
| _filtered, _warnings = _cite_review(_text, problem_type=problem_type) |
| if _warnings: |
| yield "data: SAFETY:" + _json.dumps(_warnings) + chr(10)+chr(10) |
| except Exception: |
| pass |
| yield "data: [DONE]\n\n" |
| _delivered_done = True |
| return |
| except Exception as _e: |
| print(f"[query/stream] _generate_full fallback failed: {_e!r}", flush=True) |
|
|
| |
| |
| try: |
| from kcc_core.llm import _get_gemini_client as _kcc_gemini_client |
| client = _kcc_gemini_client() |
| if client is None or not hasattr(client, "models"): |
| |
| raise RuntimeError("gemini_unavailable") |
| for chunk in client.models.generate_content_stream( |
| model=config.GEMINI_MODEL, contents=prompt |
| ): |
| if getattr(chunk, "text", None): |
| _full_answer.append(chunk.text) |
| _ctxt = chunk.text.replace(chr(10), chr(92)+'n') |
| yield 'data: ' + _ctxt + chr(10)+chr(10) |
| except Exception: |
| |
| |
| if not _full_answer: |
| _msg = ("Service temporarily unavailable. Please try again in a moment, " |
| "or rephrase your question.") |
| yield 'data: ' + _msg + chr(10)+chr(10) |
| finally: |
| if not _delivered_done: |
| |
| _full = "".join(_full_answer) |
| if _full: |
| try: |
| from kcc_core.citation_guard import review as _cite_review |
| import json as _json |
| _filtered, _warnings = _cite_review(_full, problem_type=problem_type) |
| if _warnings: |
| yield "data: SAFETY:" + _json.dumps(_warnings) + chr(10)+chr(10) |
| except Exception: |
| pass |
| yield "data: [DONE]\n\n" |
|
|
| return StreamingResponse(_event_stream(), media_type="text/event-stream") |
|
|
|
|
| @app.post("/pest-risk", response_model=PestRiskResponse, tags=["Pest Prediction"]) |
| def pest_risk(req: PestRiskRequest, user: _Opt[str] = Depends(_get_current_user)): |
| """ |
| District-level pest risk prediction. |
| |
| Returns 1-month early warning: predictions represent expected pest pressure |
| NEXT month based on current/upcoming weather conditions. |
| |
| Model: LGB + XGB + CatBoost stacking ensemble (AUC 0.936) |
| Coverage: 475 districts across 28 states. |
| Uncovered districts: automatic spatial fallback to nearest covered district. |
| """ |
| import datetime |
| t0 = time.perf_counter() |
|
|
| |
| |
| |
| |
| if not req.district and user is None: |
| raise HTTPException( |
| status_code=401, |
| detail="State-only pest aggregates require B2B authentication. " |
| "POST /auth/token to get a token, or supply 'district' for " |
| "a single-block lookup.", |
| ) |
|
|
| month = req.month or datetime.datetime.now().month |
| year = req.year or datetime.datetime.now().year |
|
|
| try: |
| from mandi_advisor.pest_predictor import predict_pest_risk |
| raw_results = predict_pest_risk( |
| state = req.state, |
| crop = req.crop, |
| district = req.district or None, |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Pest prediction error: {e}") |
|
|
| ms = (time.perf_counter() - t0) * 1000 |
|
|
| predictions = [] |
| for r in raw_results: |
| predictions.append(PestRiskResult( |
| pest = r.get("pest", ""), |
| pest_cat = r.get("pest_cat", ""), |
| risk_score = r.get("risk_score", 0), |
| risk_level = r.get("risk_level", "LOW"), |
| confidence = r.get("confidence", "Low"), |
| confidence_tier = r.get("confidence_tier"), |
| model_auc = r.get("model_auc", 0.0), |
| model_score = r.get("model_score"), |
| model_version = r.get("model_version", ""), |
| recommended_action = r.get("recommended_action", ""), |
| spray = r.get("spray", ""), |
| weather_driver = r.get("weather_driver", ""), |
| history_note = r.get("history_note", ""), |
| feature_drivers = r.get("feature_drivers", []), |
| growth_stage = r.get("growth_stage", ""), |
| ndvi_index = r.get("ndvi_index", 50), |
| )) |
|
|
| model_ver = raw_results[0].get("model_version", "unknown") if raw_results else "unknown" |
|
|
| return PestRiskResponse( |
| state = req.state, |
| district = req.district or "auto (spatial fallback)", |
| crop = req.crop, |
| month = month, |
| predictions = predictions, |
| model_version = model_ver, |
| lead_time_note= ( |
| "v2: Predictions represent NEXT month's pest risk based on current weather. " |
| "1-month early warning allows preventive spray before damage occurs." |
| if "v2" in model_ver else |
| "v1: Current-month pest risk. Upgrade to v2 for 1-month early warning." |
| ), |
| retrieval_ms = round(ms, 1), |
| ) |
|
|
|
|
| @app.post("/diagnose", response_model=DiagnoseResponse, tags=["Image Diagnosis"]) |
| async def diagnose(image: UploadFile = File(...)): |
| """Plant disease diagnosis from crop photo (Gemini Vision).""" |
| retriever = _get_retriever() |
| image_bytes = await image.read() |
| if len(image_bytes) > 10 * 1024 * 1024: |
| raise HTTPException(status_code=413, detail="Image too large (max 10 MB)") |
|
|
| try: |
| diagnosis = _diagnose_image_gemini(image_bytes) |
| except Exception as exc: |
| raise HTTPException(status_code=502, detail=f"Vision model error: {exc}") |
|
|
| if not diagnosis or not diagnosis.get("crop"): |
| raise HTTPException(status_code=502, detail="Vision model returned no diagnosis") |
|
|
| crop_name = diagnosis.get("crop", "Unknown") |
| condition = diagnosis.get("condition", "Unknown") |
| problem_type = diagnosis.get("problem_type", "disease") |
| confidence = diagnosis.get("confidence", "Medium") |
| symptoms = diagnosis.get("visible_symptoms", "") |
| treatment_answer = None |
| docs: list[RetrievedDoc] = [] |
|
|
| if problem_type != "healthy" and condition.lower() != "healthy": |
| auto_query = f"{crop_name} {condition} control treatment" |
| settings = {"top_k": 5, "min_score": 0.0, "deduplicate": True} |
| docs = multi_step_retrieve(retriever, auto_query, normalize_query(auto_query), |
| crop_name, problem_type, settings) |
| if docs: |
| context = retriever.format_context(docs) |
| safety = SAFETY_GUARDRAILS.get(problem_type, "") |
| meta = ( |
| f"DETECTED CROP: {crop_name}\nDETECTED CONDITION: {condition} " |
| f"(confidence: {confidence})\nVISIBLE SYMPTOMS: {symptoms}\n" |
| f"PROBLEM TYPE: {problem_type.upper()}\n{safety}\n\n" |
| ) |
| prompt = _build_prompt( |
| f"My {crop_name} has {condition}. What should I do?", |
| meta + context, [], problem_type, "English" |
| ) |
| try: |
| treatment_answer = _generate_full(prompt) |
| except Exception: |
| treatment_answer = None |
|
|
| return DiagnoseResponse( |
| crop=crop_name, condition=condition, problem_type=problem_type, |
| confidence=confidence, visible_symptoms=symptoms, |
| treatment_advice=treatment_answer, sources=_docs_to_schema(docs), |
| ) |
|
|
|
|
|
|
|
|
| |
| import datetime as _dt |
| _START_TIME = _dt.datetime.now(_dt.timezone.utc) |
| _REQUEST_COUNTS: dict = {"query": 0, "pest_risk": 0, "diagnose": 0, "errors": 0} |
|
|
| @app.get("/metrics", tags=["System"]) |
| def metrics(): |
| """Lightweight metrics endpoint for enterprise monitoring.""" |
| uptime_s = (_dt.datetime.now(_dt.timezone.utc) - _START_TIME).total_seconds() |
| retriever_ok = False |
| try: |
| r = _get_retriever() |
| retriever_ok = r is not None |
| except Exception: |
| pass |
|
|
| return { |
| "status": "ok", |
| "uptime_seconds": round(uptime_s), |
| "uptime_human": f"{int(uptime_s//3600)}h {int((uptime_s%3600)//60)}m", |
| "requests": _REQUEST_COUNTS, |
| "retriever_loaded": retriever_ok, |
| "models": { |
| "llm_primary": config.GROQ_MODEL_PRIMARY, |
| "llm_fallback": config.GROQ_MODEL_FALLBACK, |
| "gemini": config.GEMINI_MODEL, |
| "reranker": config.RERANKER_MODEL, |
| "embedding": config.EMBEDDING_MODEL, |
| }, |
| "version": "2.0.0", |
| "environment": config.ENVIRONMENT, |
| "timestamp": _dt.datetime.now(_dt.timezone.utc).isoformat(), |
| } |
|
|
|
|
| _RETRIEVER_READY = False |
|
|
| @app.on_event("startup") |
| async def _warm_retriever(): |
| """Pre-load FAISS+BM25 index at startup. Sets _RETRIEVER_READY=True when done.""" |
| import threading |
| global _RETRIEVER_READY |
| def _load(): |
| global _RETRIEVER_READY |
| try: |
| _get_retriever() |
| _RETRIEVER_READY = True |
| print("[API] Retriever warmed up successfully", flush=True) |
| except Exception as ex: |
| print(f"[API] Retriever warmup failed: {ex}", flush=True) |
| threading.Thread(target=_load, daemon=True).start() |
|
|
|
|
|
|
| |
| @app.get("/crops/static", tags=["Data"]) |
| def crops_static(): |
| """Return all static crop data for React frontend.""" |
| from step4_app import _VARIETY_DATA, _INPUT_COST, _SOIL_PREP, _MSP_2025 |
| return { |
| "varieties": _VARIETY_DATA, |
| "input_costs": _INPUT_COST, |
| "soil_prep": _SOIL_PREP, |
| "msp_2025": _MSP_2025, |
| } |
|
|
| |
| class _PriceForecastRequest(BaseModel): |
| crop: str = "" |
| state: str = "" |
|
|
| @app.post("/price-forecast", tags=["Models"]) |
| def price_forecast(req: _PriceForecastRequest): |
| """Get presow_v4 price forecast for a crop+state combo (p25/p50/p75).""" |
| from mandi_advisor.enterprise_engine_v2 import get_presow_signal |
| |
| _CROP_MAP = { |
| "Soybean": "Soyabean", |
| "Arhar": "Arhar/Tur Dal", |
| "Gram": "Bengal Gram(Gram)(Whole)", |
| "Mustard": "Rapeseed/Mustard(Toria)", |
| } |
| crop = _CROP_MAP.get(req.crop, req.crop) |
| try: |
| result = get_presow_signal(crop, req.state) |
| return result |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| |
| @app.get("/weather", tags=["Data"]) |
| def weather(lat: float = 20.5, lon: float = 78.9): |
| """Get 3-day weather forecast from Open-Meteo (free, no key needed).""" |
| import httpx |
| url = ( |
| f"https://api.open-meteo.com/v1/forecast" |
| f"?latitude={lat}&longitude={lon}" |
| f"&daily=temperature_2m_max,temperature_2m_min,precipitation_sum,weathercode" |
| f"¤t_weather=true&timezone=Asia%2FKolkata&forecast_days=3" |
| ) |
| try: |
| r = httpx.get(url, timeout=10) |
| return r.json() |
| except Exception as e: |
| return {"error": str(e)} |
|
|
|
|
|
|
| |
| import os as _os2 |
| from pathlib import Path as _Path |
|
|
| _STATIC_DIR = _Path(__file__).parent / "static" |
| if _STATIC_DIR.exists(): |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import FileResponse |
|
|
| app.mount("/assets", StaticFiles(directory=str(_STATIC_DIR / "assets")), name="assets") |
|
|
| @app.get("/", include_in_schema=False) |
| @app.get("/{full_path:path}", include_in_schema=False) |
| def serve_spa(full_path: str = ""): |
| """Serve React SPA — all non-API routes return index.html.""" |
| |
| if full_path.startswith("api/") or full_path.startswith("docs") or full_path.startswith("openapi"): |
| raise HTTPException(status_code=404) |
| index = _STATIC_DIR / "index.html" |
| if index.exists(): |
| return FileResponse(str(index)) |
| return {"message": "KCC AgriAdvisor API running. Frontend not built yet."} |
| |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) |