kcc-agri / mandi_advisor /enterprise_engine_v2.py
hritikm15's picture
KCC AgriAdvisor v1 — code deploy
ca88a2c verified
"""
Enterprise Engine v2 — Mandi Advisor
Production routing engine: weekly perishable model → district v2 monthly model → v4 state fallback.
"""
from __future__ import annotations
import pickle
import warnings
from datetime import date, datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import duckdb
import numpy as np
import pandas as pd
warnings.filterwarnings("ignore")
# ---------------------------------------------------------------------------
# Paths & Constants
# ---------------------------------------------------------------------------
BASE = Path(__file__).parent.resolve() # relative — works on any machine
FEATURE_STORE = BASE / "feature_store.parquet"
MANDI_DATA = BASE / "mandi_data_clean.parquet"
PERISHABLE_CROPS = [
"Tomato", "Onion", "Potato", "Green Chilli", "Brinjal",
"Cabbage", "Cauliflower", "Bhindi (Okra)", "Banana",
"Cucumber", "Bitter Gourd", "Bottle Gourd", "Pumpkin",
]
MSP_TABLE: Dict[str, Dict[int, int]] = {
"Wheat": {2024: 2275, 2025: 2425, 2026: 2600},
"Paddy": {2024: 2300, 2025: 2369, 2026: 2500},
"Mustard": {2024: 5650, 2025: 5950, 2026: 6200},
"Gram": {2024: 5440, 2025: 5650, 2026: 5800},
"Bengal Gram": {2024: 5440, 2025: 5650, 2026: 5800},
"Arhar/Tur": {2024: 7000, 2025: 7550, 2026: 8000},
"Moong": {2024: 8682, 2025: 8908, 2026: 9500},
}
MODEL_FILES = {
"weekly_perishable": {
"model": BASE / "weekly_perishable_model.pkl",
"meta": BASE / "weekly_perishable_meta.pkl",
"winrates": BASE / "weekly_perishable_winrates.pkl",
},
"district_v2": {
"model": BASE / "district_sellhold_v2_model.pkl",
"meta": BASE / "district_sellhold_v2_meta.pkl",
"winrates": BASE / "district_sellhold_v2_winrates.pkl",
},
"presow_v3": {
"model": BASE / "presow_v3_model.pkl",
"meta": BASE / "presow_v3_meta.pkl",
},
"presow_v4": {
"model": BASE / "presow_v4_model.pkl",
"meta": BASE / "presow_v4_meta.pkl",
},
"sellhold_v4": {
"model": BASE / "sellhold_v4_model.pkl",
"meta": BASE / "sellhold_v4_meta.pkl",
},
}
LOW_CONF_THRESHOLD = 0.55
VOLATILITY_THRESHOLD = 0.35 # 35% price move in last 7 days
# ---------------------------------------------------------------------------
# Lazy model cache
# ---------------------------------------------------------------------------
_CACHE: Dict[str, Any] = {}
def _load_model(name: str) -> Any:
"""Lazy-load and cache a pickle artefact by logical name (key in MODEL_FILES)."""
if name in _CACHE:
return _CACHE[name]
entry = MODEL_FILES.get(name)
if entry is None:
raise KeyError(f"Unknown model name: {name!r}")
loaded: Dict[str, Any] = {}
for part, path in entry.items():
with open(path, "rb") as fh:
loaded[part] = pickle.load(fh)
_CACHE[name] = loaded
return loaded
# ---------------------------------------------------------------------------
# DuckDB helper
# ---------------------------------------------------------------------------
def _sanitize(val: str) -> str:
"""Strip SQL-injection characters from a commodity/state/district name.
These values come from an internal AGMARKNET commodity map, but sanitizing
ensures safety if the call path ever receives user-supplied input.
"""
if val is None:
return ""
# Remove single quotes, double quotes, semicolons, comment markers
return val.replace("'", "").replace('"', "").replace(";", "").replace("--", "").strip()
def _duck(sql: str, params: Optional[list] = None) -> pd.DataFrame:
"""Execute a DuckDB query and return a DataFrame."""
con = duckdb.connect(database=":memory:")
con.execute("SET memory_limit='1GB';")
if params:
return con.execute(sql, params).df()
return con.execute(sql).df()
# ---------------------------------------------------------------------------
# Name resolution (fuzzy match against LabelEncoder classes)
# ---------------------------------------------------------------------------
def _resolve_name(name: str, le) -> Optional[str]:
"""
Attempt to resolve *name* against the classes of a sklearn LabelEncoder.
Priority: exact → case-insensitive → substring.
Returns the matched class string, or None if no match.
"""
if name is None:
return None
classes: List[str] = list(le.classes_)
# 1. Exact match
if name in classes:
return name
# 2. Case-insensitive
name_lower = name.lower()
for c in classes:
if c.lower() == name_lower:
return c
# 3. Substring (name in class or class in name)
for c in classes:
if name_lower in c.lower() or c.lower() in name_lower:
return c
return None
# ---------------------------------------------------------------------------
# Feature retrieval helpers
# ---------------------------------------------------------------------------
def _get_feature_row(
crop: str,
state: str,
district: Optional[str],
month: int,
year: int,
) -> Optional[pd.Series]:
"""Pull the most relevant feature-store row for the given key."""
try:
_crop = _sanitize(crop).lower()
_state = _sanitize(state).lower()
_district = _sanitize(district).lower() if district else None
_month = int(month)
if _district:
sql = f"""
SELECT * FROM read_parquet('{FEATURE_STORE}')
WHERE lower(commodity) = $1
AND lower(state) = $2
AND lower(district) = $3
AND month = $4
ORDER BY year DESC
LIMIT 1
"""
df = _duck(sql, [_crop, _state, _district, _month])
else:
sql = f"""
SELECT * FROM read_parquet('{FEATURE_STORE}')
WHERE lower(commodity) = $1
AND lower(state) = $2
AND month = $3
ORDER BY year DESC
LIMIT 1
"""
df = _duck(sql, [_crop, _state, _month])
if df.empty:
# fallback: any year
if _district:
sql2 = f"""
SELECT * FROM read_parquet('{FEATURE_STORE}')
WHERE lower(commodity) = $1
AND lower(state) = $2
AND lower(district) = $3
ORDER BY year DESC, month DESC
LIMIT 1
"""
df = _duck(sql2, [_crop, _state, _district])
else:
sql2 = f"""
SELECT * FROM read_parquet('{FEATURE_STORE}')
WHERE lower(commodity) = $1
AND lower(state) = $2
ORDER BY year DESC, month DESC
LIMIT 1
"""
df = _duck(sql2, [_crop, _state])
return df.iloc[0] if not df.empty else None
except Exception:
return None
def _get_recent_mandi_prices(crop: str, state: str, days: int = 7) -> Optional[pd.DataFrame]:
"""Return recent modal_price rows for volatility check."""
try:
_crop = _sanitize(crop).lower()
_state = _sanitize(state).lower()
_limit = int(days) * 20
sql = f"""
SELECT arrival_date, modal_price
FROM read_parquet('{MANDI_DATA}')
WHERE lower(commodity) = $1
AND lower(state) = $2
ORDER BY arrival_date DESC
LIMIT {_limit}
"""
return _duck(sql, [_crop, _state])
except Exception:
return None
def _check_volatility(crop: str, state: str) -> tuple[bool, str]:
"""
Returns (is_volatile: bool, message: str).
Checks if the price moved >35% in the last 7 days of available data.
"""
try:
df = _get_recent_mandi_prices(crop, state, days=7)
if df is None or df.empty or len(df) < 2:
return False, ""
prices = df["modal_price"].dropna().values
if len(prices) < 2:
return False, ""
p_latest = float(prices[0])
p_old = float(prices[-1])
if p_old == 0:
return False, ""
pct_change = abs(p_latest - p_old) / p_old
if pct_change > VOLATILITY_THRESHOLD:
return True, f"⚠️ High volatility — signal suppressed (price moved {pct_change*100:.1f}% in last 7 days)"
return False, ""
except Exception:
return False, ""
# ---------------------------------------------------------------------------
# Confidence helpers
# ---------------------------------------------------------------------------
def _score_to_label(score: float) -> str:
if score >= 0.70:
return "HIGH"
if score >= LOW_CONF_THRESHOLD:
return "MEDIUM"
return "LOW"
def _lookup_win_rate(winrates: Any, crop: str, state: str,
district: Optional[str] = None,
month: Optional[int] = None,
iso_week: Optional[int] = None) -> Optional[float]:
"""Best-effort win-rate lookup from the winrates dict/DataFrame artefact."""
try:
cur_month = month or date.today().month
cur_isoweek = iso_week or date.today().isocalendar()[1]
def _time_mask(df):
"""Return time-based mask: iso_week if available, else month."""
if "iso_week" in df.columns:
return df["iso_week"] == cur_isoweek
elif "month" in df.columns:
return df["month"] == cur_month
return pd.Series([True] * len(df), index=df.index)
# ── New structure: {"national": df, "state": df, "district": df} ──
if isinstance(winrates, dict) and any(k in winrates for k in ("national", "state", "district")):
# 1. Try district-level first
if district and "district" in winrates:
df = winrates["district"]
if isinstance(df, pd.DataFrame) and "district" in df.columns:
mask = (
(df["commodity"].str.lower() == crop.lower()) &
(df["state"].str.lower() == state.lower()) &
(df["district"].str.lower() == district.lower()) &
_time_mask(df)
)
row = df[mask]
if not row.empty:
wr_col = [c for c in df.columns if "win" in c.lower()]
if wr_col:
return float(row.iloc[0][wr_col[0]])
# 2. State-level
if "state" in winrates:
df = winrates["state"]
if isinstance(df, pd.DataFrame):
mask = (
(df["commodity"].str.lower() == crop.lower()) &
(df["state"].str.lower() == state.lower()) &
_time_mask(df)
)
row = df[mask]
if not row.empty:
wr_col = [c for c in df.columns if "win" in c.lower()]
if wr_col:
return float(row.iloc[0][wr_col[0]])
# 3. National fallback
if "national" in winrates:
df = winrates["national"]
if isinstance(df, pd.DataFrame):
mask = (
(df["commodity"].str.lower() == crop.lower()) &
_time_mask(df)
)
row = df[mask]
if not row.empty:
wr_col = [c for c in df.columns if "win" in c.lower()]
if wr_col:
return float(row.iloc[0][wr_col[0]])
return None
# ── Old flat-dict structure: {(crop, state): value} ──
if isinstance(winrates, dict):
for key in [(crop, state, district), (crop, state), (crop,), ("overall",)]:
if key in winrates:
v = winrates[key]
return float(v) if v is not None else None
# ── Plain DataFrame ──
if isinstance(winrates, pd.DataFrame):
cols = [c.lower() for c in winrates.columns]
month_col = next((c for c in winrates.columns if c.lower() == "month"), None)
month_mask = (winrates[month_col] == cur_month) if month_col else True
if all(k in cols for k in ["commodity", "state", "district"]) and district:
row = winrates[
(winrates.iloc[:, cols.index("commodity")].str.lower() == crop.lower()) &
(winrates.iloc[:, cols.index("state")].str.lower() == state.lower()) &
(winrates.iloc[:, cols.index("district")].str.lower() == district.lower()) &
month_mask
]
if not row.empty:
wr_col = [c for c in winrates.columns if "win" in c.lower()]
if wr_col:
return float(row.iloc[0][wr_col[0]])
if all(k in cols for k in ["commodity", "state"]):
row = winrates[
(winrates.iloc[:, cols.index("commodity")].str.lower() == crop.lower()) &
(winrates.iloc[:, cols.index("state")].str.lower() == state.lower()) &
month_mask
]
if not row.empty:
wr_col = [c for c in winrates.columns if "win" in c.lower()]
if wr_col:
return float(row.iloc[0][wr_col[0]])
return None
except Exception:
return None
# ---------------------------------------------------------------------------
# Weekly perishable model
# ---------------------------------------------------------------------------
def _predict_weekly(
crop: str,
state: str,
district: Optional[str],
current_price: Optional[float],
ref_date: date,
) -> Dict[str, Any]:
"""Build features and predict using the weekly perishable model."""
artefacts = _load_model("weekly_perishable")
model = artefacts["model"]
meta = artefacts["meta"]
winrates = artefacts.get("winrates", {})
# meta uses le_crop / le_state keys (from build_weekly_perishable_model.py)
crop_le = meta.get("le_crop") or meta.get("crop_le")
state_le = meta.get("le_state") or meta.get("state_le")
ml_features = meta.get("features", meta.get("ML_FEATURES", []))
# Resolve encoded names
crop_name = _resolve_name(crop, crop_le) if crop_le else crop
state_name = _resolve_name(state, state_le) if state_le else state
if crop_name is None:
raise ValueError(f"Crop '{crop}' not found in weekly model encoder")
if state_name is None:
raise ValueError(f"State '{state}' not found in weekly model encoder")
crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0
state_enc = int(state_le.transform([state_name])[0]) if state_le else 0
iso_week = ref_date.isocalendar()[1]
month = ref_date.month
year = ref_date.year
# Trigonometric encodings
wk_sin = np.sin(2 * np.pi * iso_week / 52)
wk_cos = np.cos(2 * np.pi * iso_week / 52)
mth_sin = np.sin(2 * np.pi * month / 12)
mth_cos = np.cos(2 * np.pi * month / 12)
# Feature-store row for proxy values
row = _get_feature_row(crop, state, district, month, year)
def _fs(col, default=0.0):
if row is not None and col in row.index and pd.notna(row[col]):
try:
return float(row[col])
except (ValueError, TypeError):
return float(default)
return float(default)
def _fs_str(col, default="state"):
if row is not None and col in row.index and pd.notna(row[col]):
return str(row[col])
return default
# Map feature-store columns to weekly model feature names (proxy mapping)
feature_map = {
"crop_enc": crop_enc,
"state_enc": state_enc,
"iso_week": iso_week,
"month": month,
"year": year,
"wk_sin": wk_sin,
"wk_cos": wk_cos,
"mth_sin": mth_sin,
"mth_cos": mth_cos,
# Momentum proxies from feature store (monthly as proxy for weekly)
"price_vs_4w": _fs("district_vs_state_pct"),
"price_vs_13w": _fs("state_vs_national_pct"),
"price_vs_52w": _fs("dist_lag_1yr", 0),
"price_vs_seas_w": _fs("seasonal_idx", 1.0) - 1.0,
"price_vs_st_pct": _fs("district_vs_state_pct"),
"price_vs_nat_pct": _fs("district_vs_national_pct"),
"mom_1w": 0.0,
"mom_2w": 0.0,
"mom_4w": _fs("district_vs_state_pct"),
"pctile_52w": _fs("price_pctile", 50),
"nat_win_rate_w": 0.5,
"st_win_rate_w": 0.5,
"price_spread_pct": 0.0,
"n_markets": _fs("district_records", 1),
}
feature_values = [feature_map.get(f, 0.0) for f in ml_features]
X = np.array(feature_values, dtype=float).reshape(1, -1)
# Predict
if hasattr(model, "predict_proba"):
proba = model.predict_proba(X)[0]
pred_class = int(model.predict(X)[0])
confidence_score = float(proba[pred_class])
else:
pred_class = int(model.predict(X)[0])
confidence_score = 0.6 # default when no proba
classes = list(model.classes_) if hasattr(model, "classes_") else [0, 1]
label_map = meta.get("label_map", {0: "SELL", 1: "HOLD", 2: "STRONG_HOLD"})
recommendation = label_map.get(pred_class, "HOLD")
win_rate = _lookup_win_rate(winrates, crop_name, state_name, iso_week=iso_week)
seasonal_avg = _fs("st_avg", None)
if seasonal_avg == 0.0:
seasonal_avg = None
price_vs_seasonal_pct = None
if current_price and seasonal_avg:
price_vs_seasonal_pct = round((current_price - seasonal_avg) / seasonal_avg * 100, 1)
return {
"recommendation": recommendation,
"confidence_score": confidence_score,
"confidence": _score_to_label(confidence_score),
"win_rate": round(win_rate, 3) if win_rate is not None else None,
"seasonal_avg": round(seasonal_avg, 1) if seasonal_avg else None,
"price_vs_seasonal_pct": price_vs_seasonal_pct,
"granularity": "weekly",
"model_version": "weekly_perishable_v1",
"data_level": _fs_str("data_level", "state") if row is not None else "state",
"iso_week": iso_week,
}
# ---------------------------------------------------------------------------
# District v2 monthly model
# ---------------------------------------------------------------------------
def _predict_district_v2(
crop: str,
state: str,
district: Optional[str],
current_price: Optional[float],
ref_date: date,
) -> Dict[str, Any]:
"""Build features and predict using the district sellhold v2 model."""
artefacts = _load_model("district_v2")
model = artefacts["model"]
meta = artefacts["meta"]
winrates = artefacts.get("winrates", {})
crop_le = meta.get("crop_le")
state_le = meta.get("state_le")
district_le = meta.get("district_le")
ml_features = meta.get("ML_FEATURES", meta.get("features", []))
crop_name = _resolve_name(crop, crop_le) if crop_le else crop
state_name = _resolve_name(state, state_le) if state_le else state
district_name = _resolve_name(district, district_le) if district_le and district else None
if crop_name is None:
raise ValueError(f"Crop '{crop}' not found in district_v2 encoder")
if state_name is None:
raise ValueError(f"State '{state}' not found in district_v2 encoder")
crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0
state_enc = int(state_le.transform([state_name])[0]) if state_le else 0
district_enc = (
int(district_le.transform([district_name])[0])
if district_le and district_name else 0
)
month = ref_date.month
year = ref_date.year
mth_sin = np.sin(2 * np.pi * month / 12)
mth_cos = np.cos(2 * np.pi * month / 12)
row = _get_feature_row(crop, state, district, month, year)
def _fs(col, default=0.0):
if row is not None and col in row.index and pd.notna(row[col]):
try:
return float(row[col])
except (ValueError, TypeError):
return float(default)
return float(default)
def _fs_str(col, default="state"):
if row is not None and col in row.index and pd.notna(row[col]):
return str(row[col])
return default
feature_map = {
"crop_enc": crop_enc,
"state_enc": state_enc,
"district_enc": district_enc,
"month": month,
"year": year,
"month_sin": mth_sin,
"month_cos": mth_cos,
"price_vs_dist_seas_avg": _fs("seasonal_idx", 1.0) - 1.0,
"price_vs_dist_seas_std": 0.0,
"price_vs_st_seas_avg": _fs("state_vs_national_pct"),
"state_vs_national_pct": _fs("state_vs_national_pct"),
"district_vs_state_pct": _fs("district_vs_state_pct"),
"district_vs_national_pct": _fs("district_vs_national_pct"),
"price_pctile": _fs("price_pctile", 50),
"mom_1m": 0.0,
"mom_3m": 0.0,
"mom_6m": 0.0,
"seasonal_idx": _fs("seasonal_idx", 1.0),
"nat_win_rate": 0.5,
"st_win_rate": 0.5,
"dist_win_rate": 0.5,
"eff_win_rate": 0.5,
"area_vs_5yr_avg": 0.0,
"yield_vs_5yr_avg": 0.0,
"area_yoy_pct": 0.0,
"district_records": _fs("district_records", 1),
}
feature_values = [feature_map.get(f, 0.0) for f in ml_features]
X = np.array(feature_values, dtype=float).reshape(1, -1)
if hasattr(model, "predict_proba"):
proba = model.predict_proba(X)[0]
pred_class = int(model.predict(X)[0])
confidence_score = float(proba[pred_class])
else:
pred_class = int(model.predict(X)[0])
confidence_score = 0.6
label_map = meta.get("label_map", {0: "SELL", 1: "HOLD", 2: "STRONG_HOLD"})
recommendation = label_map.get(pred_class, "HOLD")
win_rate = _lookup_win_rate(winrates, crop_name, state_name, district_name, month=ref_date.month)
seasonal_avg = _fs("st_avg", None)
if seasonal_avg == 0.0:
seasonal_avg = _fs("nat_avg", None)
if seasonal_avg == 0.0:
seasonal_avg = None
price_vs_seasonal_pct = None
if current_price and seasonal_avg:
price_vs_seasonal_pct = round((current_price - seasonal_avg) / seasonal_avg * 100, 1)
data_level = _fs_str("data_level", "district" if district_name else "state")
return {
"recommendation": recommendation,
"confidence_score": confidence_score,
"confidence": _score_to_label(confidence_score),
"win_rate": round(win_rate, 3) if win_rate is not None else None,
"seasonal_avg": round(seasonal_avg, 1) if seasonal_avg else None,
"price_vs_seasonal_pct": price_vs_seasonal_pct,
"granularity": "monthly",
"model_version": "district_sellhold_v2",
"data_level": data_level,
}
# ---------------------------------------------------------------------------
# Fallback: v4 state-level model
# ---------------------------------------------------------------------------
def _predict_v4_fallback(
crop: str,
state: str,
current_price: Optional[float],
ref_date: date,
) -> Dict[str, Any]:
"""Fallback to v4 state-level monthly model."""
artefacts = _load_model("sellhold_v4")
model = artefacts["model"]
meta = artefacts["meta"]
crop_le = meta.get("crop_le")
state_le = meta.get("state_le")
ml_features = meta.get("ML_FEATURES", meta.get("features", []))
crop_name = _resolve_name(crop, crop_le) if crop_le else crop
state_name = _resolve_name(state, state_le) if state_le else state
if crop_name is None:
raise ValueError(f"Crop '{crop}' not found in v4 encoder")
if state_name is None:
raise ValueError(f"State '{state}' not found in v4 encoder")
crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0
state_enc = int(state_le.transform([state_name])[0]) if state_le else 0
month = ref_date.month
year = ref_date.year
mth_sin = np.sin(2 * np.pi * month / 12)
mth_cos = np.cos(2 * np.pi * month / 12)
row = _get_feature_row(crop, state, None, month, year)
def _fs(col, default=0.0):
if row is not None and col in row.index and pd.notna(row[col]):
return float(row[col])
return float(default)
base_map = {
"crop_enc": crop_enc,
"state_enc": state_enc,
"month": month,
"year": year,
"month_sin": mth_sin,
"month_cos": mth_cos,
"state_vs_national_pct": _fs("state_vs_national_pct"),
"seasonal_idx": _fs("seasonal_idx", 1.0),
"price_pctile": _fs("price_pctile", 50),
"state_lag_1yr": _fs("state_lag_1yr"),
"state_lag_2yr": _fs("state_lag_2yr"),
"nat_lag_1yr": _fs("nat_lag_1yr"),
"nat_lag_2yr": _fs("nat_lag_2yr"),
}
feature_values = [base_map.get(f, 0.0) for f in ml_features]
X = np.array(feature_values, dtype=float).reshape(1, -1)
if hasattr(model, "predict_proba"):
proba = model.predict_proba(X)[0]
pred_class = int(model.predict(X)[0])
confidence_score = float(proba[pred_class])
else:
pred_class = int(model.predict(X)[0])
confidence_score = 0.55
label_map = meta.get("label_map", {0: "SELL", 1: "HOLD", 2: "STRONG_HOLD"})
recommendation = label_map.get(pred_class, "HOLD")
seasonal_avg = _fs("st_avg", None) or _fs("nat_avg", None) or None
price_vs_seasonal_pct = None
if current_price and seasonal_avg:
price_vs_seasonal_pct = round((current_price - float(seasonal_avg) - 0) / float(seasonal_avg) * 100, 1)
return {
"recommendation": recommendation,
"confidence_score": confidence_score,
"confidence": _score_to_label(confidence_score),
"win_rate": None,
"seasonal_avg": round(float(seasonal_avg), 1) if seasonal_avg else None,
"price_vs_seasonal_pct": price_vs_seasonal_pct,
"granularity": "monthly",
"model_version": "sellhold_v4_fallback",
"data_level": "state",
}
# ---------------------------------------------------------------------------
# Public API: get_signal
# ---------------------------------------------------------------------------
def get_signal(
crop: str,
state: str,
district: Optional[str] = None,
current_price: Optional[float] = None,
date_input: Optional[Union[str, date, datetime]] = None,
) -> Dict[str, Any]:
"""
Route a price-signal request to the correct model and return a unified dict.
Parameters
----------
crop : Commodity name (e.g. "Tomato", "Wheat")
state : State name
district : District name (optional)
current_price : Latest market price (₹/quintal)
date_input : Reference date; defaults to today
Returns
-------
dict with keys: crop, state, district, date, granularity, recommendation,
confidence, win_rate, current_price, seasonal_avg,
price_vs_seasonal_pct, reasoning, model_version, data_level,
warning (optional)
"""
try:
# Normalise date
if date_input is None:
ref_date = date.today()
elif isinstance(date_input, datetime):
ref_date = date_input.date()
elif isinstance(date_input, str):
ref_date = datetime.strptime(date_input[:10], "%Y-%m-%d").date()
else:
ref_date = date_input
warnings_list: List[str] = []
result: Dict[str, Any] = {}
used_fallback = False
# --- Route ---
is_perishable = any(c.lower() == crop.lower() for c in PERISHABLE_CROPS)
if is_perishable:
try:
result = _predict_weekly(crop, state, district, current_price, ref_date)
except Exception as e:
warnings_list.append(f"Weekly model failed ({e}); falling back to v4")
result = _predict_v4_fallback(crop, state, current_price, ref_date)
used_fallback = True
else:
try:
result = _predict_district_v2(crop, state, district, current_price, ref_date)
except Exception as e:
warnings_list.append(f"District v2 model failed ({e}); falling back to v4")
try:
result = _predict_v4_fallback(crop, state, current_price, ref_date)
except Exception as e2:
return {
"error": str(e2),
"confidence": "LOW",
"crop": crop,
"state": state,
"district": district,
}
used_fallback = True
# --- Confidence filter ---
confidence_score = result.get("confidence_score", 0.5)
if confidence_score < LOW_CONF_THRESHOLD:
result["confidence"] = "LOW"
warnings_list.append(
f"Model confidence {confidence_score:.2f} below threshold {LOW_CONF_THRESHOLD}"
)
# --- Volatility / anomaly guardrail ---
is_volatile, vol_msg = _check_volatility(crop, state)
if is_volatile:
result["recommendation"] = "HOLD"
warnings_list.append(vol_msg)
# --- Build reasoning ---
rec = result.get("recommendation", "HOLD")
conf = result.get("confidence", "MEDIUM")
p_vs_s = result.get("price_vs_seasonal_pct")
seas = result.get("seasonal_avg")
reasoning_parts = [
f"Prediction: {rec} ({conf} confidence).",
]
if p_vs_s is not None:
direction = "above" if p_vs_s >= 0 else "below"
reasoning_parts.append(
f"Current price is {abs(p_vs_s):.1f}% {direction} seasonal average"
+ (f" (₹{seas}/qtl)" if seas else "") + "."
)
if used_fallback:
reasoning_parts.append("State-level fallback model used.")
if warnings_list:
reasoning_parts += warnings_list
# --- Assemble final result ---
out: Dict[str, Any] = {
"crop": crop,
"state": state,
"district": district,
"date": ref_date.isoformat(),
"granularity": result.get("granularity", "monthly"),
"recommendation": result.get("recommendation", "HOLD"),
"confidence": result.get("confidence", "LOW"),
"win_rate": result.get("win_rate"),
"current_price": current_price,
"seasonal_avg": result.get("seasonal_avg"),
"price_vs_seasonal_pct": result.get("price_vs_seasonal_pct"),
"reasoning": " ".join(reasoning_parts),
"model_version": result.get("model_version", "unknown"),
"data_level": result.get("data_level", "state"),
}
if warnings_list:
out["warning"] = " | ".join(warnings_list)
# Remove internal score key
out.pop("confidence_score", None)
return out
except Exception as e:
return {
"error": str(e),
"confidence": "LOW",
"crop": crop,
"state": state,
"district": district,
}
# ---------------------------------------------------------------------------
# Public API: get_presow_signal
# ---------------------------------------------------------------------------
def get_presow_signal(
crop: str,
state: str,
district: Optional[str] = None,
sowing_month: Optional[int] = None,
) -> Dict[str, Any]:
"""
Return harvest-price range forecast for pre-sowing decisions.
Parameters
----------
crop : Commodity name
state : State name
district : District (optional)
sowing_month : Month of sowing (1-12); defaults to current month
Returns
-------
dict: p25, p50, p75, harvest_window, msp, profit_probability,
confidence, reasoning
"""
try:
# Try v4 first, fall back to v3
_v4_ok = False
try:
art4 = _load_model("presow_v4")
_v4_ok = True
except Exception:
pass
artefacts = art4 if _v4_ok else _load_model("presow_v3")
model = artefacts["model"]
meta = artefacts["meta"]
_ver = "presow_v4" if _v4_ok else "presow_v3"
crop_le = meta.get("le_crop") or meta.get("crop_le")
state_le = meta.get("le_state") or meta.get("state_le")
ml_features = (meta.get("features_v4") or meta.get("features_v3")
or meta.get("ML_FEATURES") or meta.get("features", []))
crop_name = _resolve_name(crop, crop_le) if crop_le else crop
state_name = _resolve_name(state, state_le) if state_le else state
if crop_name is None:
return {"error": f"Crop '{crop}' not found in {_ver} encoder", "confidence": "LOW"}
if state_name is None:
return {"error": f"State '{state}' not found in {_ver} encoder", "confidence": "LOW"}
crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0
state_enc = int(state_le.transform([state_name])[0]) if state_le else 0
ref_month = sowing_month or date.today().month
ref_year = date.today().year
mth_sin = np.sin(2 * np.pi * ref_month / 12)
mth_cos = np.cos(2 * np.pi * ref_month / 12)
row = _get_feature_row(crop_name, state_name, district, ref_month, ref_year)
def _fs(col, default=0.0):
if row is not None and col in row.index and pd.notna(row[col]):
return float(row[col])
return float(default)
# Compute v4 new features using available data as proxies
sl1 = _fs("state_lag_1yr") # state-level price 1 year ago at this month
sl2 = _fs("state_lag_2yr")
ap = _fs("avg_price")
# harvest_price_lag1: best proxy = state_lag_1yr (state price 1yr ago)
hpl1 = sl1 if sl1 > 0 else ap * 0.95
hpl2 = sl2 if sl2 > 0 else ap * 0.90
# price_cv: try to compute from recent mandi data, else use crop-level default
_price_cv = 0.15 # default: 15% CV for most crops
try:
conf_map = meta.get("confidence_map", {})
crop_conf = conf_map.get(crop_name, "MEDIUM")
_price_cv = {"HIGH": 0.10, "MEDIUM": 0.18, "LOW": 0.30}.get(crop_conf, 0.15)
except Exception:
pass
# harvest_to_sow_ratio: last-year seasonal lift
_h2s = (hpl1 / ap) if ap > 10 else 1.0
_h2s = max(0.3, min(5.0, _h2s))
base_map = {
# identifiers
"crop_enc": crop_enc,
"state_enc": state_enc,
# time
"month": ref_month,
"year": ref_year,
"month_sin": mth_sin,
"month_cos": mth_cos,
# price levels
"nat_avg": _fs("nat_avg"),
"st_avg": _fs("st_avg"),
"avg_price": ap,
"state_vs_national_pct": _fs("state_vs_national_pct"),
# price lags (feature store)
"dist_lag_1yr": _fs("dist_lag_1yr"),
"dist_lag_2yr": _fs("dist_lag_2yr"),
"dist_lag_3yr": _fs("dist_lag_3yr"),
"state_lag_1yr": sl1,
"state_lag_2yr": sl2,
"nat_lag_1yr": _fs("nat_lag_1yr"),
"nat_lag_2yr": _fs("nat_lag_2yr"),
"dist_roll3yr": _fs("dist_roll3yr"),
# seasonal
"seasonal_idx": _fs("seasonal_idx", 1.0),
"price_pctile": _fs("price_pctile", 50),
# v4 new features
"harvest_price_lag1": hpl1,
"harvest_price_lag2": hpl2,
"price_cv": _price_cv,
"harvest_to_sow_ratio": _h2s,
# v3 legacy features (filled neutral when v4)
"temp_mean": 0.0,
"rainfall_28d": 0.0,
"ndvi_mean": 0.0,
"area_yoy_pct": 0.0,
"area_2yr_pct": 0.0,
"area_vs_3yr": 0.0,
"supply_up": 0.0,
"supply_down": 0.0,
}
feature_values = [base_map.get(f, 0.0) for f in ml_features]
X = np.array(feature_values, dtype=float).reshape(1, -1)
# Predict using quantile models (p25, p50, p75) if available
if isinstance(model, dict) and "p25" in model and "p50" in model and "p75" in model:
p50_pred = float(model["p50"].predict(X)[0])
p25_pred = float(model["p25"].predict(X)[0])
p75_pred = float(model["p75"].predict(X)[0])
# Ensure monotonicity
p25_pred = min(p25_pred, p50_pred)
p75_pred = max(p75_pred, p50_pred)
else:
# Legacy: single model + ±15% spread
p50_pred = float(model.predict(X)[0])
p25_pred = p50_pred * 0.85
p75_pred = p50_pred * 1.15
p25 = round(p25_pred, 0)
p50 = round(p50_pred, 0)
p75 = round(p75_pred, 0)
# Harvest window from calendar
harvest_cal = meta.get("harvest_calendar", {})
hw_months = None
for k in harvest_cal:
if k.lower() in crop_name.lower() or crop_name.lower() in k.lower():
hw_months = harvest_cal[k]
break
if hw_months and hw_months != (0, 0):
hw_start = hw_months[0]
try:
harvest_label = datetime(ref_year + (1 if hw_start < ref_month else 0),
hw_start, 1).strftime("%B %Y")
except Exception:
harvest_label = "---"
else:
harvest_month = ((ref_month + 3) % 12) + 1
harvest_label = datetime(ref_year, harvest_month, 1).strftime("%B %Y")
# MSP lookup
msp_entry = None
crop_key = crop_name
for k in MSP_TABLE:
if k.lower() in crop_name.lower() or crop_name.lower() in k.lower():
crop_key = k
break
if crop_key in MSP_TABLE:
msp_entry = (MSP_TABLE[crop_key].get(ref_year)
or MSP_TABLE[crop_key].get(max(MSP_TABLE[crop_key])))
# Profit probability
if msp_entry:
if p25 > msp_entry:
profit_probability = "HIGH"
elif p50 > msp_entry:
profit_probability = "MEDIUM"
else:
profit_probability = "LOW"
else:
profit_probability = "MEDIUM"
# Confidence from per-crop map
conf_map = meta.get("confidence_map", {})
confidence_label = conf_map.get(crop_name, "MEDIUM")
# Accuracy metadata for UI
val_meta = meta.get("validation", {})
p50_acc = val_meta.get("P50", {}).get("within15pct", 87.2)
mape_pct = val_meta.get("P50", {}).get("mape", 7.6)
reasoning_parts = [
f"Forecast for {crop_name} in {state_name}.",
f"Expected harvest window: {harvest_label}.",
f"Price range: ₹{p25:.0f}–₹{p75:.0f}/qtl (median ₹{p50:.0f}).",
]
if msp_entry:
reasoning_parts.append(
f"MSP ({ref_year}): ₹{msp_entry}/qtl — profit probability: {profit_probability}."
)
return {
"crop": crop,
"state": state,
"district": district,
"sowing_month": ref_month,
"p25": p25,
"p50": p50,
"p75": p75,
"harvest_window": harvest_label,
"msp": msp_entry,
"profit_probability": profit_probability,
"confidence": confidence_label,
"mape_pct": mape_pct,
"reasoning": " ".join(reasoning_parts),
"model_version": _ver,
}
except Exception as e:
return {
"error": str(e),
"confidence": "LOW",
"crop": crop,
"state": state,
}
# ---------------------------------------------------------------------------
# Public API: get_batch_signals
# ---------------------------------------------------------------------------
def get_batch_signals(requests_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Process a list of signal requests in sequence.
Parameters
----------
requests_list : list of dicts, each with keys:
crop, state, district (opt), current_price (opt), date (opt)
Returns
-------
list of get_signal() result dicts (same order as input)
"""
results = []
for req in requests_list:
try:
out = get_signal(
crop=req.get("crop", ""),
state=req.get("state", ""),
district=req.get("district"),
current_price=req.get("current_price"),
date_input=req.get("date"),
)
except Exception as e:
out = {
"error": str(e),
"confidence": "LOW",
"crop": req.get("crop"),
"state": req.get("state"),
}
results.append(out)
return results
# ---------------------------------------------------------------------------
# Public API: get_coverage
# ---------------------------------------------------------------------------
def get_coverage() -> Dict[str, Any]:
"""
Return metadata about what crops, states, and districts are covered.
Returns
-------
dict: crops, states, districts_count, weekly_crops, monthly_crops_count
"""
try:
sql_crops = f"SELECT DISTINCT commodity FROM read_parquet('{FEATURE_STORE}') ORDER BY commodity"
sql_states = f"SELECT DISTINCT state FROM read_parquet('{FEATURE_STORE}') ORDER BY state"
sql_dist_count = f"SELECT COUNT(DISTINCT district) AS cnt FROM read_parquet('{FEATURE_STORE}')"
crops_df = _duck(sql_crops)
states_df = _duck(sql_states)
dist_count_df = _duck(sql_dist_count)
all_crops = crops_df["commodity"].tolist() if not crops_df.empty else []
all_states = states_df["state"].tolist() if not states_df.empty else []
dist_count = int(dist_count_df["cnt"].iloc[0]) if not dist_count_df.empty else 0
weekly_crops = [c for c in PERISHABLE_CROPS if any(c.lower() == a.lower() for a in all_crops)]
monthly_count = len([c for c in all_crops if not any(c.lower() == p.lower() for p in PERISHABLE_CROPS)])
return {
"crops": all_crops,
"total_crops": len(all_crops),
"states": all_states,
"total_districts": dist_count,
"districts_count": dist_count,
"weekly_crops": weekly_crops,
"monthly_crops_count": monthly_count,
}
except Exception as e:
return {
"error": str(e),
"weekly_crops": PERISHABLE_CROPS,
"monthly_crops_count": 0,
"total_crops": 0,
"total_districts": 0,
"crops": [],
"states": [],
"districts_count": 0,
}
# ---------------------------------------------------------------------------
# CLI smoke test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("=== Enterprise Engine v2 — Smoke Test ===\n")
print("[1] get_signal (perishable — Tomato)")
r = get_signal("Tomato", "Maharashtra", "Pune", current_price=1200)
for k, v in r.items():
print(f" {k}: {v}")
print("\n[2] get_signal (non-perishable — Wheat)")
r2 = get_signal("Wheat", "Punjab", "Ludhiana", current_price=2100)
for k, v in r2.items():
print(f" {k}: {v}")
print("\n[3] get_presow_signal (Wheat)")
r3 = get_presow_signal("Wheat", "Haryana", sowing_month=11)
for k, v in r3.items():
print(f" {k}: {v}")
print("\n[4] get_batch_signals")
batch = get_batch_signals([
{"crop": "Onion", "state": "Maharashtra", "current_price": 800},
{"crop": "Mustard", "state": "Rajasthan", "current_price": 5800},
])
for i, b in enumerate(batch, 1):
print(f" [{i}] {b.get('crop')}{b.get('recommendation')} ({b.get('confidence')})")
print("\n[5] get_coverage")
cov = get_coverage()
print(f" Crops: {len(cov.get('crops', []))}, States: {len(cov.get('states', []))}, "
f"Districts: {cov.get('districts_count')}, Weekly: {len(cov.get('weekly_crops', []))}")