""" Enterprise Engine v2 — Mandi Advisor Production routing engine: weekly perishable model → district v2 monthly model → v4 state fallback. """ from __future__ import annotations import pickle import warnings from datetime import date, datetime from pathlib import Path from typing import Any, Dict, List, Optional, Union import duckdb import numpy as np import pandas as pd warnings.filterwarnings("ignore") # --------------------------------------------------------------------------- # Paths & Constants # --------------------------------------------------------------------------- BASE = Path(__file__).parent.resolve() # relative — works on any machine FEATURE_STORE = BASE / "feature_store.parquet" MANDI_DATA = BASE / "mandi_data_clean.parquet" PERISHABLE_CROPS = [ "Tomato", "Onion", "Potato", "Green Chilli", "Brinjal", "Cabbage", "Cauliflower", "Bhindi (Okra)", "Banana", "Cucumber", "Bitter Gourd", "Bottle Gourd", "Pumpkin", ] MSP_TABLE: Dict[str, Dict[int, int]] = { "Wheat": {2024: 2275, 2025: 2425, 2026: 2600}, "Paddy": {2024: 2300, 2025: 2369, 2026: 2500}, "Mustard": {2024: 5650, 2025: 5950, 2026: 6200}, "Gram": {2024: 5440, 2025: 5650, 2026: 5800}, "Bengal Gram": {2024: 5440, 2025: 5650, 2026: 5800}, "Arhar/Tur": {2024: 7000, 2025: 7550, 2026: 8000}, "Moong": {2024: 8682, 2025: 8908, 2026: 9500}, } MODEL_FILES = { "weekly_perishable": { "model": BASE / "weekly_perishable_model.pkl", "meta": BASE / "weekly_perishable_meta.pkl", "winrates": BASE / "weekly_perishable_winrates.pkl", }, "district_v2": { "model": BASE / "district_sellhold_v2_model.pkl", "meta": BASE / "district_sellhold_v2_meta.pkl", "winrates": BASE / "district_sellhold_v2_winrates.pkl", }, "presow_v3": { "model": BASE / "presow_v3_model.pkl", "meta": BASE / "presow_v3_meta.pkl", }, "presow_v4": { "model": BASE / "presow_v4_model.pkl", "meta": BASE / "presow_v4_meta.pkl", }, "sellhold_v4": { "model": BASE / "sellhold_v4_model.pkl", "meta": BASE / "sellhold_v4_meta.pkl", }, } LOW_CONF_THRESHOLD = 0.55 VOLATILITY_THRESHOLD = 0.35 # 35% price move in last 7 days # --------------------------------------------------------------------------- # Lazy model cache # --------------------------------------------------------------------------- _CACHE: Dict[str, Any] = {} def _load_model(name: str) -> Any: """Lazy-load and cache a pickle artefact by logical name (key in MODEL_FILES).""" if name in _CACHE: return _CACHE[name] entry = MODEL_FILES.get(name) if entry is None: raise KeyError(f"Unknown model name: {name!r}") loaded: Dict[str, Any] = {} for part, path in entry.items(): with open(path, "rb") as fh: loaded[part] = pickle.load(fh) _CACHE[name] = loaded return loaded # --------------------------------------------------------------------------- # DuckDB helper # --------------------------------------------------------------------------- def _sanitize(val: str) -> str: """Strip SQL-injection characters from a commodity/state/district name. These values come from an internal AGMARKNET commodity map, but sanitizing ensures safety if the call path ever receives user-supplied input. """ if val is None: return "" # Remove single quotes, double quotes, semicolons, comment markers return val.replace("'", "").replace('"', "").replace(";", "").replace("--", "").strip() def _duck(sql: str, params: Optional[list] = None) -> pd.DataFrame: """Execute a DuckDB query and return a DataFrame.""" con = duckdb.connect(database=":memory:") con.execute("SET memory_limit='1GB';") if params: return con.execute(sql, params).df() return con.execute(sql).df() # --------------------------------------------------------------------------- # Name resolution (fuzzy match against LabelEncoder classes) # --------------------------------------------------------------------------- def _resolve_name(name: str, le) -> Optional[str]: """ Attempt to resolve *name* against the classes of a sklearn LabelEncoder. Priority: exact → case-insensitive → substring. Returns the matched class string, or None if no match. """ if name is None: return None classes: List[str] = list(le.classes_) # 1. Exact match if name in classes: return name # 2. Case-insensitive name_lower = name.lower() for c in classes: if c.lower() == name_lower: return c # 3. Substring (name in class or class in name) for c in classes: if name_lower in c.lower() or c.lower() in name_lower: return c return None # --------------------------------------------------------------------------- # Feature retrieval helpers # --------------------------------------------------------------------------- def _get_feature_row( crop: str, state: str, district: Optional[str], month: int, year: int, ) -> Optional[pd.Series]: """Pull the most relevant feature-store row for the given key.""" try: _crop = _sanitize(crop).lower() _state = _sanitize(state).lower() _district = _sanitize(district).lower() if district else None _month = int(month) if _district: sql = f""" SELECT * FROM read_parquet('{FEATURE_STORE}') WHERE lower(commodity) = $1 AND lower(state) = $2 AND lower(district) = $3 AND month = $4 ORDER BY year DESC LIMIT 1 """ df = _duck(sql, [_crop, _state, _district, _month]) else: sql = f""" SELECT * FROM read_parquet('{FEATURE_STORE}') WHERE lower(commodity) = $1 AND lower(state) = $2 AND month = $3 ORDER BY year DESC LIMIT 1 """ df = _duck(sql, [_crop, _state, _month]) if df.empty: # fallback: any year if _district: sql2 = f""" SELECT * FROM read_parquet('{FEATURE_STORE}') WHERE lower(commodity) = $1 AND lower(state) = $2 AND lower(district) = $3 ORDER BY year DESC, month DESC LIMIT 1 """ df = _duck(sql2, [_crop, _state, _district]) else: sql2 = f""" SELECT * FROM read_parquet('{FEATURE_STORE}') WHERE lower(commodity) = $1 AND lower(state) = $2 ORDER BY year DESC, month DESC LIMIT 1 """ df = _duck(sql2, [_crop, _state]) return df.iloc[0] if not df.empty else None except Exception: return None def _get_recent_mandi_prices(crop: str, state: str, days: int = 7) -> Optional[pd.DataFrame]: """Return recent modal_price rows for volatility check.""" try: _crop = _sanitize(crop).lower() _state = _sanitize(state).lower() _limit = int(days) * 20 sql = f""" SELECT arrival_date, modal_price FROM read_parquet('{MANDI_DATA}') WHERE lower(commodity) = $1 AND lower(state) = $2 ORDER BY arrival_date DESC LIMIT {_limit} """ return _duck(sql, [_crop, _state]) except Exception: return None def _check_volatility(crop: str, state: str) -> tuple[bool, str]: """ Returns (is_volatile: bool, message: str). Checks if the price moved >35% in the last 7 days of available data. """ try: df = _get_recent_mandi_prices(crop, state, days=7) if df is None or df.empty or len(df) < 2: return False, "" prices = df["modal_price"].dropna().values if len(prices) < 2: return False, "" p_latest = float(prices[0]) p_old = float(prices[-1]) if p_old == 0: return False, "" pct_change = abs(p_latest - p_old) / p_old if pct_change > VOLATILITY_THRESHOLD: return True, f"⚠️ High volatility — signal suppressed (price moved {pct_change*100:.1f}% in last 7 days)" return False, "" except Exception: return False, "" # --------------------------------------------------------------------------- # Confidence helpers # --------------------------------------------------------------------------- def _score_to_label(score: float) -> str: if score >= 0.70: return "HIGH" if score >= LOW_CONF_THRESHOLD: return "MEDIUM" return "LOW" def _lookup_win_rate(winrates: Any, crop: str, state: str, district: Optional[str] = None, month: Optional[int] = None, iso_week: Optional[int] = None) -> Optional[float]: """Best-effort win-rate lookup from the winrates dict/DataFrame artefact.""" try: cur_month = month or date.today().month cur_isoweek = iso_week or date.today().isocalendar()[1] def _time_mask(df): """Return time-based mask: iso_week if available, else month.""" if "iso_week" in df.columns: return df["iso_week"] == cur_isoweek elif "month" in df.columns: return df["month"] == cur_month return pd.Series([True] * len(df), index=df.index) # ── New structure: {"national": df, "state": df, "district": df} ── if isinstance(winrates, dict) and any(k in winrates for k in ("national", "state", "district")): # 1. Try district-level first if district and "district" in winrates: df = winrates["district"] if isinstance(df, pd.DataFrame) and "district" in df.columns: mask = ( (df["commodity"].str.lower() == crop.lower()) & (df["state"].str.lower() == state.lower()) & (df["district"].str.lower() == district.lower()) & _time_mask(df) ) row = df[mask] if not row.empty: wr_col = [c for c in df.columns if "win" in c.lower()] if wr_col: return float(row.iloc[0][wr_col[0]]) # 2. State-level if "state" in winrates: df = winrates["state"] if isinstance(df, pd.DataFrame): mask = ( (df["commodity"].str.lower() == crop.lower()) & (df["state"].str.lower() == state.lower()) & _time_mask(df) ) row = df[mask] if not row.empty: wr_col = [c for c in df.columns if "win" in c.lower()] if wr_col: return float(row.iloc[0][wr_col[0]]) # 3. National fallback if "national" in winrates: df = winrates["national"] if isinstance(df, pd.DataFrame): mask = ( (df["commodity"].str.lower() == crop.lower()) & _time_mask(df) ) row = df[mask] if not row.empty: wr_col = [c for c in df.columns if "win" in c.lower()] if wr_col: return float(row.iloc[0][wr_col[0]]) return None # ── Old flat-dict structure: {(crop, state): value} ── if isinstance(winrates, dict): for key in [(crop, state, district), (crop, state), (crop,), ("overall",)]: if key in winrates: v = winrates[key] return float(v) if v is not None else None # ── Plain DataFrame ── if isinstance(winrates, pd.DataFrame): cols = [c.lower() for c in winrates.columns] month_col = next((c for c in winrates.columns if c.lower() == "month"), None) month_mask = (winrates[month_col] == cur_month) if month_col else True if all(k in cols for k in ["commodity", "state", "district"]) and district: row = winrates[ (winrates.iloc[:, cols.index("commodity")].str.lower() == crop.lower()) & (winrates.iloc[:, cols.index("state")].str.lower() == state.lower()) & (winrates.iloc[:, cols.index("district")].str.lower() == district.lower()) & month_mask ] if not row.empty: wr_col = [c for c in winrates.columns if "win" in c.lower()] if wr_col: return float(row.iloc[0][wr_col[0]]) if all(k in cols for k in ["commodity", "state"]): row = winrates[ (winrates.iloc[:, cols.index("commodity")].str.lower() == crop.lower()) & (winrates.iloc[:, cols.index("state")].str.lower() == state.lower()) & month_mask ] if not row.empty: wr_col = [c for c in winrates.columns if "win" in c.lower()] if wr_col: return float(row.iloc[0][wr_col[0]]) return None except Exception: return None # --------------------------------------------------------------------------- # Weekly perishable model # --------------------------------------------------------------------------- def _predict_weekly( crop: str, state: str, district: Optional[str], current_price: Optional[float], ref_date: date, ) -> Dict[str, Any]: """Build features and predict using the weekly perishable model.""" artefacts = _load_model("weekly_perishable") model = artefacts["model"] meta = artefacts["meta"] winrates = artefacts.get("winrates", {}) # meta uses le_crop / le_state keys (from build_weekly_perishable_model.py) crop_le = meta.get("le_crop") or meta.get("crop_le") state_le = meta.get("le_state") or meta.get("state_le") ml_features = meta.get("features", meta.get("ML_FEATURES", [])) # Resolve encoded names crop_name = _resolve_name(crop, crop_le) if crop_le else crop state_name = _resolve_name(state, state_le) if state_le else state if crop_name is None: raise ValueError(f"Crop '{crop}' not found in weekly model encoder") if state_name is None: raise ValueError(f"State '{state}' not found in weekly model encoder") crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0 state_enc = int(state_le.transform([state_name])[0]) if state_le else 0 iso_week = ref_date.isocalendar()[1] month = ref_date.month year = ref_date.year # Trigonometric encodings wk_sin = np.sin(2 * np.pi * iso_week / 52) wk_cos = np.cos(2 * np.pi * iso_week / 52) mth_sin = np.sin(2 * np.pi * month / 12) mth_cos = np.cos(2 * np.pi * month / 12) # Feature-store row for proxy values row = _get_feature_row(crop, state, district, month, year) def _fs(col, default=0.0): if row is not None and col in row.index and pd.notna(row[col]): try: return float(row[col]) except (ValueError, TypeError): return float(default) return float(default) def _fs_str(col, default="state"): if row is not None and col in row.index and pd.notna(row[col]): return str(row[col]) return default # Map feature-store columns to weekly model feature names (proxy mapping) feature_map = { "crop_enc": crop_enc, "state_enc": state_enc, "iso_week": iso_week, "month": month, "year": year, "wk_sin": wk_sin, "wk_cos": wk_cos, "mth_sin": mth_sin, "mth_cos": mth_cos, # Momentum proxies from feature store (monthly as proxy for weekly) "price_vs_4w": _fs("district_vs_state_pct"), "price_vs_13w": _fs("state_vs_national_pct"), "price_vs_52w": _fs("dist_lag_1yr", 0), "price_vs_seas_w": _fs("seasonal_idx", 1.0) - 1.0, "price_vs_st_pct": _fs("district_vs_state_pct"), "price_vs_nat_pct": _fs("district_vs_national_pct"), "mom_1w": 0.0, "mom_2w": 0.0, "mom_4w": _fs("district_vs_state_pct"), "pctile_52w": _fs("price_pctile", 50), "nat_win_rate_w": 0.5, "st_win_rate_w": 0.5, "price_spread_pct": 0.0, "n_markets": _fs("district_records", 1), } feature_values = [feature_map.get(f, 0.0) for f in ml_features] X = np.array(feature_values, dtype=float).reshape(1, -1) # Predict if hasattr(model, "predict_proba"): proba = model.predict_proba(X)[0] pred_class = int(model.predict(X)[0]) confidence_score = float(proba[pred_class]) else: pred_class = int(model.predict(X)[0]) confidence_score = 0.6 # default when no proba classes = list(model.classes_) if hasattr(model, "classes_") else [0, 1] label_map = meta.get("label_map", {0: "SELL", 1: "HOLD", 2: "STRONG_HOLD"}) recommendation = label_map.get(pred_class, "HOLD") win_rate = _lookup_win_rate(winrates, crop_name, state_name, iso_week=iso_week) seasonal_avg = _fs("st_avg", None) if seasonal_avg == 0.0: seasonal_avg = None price_vs_seasonal_pct = None if current_price and seasonal_avg: price_vs_seasonal_pct = round((current_price - seasonal_avg) / seasonal_avg * 100, 1) return { "recommendation": recommendation, "confidence_score": confidence_score, "confidence": _score_to_label(confidence_score), "win_rate": round(win_rate, 3) if win_rate is not None else None, "seasonal_avg": round(seasonal_avg, 1) if seasonal_avg else None, "price_vs_seasonal_pct": price_vs_seasonal_pct, "granularity": "weekly", "model_version": "weekly_perishable_v1", "data_level": _fs_str("data_level", "state") if row is not None else "state", "iso_week": iso_week, } # --------------------------------------------------------------------------- # District v2 monthly model # --------------------------------------------------------------------------- def _predict_district_v2( crop: str, state: str, district: Optional[str], current_price: Optional[float], ref_date: date, ) -> Dict[str, Any]: """Build features and predict using the district sellhold v2 model.""" artefacts = _load_model("district_v2") model = artefacts["model"] meta = artefacts["meta"] winrates = artefacts.get("winrates", {}) crop_le = meta.get("crop_le") state_le = meta.get("state_le") district_le = meta.get("district_le") ml_features = meta.get("ML_FEATURES", meta.get("features", [])) crop_name = _resolve_name(crop, crop_le) if crop_le else crop state_name = _resolve_name(state, state_le) if state_le else state district_name = _resolve_name(district, district_le) if district_le and district else None if crop_name is None: raise ValueError(f"Crop '{crop}' not found in district_v2 encoder") if state_name is None: raise ValueError(f"State '{state}' not found in district_v2 encoder") crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0 state_enc = int(state_le.transform([state_name])[0]) if state_le else 0 district_enc = ( int(district_le.transform([district_name])[0]) if district_le and district_name else 0 ) month = ref_date.month year = ref_date.year mth_sin = np.sin(2 * np.pi * month / 12) mth_cos = np.cos(2 * np.pi * month / 12) row = _get_feature_row(crop, state, district, month, year) def _fs(col, default=0.0): if row is not None and col in row.index and pd.notna(row[col]): try: return float(row[col]) except (ValueError, TypeError): return float(default) return float(default) def _fs_str(col, default="state"): if row is not None and col in row.index and pd.notna(row[col]): return str(row[col]) return default feature_map = { "crop_enc": crop_enc, "state_enc": state_enc, "district_enc": district_enc, "month": month, "year": year, "month_sin": mth_sin, "month_cos": mth_cos, "price_vs_dist_seas_avg": _fs("seasonal_idx", 1.0) - 1.0, "price_vs_dist_seas_std": 0.0, "price_vs_st_seas_avg": _fs("state_vs_national_pct"), "state_vs_national_pct": _fs("state_vs_national_pct"), "district_vs_state_pct": _fs("district_vs_state_pct"), "district_vs_national_pct": _fs("district_vs_national_pct"), "price_pctile": _fs("price_pctile", 50), "mom_1m": 0.0, "mom_3m": 0.0, "mom_6m": 0.0, "seasonal_idx": _fs("seasonal_idx", 1.0), "nat_win_rate": 0.5, "st_win_rate": 0.5, "dist_win_rate": 0.5, "eff_win_rate": 0.5, "area_vs_5yr_avg": 0.0, "yield_vs_5yr_avg": 0.0, "area_yoy_pct": 0.0, "district_records": _fs("district_records", 1), } feature_values = [feature_map.get(f, 0.0) for f in ml_features] X = np.array(feature_values, dtype=float).reshape(1, -1) if hasattr(model, "predict_proba"): proba = model.predict_proba(X)[0] pred_class = int(model.predict(X)[0]) confidence_score = float(proba[pred_class]) else: pred_class = int(model.predict(X)[0]) confidence_score = 0.6 label_map = meta.get("label_map", {0: "SELL", 1: "HOLD", 2: "STRONG_HOLD"}) recommendation = label_map.get(pred_class, "HOLD") win_rate = _lookup_win_rate(winrates, crop_name, state_name, district_name, month=ref_date.month) seasonal_avg = _fs("st_avg", None) if seasonal_avg == 0.0: seasonal_avg = _fs("nat_avg", None) if seasonal_avg == 0.0: seasonal_avg = None price_vs_seasonal_pct = None if current_price and seasonal_avg: price_vs_seasonal_pct = round((current_price - seasonal_avg) / seasonal_avg * 100, 1) data_level = _fs_str("data_level", "district" if district_name else "state") return { "recommendation": recommendation, "confidence_score": confidence_score, "confidence": _score_to_label(confidence_score), "win_rate": round(win_rate, 3) if win_rate is not None else None, "seasonal_avg": round(seasonal_avg, 1) if seasonal_avg else None, "price_vs_seasonal_pct": price_vs_seasonal_pct, "granularity": "monthly", "model_version": "district_sellhold_v2", "data_level": data_level, } # --------------------------------------------------------------------------- # Fallback: v4 state-level model # --------------------------------------------------------------------------- def _predict_v4_fallback( crop: str, state: str, current_price: Optional[float], ref_date: date, ) -> Dict[str, Any]: """Fallback to v4 state-level monthly model.""" artefacts = _load_model("sellhold_v4") model = artefacts["model"] meta = artefacts["meta"] crop_le = meta.get("crop_le") state_le = meta.get("state_le") ml_features = meta.get("ML_FEATURES", meta.get("features", [])) crop_name = _resolve_name(crop, crop_le) if crop_le else crop state_name = _resolve_name(state, state_le) if state_le else state if crop_name is None: raise ValueError(f"Crop '{crop}' not found in v4 encoder") if state_name is None: raise ValueError(f"State '{state}' not found in v4 encoder") crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0 state_enc = int(state_le.transform([state_name])[0]) if state_le else 0 month = ref_date.month year = ref_date.year mth_sin = np.sin(2 * np.pi * month / 12) mth_cos = np.cos(2 * np.pi * month / 12) row = _get_feature_row(crop, state, None, month, year) def _fs(col, default=0.0): if row is not None and col in row.index and pd.notna(row[col]): return float(row[col]) return float(default) base_map = { "crop_enc": crop_enc, "state_enc": state_enc, "month": month, "year": year, "month_sin": mth_sin, "month_cos": mth_cos, "state_vs_national_pct": _fs("state_vs_national_pct"), "seasonal_idx": _fs("seasonal_idx", 1.0), "price_pctile": _fs("price_pctile", 50), "state_lag_1yr": _fs("state_lag_1yr"), "state_lag_2yr": _fs("state_lag_2yr"), "nat_lag_1yr": _fs("nat_lag_1yr"), "nat_lag_2yr": _fs("nat_lag_2yr"), } feature_values = [base_map.get(f, 0.0) for f in ml_features] X = np.array(feature_values, dtype=float).reshape(1, -1) if hasattr(model, "predict_proba"): proba = model.predict_proba(X)[0] pred_class = int(model.predict(X)[0]) confidence_score = float(proba[pred_class]) else: pred_class = int(model.predict(X)[0]) confidence_score = 0.55 label_map = meta.get("label_map", {0: "SELL", 1: "HOLD", 2: "STRONG_HOLD"}) recommendation = label_map.get(pred_class, "HOLD") seasonal_avg = _fs("st_avg", None) or _fs("nat_avg", None) or None price_vs_seasonal_pct = None if current_price and seasonal_avg: price_vs_seasonal_pct = round((current_price - float(seasonal_avg) - 0) / float(seasonal_avg) * 100, 1) return { "recommendation": recommendation, "confidence_score": confidence_score, "confidence": _score_to_label(confidence_score), "win_rate": None, "seasonal_avg": round(float(seasonal_avg), 1) if seasonal_avg else None, "price_vs_seasonal_pct": price_vs_seasonal_pct, "granularity": "monthly", "model_version": "sellhold_v4_fallback", "data_level": "state", } # --------------------------------------------------------------------------- # Public API: get_signal # --------------------------------------------------------------------------- def get_signal( crop: str, state: str, district: Optional[str] = None, current_price: Optional[float] = None, date_input: Optional[Union[str, date, datetime]] = None, ) -> Dict[str, Any]: """ Route a price-signal request to the correct model and return a unified dict. Parameters ---------- crop : Commodity name (e.g. "Tomato", "Wheat") state : State name district : District name (optional) current_price : Latest market price (₹/quintal) date_input : Reference date; defaults to today Returns ------- dict with keys: crop, state, district, date, granularity, recommendation, confidence, win_rate, current_price, seasonal_avg, price_vs_seasonal_pct, reasoning, model_version, data_level, warning (optional) """ try: # Normalise date if date_input is None: ref_date = date.today() elif isinstance(date_input, datetime): ref_date = date_input.date() elif isinstance(date_input, str): ref_date = datetime.strptime(date_input[:10], "%Y-%m-%d").date() else: ref_date = date_input warnings_list: List[str] = [] result: Dict[str, Any] = {} used_fallback = False # --- Route --- is_perishable = any(c.lower() == crop.lower() for c in PERISHABLE_CROPS) if is_perishable: try: result = _predict_weekly(crop, state, district, current_price, ref_date) except Exception as e: warnings_list.append(f"Weekly model failed ({e}); falling back to v4") result = _predict_v4_fallback(crop, state, current_price, ref_date) used_fallback = True else: try: result = _predict_district_v2(crop, state, district, current_price, ref_date) except Exception as e: warnings_list.append(f"District v2 model failed ({e}); falling back to v4") try: result = _predict_v4_fallback(crop, state, current_price, ref_date) except Exception as e2: return { "error": str(e2), "confidence": "LOW", "crop": crop, "state": state, "district": district, } used_fallback = True # --- Confidence filter --- confidence_score = result.get("confidence_score", 0.5) if confidence_score < LOW_CONF_THRESHOLD: result["confidence"] = "LOW" warnings_list.append( f"Model confidence {confidence_score:.2f} below threshold {LOW_CONF_THRESHOLD}" ) # --- Volatility / anomaly guardrail --- is_volatile, vol_msg = _check_volatility(crop, state) if is_volatile: result["recommendation"] = "HOLD" warnings_list.append(vol_msg) # --- Build reasoning --- rec = result.get("recommendation", "HOLD") conf = result.get("confidence", "MEDIUM") p_vs_s = result.get("price_vs_seasonal_pct") seas = result.get("seasonal_avg") reasoning_parts = [ f"Prediction: {rec} ({conf} confidence).", ] if p_vs_s is not None: direction = "above" if p_vs_s >= 0 else "below" reasoning_parts.append( f"Current price is {abs(p_vs_s):.1f}% {direction} seasonal average" + (f" (₹{seas}/qtl)" if seas else "") + "." ) if used_fallback: reasoning_parts.append("State-level fallback model used.") if warnings_list: reasoning_parts += warnings_list # --- Assemble final result --- out: Dict[str, Any] = { "crop": crop, "state": state, "district": district, "date": ref_date.isoformat(), "granularity": result.get("granularity", "monthly"), "recommendation": result.get("recommendation", "HOLD"), "confidence": result.get("confidence", "LOW"), "win_rate": result.get("win_rate"), "current_price": current_price, "seasonal_avg": result.get("seasonal_avg"), "price_vs_seasonal_pct": result.get("price_vs_seasonal_pct"), "reasoning": " ".join(reasoning_parts), "model_version": result.get("model_version", "unknown"), "data_level": result.get("data_level", "state"), } if warnings_list: out["warning"] = " | ".join(warnings_list) # Remove internal score key out.pop("confidence_score", None) return out except Exception as e: return { "error": str(e), "confidence": "LOW", "crop": crop, "state": state, "district": district, } # --------------------------------------------------------------------------- # Public API: get_presow_signal # --------------------------------------------------------------------------- def get_presow_signal( crop: str, state: str, district: Optional[str] = None, sowing_month: Optional[int] = None, ) -> Dict[str, Any]: """ Return harvest-price range forecast for pre-sowing decisions. Parameters ---------- crop : Commodity name state : State name district : District (optional) sowing_month : Month of sowing (1-12); defaults to current month Returns ------- dict: p25, p50, p75, harvest_window, msp, profit_probability, confidence, reasoning """ try: # Try v4 first, fall back to v3 _v4_ok = False try: art4 = _load_model("presow_v4") _v4_ok = True except Exception: pass artefacts = art4 if _v4_ok else _load_model("presow_v3") model = artefacts["model"] meta = artefacts["meta"] _ver = "presow_v4" if _v4_ok else "presow_v3" crop_le = meta.get("le_crop") or meta.get("crop_le") state_le = meta.get("le_state") or meta.get("state_le") ml_features = (meta.get("features_v4") or meta.get("features_v3") or meta.get("ML_FEATURES") or meta.get("features", [])) crop_name = _resolve_name(crop, crop_le) if crop_le else crop state_name = _resolve_name(state, state_le) if state_le else state if crop_name is None: return {"error": f"Crop '{crop}' not found in {_ver} encoder", "confidence": "LOW"} if state_name is None: return {"error": f"State '{state}' not found in {_ver} encoder", "confidence": "LOW"} crop_enc = int(crop_le.transform([crop_name])[0]) if crop_le else 0 state_enc = int(state_le.transform([state_name])[0]) if state_le else 0 ref_month = sowing_month or date.today().month ref_year = date.today().year mth_sin = np.sin(2 * np.pi * ref_month / 12) mth_cos = np.cos(2 * np.pi * ref_month / 12) row = _get_feature_row(crop_name, state_name, district, ref_month, ref_year) def _fs(col, default=0.0): if row is not None and col in row.index and pd.notna(row[col]): return float(row[col]) return float(default) # Compute v4 new features using available data as proxies sl1 = _fs("state_lag_1yr") # state-level price 1 year ago at this month sl2 = _fs("state_lag_2yr") ap = _fs("avg_price") # harvest_price_lag1: best proxy = state_lag_1yr (state price 1yr ago) hpl1 = sl1 if sl1 > 0 else ap * 0.95 hpl2 = sl2 if sl2 > 0 else ap * 0.90 # price_cv: try to compute from recent mandi data, else use crop-level default _price_cv = 0.15 # default: 15% CV for most crops try: conf_map = meta.get("confidence_map", {}) crop_conf = conf_map.get(crop_name, "MEDIUM") _price_cv = {"HIGH": 0.10, "MEDIUM": 0.18, "LOW": 0.30}.get(crop_conf, 0.15) except Exception: pass # harvest_to_sow_ratio: last-year seasonal lift _h2s = (hpl1 / ap) if ap > 10 else 1.0 _h2s = max(0.3, min(5.0, _h2s)) base_map = { # identifiers "crop_enc": crop_enc, "state_enc": state_enc, # time "month": ref_month, "year": ref_year, "month_sin": mth_sin, "month_cos": mth_cos, # price levels "nat_avg": _fs("nat_avg"), "st_avg": _fs("st_avg"), "avg_price": ap, "state_vs_national_pct": _fs("state_vs_national_pct"), # price lags (feature store) "dist_lag_1yr": _fs("dist_lag_1yr"), "dist_lag_2yr": _fs("dist_lag_2yr"), "dist_lag_3yr": _fs("dist_lag_3yr"), "state_lag_1yr": sl1, "state_lag_2yr": sl2, "nat_lag_1yr": _fs("nat_lag_1yr"), "nat_lag_2yr": _fs("nat_lag_2yr"), "dist_roll3yr": _fs("dist_roll3yr"), # seasonal "seasonal_idx": _fs("seasonal_idx", 1.0), "price_pctile": _fs("price_pctile", 50), # v4 new features "harvest_price_lag1": hpl1, "harvest_price_lag2": hpl2, "price_cv": _price_cv, "harvest_to_sow_ratio": _h2s, # v3 legacy features (filled neutral when v4) "temp_mean": 0.0, "rainfall_28d": 0.0, "ndvi_mean": 0.0, "area_yoy_pct": 0.0, "area_2yr_pct": 0.0, "area_vs_3yr": 0.0, "supply_up": 0.0, "supply_down": 0.0, } feature_values = [base_map.get(f, 0.0) for f in ml_features] X = np.array(feature_values, dtype=float).reshape(1, -1) # Predict using quantile models (p25, p50, p75) if available if isinstance(model, dict) and "p25" in model and "p50" in model and "p75" in model: p50_pred = float(model["p50"].predict(X)[0]) p25_pred = float(model["p25"].predict(X)[0]) p75_pred = float(model["p75"].predict(X)[0]) # Ensure monotonicity p25_pred = min(p25_pred, p50_pred) p75_pred = max(p75_pred, p50_pred) else: # Legacy: single model + ±15% spread p50_pred = float(model.predict(X)[0]) p25_pred = p50_pred * 0.85 p75_pred = p50_pred * 1.15 p25 = round(p25_pred, 0) p50 = round(p50_pred, 0) p75 = round(p75_pred, 0) # Harvest window from calendar harvest_cal = meta.get("harvest_calendar", {}) hw_months = None for k in harvest_cal: if k.lower() in crop_name.lower() or crop_name.lower() in k.lower(): hw_months = harvest_cal[k] break if hw_months and hw_months != (0, 0): hw_start = hw_months[0] try: harvest_label = datetime(ref_year + (1 if hw_start < ref_month else 0), hw_start, 1).strftime("%B %Y") except Exception: harvest_label = "---" else: harvest_month = ((ref_month + 3) % 12) + 1 harvest_label = datetime(ref_year, harvest_month, 1).strftime("%B %Y") # MSP lookup msp_entry = None crop_key = crop_name for k in MSP_TABLE: if k.lower() in crop_name.lower() or crop_name.lower() in k.lower(): crop_key = k break if crop_key in MSP_TABLE: msp_entry = (MSP_TABLE[crop_key].get(ref_year) or MSP_TABLE[crop_key].get(max(MSP_TABLE[crop_key]))) # Profit probability if msp_entry: if p25 > msp_entry: profit_probability = "HIGH" elif p50 > msp_entry: profit_probability = "MEDIUM" else: profit_probability = "LOW" else: profit_probability = "MEDIUM" # Confidence from per-crop map conf_map = meta.get("confidence_map", {}) confidence_label = conf_map.get(crop_name, "MEDIUM") # Accuracy metadata for UI val_meta = meta.get("validation", {}) p50_acc = val_meta.get("P50", {}).get("within15pct", 87.2) mape_pct = val_meta.get("P50", {}).get("mape", 7.6) reasoning_parts = [ f"Forecast for {crop_name} in {state_name}.", f"Expected harvest window: {harvest_label}.", f"Price range: ₹{p25:.0f}–₹{p75:.0f}/qtl (median ₹{p50:.0f}).", ] if msp_entry: reasoning_parts.append( f"MSP ({ref_year}): ₹{msp_entry}/qtl — profit probability: {profit_probability}." ) return { "crop": crop, "state": state, "district": district, "sowing_month": ref_month, "p25": p25, "p50": p50, "p75": p75, "harvest_window": harvest_label, "msp": msp_entry, "profit_probability": profit_probability, "confidence": confidence_label, "mape_pct": mape_pct, "reasoning": " ".join(reasoning_parts), "model_version": _ver, } except Exception as e: return { "error": str(e), "confidence": "LOW", "crop": crop, "state": state, } # --------------------------------------------------------------------------- # Public API: get_batch_signals # --------------------------------------------------------------------------- def get_batch_signals(requests_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Process a list of signal requests in sequence. Parameters ---------- requests_list : list of dicts, each with keys: crop, state, district (opt), current_price (opt), date (opt) Returns ------- list of get_signal() result dicts (same order as input) """ results = [] for req in requests_list: try: out = get_signal( crop=req.get("crop", ""), state=req.get("state", ""), district=req.get("district"), current_price=req.get("current_price"), date_input=req.get("date"), ) except Exception as e: out = { "error": str(e), "confidence": "LOW", "crop": req.get("crop"), "state": req.get("state"), } results.append(out) return results # --------------------------------------------------------------------------- # Public API: get_coverage # --------------------------------------------------------------------------- def get_coverage() -> Dict[str, Any]: """ Return metadata about what crops, states, and districts are covered. Returns ------- dict: crops, states, districts_count, weekly_crops, monthly_crops_count """ try: sql_crops = f"SELECT DISTINCT commodity FROM read_parquet('{FEATURE_STORE}') ORDER BY commodity" sql_states = f"SELECT DISTINCT state FROM read_parquet('{FEATURE_STORE}') ORDER BY state" sql_dist_count = f"SELECT COUNT(DISTINCT district) AS cnt FROM read_parquet('{FEATURE_STORE}')" crops_df = _duck(sql_crops) states_df = _duck(sql_states) dist_count_df = _duck(sql_dist_count) all_crops = crops_df["commodity"].tolist() if not crops_df.empty else [] all_states = states_df["state"].tolist() if not states_df.empty else [] dist_count = int(dist_count_df["cnt"].iloc[0]) if not dist_count_df.empty else 0 weekly_crops = [c for c in PERISHABLE_CROPS if any(c.lower() == a.lower() for a in all_crops)] monthly_count = len([c for c in all_crops if not any(c.lower() == p.lower() for p in PERISHABLE_CROPS)]) return { "crops": all_crops, "total_crops": len(all_crops), "states": all_states, "total_districts": dist_count, "districts_count": dist_count, "weekly_crops": weekly_crops, "monthly_crops_count": monthly_count, } except Exception as e: return { "error": str(e), "weekly_crops": PERISHABLE_CROPS, "monthly_crops_count": 0, "total_crops": 0, "total_districts": 0, "crops": [], "states": [], "districts_count": 0, } # --------------------------------------------------------------------------- # CLI smoke test # --------------------------------------------------------------------------- if __name__ == "__main__": print("=== Enterprise Engine v2 — Smoke Test ===\n") print("[1] get_signal (perishable — Tomato)") r = get_signal("Tomato", "Maharashtra", "Pune", current_price=1200) for k, v in r.items(): print(f" {k}: {v}") print("\n[2] get_signal (non-perishable — Wheat)") r2 = get_signal("Wheat", "Punjab", "Ludhiana", current_price=2100) for k, v in r2.items(): print(f" {k}: {v}") print("\n[3] get_presow_signal (Wheat)") r3 = get_presow_signal("Wheat", "Haryana", sowing_month=11) for k, v in r3.items(): print(f" {k}: {v}") print("\n[4] get_batch_signals") batch = get_batch_signals([ {"crop": "Onion", "state": "Maharashtra", "current_price": 800}, {"crop": "Mustard", "state": "Rajasthan", "current_price": 5800}, ]) for i, b in enumerate(batch, 1): print(f" [{i}] {b.get('crop')} → {b.get('recommendation')} ({b.get('confidence')})") print("\n[5] get_coverage") cov = get_coverage() print(f" Crops: {len(cov.get('crops', []))}, States: {len(cov.get('states', []))}, " f"Districts: {cov.get('districts_count')}, Weekly: {len(cov.get('weekly_crops', []))}")