import gradio as gr import shap import numpy as np import pandas as pd import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split import json import io from PIL import Image import warnings warnings.filterwarnings("ignore") # Configure TensorFlow to avoid GPU issues import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress TensorFlow warnings os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force TensorFlow to use CPU only import tensorflow as tf # Disable GPU for TensorFlow tf.config.set_visible_devices([], 'GPU') from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input # Set random seeds for reproducibility torch.manual_seed(42) np.random.seed(42) # ============================================================================ # MNIST Model Definition (for Pixel-level SHAP) # ============================================================================ class MNISTNet(nn.Module): def __init__(self): super(MNISTNet, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.softmax(x, dim=1) return output # ============================================================================ # Global Variables and Model Loading # ============================================================================ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load MNIST data transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # Initialize models (will be loaded on first use) mnist_model = None mnist_background = None resnet_model = None resnet_explainer = None tabular_model = None tabular_explainer = None tabular_data = None text_model = None text_tokenizer = None text_explainer = None # ============================================================================ # Helper Functions # ============================================================================ def initialize_mnist_model(): """Initialize MNIST model and background data""" global mnist_model, mnist_background if mnist_model is None: # Load MNIST test data test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=200, shuffle=False) # Get background and test images images, targets = next(iter(test_loader)) mnist_background = images[:100] # Create and train a simple model mnist_model = MNISTNet().to(DEVICE) mnist_model.eval() return mnist_model, mnist_background def initialize_resnet_model(): """Initialize ResNet50 model and explainer""" global resnet_model, resnet_explainer if resnet_model is None: resnet_model = ResNet50(weights="imagenet") # Load ImageNet class names class_names = None json_path = "imagenet_class_index.json" # Try to load from file if os.path.exists(json_path): try: with open(json_path) as f: class_idx = json.load(f) class_names = [class_idx[str(i)][1] for i in range(1000)] print(f"✓ Loaded {len(class_names)} ImageNet class names") except Exception as e: print(f"⚠ Error loading class names: {e}") # If not found, try to download if class_names is None: print("Downloading ImageNet class names...") try: import urllib.request url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json" urllib.request.urlretrieve(url, json_path) with open(json_path) as f: class_idx = json.load(f) class_names = [class_idx[str(i)][1] for i in range(1000)] print(f"✓ Downloaded and loaded {len(class_names)} ImageNet class names") except Exception as e: print(f"⚠ Could not download class names: {e}") print("Using placeholder names...") class_names = [f"class_{i}" for i in range(1000)] def f(x): tmp = x.copy() preprocess_input(tmp) return resnet_model(tmp) masker = shap.maskers.Image("inpaint_telea", (224, 224, 3)) resnet_explainer = shap.Explainer(f, masker, output_names=class_names) return resnet_model, resnet_explainer def initialize_tabular_model(): """Initialize tabular model and explainer""" global tabular_model, tabular_explainer, tabular_data if tabular_model is None: # Load adult income dataset (returns DataFrame and Series) X, y = shap.datasets.adult() # Convert to pandas DataFrame if it's not already if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) if not isinstance(y, pd.Series): y = pd.Series(y) # Keep as DataFrame after split X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) # Train Random Forest tabular_model = RandomForestClassifier(n_estimators=100, random_state=42) tabular_model.fit(X_train, y_train) # Create explainer tabular_explainer = shap.TreeExplainer(tabular_model) tabular_data = (X_test, y_test) return tabular_model, tabular_explainer, tabular_data # ============================================================================ # SHAP Explanation Functions # ============================================================================ def explain_mnist_digit(digit_index): """Generate SHAP explanation for MNIST digit""" try: model, background = initialize_mnist_model() # Load test data test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=200, shuffle=False) images, targets = next(iter(test_loader)) test_images = images[100:110] test_targets = targets[100:110].numpy() # Select image idx = min(digit_index, len(test_images) - 1) test_image = test_images[[idx]] # Move to same device as model test_image = test_image.to(DEVICE) background_device = background.to(DEVICE) # Get prediction with torch.no_grad(): output = model(test_image) pred = output.max(1, keepdim=True)[1].cpu().numpy()[0][0] # Create explainer and get SHAP values explainer = shap.DeepExplainer(model, background_device) shap_values = explainer.shap_values(test_image) # Prepare for visualization shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values] test_numpy = np.swapaxes(np.swapaxes(test_image.cpu().numpy(), 1, -1), 1, 2) # Create plot fig = plt.figure(figsize=(15, 3)) shap.image_plot(shap_numpy, -test_numpy, show=False) # Add title plt.suptitle(f'Actual: {test_targets[idx]}, Predicted: {pred}', fontsize=14, y=1.02) # Convert to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) buf.seek(0) img = Image.open(buf) plt.close() return img, f"Prediction: {pred} (Actual: {test_targets[idx]})" except Exception as e: return None, f"Error: {str(e)}" def explain_imagenet_image(image): """Generate SHAP explanation for ImageNet image""" try: model, explainer = initialize_resnet_model() # Preprocess image if image is None: return None, "Please upload an image" # Resize and prepare image img = Image.fromarray(image).resize((224, 224)) img_array = np.array(img) if len(img_array.shape) == 2: # Grayscale img_array = np.stack([img_array] * 3, axis=-1) elif img_array.shape[2] == 4: # RGBA img_array = img_array[:, :, :3] img_array = np.clip(img_array, 0, 255).astype(np.uint8) img_array = np.expand_dims(img_array, axis=0) # Calculate SHAP values shap_values = explainer(img_array, max_evals=100, batch_size=50, outputs=shap.Explanation.argsort.flip[:4]) # Create plot fig = plt.figure(figsize=(15, 5)) shap.image_plot(shap_values, show=False) # Convert to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) buf.seek(0) result_img = Image.open(buf) plt.close() return result_img, "SHAP explanation generated successfully" except Exception as e: return None, f"Error: {str(e)}" def explain_tabular_sample(sample_index): """Generate SHAP explanation for tabular data sample""" try: model, explainer, (X_test, y_test) = initialize_tabular_model() # Select sample idx = min(sample_index, len(X_test) - 1) # Get first 100 samples for SHAP calculation X_subset = X_test.iloc[:100] if hasattr(X_test, 'iloc') else X_test[:100] shap_values = explainer(X_subset) # Create waterfall plot fig = plt.figure(figsize=(10, 8)) shap.plots.waterfall(shap_values[idx, :, 1], show=False) # Convert to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) buf.seek(0) img = Image.open(buf) plt.close() # Get prediction - handle both DataFrame and numpy array if hasattr(X_test, 'iloc'): # DataFrame/Series X_sample = X_test.iloc[[idx]] actual = y_test.iloc[idx] else: # Numpy array X_sample = X_test[idx:idx+1] actual = y_test[idx] pred = model.predict(X_sample)[0] return img, f"Prediction: {pred} (Actual: {actual})" except Exception as e: import traceback error_details = traceback.format_exc() return None, f"Error: {str(e)}\n\nDetails:\n{error_details}" # ============================================================================ # Text Model Functions (for Text SHAP) # ============================================================================ def initialize_text_model(): """Initialize text sentiment model and explainer""" global text_model, text_tokenizer, text_explainer if text_model is None: try: from transformers import pipeline # Use a lightweight sentiment analysis model print("Loading text sentiment model...") text_model = pipeline( "sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", return_all_scores=True ) print("✓ Text model loaded successfully") except Exception as e: print(f"⚠ Error loading text model: {e}") raise return text_model def explain_text_sample(text_input, use_sample=False, sample_index=0): """Generate SHAP explanation for text input""" try: model = initialize_text_model() # Sample texts for demonstration sample_texts = [ "This movie was absolutely fantastic! I loved every minute of it.", "Terrible experience. The worst product I've ever bought.", "It's okay, nothing special but not bad either.", "Amazing quality and great customer service. Highly recommend!", "Disappointing and overpriced. Would not buy again.", "The best decision I ever made. Exceeded all expectations!", "Boring and uninspired. Complete waste of time and money.", "Pretty good overall, with some minor issues here and there.", "Absolutely horrible. Save your money and avoid this.", "Outstanding performance! This is exactly what I needed." ] # Use sample or custom text if use_sample: idx = min(sample_index, len(sample_texts) - 1) text = sample_texts[idx] else: text = text_input if text_input and text_input.strip() else sample_texts[0] # Get prediction prediction = model(text)[0] # Find the predicted class (highest score) pred_label = max(prediction, key=lambda x: x['score']) sentiment = pred_label['label'] confidence = pred_label['score'] # Create a simple word-level explanation # Split text into words words = text.split() # Create a simple masker function def predict_fn(texts): """Prediction function for SHAP""" results = [] for t in texts: if not t.strip(): # Empty text - return neutral scores results.append([0.5, 0.5]) else: pred = model(t)[0] # Get scores for NEGATIVE and POSITIVE scores = {p['label']: p['score'] for p in pred} results.append([ scores.get('NEGATIVE', 0.0), scores.get('POSITIVE', 0.0) ]) return np.array(results) # Use SHAP's Partition explainer for text print("Computing SHAP values for text...") explainer = shap.Explainer(predict_fn, shap.maskers.Text(r"\W+")) shap_values = explainer([text]) # Create visualization fig = plt.figure(figsize=(12, 6)) # Plot for the positive class (index 1) class_idx = 1 if sentiment == "POSITIVE" else 0 shap.plots.text(shap_values[0, :, class_idx], display=False) # Convert to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=150) buf.seek(0) img = Image.open(buf) plt.close() result_text = f"**Text:** {text}\n\n**Prediction:** {sentiment}\n**Confidence:** {confidence:.2%}\n\n" result_text += f"**Explanation:** Words highlighted in red increase the {sentiment} sentiment, " result_text += f"while words in blue decrease it." return img, result_text except Exception as e: import traceback error_details = traceback.format_exc() return None, f"Error: {str(e)}\n\nDetails:\n{error_details}" # ============================================================================ # Gradio Interface # ============================================================================ def create_demo(): """Create Gradio demo interface""" with gr.Blocks(title="SHAP Explanations Demo") as demo: gr.Markdown("# SHAP (SHapley Additive exPlanations) Demo") gr.Markdown("This demo showcases four different SHAP explanation methods for machine learning models.") with gr.Tabs(): # Tab 1: MNIST Pixel-level Explanations with gr.Tab("1. Pixel-level (MNIST Digits)"): gr.Markdown(""" ### Pixel-level SHAP Explanations This method uses **DeepExplainer** to show which pixels contribute to the model's prediction. - **Red pixels**: Increase the probability of the predicted class - **Blue pixels**: Decrease the probability of the predicted class """) with gr.Row(): with gr.Column(): mnist_slider = gr.Slider(minimum=0, maximum=9, step=1, value=0, label="Select Test Image Index") mnist_button = gr.Button("Generate Explanation", variant="primary") with gr.Column(): mnist_output = gr.Image(label="SHAP Explanation") mnist_text = gr.Textbox(label="Prediction Result") mnist_button.click( fn=explain_mnist_digit, inputs=[mnist_slider], outputs=[mnist_output, mnist_text] ) # Tab 2: ImageNet Image Explanations with gr.Tab("2. Image Segmentation (ImageNet)"): gr.Markdown(""" ### Image Segmentation SHAP Explanations This method uses **Partition Explainer** with image masking to explain ResNet50 predictions. Upload an image to see which regions contribute to the top predicted classes. """) with gr.Row(): with gr.Column(): image_input = gr.Image(label="Upload Image") image_button = gr.Button("Generate Explanation", variant="primary") with gr.Column(): image_output = gr.Image(label="SHAP Explanation") image_text = gr.Textbox(label="Status") image_button.click( fn=explain_imagenet_image, inputs=[image_input], outputs=[image_output, image_text] ) # Tab 3: Tabular Data Explanations with gr.Tab("3. Tabular Data (Adult Income)"): gr.Markdown(""" ### Tabular Data SHAP Explanations This method uses **TreeExplainer** to explain Random Forest predictions on the Adult Income dataset. The waterfall plot shows how each feature contributes to the prediction. """) with gr.Row(): with gr.Column(): tabular_slider = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Select Sample Index") tabular_button = gr.Button("Generate Explanation", variant="primary") with gr.Column(): tabular_output = gr.Image(label="SHAP Waterfall Plot") tabular_text = gr.Textbox(label="Prediction Result") tabular_button.click( fn=explain_tabular_sample, inputs=[tabular_slider], outputs=[tabular_output, tabular_text] ) # Tab 4: Text Data Explanations with gr.Tab("4. Text Data (Sentiment Analysis)"): gr.Markdown(""" ### Text SHAP Explanations This method uses **Partition Explainer** to explain sentiment analysis predictions. Enter your own text or select from sample texts to see which words contribute to the sentiment prediction. - **Red words**: Increase the predicted sentiment - **Blue words**: Decrease the predicted sentiment """) with gr.Row(): with gr.Column(): text_mode = gr.Radio( choices=["Custom Text", "Sample Text"], value="Sample Text", label="Input Mode" ) text_input = gr.Textbox( label="Enter Your Text", placeholder="Type your text here...", lines=3, visible=False ) text_slider = gr.Slider( minimum=0, maximum=9, step=1, value=0, label="Select Sample Text Index", visible=True ) text_button = gr.Button("Generate Explanation", variant="primary") gr.Markdown(""" **Sample texts include:** - Positive reviews (indices 0, 3, 5, 9) - Negative reviews (indices 1, 4, 6, 8) - Neutral reviews (indices 2, 7) """) with gr.Column(): text_output = gr.Image(label="SHAP Text Explanation") text_result = gr.Textbox(label="Prediction Result", lines=5) # Toggle visibility based on mode def update_visibility(mode): if mode == "Custom Text": return gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True) text_mode.change( fn=update_visibility, inputs=[text_mode], outputs=[text_input, text_slider] ) # Generate explanation def generate_text_explanation(mode, custom_text, sample_idx): use_sample = (mode == "Sample Text") return explain_text_sample(custom_text, use_sample, sample_idx) text_button.click( fn=generate_text_explanation, inputs=[text_mode, text_input, text_slider], outputs=[text_output, text_result] ) gr.Markdown(""" --- ### About SHAP SHAP (SHapley Additive exPlanations) is a unified approach to explain the output of machine learning models. It connects game theory with local explanations and provides consistent and locally accurate feature attributions. """) return demo # ============================================================================ # Main # ============================================================================ if __name__ == "__main__": demo = create_demo() demo.launch(share=False, server_name="0.0.0.0", server_port=7860)