import gradio as gr import pandas as pd import numpy as np from sklearn.preprocessing import LabelEncoder from src import random_forest_core import vlai_template # Configure theme for Random Forest Demo vlai_template.configure( project_name="Random Forest Demo", year="2025", module="03", description="Interactive demonstration of Random Forest algorithms for classification and regression tasks. Explore ensemble learning with decision trees through dynamic parameter adjustment and comprehensive visualizations.", colors={ "primary": "#2E7D32", # Forest green "accent": "#8BC34A", # Light green "bg1": "#F1F8E9", # Very light green "bg2": "#E8F5E8", # Light green background "bg3": "#C8E6C9", # Pale green }, font_family="'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif" ) # Global state current_dataframe = None # Dataset configurations SAMPLE_DATA_CONFIG = { "Iris": {"target_column": "target", "problem_type": "classification"}, "Wine": {"target_column": "target", "problem_type": "classification"}, "Breast Cancer": {"target_column": "target", "problem_type": "classification"}, "Diabetes": {"target_column": "target", "problem_type": "regression"}, "Titanic": {"target_column": "survived", "problem_type": "classification"}, } force_light_theme_js = """ () => { const params = new URLSearchParams(window.location.search); if (!params.has('__theme')) { params.set('__theme', 'light'); window.location.search = params.toString(); } } """ def validate_config(df, target_col): """Validate target column and determine problem type.""" if not target_col or target_col not in df.columns: return False, "❌ Please select a valid target column from the dropdown.", None target_series = df[target_col] unique_vals = target_series.nunique() # Auto-detect if target_series.dtype == "object" or unique_vals <= min(20, len(target_series) * 0.1): problem_type = "classification" if unique_vals > 50: return False, f"⚠️ Too many classes ({unique_vals}). Consider another target.", None if target_series.isnull().any(): return False, "⚠️ Target column has missing values. Please clean your data.", None else: problem_type = "regression" if unique_vals < 5: return False, f"⚠️ Too few unique values ({unique_vals}). Consider another target.", None return True, f"\n✅ Configuration is valid! Ready for {unique_vals} {'classes' if problem_type=='classification' else 'values'}.", problem_type def get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg): if is_sample: return f"✅ **Selected Dataset**: {dataset_choice} | **Target**: {target_col} | **Type**: {problem_type.title()}" elif target_col and problem_type: status_icon = "✅" if is_valid else "⚠️" return f"{status_icon} **Custom Data** | **Target**: {target_col} | **Type**: {problem_type.title()} | {validation_msg}" else: return "📁 **Custom data uploaded!** 👆 Please select target column above to continue." def load_and_configure_data(file_obj=None, dataset_choice="Iris"): """Load data and prepare target/problem type + feature inputs.""" global current_dataframe try: df = random_forest_core.load_data(file_obj, dataset_choice) current_dataframe = df target_options = df.columns.tolist() is_sample = file_obj is None if is_sample: cfg = SAMPLE_DATA_CONFIG.get(dataset_choice, {}) target_col = cfg.get("target_column") problem_type = cfg.get("problem_type") else: target_col, problem_type = None, None # Validate & status if target_col: is_valid, validation_msg, detected = validate_config(df, target_col) if detected: problem_type = detected status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg) else: status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, False, "") # Build feature input widgets input_updates = [gr.update(visible=False)] * 40 # 20 features * (number + dropdown) inputs_visible = gr.update(visible=False) input_status = "⚙️ Configure target column above to enable feature inputs." if target_col and problem_type and (not is_sample or is_valid): try: components_info = random_forest_core.create_input_components(df, target_col) for i in range(min(20, len(components_info))): comp = components_info[i] number_idx, dropdown_idx = i * 2, i * 2 + 1 if comp["type"] == "number": upd = {"visible": True, "label": comp["name"], "value": comp["value"]} if comp["minimum"] is not None: upd["minimum"] = comp["minimum"] if comp["maximum"] is not None: upd["maximum"] = comp["maximum"] input_updates[number_idx] = gr.update(**upd) input_updates[dropdown_idx] = gr.update(visible=False) else: input_updates[number_idx] = gr.update(visible=False) input_updates[dropdown_idx] = gr.update( visible=True, label=comp["name"], choices=comp["choices"], value=comp["value"] ) inputs_visible = gr.update(visible=True) input_status = f"📝 **Ready!** Enter values for {len(components_info)} features below, then click Run prediction. | {validation_msg}" except Exception as e: input_status = f"❌ Error generating inputs: {str(e)}" return [df.head(5).round(2), gr.Dropdown(choices=target_options, value=target_col), status_msg] + input_updates + [inputs_visible, input_status] except Exception as e: current_dataframe = None empty = [pd.DataFrame(), gr.Dropdown(choices=[], value=None), f"❌ **Error loading data**: {str(e)} | Please try a different file or dataset."] return empty + [gr.update(visible=False)] * 40 + [gr.update(visible=False), "No data loaded."] def update_configuration(df_preview, target_col): """Rebuild feature widgets when target changes.""" global current_dataframe df = current_dataframe if df is None or df.empty: return [gr.update(visible=False)] * 40 + [gr.update(visible=False), "No data available.", "No data available."] if not target_col: return [gr.update(visible=False)] * 40 + [gr.update(visible=False), "Select target column.", "Select target column."] try: is_valid, validation_msg, problem_type = validate_config(df, target_col) if not is_valid: return [gr.update(visible=False)] * 40 + [gr.update(visible=False), f"⚠️ {validation_msg}", f"⚠️ {validation_msg}"] components_info = random_forest_core.create_input_components(df, target_col) input_updates = [gr.update(visible=False)] * 40 for i in range(min(20, len(components_info))): comp = components_info[i] number_idx, dropdown_idx = i * 2, i * 2 + 1 if comp["type"] == "number": upd = {"visible": True, "label": comp["name"], "value": comp["value"]} if comp["minimum"] is not None: upd["minimum"] = comp["minimum"] if comp["maximum"] is not None: upd["maximum"] = comp["maximum"] input_updates[number_idx] = gr.update(**upd) input_updates[dropdown_idx] = gr.update(visible=False) else: input_updates[number_idx] = gr.update(visible=False) input_updates[dropdown_idx] = gr.update( visible=True, label=comp["name"], choices=comp["choices"], value=comp["value"] ) input_status = f"📝 Enter values for {len(components_info)} features | {validation_msg}" status_msg = f"✅ **Selected Dataset**: Custom Data | **Target**: {target_col} | **Type**: {problem_type.title()}" return input_updates + [gr.update(visible=True), input_status, status_msg] except Exception as e: return [gr.update(visible=False)] * 40 + [gr.update(visible=False), f"❌ Error: {str(e)}", f"❌ Error: {str(e)}"] # ---- criterion helpers ---- CLASS_CRITS = {"gini", "entropy", "log_loss"} REGR_CRITS = {"squared_error", "absolute_error", "friedman_mse", "poisson"} def update_criterion_choices(problem_type): if problem_type == "classification": return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") else: return gr.Dropdown(choices=sorted(REGR_CRITS), value="squared_error") def update_criterion_on_target_change(df_preview, target_col): """Recompute problem type from current df + target and return the right dropdown config.""" if not target_col: return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") global current_dataframe df = current_dataframe if df is None or df.empty: return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") try: is_valid, _, problem_type = validate_config(df, target_col) if problem_type == "classification": return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") else: return gr.Dropdown(choices=sorted(REGR_CRITS), value="squared_error") except Exception: return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") def execute_prediction(df_preview, target_col, n_estimators, max_depth, min_samples_split, min_samples_leaf, criterion, max_features, *input_values): """Run the random forest and produce all outputs. Always return 5 values.""" global current_dataframe df = current_dataframe EMPTY_PLOT = None # for gr.Plot EMPTY_MD = " " # for gr.Markdown if df is None or df.empty: return (EMPTY_PLOT, EMPTY_PLOT, EMPTY_PLOT, "❌ **No data loaded!** 📊 Please select a sample dataset or upload a file first.", "
🗳️ Voting Results

No data available.
", gr.Dropdown(choices=["Tree 1"], value="Tree 1")) if not target_col: return (EMPTY_PLOT, EMPTY_PLOT, EMPTY_PLOT, "❌ **Configuration incomplete!** 🎯 Please select target column above.", "
🗳️ Voting Results

Configuration incomplete.
", gr.Dropdown(choices=["Tree 1"], value="Tree 1")) is_valid, validation_msg, problem_type = validate_config(df, target_col) if not is_valid: return (EMPTY_PLOT, EMPTY_PLOT, EMPTY_PLOT, f"❌ **Configuration issue**: {validation_msg}", "
🗳️ Voting Results

Configuration issue.
", gr.Dropdown(choices=["Tree 1"], value="Tree 1")) # normalize criterion defensively if problem_type == "classification": if criterion not in CLASS_CRITS: criterion = "gini" else: if criterion not in REGR_CRITS: criterion = "squared_error" try: components_info = random_forest_core.create_input_components(df, target_col) new_point_dict = {} for i, comp in enumerate(components_info): number_idx, dropdown_idx = i * 2, i * 2 + 1 if comp["type"] == "number": v = input_values[number_idx] if number_idx < len(input_values) and input_values[number_idx] is not None else comp["value"] else: v = input_values[dropdown_idx] if dropdown_idx < len(input_values) and input_values[dropdown_idx] is not None else comp["value"] new_point_dict[comp["name"]] = v tree_confidence_fig, importance_fig, prediction, pred_details, summary, aggregation_display, error = random_forest_core.run_random_forest_and_visualize( df, target_col, new_point_dict, n_estimators, max_depth, min_samples_split, min_samples_leaf, criterion, max_features, problem_type ) if error: return (tree_confidence_fig or EMPTY_PLOT, EMPTY_PLOT, importance_fig or EMPTY_PLOT, f"❌ **Prediction failed**: {error}", "
🗳️ Voting Results

Error occurred during prediction.
", gr.Dropdown(choices=["Tree 1"], value="Tree 1")) # Get first tree visualization feature_cols = [c for c in df.columns if c != target_col] first_tree_fig = random_forest_core.get_individual_tree_visualization( random_forest_core._get_current_model(), 0, feature_cols, problem_type ) # Update tree selector choices updated_tree_selector = update_tree_selector_choices(n_estimators) return (tree_confidence_fig, first_tree_fig, importance_fig, aggregation_display, updated_tree_selector) except Exception as e: return (EMPTY_PLOT, EMPTY_PLOT, EMPTY_PLOT, f"❌ **Execution error**: {str(e)}", "
🗳️ Voting Results

Execution error occurred.
", gr.Dropdown(choices=["Tree 1"], value="Tree 1")) def update_tree_selector_choices(n_estimators): """Update the tree selector choices based on number of trees.""" n_trees = min(int(n_estimators), 20) choices = [f"Tree {i+1}" for i in range(n_trees)] return gr.Dropdown(choices=choices, value="Tree 1") def update_tree_visualization(tree_selector): """Update the individual tree visualization based on selection.""" global current_dataframe if current_dataframe is None or current_dataframe.empty: return None try: # Get current model model = random_forest_core._get_current_model() if model is None: return None # Parse tree index from selector tree_index = int(tree_selector.split()[-1]) - 1 # Get problem type _, _, problem_type = validate_config(current_dataframe, current_dataframe.columns[-1]) # Assume last column is target # Get feature columns feature_cols = [c for c in current_dataframe.columns if c != current_dataframe.columns[-1]] # Get tree visualization tree_fig = random_forest_core.get_individual_tree_visualization(model, tree_index, feature_cols, problem_type) return tree_fig except Exception as e: return None # ========================== # Gradio UI # ========================== with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo: vlai_template.create_header() # Info card with description gr.HTML(vlai_template.render_info_card( icon="🌲", title="About this Random Forest Demo", description="This interactive demo showcases Random Forest algorithms for both classification and regression tasks. Explore ensemble learning with decision trees through dynamic parameter adjustment, comprehensive visualizations, and real-time predictions." )) gr.Markdown("### 🌲 **How to Use**: Select data → Configure target → Set forest parameters → Enter new point → Run prediction!") with gr.Row(equal_height=False, variant="panel"): with gr.Column(scale=45): with gr.Accordion("🌿 Data & Configuration", open=True): with gr.Row(): with gr.Column(scale=1): gr.Markdown("Start with sample datasets or upload your own CSV/Excel files.") file_upload = gr.File(label="📁 Upload Your Data", file_types=[".csv", ".xlsx", ".xls"]) with gr.Column(scale=3): sample_dataset = gr.Dropdown(choices=list(SAMPLE_DATA_CONFIG.keys()), value="Titanic", label="🗂️ Sample Datasets") with gr.Row(): target_column = gr.Dropdown(choices=[], label="🎯 Target Column", interactive=True) status_message = gr.Markdown("🔄 Loading sample data...") data_preview = gr.DataFrame(label="📋 Data Preview (First 5 Rows)", row_count=5, interactive=False, max_height=250) with gr.Accordion("🌳 Forest Parameters & Input", open=True): gr.Markdown("**🌲 Random Forest Parameters**") with gr.Row(): n_estimators = gr.Number( label="Number of Trees", value=10, minimum=1, maximum=20, precision=0, info="Number of trees in the forest (limited to 20)" ) max_depth = gr.Number( label="Max Depth", value=5, minimum=0, maximum=50, precision=0, info="Set to 0 for unlimited depth" ) with gr.Row(): min_samples_split = gr.Number( label="Min Samples Split", value=2, minimum=2, maximum=100, precision=0, info="Minimum samples required to split an internal node" ) min_samples_leaf = gr.Number( label="Min Samples Leaf", value=1, minimum=1, maximum=50, precision=0, info="Minimum samples required to be at a leaf node" ) with gr.Row(): criterion = gr.Dropdown( choices=sorted(CLASS_CRITS), value="gini", label="🎯 Criterion", info="Objective to measure split quality (auto-switched for regression)" ) max_features = gr.Dropdown( choices=["sqrt", "log2", "auto"], value="sqrt", label="Max Features", info="Number of features to consider for best split" ) inputs_group = gr.Group(visible=False) with inputs_group: input_status = gr.Markdown("Configure inputs above.") gr.Markdown("**📝 New Data Point** - Enter feature values for prediction:") input_components = [] for row in range(5): with gr.Row(): for col in range(4): idx = row * 4 + col if idx < 20: number_comp = gr.Number(label=f"Feature {idx+1}", visible=False) dropdown_comp = gr.Dropdown(label=f"Feature {idx+1}", visible=False) input_components.extend([number_comp, dropdown_comp]) run_prediction_btn = gr.Button("🚀 Run Prediction", variant="primary", size="lg") with gr.Column(scale=55): gr.Markdown("### 🌲 **Random Forest Results & Visualization**") # First visualization: Tree confidence chart tree_confidence_chart = gr.Plot(label="Individual Tree Predictions & Confidence Scores", visible=True) # Second visualization: Individual tree details with gr.Row(): tree_selector = gr.Dropdown( choices=["Tree 1"], value="Tree 1", label="🌳 Select Tree to Visualize", interactive=True ) individual_tree_plot = gr.Plot(label="Individual Tree Structure", visible=True) # Third visualization: Feature importance feature_importance_plot = gr.Plot(label="Feature Importance", visible=True) # Classification aggregation display aggregation_display = gr.HTML("**🗳️ Voting Results**

Voting details will appear here for classification tasks.", label="🗳️ Voting Process") gr.Markdown("""🌲 **Random Forest Tips**: - **📊 Tree Confidence Chart**: Shows confidence scores and predictions for each individual tree in the forest. - **🌳 Individual Tree Visualization**: Select any tree from the dropdown to see its detailed structure and decision paths. - **📈 Feature Importance**: Displays which features matter most across all trees in the forest. - **🎯 Parameter Tuning**: Try different **number of trees** (limited to 20) and **max depth** (5-15) to see changes. - **🌿 Diversity Control**: **Max features** controls tree diversity - 'sqrt' is often optimal for balanced performance. - **🛡️ Overfitting Prevention**: **Min samples split/leaf** parameters help control complexity and reduce overfitting. - **🔍 Interactive Analysis**: Use the tree selector to explore different trees and understand their decision patterns. """) vlai_template.create_footer() # ---- Event bindings ---- load_evt = demo.load( fn=lambda: load_and_configure_data(None, "Titanic"), outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status], ) load_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion]) upload_evt = file_upload.upload( fn=lambda file: load_and_configure_data(file, "Titanic"), inputs=[file_upload], outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status], ) upload_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion]) sample_evt = sample_dataset.change( fn=lambda choice: load_and_configure_data(None, choice), inputs=[sample_dataset], outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status], ) sample_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion]) target_column.change( fn=update_configuration, inputs=[data_preview, target_column], outputs=input_components + [inputs_group, input_status, status_message], ) target_column.change( fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion], ) run_prediction_btn.click( fn=execute_prediction, inputs=[data_preview, target_column, n_estimators, max_depth, min_samples_split, min_samples_leaf, criterion, max_features] + input_components, outputs=[tree_confidence_chart, individual_tree_plot, feature_importance_plot, aggregation_display, tree_selector], ) # Add tree selector event tree_selector.change( fn=update_tree_visualization, inputs=[tree_selector], outputs=[individual_tree_plot], ) if __name__ == "__main__": demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static"])