kcc-agri / api.py
hritikm15's picture
Pest: require district selection + restore B2B auth gate
16ed560 verified
"""
api.py
======
FastAPI REST wrapper around the KCC RAG pipeline.
Endpoints
---------
POST /query — main chat endpoint (JSON in, JSON out)
POST /query/stream — streaming chat (Server-Sent Events)
POST /pest-risk — district-level pest risk prediction (NEW v2)
POST /diagnose — image disease diagnosis (upload image file)
GET /health — liveness probe
Usage (development)
-------------------
uvicorn api:app --host 0.0.0.0 --port 8000 --reload
Usage (production)
------------------
uvicorn api:app --host 0.0.0.0 --port 8000 --workers 2
"""
import os as _os, glob as _glob
from pathlib import Path as _Path
# ── HF Spaces: download large files from private HF Dataset on cold start ─────
def _ensure_large_files():
_BASE = _Path(_os.path.abspath(__file__)).parent
_TOKEN = _os.environ.get("HF_TOKEN")
_DATASET = "hritikm15/kcc-data"
_LARGE_FILES = [
"index/faiss.index",
"index/metadata.parquet",
"index/bm25_chunk_aa",
"index/bm25_chunk_ab",
"index/bm25_chunk_ac",
"index/bm25_chunk_ad",
"mandi_advisor/presow_v4_model.pkl",
"mandi_advisor/presow_v4_meta.pkl",
"mandi_advisor/feature_store.parquet",
"pest_model/pest_risk_model_district_v2.pkl",
]
_missing = [f for f in _LARGE_FILES if not (_BASE / f).exists() or (_BASE / f).stat().st_size < 1000]
if _missing:
print(f"[startup] {len(_missing)} large files missing — downloading from HF Dataset...", flush=True)
from huggingface_hub import hf_hub_download
for _repo_file in _missing:
_local = _BASE / _repo_file
_local.parent.mkdir(parents=True, exist_ok=True)
print(f"[startup] Downloading {_repo_file}...", flush=True)
hf_hub_download(
repo_id=_DATASET,
filename=_repo_file,
repo_type="dataset",
token=_TOKEN,
local_dir=str(_BASE),
)
print(f"[startup] {_repo_file} done ({_local.stat().st_size // 1_000_000}MB)", flush=True)
# Reassemble bm25.db from chunks if needed
_BM25 = _BASE / "index" / "bm25.db"
_CHUNKS = sorted((_BASE / "index").glob("bm25_chunk_*"))
if _CHUNKS and (not _BM25.exists() or _BM25.stat().st_size < 1_000_000):
print(f"[startup] Reassembling bm25.db from {len(_CHUNKS)} chunks...", flush=True)
with open(_BM25, "wb") as _out:
for _chunk in _CHUNKS:
with open(_chunk, "rb") as _f:
while True:
_buf = _f.read(64 * 1024 * 1024)
if not _buf: break
_out.write(_buf)
print(f"[startup] bm25.db ready ({_BM25.stat().st_size // 1_000_000}MB)", flush=True)
_ensure_large_files()
# ─────────────────────────────────────────────────────────────────────────────
import io
import re
import sys
import time
from pathlib import Path
from typing import AsyncIterator, List, Optional
from fastapi import FastAPI, File, HTTPException, Response, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel, Field
sys.path.insert(0, str(Path(__file__).parent))
import config
from step3_retrieval import KCCRetriever, RetrievedDoc, get_retriever
from step4_app import (
SAFETY_GUARDRAILS,
LOW_CONF_THRESHOLD,
classify_problem,
detect_crop,
multi_step_retrieve,
normalize_query,
_build_prompt,
detect_language,
_diagnose_image_gemini,
_get_gemini_client,
)
# ── app ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title="KCC Agricultural Chatbot API",
description=(
"Enterprise REST API over 16.5 M Kisan Call Center Q&A pairs. "
"Pest predictions: district-level stacking model (AUC 0.936), 1-month lead time. "
"Supports Hindi, English, Telugu, Kannada, Marathi."
),
version="2.0.0",
)
# Load config for allowed origins
import config as _cfg
import os
# ── JWT Authentication ─────────────────────────────────────────────────────────
import hashlib, secrets as _secrets
from datetime import timedelta as _timedelta, datetime as _datetime, timezone as _timezone
from typing import Optional as _Opt
try:
from jose import JWTError, jwt as _jwt
_JWT_AVAILABLE = True
except ImportError:
_JWT_AVAILABLE = False
_JWT_SECRET = os.environ.get("JWT_SECRET", _secrets.token_hex(32))
_JWT_ALGO = "HS256"
_JWT_EXP_HOURS = 24
_B2B_USERS = {
"enterprise": os.environ.get("B2B_DEMO_PASSWORD", "KCC@Demo2025"),
"admin": os.environ.get("ADMIN_PASSWORD", "Admin@KCC2025"),
}
class _TokenRequest(BaseModel):
username: str
password: str
class _TokenResponse(BaseModel):
access_token: str
token_type: str = "bearer"
expires_in: int = _JWT_EXP_HOURS * 3600
def _create_token(username: str) -> str:
if not _JWT_AVAILABLE:
return f"mock-token-{username}"
exp = _datetime.now(_timezone.utc) + _timedelta(hours=_JWT_EXP_HOURS)
return _jwt.encode({"sub": username, "exp": exp}, _JWT_SECRET, algorithm=_JWT_ALGO)
def _verify_token(token: str) -> _Opt[str]:
if not _JWT_AVAILABLE:
return token.replace("mock-token-", "") if token.startswith("mock-token-") else None
try:
payload = _jwt.decode(token, _JWT_SECRET, algorithms=[_JWT_ALGO])
return payload.get("sub")
except Exception:
return None
from fastapi import Depends, Header
from fastapi.security import HTTPBearer as _HTTPBearer, HTTPAuthorizationCredentials as _Creds
_security = _HTTPBearer(auto_error=False)
def _get_current_user(creds: _Opt[_Creds] = Depends(_security)) -> _Opt[str]:
if creds is None:
return None
return _verify_token(creds.credentials)
def _require_b2b(user: _Opt[str] = Depends(_get_current_user)):
if user is None:
raise HTTPException(status_code=401,
detail="B2B endpoint requires authentication. POST /auth/token to get a token.")
return user
_ALLOWED_ORIGINS = _cfg.ALLOWED_ORIGINS or ["http://localhost:8501", "http://localhost:8000"]
app.add_middleware(
CORSMiddleware,
allow_origins=_ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["Content-Type", "Authorization", "X-API-Key", "X-Request-ID"],
expose_headers=["X-Request-ID"],
)
# ── Rate limiting (slowapi) ──────────────────────────────────────────────────
# Soft-import so the app boots even when slowapi isn't installed.
try:
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi import Request as _SlowReq
_RATE_LIMIT_OK = True
except ImportError:
_RATE_LIMIT_OK = False
if _RATE_LIMIT_OK:
_limiter = Limiter(
key_func=get_remote_address,
default_limits=[
os.environ.get("RATE_LIMIT_DEFAULT", "120/minute"),
os.environ.get("RATE_LIMIT_HOUR", "2000/hour"),
],
)
app.state.limiter = _limiter
@app.exception_handler(RateLimitExceeded)
async def _rate_limit_handler(request: _SlowReq, exc: RateLimitExceeded):
return JSONResponse(status_code=429,
content={"detail": f"Rate limit exceeded: {exc.detail}"})
# ── Request-ID middleware ────────────────────────────────────────────────────
# Attaches a UUID to every request and echoes it back via X-Request-ID
# header. Lets ops correlate logs across services.
import uuid as _uuid
from fastapi import Request as _ReqIDReq
@app.middleware("http")
async def _request_id_middleware(request: _ReqIDReq, call_next):
request_id = request.headers.get("X-Request-ID") or str(_uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
# ── singleton retriever ───────────────────────────────────────────────────────
_retriever: Optional[KCCRetriever] = None
def _get_retriever() -> KCCRetriever:
global _retriever
if _retriever is None:
_retriever = get_retriever(
index_path = config.FAISS_INDEX_FILE,
metadata_path = config.METADATA_FILE,
)
return _retriever
# ── ★ Agriculture topic guard ────────────────────────────────────────────────
# Fast regex + signal word check. Blocks non-agriculture queries before
# expensive FAISS retrieval and LLM generation.
_AGRI_SIGNALS = {
# Crops (incl. inflected forms — fixes "sown"/"grown"/"growing" false off-topic)
"crop","crops","plant","plants","planting","seed","seeds","soil","fertilizer",
"pesticide","pest","disease","farm","farmer","farming","kisan","kheti","fasal",
"gehu","gehun","dhan","kapas","tamatar","aloo","pyaz","mandi","bhav","rate",
"harvest","harvesting","sow","sown","sowing","grow","grows","growing","grown",
"variety","varieties","cultivar","cultivars","season","seasonal","kharif","rabi","zaid",
"irrigation","spray","insecticide","fungicide","organic","yield","yields","blight",
"wheat","rice","paddy","cotton","maize","sugarcane","soybean","mustard",
"chilli","brinjal","onion","tomato","potato","groundnut","gram","pulses","pulse",
"barley","jowar","bajra","ragi","arhar","tur","moong","urad","masoor","chana",
# Pests / diseases
"aphid","borer","mildew","rust","wilt","thrips","mite","whitefly",
"caterpillar","jassid","leaf spot","mosaic","virus","rot","fungus",
"blast","armyworm","bollworm","helicoverpa","spodoptera","girdle",
# Hindi / regional
"khaad","dawai","beej","sinchai","pattiya","pattiyan","keeda","bimari",
"rog","upchar","khet","paidavar","safed makhi","tela","mahu","tikda",
# Weather for crops
"pala","thand","frost","baarish","drought","sukha","heat stress",
"ola","flood","andhi",
# Schemes / advisory
"kvk","icar","pm kisan","fasal bima","kcc","kisan call","advisory",
# Chemicals (banned-chemical questions ARE agri queries)
"endosulfan","monocrotophos","chlorpyrifos","imidacloprid","emamectin",
"mancozeb","propiconazole","thiamethoxam","rhizobium","urea","dap",
# Devanagari script keywords (Indian farmers typing in Hindi)
"फसल","खेत","किसान","खाद","बीज","धान","गेहूं","कपास",
"मक्का","सरसों","चना","अरहर","मूंग","प्याज","आलू",
"कीड़ा","बीमारी","दवाइ","स्प्रे","उपचार","सिंचाई",
}
_NON_AGRI_RE = re.compile(
r"\b(stock market|share market|bitcoin|crypto|politics|election|"
r"movie|cricket|football|recipe|cooking|exam|bank account|insurance claim|"
r"marriage|divorce|job|salary|relationship|celebrity|news|tv show|"
r"web series|ipl|bollywood|actor|actress|pakistan|china|war|army|"
r"love|dating|girlfriend|boyfriend|password|hack|code)\b",
re.IGNORECASE,
)
_SHORT_FOLLOWUP_RE = re.compile(r"^(ha|haan|yes|ok|okay|nahi|no|theek|accha|sahi|nahi ji|ji|acha)\b", re.IGNORECASE)
def is_agriculture_query(query: str) -> bool:
"""
Returns True if query is agriculture-related.
Short follow-ups (yes/no/ha) always pass (they're contextual replies).
"""
stripped = query.strip()
q_lower_short = stripped.lower()
# Short follow-ups in a conversation → allow ONLY if no hard non-agri keyword.
# ("best Bollywood movies" is 3 words — must NOT bypass.)
if len(stripped.split()) <= 3 and not _NON_AGRI_RE.search(q_lower_short):
return True
# Devanagari script → almost certainly an Indian farmer query
if any('ऀ' <= c <= 'ॿ' for c in stripped):
if _NON_AGRI_RE.search(stripped.lower()):
return False # hard non-agri even in Devanagari
return True
q_lower = stripped.lower()
# Hard non-agriculture signal + no agriculture override
if _NON_AGRI_RE.search(q_lower):
words = set(re.findall(r"\b\w+\b", q_lower))
if words.intersection(_AGRI_SIGNALS):
return True # e.g. "cotton futures" — agri topic with market signal
return False
# Any agriculture signal → allow
words = set(re.findall(r"\b\w+\b", q_lower))
return bool(words.intersection(_AGRI_SIGNALS))
_OFF_TOPIC_MSG = (
"I can only help with agriculture-related questions — crops, pests, diseases, "
"fertilizers, mandi prices, irrigation, seeds, and farming advice. "
"Please ask a farming question and I'll be happy to help! 🌾\n\n"
"Main sirf kheti-baadi, fasal, keede-makode, bimari, khaad, mandi bhav, "
"aur kisan-sambandhit sawalon ka jawab de sakta hoon."
)
# ── request / response schemas ────────────────────────────────────────────────
class QueryRequest(BaseModel):
query: str = Field(..., min_length=2, description="Farmer's question (any language)")
top_k: int = Field(5, ge=1, le=20)
min_score: float = Field(0.0, ge=0.0, le=1.0)
deduplicate: bool = Field(True)
history: List[dict] = Field(default_factory=list)
state: str = Field("", description="Farmer's state (e.g. Madhya Pradesh) — boosts regional results")
district: str = Field("", description="Farmer's district (e.g. Barwani) — boosts hyper-local results")
class SourceDoc(BaseModel):
rank: int
score: float
query: str
answer: str
crop: Optional[str]
state: Optional[str]
year: Optional[int]
class QueryResponse(BaseModel):
answer: str
sources: List[SourceDoc]
detected_crop: Optional[str]
problem_type: str
low_confidence: bool
off_topic: bool
retrieval_ms: float
generation_ms: float
model_backend: str
safety_warnings: List[str] = Field(default_factory=list)
class DiagnoseResponse(BaseModel):
crop: str
condition: str
problem_type: str
confidence: str
visible_symptoms: Optional[str]
treatment_advice: Optional[str]
sources: List[SourceDoc]
# ── ★ Pest risk request/response schemas ─────────────────────────────────────
class PestRiskRequest(BaseModel):
state: str = Field(..., description="Indian state name (English)")
district: str = Field("", description="District name (optional — uses spatial fallback)")
crop: str = Field(..., description="Crop name (Hindi or English)")
month: Optional[int] = Field(None, ge=1, le=12, description="Month (1-12). Default: current month.")
year: Optional[int] = Field(None, description="Year. Default: current year.")
class PestRiskResult(BaseModel):
pest: str
pest_cat: str
risk_score: int # 0-100
risk_level: str # NEGLIGIBLE / LOW / MEDIUM / HIGH / CRITICAL
confidence: str # Very High / High / Moderate / Low / Very Low
confidence_tier: Optional[dict] = None # {label, color, auc, note}
model_auc: float
model_score: Optional[float] = None # raw model probability
model_version: str
recommended_action: str
spray: str
weather_driver: str
history_note: str
feature_drivers: List[str]
growth_stage: str
ndvi_index: int
class PestRiskResponse(BaseModel):
state: str
district: str
crop: str
month: int
predictions: List[PestRiskResult]
model_version: str
lead_time_note: str
retrieval_ms: float
# ── helpers ───────────────────────────────────────────────────────────────────
def _request_to_settings(req: QueryRequest) -> dict:
return {"top_k": req.top_k, "min_score": req.min_score, "deduplicate": req.deduplicate}
def _docs_to_schema(docs: List[RetrievedDoc]) -> List[SourceDoc]:
return [
SourceDoc(rank=d.rank, score=round(d.score, 4),
query=d.query, answer=d.answer,
crop=d.crop, state=d.state, year=d.year)
for d in docs
]
def _get_llm_backend() -> str:
"""Return name of active LLM backend — always Groq for API mode."""
return "groq"
def _trim_prompt_for_groq(prompt: str, max_chars: int = 22000) -> str:
"""Hard-cap prompt size so it fits in Groq free-tier per-minute token budget.
Groq free tier limits: llama-4-scout 30K TPM, llama-3.3-70b 12K TPM,
llama-3.1-8b 6K TPM. ~4 chars/token → 22000 chars ≈ 5.5K tokens, safe for all.
Truncates the MIDDLE (retrieved context) and keeps the question intact."""
if len(prompt) <= max_chars:
return prompt
q_marker = "FARMER'S QUESTION:"
q_idx = prompt.rfind(q_marker)
if q_idx == -1:
# No marker — just chop the tail
return prompt[:max_chars]
head_budget = max_chars - (len(prompt) - q_idx) - 200 # leave 200 char headroom
if head_budget < 1000:
head_budget = 1000
head = prompt[:head_budget]
tail = prompt[q_idx:]
print(f"[_trim_prompt] truncated {len(prompt)}{len(head)+len(tail)} chars", flush=True)
return head + "\n\n[...context truncated to fit Groq free-tier budget...]\n\n" + tail
def _generate_full(prompt: str) -> str:
"""Generate full response using Groq cascade.
llama-4-scout -> llama-3.3-70b -> llama-3.1-8b-instant -> Gemini -> kcc_llm local
"""
import groq as _groq
import time as _time
prompt = _trim_prompt_for_groq(prompt)
_GROQ_MODELS = [
getattr(config, "GROQ_MODEL_PRIMARY",
"meta-llama/llama-4-scout-17b-16e-instruct"),
"llama-3.3-70b-versatile",
"llama-3.1-8b-instant", # was: gemma2-9b-it (decommissioned 2026-05)
]
gc = _groq.Groq(api_key=config.GROQ_API_KEY)
for model in _GROQ_MODELS:
try:
resp = gc.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=600,
temperature=0.1,
)
txt = resp.choices[0].message.content or ""
if txt.strip():
return txt
except Exception as exc:
print(f"[_generate_full] groq {model} failed: {type(exc).__name__}: {str(exc)[:200]}", flush=True)
if "429" in str(exc) or "rate" in str(exc).lower():
_time.sleep(2)
continue
# Gemini fallback (genai v1 style)
try:
import google.generativeai as genai
genai.configure(api_key=config.GEMINI_API_KEY)
for gm in ["gemini-2.0-flash", "gemini-1.5-flash"]:
try:
resp = genai.GenerativeModel(gm).generate_content(
prompt,
generation_config={"max_output_tokens": 600, "temperature": 0.1},
)
return resp.text
except Exception:
continue
except Exception:
pass
# Last resort: local fine-tuned Llama
try:
from kcc_llm import generate
return generate(prompt)
except Exception:
pass
return "Service temporarily unavailable. Please try again."
@app.post("/auth/token", response_model=_TokenResponse, tags=["Auth"])
def login(req: _TokenRequest):
"""Get JWT token for B2B dashboard and enterprise API access."""
pwd = _B2B_USERS.get(req.username)
if not pwd or pwd != req.password:
raise HTTPException(status_code=401, detail="Invalid credentials")
token = _create_token(req.username)
return _TokenResponse(access_token=token, expires_in=_JWT_EXP_HOURS * 3600)
@app.get("/auth/me", tags=["Auth"])
def me(user: str = Depends(_require_b2b)):
"""Get current authenticated user info."""
return {
"username": user,
"role": "b2b_enterprise",
"access": ["pest-risk", "price-forecast", "advisory"],
}
@app.get("/health", tags=["System"])
def health(response: Response):
"""Liveness + readiness probe. Returns 503 until retriever is warmed up."""
backend = _get_llm_backend()
retriever_ready = False
try:
retriever_ready = _get_retriever() is not None
except Exception:
pass
payload = {
"status": "ok" if retriever_ready else "warming_up",
"retriever_ready": retriever_ready,
"llm_backend": backend,
"groq_model": config.GROQ_MODEL_PRIMARY,
"gemini_model": config.GEMINI_MODEL,
"api_keys_set": bool(config.GROQ_API_KEY and config.GEMINI_API_KEY),
"version": "2.0.0",
"environment": config.ENVIRONMENT,
"disclaimer": "AI advisory only — always consult a qualified agronomist before acting.",
}
if not retriever_ready:
return JSONResponse(content=payload, status_code=503)
return payload
@app.post("/query", response_model=QueryResponse, tags=["Chat"])
def query(req: QueryRequest):
"""
Main chat endpoint.
- Topic guard: rejects non-agriculture questions immediately (no LLM cost)
- Multi-step FAISS retrieval over 16.5M KCC Q&A pairs
- Fine-tuned Llama-3.2-3B-Instruct (KCC domain expert) or Gemini fallback
"""
# ── ★ Topic guard ─────────────────────────────────────────────────────────
if not is_agriculture_query(req.query):
return QueryResponse(
answer = _OFF_TOPIC_MSG,
sources = [],
detected_crop = None,
problem_type = "off_topic",
low_confidence= False,
off_topic = True,
retrieval_ms = 0.0,
generation_ms = 0.0,
model_backend = "guard",
)
retriever = _get_retriever()
settings = _request_to_settings(req)
detected_crop = detect_crop(req.query)
problem_type = classify_problem(req.query)
normalized_q = normalize_query(req.query)
t0 = time.perf_counter()
docs = multi_step_retrieve(retriever, req.query, normalized_q,
detected_crop, problem_type, settings,
state=req.state, district=req.district)
ret_ms = (time.perf_counter() - t0) * 1000
top_score = docs[0].score if docs else 0.0
low_confidence = top_score < LOW_CONF_THRESHOLD
context = retriever.format_context(docs)
meta_lines = []
if detected_crop:
meta_lines.append(f"DETECTED CROP: {detected_crop}")
if problem_type != "general":
meta_lines.append(f"PROBLEM TYPE: {problem_type.upper()}")
safety_note = SAFETY_GUARDRAILS.get(problem_type, "")
if safety_note:
meta_lines.append(safety_note)
if low_confidence:
meta_lines.append(
f"LOW CONFIDENCE (top score: {top_score*100:.0f}%) — "
"ask the farmer 1-2 clarifying questions instead of guessing."
)
if meta_lines:
context = "\n".join(meta_lines) + "\n\n" + context
_lang = detect_language(req.query)
prompt = _build_prompt(req.query, context, req.history, problem_type, _lang)
prompt = _trim_prompt_for_groq(prompt)
t1 = time.perf_counter()
# v4 merge: route through kcc_core.llm so we get the actual backend that
# answered (fixes v2's "always say groq" lie) and can run citation_guard
# post-generation.
try:
from kcc_core.llm import generate_with_meta
from kcc_core.citation_guard import review as _cite_review
_out = generate_with_meta(prompt, max_tokens=600, temperature=0.1)
answer = _out["text"]
backend = _out["backend"]
if not answer:
# All cascade tiers failed — fall back to v2's older path
answer = _generate_full(prompt)
backend = _get_llm_backend()
except Exception as exc:
# Hard failure of kcc_core path — keep v2 working as fallback
try:
answer = _generate_full(prompt)
backend = _get_llm_backend()
except Exception as exc2:
raise HTTPException(status_code=502, detail=f"LLM error: {exc2}") from exc2
gen_ms = (time.perf_counter() - t1) * 1000
# Post-generation safety pass: strike banned chemicals, hard overrides,
# citation enforcement. Negation-aware (do NOT use Endosulfan ⇒ no flag).
try:
from kcc_core.citation_guard import review as _cite_review
answer, safety_warnings = _cite_review(answer, problem_type=problem_type)
except Exception:
safety_warnings = []
return QueryResponse(
answer = answer,
sources = _docs_to_schema(docs),
detected_crop = detected_crop,
problem_type = problem_type,
low_confidence= low_confidence,
off_topic = False,
retrieval_ms = round(ret_ms, 1),
generation_ms = round(gen_ms, 1),
model_backend = backend,
safety_warnings = safety_warnings,
)
@app.post("/query/stream", tags=["Chat"])
async def query_stream(req: QueryRequest):
"""Streaming chat (Server-Sent Events). Topic-guarded like /query."""
# ── ★ Topic guard ─────────────────────────────────────────────────────────
if not is_agriculture_query(req.query):
async def _off_topic():
yield 'data: ' + _OFF_TOPIC_MSG.replace(chr(10), chr(92)+'n') + chr(10)+chr(10)
yield "data: [DONE]\n\n"
return StreamingResponse(_off_topic(), media_type="text/event-stream")
retriever = _get_retriever()
settings = _request_to_settings(req)
detected_crop = detect_crop(req.query)
problem_type = classify_problem(req.query)
normalized_q = normalize_query(req.query)
docs = multi_step_retrieve(retriever, req.query, normalized_q,
detected_crop, problem_type, settings,
state=req.state, district=req.district)
top_score = docs[0].score if docs else 0.0
low_confidence = top_score < LOW_CONF_THRESHOLD
context = retriever.format_context(docs)
meta_lines = []
if detected_crop:
meta_lines.append(f"DETECTED CROP: {detected_crop}")
if problem_type != "general":
meta_lines.append(f"PROBLEM TYPE: {problem_type.upper()}")
safety_note = SAFETY_GUARDRAILS.get(problem_type, "")
if safety_note:
meta_lines.append(safety_note)
if low_confidence:
meta_lines.append(
f"LOW CONFIDENCE (top score: {top_score*100:.0f}%) — ask clarifying questions."
)
if meta_lines:
context = "\n".join(meta_lines) + "\n\n" + context
_lang = detect_language(req.query)
prompt = _build_prompt(req.query, context, req.history, problem_type, _lang)
prompt = _trim_prompt_for_groq(prompt)
async def _event_stream() -> AsyncIterator[str]:
# Accumulate full answer text so we can run citation_guard at the end.
# The user sees chunks as they arrive; safety warnings come as a final
# event (older clients ignore it, newer ChatBot.jsx renders them).
_full_answer: list[str] = []
_delivered_done = False
try:
# Try Groq first (fast, high quality) — local Llama is offline-only backup
import groq as _groq
_GROQ_STREAM_MODELS = [
getattr(config, "GROQ_MODEL_PRIMARY", "meta-llama/llama-4-scout-17b-16e-instruct"),
"llama-3.3-70b-versatile",
"llama-3.1-8b-instant", # was: gemma2-9b-it (decommissioned 2026-05)
]
gc = _groq.Groq(api_key=config.GROQ_API_KEY)
for _gm in _GROQ_STREAM_MODELS:
try:
stream = gc.chat.completions.create(
model=_gm,
messages=[{"role": "user", "content": prompt}],
max_tokens=600,
stream=True,
)
for chunk in stream:
token = (chunk.choices[0].delta.content or "") if chunk.choices else ""
if token:
_full_answer.append(token)
_tok = token.replace(chr(10), chr(92)+"n")
yield "data: " + _tok + chr(10)+chr(10)
# Streaming OK — run safety guard on the accumulated text.
_full = "".join(_full_answer)
try:
from kcc_core.citation_guard import review as _cite_review
import json as _json
_filtered, _warnings = _cite_review(_full, problem_type=problem_type)
if _warnings:
yield "data: SAFETY:" + _json.dumps(_warnings) + chr(10)+chr(10)
except Exception:
pass
yield "data: [DONE]\n\n"
_delivered_done = True
return
except Exception as _e:
print(f"[query/stream] groq stream {_gm} failed: {_e!r}", flush=True)
continue
except Exception as _e:
print(f"[query/stream] groq init failed: {_e!r}", flush=True)
# Groq non-stream retry — streaming sometimes fails for specific prompts
# while non-streaming succeeds. Try kcc_core's known-good cascade before
# falling further to Gemini.
if not _full_answer:
try:
from kcc_core.llm import generate_with_meta as _kcc_gen
_result = _kcc_gen(prompt, max_tokens=600)
_text = (_result or {}).get("text", "")
if _text:
_full_answer.append(_text)
# Re-chunk for SSE — split on whitespace to mimic streaming.
_i = 0
while _i < len(_text):
_chunk = _text[_i:_i+48]
_i += 48
_ctxt = _chunk.replace(chr(10), chr(92)+'n')
yield 'data: ' + _ctxt + chr(10)+chr(10)
# Citation guard + done
try:
from kcc_core.citation_guard import review as _cite_review
import json as _json
_filtered, _warnings = _cite_review(_text, problem_type=problem_type)
if _warnings:
yield "data: SAFETY:" + _json.dumps(_warnings) + chr(10)+chr(10)
except Exception:
pass
yield "data: [DONE]\n\n"
_delivered_done = True
return
except Exception as _e:
print(f"[query/stream] groq non-stream fallback failed: {_e!r}", flush=True)
# Last-resort: same v2 path that /query uses when kcc_core returns empty.
# /query is proven to work for these prompts via _generate_full().
if not _full_answer:
try:
_text = _generate_full(prompt)
if _text and _text.strip() and "Service temporarily" not in _text:
_full_answer.append(_text)
_i = 0
while _i < len(_text):
_chunk = _text[_i:_i+48]; _i += 48
_ctxt = _chunk.replace(chr(10), chr(92)+'n')
yield 'data: ' + _ctxt + chr(10)+chr(10)
try:
from kcc_core.citation_guard import review as _cite_review
import json as _json
_filtered, _warnings = _cite_review(_text, problem_type=problem_type)
if _warnings:
yield "data: SAFETY:" + _json.dumps(_warnings) + chr(10)+chr(10)
except Exception:
pass
yield "data: [DONE]\n\n"
_delivered_done = True
return
except Exception as _e:
print(f"[query/stream] _generate_full fallback failed: {_e!r}", flush=True)
# Gemini stream fallback — use kcc_core (new google.genai SDK with .models)
# not step4_app's _get_gemini_client (legacy google.generativeai SDK).
try:
from kcc_core.llm import _get_gemini_client as _kcc_gemini_client
client = _kcc_gemini_client()
if client is None or not hasattr(client, "models"):
# No usable Gemini — fall through to safe-fallback message below
raise RuntimeError("gemini_unavailable")
for chunk in client.models.generate_content_stream(
model=config.GEMINI_MODEL, contents=prompt
):
if getattr(chunk, "text", None):
_full_answer.append(chunk.text)
_ctxt = chunk.text.replace(chr(10), chr(92)+'n')
yield 'data: ' + _ctxt + chr(10)+chr(10)
except Exception:
# Only show user a clean message if NOTHING was streamed yet.
# Never leak the raw exception text (avoids "NoneType has no attr models").
if not _full_answer:
_msg = ("Service temporarily unavailable. Please try again in a moment, "
"or rephrase your question.")
yield 'data: ' + _msg + chr(10)+chr(10)
finally:
if not _delivered_done:
# Safety guard on whatever Gemini produced
_full = "".join(_full_answer)
if _full:
try:
from kcc_core.citation_guard import review as _cite_review
import json as _json
_filtered, _warnings = _cite_review(_full, problem_type=problem_type)
if _warnings:
yield "data: SAFETY:" + _json.dumps(_warnings) + chr(10)+chr(10)
except Exception:
pass
yield "data: [DONE]\n\n"
return StreamingResponse(_event_stream(), media_type="text/event-stream")
@app.post("/pest-risk", response_model=PestRiskResponse, tags=["Pest Prediction"])
def pest_risk(req: PestRiskRequest, user: _Opt[str] = Depends(_get_current_user)):
"""
District-level pest risk prediction.
Returns 1-month early warning: predictions represent expected pest pressure
NEXT month based on current/upcoming weather conditions.
Model: LGB + XGB + CatBoost stacking ensemble (AUC 0.936)
Coverage: 475 districts across 28 states.
Uncovered districts: automatic spatial fallback to nearest covered district.
"""
import datetime
t0 = time.perf_counter()
# ── B2B aggregate guard ────────────────────────────────────────────────────
# State-only queries (no district) return state-wide aggregates — those are
# the high-value bulk views B2B clients pay for. Require auth. Single-block
# queries (with district) stay open for retail / consumer demo use.
if not req.district and user is None:
raise HTTPException(
status_code=401,
detail="State-only pest aggregates require B2B authentication. "
"POST /auth/token to get a token, or supply 'district' for "
"a single-block lookup.",
)
month = req.month or datetime.datetime.now().month
year = req.year or datetime.datetime.now().year
try:
from mandi_advisor.pest_predictor import predict_pest_risk
raw_results = predict_pest_risk(
state = req.state,
crop = req.crop,
district = req.district or None,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Pest prediction error: {e}")
ms = (time.perf_counter() - t0) * 1000
predictions = []
for r in raw_results:
predictions.append(PestRiskResult(
pest = r.get("pest", ""),
pest_cat = r.get("pest_cat", ""),
risk_score = r.get("risk_score", 0),
risk_level = r.get("risk_level", "LOW"),
confidence = r.get("confidence", "Low"),
confidence_tier = r.get("confidence_tier"),
model_auc = r.get("model_auc", 0.0),
model_score = r.get("model_score"),
model_version = r.get("model_version", ""),
recommended_action = r.get("recommended_action", ""),
spray = r.get("spray", ""),
weather_driver = r.get("weather_driver", ""),
history_note = r.get("history_note", ""),
feature_drivers = r.get("feature_drivers", []),
growth_stage = r.get("growth_stage", ""),
ndvi_index = r.get("ndvi_index", 50),
))
model_ver = raw_results[0].get("model_version", "unknown") if raw_results else "unknown"
return PestRiskResponse(
state = req.state,
district = req.district or "auto (spatial fallback)",
crop = req.crop,
month = month,
predictions = predictions,
model_version = model_ver,
lead_time_note= (
"v2: Predictions represent NEXT month's pest risk based on current weather. "
"1-month early warning allows preventive spray before damage occurs."
if "v2" in model_ver else
"v1: Current-month pest risk. Upgrade to v2 for 1-month early warning."
),
retrieval_ms = round(ms, 1),
)
@app.post("/diagnose", response_model=DiagnoseResponse, tags=["Image Diagnosis"])
async def diagnose(image: UploadFile = File(...)):
"""Plant disease diagnosis from crop photo (Gemini Vision)."""
retriever = _get_retriever()
image_bytes = await image.read()
if len(image_bytes) > 10 * 1024 * 1024:
raise HTTPException(status_code=413, detail="Image too large (max 10 MB)")
try:
diagnosis = _diagnose_image_gemini(image_bytes)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"Vision model error: {exc}")
if not diagnosis or not diagnosis.get("crop"):
raise HTTPException(status_code=502, detail="Vision model returned no diagnosis")
crop_name = diagnosis.get("crop", "Unknown")
condition = diagnosis.get("condition", "Unknown")
problem_type = diagnosis.get("problem_type", "disease")
confidence = diagnosis.get("confidence", "Medium")
symptoms = diagnosis.get("visible_symptoms", "")
treatment_answer = None
docs: list[RetrievedDoc] = []
if problem_type != "healthy" and condition.lower() != "healthy":
auto_query = f"{crop_name} {condition} control treatment"
settings = {"top_k": 5, "min_score": 0.0, "deduplicate": True}
docs = multi_step_retrieve(retriever, auto_query, normalize_query(auto_query),
crop_name, problem_type, settings)
if docs:
context = retriever.format_context(docs)
safety = SAFETY_GUARDRAILS.get(problem_type, "")
meta = (
f"DETECTED CROP: {crop_name}\nDETECTED CONDITION: {condition} "
f"(confidence: {confidence})\nVISIBLE SYMPTOMS: {symptoms}\n"
f"PROBLEM TYPE: {problem_type.upper()}\n{safety}\n\n"
)
prompt = _build_prompt(
f"My {crop_name} has {condition}. What should I do?",
meta + context, [], problem_type, "English"
)
try:
treatment_answer = _generate_full(prompt)
except Exception:
treatment_answer = None
return DiagnoseResponse(
crop=crop_name, condition=condition, problem_type=problem_type,
confidence=confidence, visible_symptoms=symptoms,
treatment_advice=treatment_answer, sources=_docs_to_schema(docs),
)
# ── /metrics — model status & performance snapshot ───────────────────────────
import datetime as _dt
_START_TIME = _dt.datetime.now(_dt.timezone.utc)
_REQUEST_COUNTS: dict = {"query": 0, "pest_risk": 0, "diagnose": 0, "errors": 0}
@app.get("/metrics", tags=["System"])
def metrics():
"""Lightweight metrics endpoint for enterprise monitoring."""
uptime_s = (_dt.datetime.now(_dt.timezone.utc) - _START_TIME).total_seconds()
retriever_ok = False
try:
r = _get_retriever()
retriever_ok = r is not None
except Exception:
pass
return {
"status": "ok",
"uptime_seconds": round(uptime_s),
"uptime_human": f"{int(uptime_s//3600)}h {int((uptime_s%3600)//60)}m",
"requests": _REQUEST_COUNTS,
"retriever_loaded": retriever_ok,
"models": {
"llm_primary": config.GROQ_MODEL_PRIMARY,
"llm_fallback": config.GROQ_MODEL_FALLBACK,
"gemini": config.GEMINI_MODEL,
"reranker": config.RERANKER_MODEL,
"embedding": config.EMBEDDING_MODEL,
},
"version": "2.0.0",
"environment": config.ENVIRONMENT,
"timestamp": _dt.datetime.now(_dt.timezone.utc).isoformat(),
}
_RETRIEVER_READY = False # set True after warmup completes
@app.on_event("startup")
async def _warm_retriever():
"""Pre-load FAISS+BM25 index at startup. Sets _RETRIEVER_READY=True when done."""
import threading
global _RETRIEVER_READY
def _load():
global _RETRIEVER_READY
try:
_get_retriever()
_RETRIEVER_READY = True
print("[API] Retriever warmed up successfully", flush=True)
except Exception as ex:
print(f"[API] Retriever warmup failed: {ex}", flush=True)
threading.Thread(target=_load, daemon=True).start()
# ── Static crop data endpoint ──────────────────────────────────────────────
@app.get("/crops/static", tags=["Data"])
def crops_static():
"""Return all static crop data for React frontend."""
from step4_app import _VARIETY_DATA, _INPUT_COST, _SOIL_PREP, _MSP_2025
return {
"varieties": _VARIETY_DATA,
"input_costs": _INPUT_COST,
"soil_prep": _SOIL_PREP,
"msp_2025": _MSP_2025,
}
# ── Price forecast endpoint ────────────────────────────────────────────────
class _PriceForecastRequest(BaseModel):
crop: str = ""
state: str = ""
@app.post("/price-forecast", tags=["Models"])
def price_forecast(req: _PriceForecastRequest):
"""Get presow_v4 price forecast for a crop+state combo (p25/p50/p75)."""
from mandi_advisor.enterprise_engine_v2 import get_presow_signal
# Map UI crop names → AGMARKNET commodity names used by presow_v4
_CROP_MAP = {
"Soybean": "Soyabean",
"Arhar": "Arhar/Tur Dal",
"Gram": "Bengal Gram(Gram)(Whole)",
"Mustard": "Rapeseed/Mustard(Toria)",
}
crop = _CROP_MAP.get(req.crop, req.crop)
try:
result = get_presow_signal(crop, req.state)
return result
except Exception as e:
return {"error": str(e)}
# ── Weather endpoint ───────────────────────────────────────────────────────
@app.get("/weather", tags=["Data"])
def weather(lat: float = 20.5, lon: float = 78.9):
"""Get 3-day weather forecast from Open-Meteo (free, no key needed)."""
import httpx
url = (
f"https://api.open-meteo.com/v1/forecast"
f"?latitude={lat}&longitude={lon}"
f"&daily=temperature_2m_max,temperature_2m_min,precipitation_sum,weathercode"
f"&current_weather=true&timezone=Asia%2FKolkata&forecast_days=3"
)
try:
r = httpx.get(url, timeout=10)
return r.json()
except Exception as e:
return {"error": str(e)}
# ── Serve React frontend (HF Spaces / production) ─────────────────────────────
import os as _os2
from pathlib import Path as _Path
_STATIC_DIR = _Path(__file__).parent / "static"
if _STATIC_DIR.exists():
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
app.mount("/assets", StaticFiles(directory=str(_STATIC_DIR / "assets")), name="assets")
@app.get("/", include_in_schema=False)
@app.get("/{full_path:path}", include_in_schema=False)
def serve_spa(full_path: str = ""):
"""Serve React SPA — all non-API routes return index.html."""
# Don't intercept API routes
if full_path.startswith("api/") or full_path.startswith("docs") or full_path.startswith("openapi"):
raise HTTPException(status_code=404)
index = _STATIC_DIR / "index.html"
if index.exists():
return FileResponse(str(index))
return {"message": "KCC AgriAdvisor API running. Frontend not built yet."}
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)