| """ |
| Enterprise Engine v2 — Mandi Advisor |
| Production routing engine: weekly perishable model → district v2 monthly model → v4 state fallback. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import pickle |
| import warnings |
| from datetime import date, datetime |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import duckdb |
| import numpy as np |
| import pandas as pd |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
| |
| |
|
|
| BASE = Path(__file__).parent.resolve() |
|
|
| FEATURE_STORE = BASE / "feature_store.parquet" |
| MANDI_DATA = BASE / "mandi_data_clean.parquet" |
|
|
| PERISHABLE_CROPS = [ |
| "Tomato", "Onion", "Potato", "Green Chilli", "Brinjal", |
| "Cabbage", "Cauliflower", "Bhindi (Okra)", "Banana", |
| "Cucumber", "Bitter Gourd", "Bottle Gourd", "Pumpkin", |
| ] |
|
|
| MSP_TABLE: Dict[str, Dict[int, int]] = { |
| "Wheat": {2024: 2275, 2025: 2425, 2026: 2600}, |
| "Paddy": {2024: 2300, 2025: 2369, 2026: 2500}, |
| "Mustard": {2024: 5650, 2025: 5950, 2026: 6200}, |
| "Gram": {2024: 5440, 2025: 5650, 2026: 5800}, |
| "Bengal Gram": {2024: 5440, 2025: 5650, 2026: 5800}, |
| "Arhar/Tur": {2024: 7000, 2025: 7550, 2026: 8000}, |
| "Moong": {2024: 8682, 2025: 8908, 2026: 9500}, |
| } |
|
|
| MODEL_FILES = { |
| "weekly_perishable": { |
| "model": BASE / "weekly_perishable_model.pkl", |
| "meta": BASE / "weekly_perishable_meta.pkl", |
| "winrates": BASE / "weekly_perishable_winrates.pkl", |
| }, |
| "district_v2": { |
| "model": BASE / "district_sellhold_v2_model.pkl", |
| "meta": BASE / "district_sellhold_v2_meta.pkl", |
| "winrates": BASE / "district_sellhold_v2_winrates.pkl", |
| }, |
| "presow_v3": { |
| "model": BASE / "presow_v3_model.pkl", |
| "meta": BASE / "presow_v3_meta.pkl", |
| }, |
| "presow_v4": { |
| "model": BASE / "presow_v4_model.pkl", |
| "meta": BASE / "presow_v4_meta.pkl", |
| }, |
| "sellhold_v4": { |
| "model": BASE / "sellhold_v4_model.pkl", |
| "meta": BASE / "sellhold_v4_meta.pkl", |
| }, |
| } |
|
|
| LOW_CONF_THRESHOLD = 0.55 |
| VOLATILITY_THRESHOLD = 0.35 |
|
|
| |
| |
| |
|
|
| _CACHE: Dict[str, Any] = {} |
|
|
|
|
| def _load_model(name: str) -> Any: |
| """Lazy-load and cache a pickle artefact by logical name (key in MODEL_FILES).""" |
| if name in _CACHE: |
| return _CACHE[name] |
| entry = MODEL_FILES.get(name) |
| if entry is None: |
| raise KeyError(f"Unknown model name: {name!r}") |
| loaded: Dict[str, Any] = {} |
| for part, path in entry.items(): |
| with open(path, "rb") as fh: |
| loaded[part] = pickle.load(fh) |
| _CACHE[name] = loaded |
| return loaded |
|
|
|
|
| |
| |
| |
|
|
| def _sanitize(val: str) -> str: |
| """Strip SQL-injection characters from a commodity/state/district name. |
| These values come from an internal AGMARKNET commodity map, but sanitizing |
| ensures safety if the call path ever receives user-supplied input. |
| """ |
| if val is None: |
| return "" |
| |
| return val.replace("'", "").replace('"', "").replace(";", "").replace("--", "").strip() |
|
|
|
|
| def _duck(sql: str, params: Optional[list] = None) -> pd.DataFrame: |
| """Execute a DuckDB query and return a DataFrame.""" |
| con = duckdb.connect(database=":memory:") |
| con.execute("SET memory_limit='1GB';") |
| if params: |
| return con.execute(sql, params).df() |
| return con.execute(sql).df() |
|
|
|
|
| |
| |
| |
|
|
| def _resolve_name(name: str, le) -> Optional[str]: |
| """ |
| Attempt to resolve *name* against the classes of a sklearn LabelEncoder. |
| Priority: exact → case-insensitive → substring. |
| Returns the matched class string, or None if no match. |
| """ |
| if name is None: |
| return None |
| classes: List[str] = list(le.classes_) |
|
|
| |
| if name in classes: |
| return name |
|
|
| |
| name_lower = name.lower() |
| for c in classes: |
| if c.lower() == name_lower: |
| return c |
|
|
| |
| for c in classes: |
| if name_lower in c.lower() or c.lower() in name_lower: |
| return c |
|
|
| return None |
|
|
|
|
| |
| |
| |
|
|
| def _get_feature_row( |
| crop: str, |
| state: str, |
| district: Optional[str], |
| month: int, |
| year: int, |
| ) -> Optional[pd.Series]: |
| """Pull the most relevant feature-store row for the given key.""" |
| try: |
| _crop = _sanitize(crop).lower() |
| _state = _sanitize(state).lower() |
| _district = _sanitize(district).lower() if district else None |
| _month = int(month) |
|
|
| if _district: |
| sql = f""" |
| SELECT * FROM read_parquet('{FEATURE_STORE}') |
| WHERE lower(commodity) = $1 |
| AND lower(state) = $2 |
| AND lower(district) = $3 |
| AND month = $4 |
| ORDER BY year DESC |
| LIMIT 1 |
| """ |
| df = _duck(sql, [_crop, _state, _district, _month]) |
| else: |
| sql = f""" |
| SELECT * FROM read_parquet('{FEATURE_STORE}') |
| WHERE lower(commodity) = $1 |
| AND lower(state) = $2 |
| AND month = $3 |
| ORDER BY year DESC |
| LIMIT 1 |
| """ |
| df = _duck(sql, [_crop, _state, _month]) |
|
|
| if df.empty: |
| |
| if _district: |
| sql2 = f""" |
| SELECT * FROM read_parquet('{FEATURE_STORE}') |
| WHERE lower(commodity) = $1 |
| AND lower(state) = $2 |
| AND lower(district) = $3 |
| ORDER BY year DESC, month DESC |
| LIMIT 1 |
| """ |
| df = _duck(sql2, [_crop, _state, _district]) |
| else: |
| sql2 = f""" |
| SELECT * FROM read_parquet('{FEATURE_STORE}') |
| WHERE lower(commodity) = $1 |
| AND lower(state) = $2 |
| ORDER BY year DESC, month DESC |
| LIMIT 1 |
| """ |
| df = _duck(sql2, [_crop, _state]) |
| return df.iloc[0] if not df.empty else None |
| except Exception: |
| return None |
|
|
|
|
| def _get_recent_mandi_prices(crop: str, state: str, days: int = 7) -> Optional[pd.DataFrame]: |
| """Return recent modal_price rows for volatility check.""" |
| try: |
| _crop = _sanitize(crop).lower() |
| _state = _sanitize(state).lower() |
| _limit = int(days) * 20 |
| sql = f""" |
| SELECT arrival_date, modal_price |
| FROM read_parquet('{MANDI_DATA}') |
| WHERE lower(commodity) = $1 |
| AND lower(state) = $2 |
| ORDER BY arrival_date DESC |
| LIMIT {_limit} |
| """ |
| return _duck(sql, [_crop, _state]) |
| except Exception: |
| return None |
|
|
|
|
| def _check_volatility(crop: str, state: str) -> tuple[bool, str]: |
| """ |
| Returns (is_volatile: bool, message: str). |
| Checks if the price moved >35% in the last 7 days of available data. |
| """ |
| try: |
| df = _get_recent_mandi_prices(crop, state, days=7) |
| if df is None or df.empty or len(df) < 2: |
| return False, "" |
| prices = df["modal_price"].dropna().values |
| if len(prices) < 2: |
| return False, "" |
| p_latest = float(prices[0]) |
| p_old = float(prices[-1]) |
| if p_old == 0: |
| return False, "" |
| pct_change = abs(p_latest - p_old) / p_old |
| if pct_change > VOLATILITY_THRESHOLD: |
| return True, f"⚠️ High volatility — signal suppressed (price moved {pct_change*100:.1f}% in last 7 days)" |
| return False, "" |
| except Exception: |
| return False, "" |
|
|
|
|
| |
| |
| |
|
|
| def _score_to_label(score: float) -> str: |
| if score >= 0.70: |
| return "HIGH" |
| if score >= LOW_CONF_THRESHOLD: |
| return "MEDIUM" |
| return "LOW" |
|
|
|
|
| def _lookup_win_rate(winrates: Any, crop: str, state: str, |
| district: Optional[str] = None, |
| month: Optional[int] = None, |
| iso_week: Optional[int] = None) -> Optional[float]: |
| """Best-effort win-rate lookup from the winrates dict/DataFrame artefact.""" |
| try: |
| cur_month = month or date.today().month |
| cur_isoweek = iso_week or date.today().isocalendar()[1] |
|
|
| def _time_mask(df): |
| """Return time-based mask: iso_week if available, else month.""" |
| if "iso_week" in df.columns: |
| return df["iso_week"] == cur_isoweek |
| elif "month" in df.columns: |
| return df["month"] == cur_month |
| return pd.Series([True] * len(df), index=df.index) |
|
|
| |
| if isinstance(winrates, dict) and any(k in winrates for k in ("national", "state", "district")): |
| |
| if district and "district" in winrates: |
| df = winrates["district"] |
| if isinstance(df, pd.DataFrame) and "district" in df.columns: |
| mask = ( |
| (df["commodity"].str.lower() == crop.lower()) & |
| (df["state"].str.lower() == state.lower()) & |
| (df["district"].str.lower() == district.lower()) & |
| _time_mask(df) |
| ) |
| row = df[mask] |
| if not row.empty: |
| wr_col = [c for c in df.columns if "win" in c.lower()] |
| if wr_col: |
| return float(row.iloc[0][wr_col[0]]) |
| |
| if "state" in winrates: |
| df = winrates["state"] |
| if isinstance(df, pd.DataFrame): |
| mask = ( |
| (df["commodity"].str.lower() == crop.lower()) & |
| (df["state"].str.lower() == state.lower()) & |
| _time_mask(df) |
| ) |
| row = df[mask] |
| if not row.empty: |
| wr_col = [c for c in df.columns if "win" in c.lower()] |
| if wr_col: |
| return float(row.iloc[0][wr_col[0]]) |
| |
| if "national" in winrates: |
| df = winrates["national"] |
| if isinstance(df, pd.DataFrame): |
| mask = ( |
| (df["commodity"].str.lower() == crop.lower()) & |
| _time_mask(df) |
| ) |
| row = df[mask] |
| if not row.empty: |
| wr_col = [c for c in df.columns if "win" in c.lower()] |
| if wr_col: |
| return float(row.iloc[0][wr_col[0]]) |
| return None |
|
|
| |
| if isinstance(winrates, dict): |
| for key in [(crop, state, district), (crop, state), (crop,), ("overall",)]: |
| if key in winrates: |
| v = winrates[key] |
| return float(v) if v is not None else None |
|
|
| |
| if isinstance(winrates, pd.DataFrame): |
| cols = [c.lower() for c in winrates.columns] |
| month_col = next((c for c in winrates.columns if c.lower() == "month"), None) |
| month_mask = (winrates[month_col] == cur_month) if month_col else True |
| if all(k in cols for k in ["commodity", "state", "district"]) and district: |
| row = winrates[ |
| (winrates.iloc[:, cols.index("commodity")].str.lower() == crop.lower()) & |
| (winrates.iloc[:, cols.index("state")].str.lower() == state.lower()) & |
| (winrates.iloc[:, cols.index("district")].str.lower() == district.lower()) & |
| month_mask |
| ] |
| if not row.empty: |
| wr_col = [c for c in winrates.columns if "win" in c.lower()] |
| if wr_col: |
| return float(row.iloc[0][wr_col[0]]) |
| if all(k in cols for k in ["commodity", "state"]): |
| row = winrates[ |
| (winrates.iloc[:, cols.index("commodity")].str.lower() == crop.lower()) & |
| (winrates.iloc[:, cols.index("state")].str.lower() == state.lower()) & |
| month_mask |
| ] |
| if not row.empty: |
| wr_col = [c for c in winrates.columns if "win" in c.lower()] |
| if wr_col: |
| return float(row.iloc[0][wr_col[0]]) |
| return None |
| except Exception: |
| return None |
|
|
|
|
| |
| |
| |
|
|
| def _predict_weekly( |
| crop: str, |
| state: str, |
| district: Optional[str], |
| current_price: Optional[float], |
| ref_date: date, |
| ) -> Dict[str, Any]: |
| """Build features and predict using the weekly perishable model.""" |
| artefacts = _load_model("weekly_perishable") |
| model = artefacts["model"] |
| meta = artefacts["meta"] |
| winrates = artefacts.get("winrates", {}) |
|
|
| |
| crop_le = meta.get("le_crop") or meta.get("crop_le") |
| state_le = meta.get("le_state") or meta.get("state_le") |
| ml_features = meta.get("features", meta.get("ML_FEATURES", [])) |
|
|
| |
| crop_name = _resolve_name(crop, crop_le) if crop_le else crop |
| state_name = _resolve_name(state, state_le) if state_le else state |
|
|
| if crop_name is None: |
| raise ValueError(f"Crop '{crop}' not found in weekly model encoder") |
| if state_name is None: |
| raise ValueError(f"State '{state}' not found in weekly model encoder") |
|
|
| crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0 |
| state_enc = int(state_le.transform([state_name])[0]) if state_le else 0 |
|
|
| iso_week = ref_date.isocalendar()[1] |
| month = ref_date.month |
| year = ref_date.year |
|
|
| |
| wk_sin = np.sin(2 * np.pi * iso_week / 52) |
| wk_cos = np.cos(2 * np.pi * iso_week / 52) |
| mth_sin = np.sin(2 * np.pi * month / 12) |
| mth_cos = np.cos(2 * np.pi * month / 12) |
|
|
| |
| row = _get_feature_row(crop, state, district, month, year) |
|
|
| def _fs(col, default=0.0): |
| if row is not None and col in row.index and pd.notna(row[col]): |
| try: |
| return float(row[col]) |
| except (ValueError, TypeError): |
| return float(default) |
| return float(default) |
|
|
| def _fs_str(col, default="state"): |
| if row is not None and col in row.index and pd.notna(row[col]): |
| return str(row[col]) |
| return default |
|
|
| |
| feature_map = { |
| "crop_enc": crop_enc, |
| "state_enc": state_enc, |
| "iso_week": iso_week, |
| "month": month, |
| "year": year, |
| "wk_sin": wk_sin, |
| "wk_cos": wk_cos, |
| "mth_sin": mth_sin, |
| "mth_cos": mth_cos, |
| |
| "price_vs_4w": _fs("district_vs_state_pct"), |
| "price_vs_13w": _fs("state_vs_national_pct"), |
| "price_vs_52w": _fs("dist_lag_1yr", 0), |
| "price_vs_seas_w": _fs("seasonal_idx", 1.0) - 1.0, |
| "price_vs_st_pct": _fs("district_vs_state_pct"), |
| "price_vs_nat_pct": _fs("district_vs_national_pct"), |
| "mom_1w": 0.0, |
| "mom_2w": 0.0, |
| "mom_4w": _fs("district_vs_state_pct"), |
| "pctile_52w": _fs("price_pctile", 50), |
| "nat_win_rate_w": 0.5, |
| "st_win_rate_w": 0.5, |
| "price_spread_pct": 0.0, |
| "n_markets": _fs("district_records", 1), |
| } |
|
|
| feature_values = [feature_map.get(f, 0.0) for f in ml_features] |
| X = np.array(feature_values, dtype=float).reshape(1, -1) |
|
|
| |
| if hasattr(model, "predict_proba"): |
| proba = model.predict_proba(X)[0] |
| pred_class = int(model.predict(X)[0]) |
| confidence_score = float(proba[pred_class]) |
| else: |
| pred_class = int(model.predict(X)[0]) |
| confidence_score = 0.6 |
|
|
| classes = list(model.classes_) if hasattr(model, "classes_") else [0, 1] |
| label_map = meta.get("label_map", {0: "SELL", 1: "HOLD", 2: "STRONG_HOLD"}) |
| recommendation = label_map.get(pred_class, "HOLD") |
|
|
| win_rate = _lookup_win_rate(winrates, crop_name, state_name, iso_week=iso_week) |
|
|
| seasonal_avg = _fs("st_avg", None) |
| if seasonal_avg == 0.0: |
| seasonal_avg = None |
|
|
| price_vs_seasonal_pct = None |
| if current_price and seasonal_avg: |
| price_vs_seasonal_pct = round((current_price - seasonal_avg) / seasonal_avg * 100, 1) |
|
|
| return { |
| "recommendation": recommendation, |
| "confidence_score": confidence_score, |
| "confidence": _score_to_label(confidence_score), |
| "win_rate": round(win_rate, 3) if win_rate is not None else None, |
| "seasonal_avg": round(seasonal_avg, 1) if seasonal_avg else None, |
| "price_vs_seasonal_pct": price_vs_seasonal_pct, |
| "granularity": "weekly", |
| "model_version": "weekly_perishable_v1", |
| "data_level": _fs_str("data_level", "state") if row is not None else "state", |
| "iso_week": iso_week, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def _predict_district_v2( |
| crop: str, |
| state: str, |
| district: Optional[str], |
| current_price: Optional[float], |
| ref_date: date, |
| ) -> Dict[str, Any]: |
| """Build features and predict using the district sellhold v2 model.""" |
| artefacts = _load_model("district_v2") |
| model = artefacts["model"] |
| meta = artefacts["meta"] |
| winrates = artefacts.get("winrates", {}) |
|
|
| crop_le = meta.get("crop_le") |
| state_le = meta.get("state_le") |
| district_le = meta.get("district_le") |
| ml_features = meta.get("ML_FEATURES", meta.get("features", [])) |
|
|
| crop_name = _resolve_name(crop, crop_le) if crop_le else crop |
| state_name = _resolve_name(state, state_le) if state_le else state |
| district_name = _resolve_name(district, district_le) if district_le and district else None |
|
|
| if crop_name is None: |
| raise ValueError(f"Crop '{crop}' not found in district_v2 encoder") |
| if state_name is None: |
| raise ValueError(f"State '{state}' not found in district_v2 encoder") |
|
|
| crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0 |
| state_enc = int(state_le.transform([state_name])[0]) if state_le else 0 |
| district_enc = ( |
| int(district_le.transform([district_name])[0]) |
| if district_le and district_name else 0 |
| ) |
|
|
| month = ref_date.month |
| year = ref_date.year |
|
|
| mth_sin = np.sin(2 * np.pi * month / 12) |
| mth_cos = np.cos(2 * np.pi * month / 12) |
|
|
| row = _get_feature_row(crop, state, district, month, year) |
|
|
| def _fs(col, default=0.0): |
| if row is not None and col in row.index and pd.notna(row[col]): |
| try: |
| return float(row[col]) |
| except (ValueError, TypeError): |
| return float(default) |
| return float(default) |
|
|
| def _fs_str(col, default="state"): |
| if row is not None and col in row.index and pd.notna(row[col]): |
| return str(row[col]) |
| return default |
|
|
| feature_map = { |
| "crop_enc": crop_enc, |
| "state_enc": state_enc, |
| "district_enc": district_enc, |
| "month": month, |
| "year": year, |
| "month_sin": mth_sin, |
| "month_cos": mth_cos, |
| "price_vs_dist_seas_avg": _fs("seasonal_idx", 1.0) - 1.0, |
| "price_vs_dist_seas_std": 0.0, |
| "price_vs_st_seas_avg": _fs("state_vs_national_pct"), |
| "state_vs_national_pct": _fs("state_vs_national_pct"), |
| "district_vs_state_pct": _fs("district_vs_state_pct"), |
| "district_vs_national_pct": _fs("district_vs_national_pct"), |
| "price_pctile": _fs("price_pctile", 50), |
| "mom_1m": 0.0, |
| "mom_3m": 0.0, |
| "mom_6m": 0.0, |
| "seasonal_idx": _fs("seasonal_idx", 1.0), |
| "nat_win_rate": 0.5, |
| "st_win_rate": 0.5, |
| "dist_win_rate": 0.5, |
| "eff_win_rate": 0.5, |
| "area_vs_5yr_avg": 0.0, |
| "yield_vs_5yr_avg": 0.0, |
| "area_yoy_pct": 0.0, |
| "district_records": _fs("district_records", 1), |
| } |
|
|
| feature_values = [feature_map.get(f, 0.0) for f in ml_features] |
| X = np.array(feature_values, dtype=float).reshape(1, -1) |
|
|
| if hasattr(model, "predict_proba"): |
| proba = model.predict_proba(X)[0] |
| pred_class = int(model.predict(X)[0]) |
| confidence_score = float(proba[pred_class]) |
| else: |
| pred_class = int(model.predict(X)[0]) |
| confidence_score = 0.6 |
|
|
| label_map = meta.get("label_map", {0: "SELL", 1: "HOLD", 2: "STRONG_HOLD"}) |
| recommendation = label_map.get(pred_class, "HOLD") |
|
|
| win_rate = _lookup_win_rate(winrates, crop_name, state_name, district_name, month=ref_date.month) |
|
|
| seasonal_avg = _fs("st_avg", None) |
| if seasonal_avg == 0.0: |
| seasonal_avg = _fs("nat_avg", None) |
| if seasonal_avg == 0.0: |
| seasonal_avg = None |
|
|
| price_vs_seasonal_pct = None |
| if current_price and seasonal_avg: |
| price_vs_seasonal_pct = round((current_price - seasonal_avg) / seasonal_avg * 100, 1) |
|
|
| data_level = _fs_str("data_level", "district" if district_name else "state") |
|
|
| return { |
| "recommendation": recommendation, |
| "confidence_score": confidence_score, |
| "confidence": _score_to_label(confidence_score), |
| "win_rate": round(win_rate, 3) if win_rate is not None else None, |
| "seasonal_avg": round(seasonal_avg, 1) if seasonal_avg else None, |
| "price_vs_seasonal_pct": price_vs_seasonal_pct, |
| "granularity": "monthly", |
| "model_version": "district_sellhold_v2", |
| "data_level": data_level, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def _predict_v4_fallback( |
| crop: str, |
| state: str, |
| current_price: Optional[float], |
| ref_date: date, |
| ) -> Dict[str, Any]: |
| """Fallback to v4 state-level monthly model.""" |
| artefacts = _load_model("sellhold_v4") |
| model = artefacts["model"] |
| meta = artefacts["meta"] |
|
|
| crop_le = meta.get("crop_le") |
| state_le = meta.get("state_le") |
| ml_features = meta.get("ML_FEATURES", meta.get("features", [])) |
|
|
| crop_name = _resolve_name(crop, crop_le) if crop_le else crop |
| state_name = _resolve_name(state, state_le) if state_le else state |
|
|
| if crop_name is None: |
| raise ValueError(f"Crop '{crop}' not found in v4 encoder") |
| if state_name is None: |
| raise ValueError(f"State '{state}' not found in v4 encoder") |
|
|
| crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0 |
| state_enc = int(state_le.transform([state_name])[0]) if state_le else 0 |
|
|
| month = ref_date.month |
| year = ref_date.year |
| mth_sin = np.sin(2 * np.pi * month / 12) |
| mth_cos = np.cos(2 * np.pi * month / 12) |
|
|
| row = _get_feature_row(crop, state, None, month, year) |
|
|
| def _fs(col, default=0.0): |
| if row is not None and col in row.index and pd.notna(row[col]): |
| return float(row[col]) |
| return float(default) |
|
|
| base_map = { |
| "crop_enc": crop_enc, |
| "state_enc": state_enc, |
| "month": month, |
| "year": year, |
| "month_sin": mth_sin, |
| "month_cos": mth_cos, |
| "state_vs_national_pct": _fs("state_vs_national_pct"), |
| "seasonal_idx": _fs("seasonal_idx", 1.0), |
| "price_pctile": _fs("price_pctile", 50), |
| "state_lag_1yr": _fs("state_lag_1yr"), |
| "state_lag_2yr": _fs("state_lag_2yr"), |
| "nat_lag_1yr": _fs("nat_lag_1yr"), |
| "nat_lag_2yr": _fs("nat_lag_2yr"), |
| } |
|
|
| feature_values = [base_map.get(f, 0.0) for f in ml_features] |
| X = np.array(feature_values, dtype=float).reshape(1, -1) |
|
|
| if hasattr(model, "predict_proba"): |
| proba = model.predict_proba(X)[0] |
| pred_class = int(model.predict(X)[0]) |
| confidence_score = float(proba[pred_class]) |
| else: |
| pred_class = int(model.predict(X)[0]) |
| confidence_score = 0.55 |
|
|
| label_map = meta.get("label_map", {0: "SELL", 1: "HOLD", 2: "STRONG_HOLD"}) |
| recommendation = label_map.get(pred_class, "HOLD") |
|
|
| seasonal_avg = _fs("st_avg", None) or _fs("nat_avg", None) or None |
| price_vs_seasonal_pct = None |
| if current_price and seasonal_avg: |
| price_vs_seasonal_pct = round((current_price - float(seasonal_avg) - 0) / float(seasonal_avg) * 100, 1) |
|
|
| return { |
| "recommendation": recommendation, |
| "confidence_score": confidence_score, |
| "confidence": _score_to_label(confidence_score), |
| "win_rate": None, |
| "seasonal_avg": round(float(seasonal_avg), 1) if seasonal_avg else None, |
| "price_vs_seasonal_pct": price_vs_seasonal_pct, |
| "granularity": "monthly", |
| "model_version": "sellhold_v4_fallback", |
| "data_level": "state", |
| } |
|
|
|
|
| |
| |
| |
|
|
| def get_signal( |
| crop: str, |
| state: str, |
| district: Optional[str] = None, |
| current_price: Optional[float] = None, |
| date_input: Optional[Union[str, date, datetime]] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Route a price-signal request to the correct model and return a unified dict. |
| |
| Parameters |
| ---------- |
| crop : Commodity name (e.g. "Tomato", "Wheat") |
| state : State name |
| district : District name (optional) |
| current_price : Latest market price (₹/quintal) |
| date_input : Reference date; defaults to today |
| |
| Returns |
| ------- |
| dict with keys: crop, state, district, date, granularity, recommendation, |
| confidence, win_rate, current_price, seasonal_avg, |
| price_vs_seasonal_pct, reasoning, model_version, data_level, |
| warning (optional) |
| """ |
| try: |
| |
| if date_input is None: |
| ref_date = date.today() |
| elif isinstance(date_input, datetime): |
| ref_date = date_input.date() |
| elif isinstance(date_input, str): |
| ref_date = datetime.strptime(date_input[:10], "%Y-%m-%d").date() |
| else: |
| ref_date = date_input |
|
|
| warnings_list: List[str] = [] |
| result: Dict[str, Any] = {} |
| used_fallback = False |
|
|
| |
| is_perishable = any(c.lower() == crop.lower() for c in PERISHABLE_CROPS) |
|
|
| if is_perishable: |
| try: |
| result = _predict_weekly(crop, state, district, current_price, ref_date) |
| except Exception as e: |
| warnings_list.append(f"Weekly model failed ({e}); falling back to v4") |
| result = _predict_v4_fallback(crop, state, current_price, ref_date) |
| used_fallback = True |
| else: |
| try: |
| result = _predict_district_v2(crop, state, district, current_price, ref_date) |
| except Exception as e: |
| warnings_list.append(f"District v2 model failed ({e}); falling back to v4") |
| try: |
| result = _predict_v4_fallback(crop, state, current_price, ref_date) |
| except Exception as e2: |
| return { |
| "error": str(e2), |
| "confidence": "LOW", |
| "crop": crop, |
| "state": state, |
| "district": district, |
| } |
| used_fallback = True |
|
|
| |
| confidence_score = result.get("confidence_score", 0.5) |
| if confidence_score < LOW_CONF_THRESHOLD: |
| result["confidence"] = "LOW" |
| warnings_list.append( |
| f"Model confidence {confidence_score:.2f} below threshold {LOW_CONF_THRESHOLD}" |
| ) |
|
|
| |
| is_volatile, vol_msg = _check_volatility(crop, state) |
| if is_volatile: |
| result["recommendation"] = "HOLD" |
| warnings_list.append(vol_msg) |
|
|
| |
| rec = result.get("recommendation", "HOLD") |
| conf = result.get("confidence", "MEDIUM") |
| p_vs_s = result.get("price_vs_seasonal_pct") |
| seas = result.get("seasonal_avg") |
|
|
| reasoning_parts = [ |
| f"Prediction: {rec} ({conf} confidence).", |
| ] |
| if p_vs_s is not None: |
| direction = "above" if p_vs_s >= 0 else "below" |
| reasoning_parts.append( |
| f"Current price is {abs(p_vs_s):.1f}% {direction} seasonal average" |
| + (f" (₹{seas}/qtl)" if seas else "") + "." |
| ) |
| if used_fallback: |
| reasoning_parts.append("State-level fallback model used.") |
| if warnings_list: |
| reasoning_parts += warnings_list |
|
|
| |
| out: Dict[str, Any] = { |
| "crop": crop, |
| "state": state, |
| "district": district, |
| "date": ref_date.isoformat(), |
| "granularity": result.get("granularity", "monthly"), |
| "recommendation": result.get("recommendation", "HOLD"), |
| "confidence": result.get("confidence", "LOW"), |
| "win_rate": result.get("win_rate"), |
| "current_price": current_price, |
| "seasonal_avg": result.get("seasonal_avg"), |
| "price_vs_seasonal_pct": result.get("price_vs_seasonal_pct"), |
| "reasoning": " ".join(reasoning_parts), |
| "model_version": result.get("model_version", "unknown"), |
| "data_level": result.get("data_level", "state"), |
| } |
| if warnings_list: |
| out["warning"] = " | ".join(warnings_list) |
|
|
| |
| out.pop("confidence_score", None) |
|
|
| return out |
|
|
| except Exception as e: |
| return { |
| "error": str(e), |
| "confidence": "LOW", |
| "crop": crop, |
| "state": state, |
| "district": district, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def get_presow_signal( |
| crop: str, |
| state: str, |
| district: Optional[str] = None, |
| sowing_month: Optional[int] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Return harvest-price range forecast for pre-sowing decisions. |
| |
| Parameters |
| ---------- |
| crop : Commodity name |
| state : State name |
| district : District (optional) |
| sowing_month : Month of sowing (1-12); defaults to current month |
| |
| Returns |
| ------- |
| dict: p25, p50, p75, harvest_window, msp, profit_probability, |
| confidence, reasoning |
| """ |
| try: |
| |
| _v4_ok = False |
| try: |
| art4 = _load_model("presow_v4") |
| _v4_ok = True |
| except Exception: |
| pass |
|
|
| artefacts = art4 if _v4_ok else _load_model("presow_v3") |
| model = artefacts["model"] |
| meta = artefacts["meta"] |
| _ver = "presow_v4" if _v4_ok else "presow_v3" |
|
|
| crop_le = meta.get("le_crop") or meta.get("crop_le") |
| state_le = meta.get("le_state") or meta.get("state_le") |
| ml_features = (meta.get("features_v4") or meta.get("features_v3") |
| or meta.get("ML_FEATURES") or meta.get("features", [])) |
|
|
| crop_name = _resolve_name(crop, crop_le) if crop_le else crop |
| state_name = _resolve_name(state, state_le) if state_le else state |
|
|
| if crop_name is None: |
| return {"error": f"Crop '{crop}' not found in {_ver} encoder", "confidence": "LOW"} |
| if state_name is None: |
| return {"error": f"State '{state}' not found in {_ver} encoder", "confidence": "LOW"} |
|
|
| crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0 |
| state_enc = int(state_le.transform([state_name])[0]) if state_le else 0 |
|
|
| ref_month = sowing_month or date.today().month |
| ref_year = date.today().year |
|
|
| mth_sin = np.sin(2 * np.pi * ref_month / 12) |
| mth_cos = np.cos(2 * np.pi * ref_month / 12) |
|
|
| row = _get_feature_row(crop_name, state_name, district, ref_month, ref_year) |
|
|
| def _fs(col, default=0.0): |
| if row is not None and col in row.index and pd.notna(row[col]): |
| return float(row[col]) |
| return float(default) |
|
|
| |
| sl1 = _fs("state_lag_1yr") |
| sl2 = _fs("state_lag_2yr") |
| ap = _fs("avg_price") |
|
|
| |
| hpl1 = sl1 if sl1 > 0 else ap * 0.95 |
| hpl2 = sl2 if sl2 > 0 else ap * 0.90 |
|
|
| |
| _price_cv = 0.15 |
| try: |
| conf_map = meta.get("confidence_map", {}) |
| crop_conf = conf_map.get(crop_name, "MEDIUM") |
| _price_cv = {"HIGH": 0.10, "MEDIUM": 0.18, "LOW": 0.30}.get(crop_conf, 0.15) |
| except Exception: |
| pass |
|
|
| |
| _h2s = (hpl1 / ap) if ap > 10 else 1.0 |
| _h2s = max(0.3, min(5.0, _h2s)) |
|
|
| base_map = { |
| |
| "crop_enc": crop_enc, |
| "state_enc": state_enc, |
| |
| "month": ref_month, |
| "year": ref_year, |
| "month_sin": mth_sin, |
| "month_cos": mth_cos, |
| |
| "nat_avg": _fs("nat_avg"), |
| "st_avg": _fs("st_avg"), |
| "avg_price": ap, |
| "state_vs_national_pct": _fs("state_vs_national_pct"), |
| |
| "dist_lag_1yr": _fs("dist_lag_1yr"), |
| "dist_lag_2yr": _fs("dist_lag_2yr"), |
| "dist_lag_3yr": _fs("dist_lag_3yr"), |
| "state_lag_1yr": sl1, |
| "state_lag_2yr": sl2, |
| "nat_lag_1yr": _fs("nat_lag_1yr"), |
| "nat_lag_2yr": _fs("nat_lag_2yr"), |
| "dist_roll3yr": _fs("dist_roll3yr"), |
| |
| "seasonal_idx": _fs("seasonal_idx", 1.0), |
| "price_pctile": _fs("price_pctile", 50), |
| |
| "harvest_price_lag1": hpl1, |
| "harvest_price_lag2": hpl2, |
| "price_cv": _price_cv, |
| "harvest_to_sow_ratio": _h2s, |
| |
| "temp_mean": 0.0, |
| "rainfall_28d": 0.0, |
| "ndvi_mean": 0.0, |
| "area_yoy_pct": 0.0, |
| "area_2yr_pct": 0.0, |
| "area_vs_3yr": 0.0, |
| "supply_up": 0.0, |
| "supply_down": 0.0, |
| } |
|
|
| feature_values = [base_map.get(f, 0.0) for f in ml_features] |
| X = np.array(feature_values, dtype=float).reshape(1, -1) |
|
|
| |
| if isinstance(model, dict) and "p25" in model and "p50" in model and "p75" in model: |
| p50_pred = float(model["p50"].predict(X)[0]) |
| p25_pred = float(model["p25"].predict(X)[0]) |
| p75_pred = float(model["p75"].predict(X)[0]) |
| |
| p25_pred = min(p25_pred, p50_pred) |
| p75_pred = max(p75_pred, p50_pred) |
| else: |
| |
| p50_pred = float(model.predict(X)[0]) |
| p25_pred = p50_pred * 0.85 |
| p75_pred = p50_pred * 1.15 |
|
|
| p25 = round(p25_pred, 0) |
| p50 = round(p50_pred, 0) |
| p75 = round(p75_pred, 0) |
|
|
| |
| harvest_cal = meta.get("harvest_calendar", {}) |
| hw_months = None |
| for k in harvest_cal: |
| if k.lower() in crop_name.lower() or crop_name.lower() in k.lower(): |
| hw_months = harvest_cal[k] |
| break |
| if hw_months and hw_months != (0, 0): |
| hw_start = hw_months[0] |
| try: |
| harvest_label = datetime(ref_year + (1 if hw_start < ref_month else 0), |
| hw_start, 1).strftime("%B %Y") |
| except Exception: |
| harvest_label = "---" |
| else: |
| harvest_month = ((ref_month + 3) % 12) + 1 |
| harvest_label = datetime(ref_year, harvest_month, 1).strftime("%B %Y") |
|
|
| |
| msp_entry = None |
| crop_key = crop_name |
| for k in MSP_TABLE: |
| if k.lower() in crop_name.lower() or crop_name.lower() in k.lower(): |
| crop_key = k |
| break |
| if crop_key in MSP_TABLE: |
| msp_entry = (MSP_TABLE[crop_key].get(ref_year) |
| or MSP_TABLE[crop_key].get(max(MSP_TABLE[crop_key]))) |
|
|
| |
| if msp_entry: |
| if p25 > msp_entry: |
| profit_probability = "HIGH" |
| elif p50 > msp_entry: |
| profit_probability = "MEDIUM" |
| else: |
| profit_probability = "LOW" |
| else: |
| profit_probability = "MEDIUM" |
|
|
| |
| conf_map = meta.get("confidence_map", {}) |
| confidence_label = conf_map.get(crop_name, "MEDIUM") |
|
|
| |
| val_meta = meta.get("validation", {}) |
| p50_acc = val_meta.get("P50", {}).get("within15pct", 87.2) |
| mape_pct = val_meta.get("P50", {}).get("mape", 7.6) |
|
|
| reasoning_parts = [ |
| f"Forecast for {crop_name} in {state_name}.", |
| f"Expected harvest window: {harvest_label}.", |
| f"Price range: ₹{p25:.0f}–₹{p75:.0f}/qtl (median ₹{p50:.0f}).", |
| ] |
| if msp_entry: |
| reasoning_parts.append( |
| f"MSP ({ref_year}): ₹{msp_entry}/qtl — profit probability: {profit_probability}." |
| ) |
|
|
| return { |
| "crop": crop, |
| "state": state, |
| "district": district, |
| "sowing_month": ref_month, |
| "p25": p25, |
| "p50": p50, |
| "p75": p75, |
| "harvest_window": harvest_label, |
| "msp": msp_entry, |
| "profit_probability": profit_probability, |
| "confidence": confidence_label, |
| "mape_pct": mape_pct, |
| "reasoning": " ".join(reasoning_parts), |
| "model_version": _ver, |
| } |
|
|
| except Exception as e: |
| return { |
| "error": str(e), |
| "confidence": "LOW", |
| "crop": crop, |
| "state": state, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def get_batch_signals(requests_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """ |
| Process a list of signal requests in sequence. |
| |
| Parameters |
| ---------- |
| requests_list : list of dicts, each with keys: |
| crop, state, district (opt), current_price (opt), date (opt) |
| |
| Returns |
| ------- |
| list of get_signal() result dicts (same order as input) |
| """ |
| results = [] |
| for req in requests_list: |
| try: |
| out = get_signal( |
| crop=req.get("crop", ""), |
| state=req.get("state", ""), |
| district=req.get("district"), |
| current_price=req.get("current_price"), |
| date_input=req.get("date"), |
| ) |
| except Exception as e: |
| out = { |
| "error": str(e), |
| "confidence": "LOW", |
| "crop": req.get("crop"), |
| "state": req.get("state"), |
| } |
| results.append(out) |
| return results |
|
|
|
|
| |
| |
| |
|
|
| def get_coverage() -> Dict[str, Any]: |
| """ |
| Return metadata about what crops, states, and districts are covered. |
| |
| Returns |
| ------- |
| dict: crops, states, districts_count, weekly_crops, monthly_crops_count |
| """ |
| try: |
| sql_crops = f"SELECT DISTINCT commodity FROM read_parquet('{FEATURE_STORE}') ORDER BY commodity" |
| sql_states = f"SELECT DISTINCT state FROM read_parquet('{FEATURE_STORE}') ORDER BY state" |
| sql_dist_count = f"SELECT COUNT(DISTINCT district) AS cnt FROM read_parquet('{FEATURE_STORE}')" |
|
|
| crops_df = _duck(sql_crops) |
| states_df = _duck(sql_states) |
| dist_count_df = _duck(sql_dist_count) |
|
|
| all_crops = crops_df["commodity"].tolist() if not crops_df.empty else [] |
| all_states = states_df["state"].tolist() if not states_df.empty else [] |
| dist_count = int(dist_count_df["cnt"].iloc[0]) if not dist_count_df.empty else 0 |
|
|
| weekly_crops = [c for c in PERISHABLE_CROPS if any(c.lower() == a.lower() for a in all_crops)] |
| monthly_count = len([c for c in all_crops if not any(c.lower() == p.lower() for p in PERISHABLE_CROPS)]) |
|
|
| return { |
| "crops": all_crops, |
| "total_crops": len(all_crops), |
| "states": all_states, |
| "total_districts": dist_count, |
| "districts_count": dist_count, |
| "weekly_crops": weekly_crops, |
| "monthly_crops_count": monthly_count, |
| } |
|
|
| except Exception as e: |
| return { |
| "error": str(e), |
| "weekly_crops": PERISHABLE_CROPS, |
| "monthly_crops_count": 0, |
| "total_crops": 0, |
| "total_districts": 0, |
| "crops": [], |
| "states": [], |
| "districts_count": 0, |
| } |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| print("=== Enterprise Engine v2 — Smoke Test ===\n") |
|
|
| print("[1] get_signal (perishable — Tomato)") |
| r = get_signal("Tomato", "Maharashtra", "Pune", current_price=1200) |
| for k, v in r.items(): |
| print(f" {k}: {v}") |
|
|
| print("\n[2] get_signal (non-perishable — Wheat)") |
| r2 = get_signal("Wheat", "Punjab", "Ludhiana", current_price=2100) |
| for k, v in r2.items(): |
| print(f" {k}: {v}") |
|
|
| print("\n[3] get_presow_signal (Wheat)") |
| r3 = get_presow_signal("Wheat", "Haryana", sowing_month=11) |
| for k, v in r3.items(): |
| print(f" {k}: {v}") |
|
|
| print("\n[4] get_batch_signals") |
| batch = get_batch_signals([ |
| {"crop": "Onion", "state": "Maharashtra", "current_price": 800}, |
| {"crop": "Mustard", "state": "Rajasthan", "current_price": 5800}, |
| ]) |
| for i, b in enumerate(batch, 1): |
| print(f" [{i}] {b.get('crop')} → {b.get('recommendation')} ({b.get('confidence')})") |
|
|
| print("\n[5] get_coverage") |
| cov = get_coverage() |
| print(f" Crops: {len(cov.get('crops', []))}, States: {len(cov.get('states', []))}, " |
| f"Districts: {cov.get('districts_count')}, Weekly: {len(cov.get('weekly_crops', []))}") |
|
|