| import gradio as gr |
| import pandas as pd |
| import numpy as np |
| from sklearn.preprocessing import LabelEncoder |
| from src import random_forest_core |
| import vlai_template |
|
|
| |
| vlai_template.configure( |
| project_name="Random Forest Demo", |
| year="2025", |
| module="03", |
| description="Interactive demonstration of Random Forest algorithms for classification and regression tasks. Explore ensemble learning with decision trees through dynamic parameter adjustment and comprehensive visualizations.", |
| colors={ |
| "primary": "#2E7D32", |
| "accent": "#8BC34A", |
| "bg1": "#F1F8E9", |
| "bg2": "#E8F5E8", |
| "bg3": "#C8E6C9", |
| }, |
| font_family="'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif" |
| ) |
|
|
| |
| current_dataframe = None |
|
|
| |
| SAMPLE_DATA_CONFIG = { |
| "Iris": {"target_column": "target", "problem_type": "classification"}, |
| "Wine": {"target_column": "target", "problem_type": "classification"}, |
| "Breast Cancer": {"target_column": "target", "problem_type": "classification"}, |
| "Diabetes": {"target_column": "target", "problem_type": "regression"}, |
| "Titanic": {"target_column": "survived", "problem_type": "classification"}, |
| } |
|
|
| force_light_theme_js = """ |
| () => { |
| const params = new URLSearchParams(window.location.search); |
| if (!params.has('__theme')) { |
| params.set('__theme', 'light'); |
| window.location.search = params.toString(); |
| } |
| } |
| """ |
|
|
| def validate_config(df, target_col): |
| """Validate target column and determine problem type.""" |
| if not target_col or target_col not in df.columns: |
| return False, "β Please select a valid target column from the dropdown.", None |
|
|
| target_series = df[target_col] |
| unique_vals = target_series.nunique() |
|
|
| |
| if target_series.dtype == "object" or unique_vals <= min(20, len(target_series) * 0.1): |
| problem_type = "classification" |
| if unique_vals > 50: |
| return False, f"β οΈ Too many classes ({unique_vals}). Consider another target.", None |
| if target_series.isnull().any(): |
| return False, "β οΈ Target column has missing values. Please clean your data.", None |
| else: |
| problem_type = "regression" |
| if unique_vals < 5: |
| return False, f"β οΈ Too few unique values ({unique_vals}). Consider another target.", None |
|
|
| return True, f"\nβ
Configuration is valid! Ready for {unique_vals} {'classes' if problem_type=='classification' else 'values'}.", problem_type |
|
|
|
|
| def get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg): |
| if is_sample: |
| return f"β
**Selected Dataset**: {dataset_choice} | **Target**: {target_col} | **Type**: {problem_type.title()}" |
| elif target_col and problem_type: |
| status_icon = "β
" if is_valid else "β οΈ" |
| return f"{status_icon} **Custom Data** | **Target**: {target_col} | **Type**: {problem_type.title()} | {validation_msg}" |
| else: |
| return "π **Custom data uploaded!** π Please select target column above to continue." |
|
|
|
|
| def load_and_configure_data(file_obj=None, dataset_choice="Iris"): |
| """Load data and prepare target/problem type + feature inputs.""" |
| global current_dataframe |
| try: |
| df = random_forest_core.load_data(file_obj, dataset_choice) |
| current_dataframe = df |
|
|
| target_options = df.columns.tolist() |
| is_sample = file_obj is None |
|
|
| if is_sample: |
| cfg = SAMPLE_DATA_CONFIG.get(dataset_choice, {}) |
| target_col = cfg.get("target_column") |
| problem_type = cfg.get("problem_type") |
| else: |
| target_col, problem_type = None, None |
|
|
| |
| if target_col: |
| is_valid, validation_msg, detected = validate_config(df, target_col) |
| if detected: |
| problem_type = detected |
| status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg) |
| else: |
| status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, False, "") |
|
|
| |
| input_updates = [gr.update(visible=False)] * 40 |
| inputs_visible = gr.update(visible=False) |
| input_status = "βοΈ Configure target column above to enable feature inputs." |
|
|
| if target_col and problem_type and (not is_sample or is_valid): |
| try: |
| components_info = random_forest_core.create_input_components(df, target_col) |
| for i in range(min(20, len(components_info))): |
| comp = components_info[i] |
| number_idx, dropdown_idx = i * 2, i * 2 + 1 |
| if comp["type"] == "number": |
| upd = {"visible": True, "label": comp["name"], "value": comp["value"]} |
| if comp["minimum"] is not None: |
| upd["minimum"] = comp["minimum"] |
| if comp["maximum"] is not None: |
| upd["maximum"] = comp["maximum"] |
| input_updates[number_idx] = gr.update(**upd) |
| input_updates[dropdown_idx] = gr.update(visible=False) |
| else: |
| input_updates[number_idx] = gr.update(visible=False) |
| input_updates[dropdown_idx] = gr.update( |
| visible=True, label=comp["name"], choices=comp["choices"], value=comp["value"] |
| ) |
| inputs_visible = gr.update(visible=True) |
| input_status = f"π **Ready!** Enter values for {len(components_info)} features below, then click Run prediction. | {validation_msg}" |
| except Exception as e: |
| input_status = f"β Error generating inputs: {str(e)}" |
|
|
| return [df.head(5).round(2), gr.Dropdown(choices=target_options, value=target_col), status_msg] + input_updates + [inputs_visible, input_status] |
|
|
| except Exception as e: |
| current_dataframe = None |
| empty = [pd.DataFrame(), gr.Dropdown(choices=[], value=None), f"β **Error loading data**: {str(e)} | Please try a different file or dataset."] |
| return empty + [gr.update(visible=False)] * 40 + [gr.update(visible=False), "No data loaded."] |
|
|
|
|
| def update_configuration(df_preview, target_col): |
| """Rebuild feature widgets when target changes.""" |
| global current_dataframe |
| df = current_dataframe |
|
|
| if df is None or df.empty: |
| return [gr.update(visible=False)] * 40 + [gr.update(visible=False), "No data available.", "No data available."] |
| if not target_col: |
| return [gr.update(visible=False)] * 40 + [gr.update(visible=False), "Select target column.", "Select target column."] |
|
|
| try: |
| is_valid, validation_msg, problem_type = validate_config(df, target_col) |
| if not is_valid: |
| return [gr.update(visible=False)] * 40 + [gr.update(visible=False), f"β οΈ {validation_msg}", f"β οΈ {validation_msg}"] |
|
|
| components_info = random_forest_core.create_input_components(df, target_col) |
| input_updates = [gr.update(visible=False)] * 40 |
| for i in range(min(20, len(components_info))): |
| comp = components_info[i] |
| number_idx, dropdown_idx = i * 2, i * 2 + 1 |
| if comp["type"] == "number": |
| upd = {"visible": True, "label": comp["name"], "value": comp["value"]} |
| if comp["minimum"] is not None: |
| upd["minimum"] = comp["minimum"] |
| if comp["maximum"] is not None: |
| upd["maximum"] = comp["maximum"] |
| input_updates[number_idx] = gr.update(**upd) |
| input_updates[dropdown_idx] = gr.update(visible=False) |
| else: |
| input_updates[number_idx] = gr.update(visible=False) |
| input_updates[dropdown_idx] = gr.update( |
| visible=True, label=comp["name"], choices=comp["choices"], value=comp["value"] |
| ) |
| input_status = f"π Enter values for {len(components_info)} features | {validation_msg}" |
| status_msg = f"β
**Selected Dataset**: Custom Data | **Target**: {target_col} | **Type**: {problem_type.title()}" |
| return input_updates + [gr.update(visible=True), input_status, status_msg] |
|
|
| except Exception as e: |
| return [gr.update(visible=False)] * 40 + [gr.update(visible=False), f"β Error: {str(e)}", f"β Error: {str(e)}"] |
|
|
|
|
| |
| CLASS_CRITS = {"gini", "entropy", "log_loss"} |
| REGR_CRITS = {"squared_error", "absolute_error", "friedman_mse", "poisson"} |
|
|
| def update_criterion_choices(problem_type): |
| if problem_type == "classification": |
| return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") |
| else: |
| return gr.Dropdown(choices=sorted(REGR_CRITS), value="squared_error") |
|
|
|
|
| def update_criterion_on_target_change(df_preview, target_col): |
| """Recompute problem type from current df + target and return the right dropdown config.""" |
| if not target_col: |
| return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") |
| global current_dataframe |
| df = current_dataframe |
| if df is None or df.empty: |
| return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") |
| try: |
| is_valid, _, problem_type = validate_config(df, target_col) |
| if problem_type == "classification": |
| return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") |
| else: |
| return gr.Dropdown(choices=sorted(REGR_CRITS), value="squared_error") |
| except Exception: |
| return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") |
|
|
|
|
| def execute_prediction(df_preview, target_col, n_estimators, max_depth, min_samples_split, min_samples_leaf, criterion, max_features, *input_values): |
| """Run the random forest and produce all outputs. Always return 5 values.""" |
| global current_dataframe |
| df = current_dataframe |
|
|
| EMPTY_PLOT = None |
| EMPTY_MD = " " |
|
|
| if df is None or df.empty: |
| return (EMPTY_PLOT, EMPTY_PLOT, EMPTY_PLOT, "β **No data loaded!** π Please select a sample dataset or upload a file first.", "<div style='background:#FFF4F4;border-left:6px solid #C4314B;padding:14px 16px;border-radius:10px;'><strong>π³οΈ Voting Results</strong><br><br>No data available.</div>", gr.Dropdown(choices=["Tree 1"], value="Tree 1")) |
| if not target_col: |
| return (EMPTY_PLOT, EMPTY_PLOT, EMPTY_PLOT, "β **Configuration incomplete!** π― Please select target column above.", "<div style='background:#FFF4F4;border-left:6px solid #C4314B;padding:14px 16px;border-radius:10px;'><strong>π³οΈ Voting Results</strong><br><br>Configuration incomplete.</div>", gr.Dropdown(choices=["Tree 1"], value="Tree 1")) |
|
|
| is_valid, validation_msg, problem_type = validate_config(df, target_col) |
| if not is_valid: |
| return (EMPTY_PLOT, EMPTY_PLOT, EMPTY_PLOT, f"β **Configuration issue**: {validation_msg}", "<div style='background:#FFF4F4;border-left:6px solid #C4314B;padding:14px 16px;border-radius:10px;'><strong>π³οΈ Voting Results</strong><br><br>Configuration issue.</div>", gr.Dropdown(choices=["Tree 1"], value="Tree 1")) |
|
|
| |
| if problem_type == "classification": |
| if criterion not in CLASS_CRITS: |
| criterion = "gini" |
| else: |
| if criterion not in REGR_CRITS: |
| criterion = "squared_error" |
|
|
| try: |
| components_info = random_forest_core.create_input_components(df, target_col) |
| new_point_dict = {} |
| for i, comp in enumerate(components_info): |
| number_idx, dropdown_idx = i * 2, i * 2 + 1 |
| if comp["type"] == "number": |
| v = input_values[number_idx] if number_idx < len(input_values) and input_values[number_idx] is not None else comp["value"] |
| else: |
| v = input_values[dropdown_idx] if dropdown_idx < len(input_values) and input_values[dropdown_idx] is not None else comp["value"] |
| new_point_dict[comp["name"]] = v |
|
|
| tree_confidence_fig, importance_fig, prediction, pred_details, summary, aggregation_display, error = random_forest_core.run_random_forest_and_visualize( |
| df, target_col, new_point_dict, n_estimators, max_depth, min_samples_split, min_samples_leaf, criterion, max_features, problem_type |
| ) |
|
|
| if error: |
| return (tree_confidence_fig or EMPTY_PLOT, EMPTY_PLOT, importance_fig or EMPTY_PLOT, f"β **Prediction failed**: {error}", "<div style='background:#FFF4F4;border-left:6px solid #C4314B;padding:14px 16px;border-radius:10px;'><strong>π³οΈ Voting Results</strong><br><br>Error occurred during prediction.</div>", gr.Dropdown(choices=["Tree 1"], value="Tree 1")) |
|
|
| |
| feature_cols = [c for c in df.columns if c != target_col] |
| first_tree_fig = random_forest_core.get_individual_tree_visualization( |
| random_forest_core._get_current_model(), 0, feature_cols, problem_type |
| ) |
|
|
| |
| updated_tree_selector = update_tree_selector_choices(n_estimators) |
|
|
| return (tree_confidence_fig, first_tree_fig, importance_fig, aggregation_display, updated_tree_selector) |
|
|
| except Exception as e: |
| return (EMPTY_PLOT, EMPTY_PLOT, EMPTY_PLOT, f"β **Execution error**: {str(e)}", "<div style='background:#FFF4F4;border-left:6px solid #C4314B;padding:14px 16px;border-radius:10px;'><strong>π³οΈ Voting Results</strong><br><br>Execution error occurred.</div>", gr.Dropdown(choices=["Tree 1"], value="Tree 1")) |
|
|
|
|
| def update_tree_selector_choices(n_estimators): |
| """Update the tree selector choices based on number of trees.""" |
| n_trees = min(int(n_estimators), 20) |
| choices = [f"Tree {i+1}" for i in range(n_trees)] |
| return gr.Dropdown(choices=choices, value="Tree 1") |
|
|
|
|
| def update_tree_visualization(tree_selector): |
| """Update the individual tree visualization based on selection.""" |
| global current_dataframe |
| |
| if current_dataframe is None or current_dataframe.empty: |
| return None |
| |
| try: |
| |
| model = random_forest_core._get_current_model() |
| if model is None: |
| return None |
| |
| |
| tree_index = int(tree_selector.split()[-1]) - 1 |
| |
| |
| _, _, problem_type = validate_config(current_dataframe, current_dataframe.columns[-1]) |
| |
| |
| feature_cols = [c for c in current_dataframe.columns if c != current_dataframe.columns[-1]] |
| |
| |
| tree_fig = random_forest_core.get_individual_tree_visualization(model, tree_index, feature_cols, problem_type) |
| |
| return tree_fig |
| except Exception as e: |
| return None |
|
|
|
|
| |
| |
| |
| with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo: |
| vlai_template.create_header() |
| |
| |
| gr.HTML(vlai_template.render_info_card( |
| icon="π²", |
| title="About this Random Forest Demo", |
| description="This interactive demo showcases Random Forest algorithms for both classification and regression tasks. Explore ensemble learning with decision trees through dynamic parameter adjustment, comprehensive visualizations, and real-time predictions." |
| )) |
| |
|
|
| |
| gr.Markdown("### π² **How to Use**: Select data β Configure target β Set forest parameters β Enter new point β Run prediction!") |
|
|
| with gr.Row(equal_height=False, variant="panel"): |
| with gr.Column(scale=45): |
| with gr.Accordion("πΏ Data & Configuration", open=True): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("Start with sample datasets or upload your own CSV/Excel files.") |
| file_upload = gr.File(label="π Upload Your Data", file_types=[".csv", ".xlsx", ".xls"]) |
| with gr.Column(scale=3): |
| sample_dataset = gr.Dropdown(choices=list(SAMPLE_DATA_CONFIG.keys()), value="Titanic", label="ποΈ Sample Datasets") |
|
|
| with gr.Row(): |
| target_column = gr.Dropdown(choices=[], label="π― Target Column", interactive=True) |
|
|
| status_message = gr.Markdown("π Loading sample data...") |
| data_preview = gr.DataFrame(label="π Data Preview (First 5 Rows)", row_count=5, interactive=False, max_height=250) |
|
|
| with gr.Accordion("π³ Forest Parameters & Input", open=True): |
| gr.Markdown("**π² Random Forest Parameters**") |
| with gr.Row(): |
| n_estimators = gr.Number( |
| label="Number of Trees", |
| value=10, minimum=1, maximum=20, precision=0, |
| info="Number of trees in the forest (limited to 20)" |
| ) |
| max_depth = gr.Number( |
| label="Max Depth", |
| value=5, minimum=0, maximum=50, precision=0, |
| info="Set to 0 for unlimited depth" |
| ) |
| with gr.Row(): |
| min_samples_split = gr.Number( |
| label="Min Samples Split", |
| value=2, minimum=2, maximum=100, precision=0, |
| info="Minimum samples required to split an internal node" |
| ) |
| min_samples_leaf = gr.Number( |
| label="Min Samples Leaf", |
| value=1, minimum=1, maximum=50, precision=0, |
| info="Minimum samples required to be at a leaf node" |
| ) |
| with gr.Row(): |
| criterion = gr.Dropdown( |
| choices=sorted(CLASS_CRITS), value="gini", label="π― Criterion", |
| info="Objective to measure split quality (auto-switched for regression)" |
| ) |
| max_features = gr.Dropdown( |
| choices=["sqrt", "log2", "auto"], value="sqrt", label="Max Features", |
| info="Number of features to consider for best split" |
| ) |
|
|
| inputs_group = gr.Group(visible=False) |
| with inputs_group: |
| input_status = gr.Markdown("Configure inputs above.") |
| gr.Markdown("**π New Data Point** - Enter feature values for prediction:") |
| input_components = [] |
| for row in range(5): |
| with gr.Row(): |
| for col in range(4): |
| idx = row * 4 + col |
| if idx < 20: |
| number_comp = gr.Number(label=f"Feature {idx+1}", visible=False) |
| dropdown_comp = gr.Dropdown(label=f"Feature {idx+1}", visible=False) |
| input_components.extend([number_comp, dropdown_comp]) |
|
|
| run_prediction_btn = gr.Button("π Run Prediction", variant="primary", size="lg") |
|
|
| with gr.Column(scale=55): |
| gr.Markdown("### π² **Random Forest Results & Visualization**") |
| |
| |
| tree_confidence_chart = gr.Plot(label="Individual Tree Predictions & Confidence Scores", visible=True) |
| |
| |
| with gr.Row(): |
| tree_selector = gr.Dropdown( |
| choices=["Tree 1"], |
| value="Tree 1", |
| label="π³ Select Tree to Visualize", |
| interactive=True |
| ) |
| |
| individual_tree_plot = gr.Plot(label="Individual Tree Structure", visible=True) |
| |
| |
| feature_importance_plot = gr.Plot(label="Feature Importance", visible=True) |
|
|
| |
| aggregation_display = gr.HTML("**π³οΈ Voting Results**<br><br>Voting details will appear here for classification tasks.", label="π³οΈ Voting Process") |
|
|
| gr.Markdown("""π² **Random Forest Tips**: |
| - **π Tree Confidence Chart**: Shows confidence scores and predictions for each individual tree in the forest. |
| - **π³ Individual Tree Visualization**: Select any tree from the dropdown to see its detailed structure and decision paths. |
| - **π Feature Importance**: Displays which features matter most across all trees in the forest. |
| - **π― Parameter Tuning**: Try different **number of trees** (limited to 20) and **max depth** (5-15) to see changes. |
| - **πΏ Diversity Control**: **Max features** controls tree diversity - 'sqrt' is often optimal for balanced performance. |
| - **π‘οΈ Overfitting Prevention**: **Min samples split/leaf** parameters help control complexity and reduce overfitting. |
| - **π Interactive Analysis**: Use the tree selector to explore different trees and understand their decision patterns. |
| """) |
|
|
| vlai_template.create_footer() |
|
|
| |
| load_evt = demo.load( |
| fn=lambda: load_and_configure_data(None, "Titanic"), |
| outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status], |
| ) |
| load_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion]) |
|
|
| upload_evt = file_upload.upload( |
| fn=lambda file: load_and_configure_data(file, "Titanic"), |
| inputs=[file_upload], |
| outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status], |
| ) |
| upload_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion]) |
|
|
| sample_evt = sample_dataset.change( |
| fn=lambda choice: load_and_configure_data(None, choice), |
| inputs=[sample_dataset], |
| outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status], |
| ) |
| sample_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion]) |
|
|
| target_column.change( |
| fn=update_configuration, inputs=[data_preview, target_column], |
| outputs=input_components + [inputs_group, input_status, status_message], |
| ) |
| target_column.change( |
| fn=update_criterion_on_target_change, inputs=[data_preview, target_column], |
| outputs=[criterion], |
| ) |
|
|
| run_prediction_btn.click( |
| fn=execute_prediction, |
| inputs=[data_preview, target_column, n_estimators, max_depth, min_samples_split, min_samples_leaf, criterion, max_features] + input_components, |
| outputs=[tree_confidence_chart, individual_tree_plot, feature_importance_plot, aggregation_display, tree_selector], |
| ) |
| |
| |
| tree_selector.change( |
| fn=update_tree_visualization, |
| inputs=[tree_selector], |
| outputs=[individual_tree_plot], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static"]) |
|
|