Upload folder using huggingface_hub
Browse files- README.md +25 -13
- app.py +340 -169
- docs/superpowers/plans/2026-03-28-gradio-xai-merge.md +1231 -0
- docs/superpowers/specs/2026-03-28-gradio-xai-merge-design.md +156 -0
- feedback/.gitkeep +0 -0
- feedback/feedback_log.csv +2 -0
- requirements.txt +2 -1
- retrain.py +196 -0
README.md
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
---
|
| 2 |
-
title: Spam Email Classifier with XAI
|
| 3 |
emoji: 📧
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: "4.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
|
@@ -13,39 +13,51 @@ tags:
|
|
| 13 |
- xai
|
| 14 |
- lime
|
| 15 |
- shap
|
|
|
|
| 16 |
- scikit-learn
|
| 17 |
- nlp
|
|
|
|
| 18 |
---
|
| 19 |
|
| 20 |
# Spam Email Classifier with XAI Explanations
|
| 21 |
|
| 22 |
-
**
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
A Gradio web app that classifies emails as spam or ham and provides explainable AI (XAI) insights using LIME and SHAP.
|
| 27 |
|
| 28 |
## Features
|
| 29 |
|
| 30 |
- Paste any email and get an instant spam/ham prediction
|
| 31 |
-
- LIME explanations
|
| 32 |
-
- SHAP feature importance
|
| 33 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
## How to Run Locally
|
| 36 |
|
| 37 |
```bash
|
| 38 |
pip install -r requirements.txt
|
| 39 |
-
python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
```
|
| 41 |
|
| 42 |
## Model
|
| 43 |
|
| 44 |
-
Voting ensemble
|
| 45 |
|
| 46 |
## Tech Stack
|
| 47 |
|
| 48 |
-
- scikit-learn (
|
| 49 |
-
- LIME + SHAP (explainability)
|
| 50 |
- Gradio (web interface)
|
| 51 |
- NLTK (text preprocessing)
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Spam Email Classifier with XAI
|
| 3 |
emoji: 📧
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "4.44.0"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
|
|
|
| 13 |
- xai
|
| 14 |
- lime
|
| 15 |
- shap
|
| 16 |
+
- eli5
|
| 17 |
- scikit-learn
|
| 18 |
- nlp
|
| 19 |
+
- explainable-ai
|
| 20 |
---
|
| 21 |
|
| 22 |
# Spam Email Classifier with XAI Explanations
|
| 23 |
|
| 24 |
+
> **Disclaimer:** This model was created as a university course project. It is intended for **educational and research purposes only** and should not be used as a sole spam/phishing filter in production. Classification accuracy may vary. Always use established email security tools for real-world spam filtering.
|
| 25 |
|
| 26 |
+
A Gradio web app that classifies emails as spam or ham and provides explainable AI (XAI) insights using three different methods.
|
|
|
|
|
|
|
| 27 |
|
| 28 |
## Features
|
| 29 |
|
| 30 |
- Paste any email and get an instant spam/ham prediction
|
| 31 |
+
- **LIME** explanations — which words pushed the decision
|
| 32 |
+
- **SHAP** feature importance — game-theoretic attribution
|
| 33 |
+
- **ELI5** — model internal feature weights
|
| 34 |
+
- **Side-by-side comparison** of all three XAI methods
|
| 35 |
+
- **Plain English summary** of why the model made its decision
|
| 36 |
+
- **User feedback** — thumbs up/down to log corrections for batch retraining
|
| 37 |
+
- Adjustable classification threshold
|
| 38 |
|
| 39 |
## How to Run Locally
|
| 40 |
|
| 41 |
```bash
|
| 42 |
pip install -r requirements.txt
|
| 43 |
+
python train.py # train the models (first time only)
|
| 44 |
+
python app.py # launch the Gradio app
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## Retraining with Feedback
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
python retrain.py # retrain with accumulated feedback corrections
|
| 51 |
+
python retrain.py --no-feedback # retrain with original data only
|
| 52 |
```
|
| 53 |
|
| 54 |
## Model
|
| 55 |
|
| 56 |
+
Voting ensemble (Random Forest + Logistic Regression + SVM) trained on SpamAssassin + Enron email datasets using TF-IDF + 24 metadata features.
|
| 57 |
|
| 58 |
## Tech Stack
|
| 59 |
|
| 60 |
+
- scikit-learn (ensemble classifier)
|
| 61 |
+
- LIME + SHAP + ELI5 (explainability)
|
| 62 |
- Gradio (web interface)
|
| 63 |
- NLTK (text preprocessing)
|
app.py
CHANGED
|
@@ -1,29 +1,36 @@
|
|
| 1 |
# Gradio web app for the Spam Email Classifier with XAI explanations
|
| 2 |
-
#
|
| 3 |
-
#
|
| 4 |
-
#
|
| 5 |
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
import
|
| 9 |
-
|
| 10 |
-
|
|
|
|
| 11 |
import gradio as gr
|
| 12 |
import lime
|
| 13 |
import lime.lime_tabular
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
import shap
|
| 15 |
-
|
| 16 |
from scipy.sparse import hstack, csr_matrix
|
|
|
|
| 17 |
from utils import (preprocess_text, compute_metadata_features,
|
| 18 |
META_FEATURE_NAMES, FEATURE_DESCRIPTIONS)
|
| 19 |
|
| 20 |
# ---------------------------------------------------------------------------
|
| 21 |
# 1. Model Loading
|
| 22 |
# ---------------------------------------------------------------------------
|
| 23 |
-
# All trained artifacts live in the models/ folder.
|
| 24 |
-
# If any file is missing, we set it to None and show an error later.
|
| 25 |
|
| 26 |
models_dir = Path(__file__).parent / 'models'
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
try:
|
| 29 |
voting_model = joblib.load(models_dir / 'voting_model.joblib')
|
|
@@ -32,7 +39,8 @@ try:
|
|
| 32 |
feature_names = joblib.load(models_dir / 'feature_names.joblib')
|
| 33 |
optimal_threshold = joblib.load(models_dir / 'optimal_threshold.joblib')
|
| 34 |
training_sample = joblib.load(models_dir / 'training_sample.joblib')
|
| 35 |
-
|
|
|
|
| 36 |
except FileNotFoundError as e:
|
| 37 |
print(f"Model file not found: {e}")
|
| 38 |
voting_model = None
|
|
@@ -41,12 +49,11 @@ except FileNotFoundError as e:
|
|
| 41 |
feature_names = None
|
| 42 |
optimal_threshold = None
|
| 43 |
training_sample = None
|
|
|
|
| 44 |
|
| 45 |
# ---------------------------------------------------------------------------
|
| 46 |
# 2. LIME Explainer Setup
|
| 47 |
# ---------------------------------------------------------------------------
|
| 48 |
-
# LIME needs a sample of training data to understand feature distributions.
|
| 49 |
-
# We set this up once at startup so we don't recreate it every prediction.
|
| 50 |
|
| 51 |
lime_explainer = None
|
| 52 |
if training_sample is not None and feature_names is not None:
|
|
@@ -59,264 +66,387 @@ if training_sample is not None and feature_names is not None:
|
|
| 59 |
print("LIME explainer ready.")
|
| 60 |
|
| 61 |
# ---------------------------------------------------------------------------
|
| 62 |
-
# 3. classify_email
|
| 63 |
# ---------------------------------------------------------------------------
|
| 64 |
-
# Takes raw email text, runs it through the full pipeline, returns the
|
| 65 |
-
# prediction label, confidence score, and the combined feature matrix.
|
| 66 |
|
| 67 |
-
def classify_email(email_text):
|
| 68 |
-
"""Classify a single email
|
| 69 |
-
# Step 1: Preprocess the text (remove HTML, URLs, stopwords, stem)
|
| 70 |
cleaned_text = preprocess_text(email_text)
|
| 71 |
-
|
| 72 |
-
# Step 2: Convert to TF-IDF features (sparse matrix)
|
| 73 |
tfidf_features = tfidf_vectorizer.transform([cleaned_text])
|
| 74 |
-
|
| 75 |
-
# Step 3: Compute the 24 metadata features from the RAW text
|
| 76 |
-
# (metadata uses things like exclamation marks that preprocessing removes)
|
| 77 |
meta_raw = compute_metadata_features([email_text])
|
| 78 |
-
|
| 79 |
-
# Step 4: Scale the metadata features
|
| 80 |
meta_scaled = meta_scaler.transform(meta_raw)
|
| 81 |
-
|
| 82 |
-
# Step 5: Combine TF-IDF + metadata into one feature matrix
|
| 83 |
combined = hstack([tfidf_features, csr_matrix(meta_scaled)])
|
| 84 |
-
|
| 85 |
-
# Step 6: Get the spam probability from the model
|
| 86 |
spam_proba = voting_model.predict_proba(combined)[0][1]
|
| 87 |
|
| 88 |
-
|
| 89 |
-
if spam_proba >= optimal_threshold:
|
| 90 |
label = "SPAM"
|
| 91 |
confidence = spam_proba
|
| 92 |
else:
|
| 93 |
label = "HAM (Not Spam)"
|
| 94 |
confidence = 1.0 - spam_proba
|
| 95 |
|
| 96 |
-
return label, confidence, combined
|
| 97 |
-
|
| 98 |
|
| 99 |
# ---------------------------------------------------------------------------
|
| 100 |
-
# 4.
|
| 101 |
# ---------------------------------------------------------------------------
|
| 102 |
-
# Creates a human-readable markdown summary of the classification result.
|
| 103 |
-
# Uses LIME feature importance if available, otherwise falls back to
|
| 104 |
-
# showing notable metadata values.
|
| 105 |
-
|
| 106 |
-
def generate_summary(label, confidence, email_text, lime_explanation=None):
|
| 107 |
-
"""Generate a markdown summary of the classification result."""
|
| 108 |
-
# Start with the main result
|
| 109 |
-
summary = f"This email was classified as **{label}** ({confidence:.0%} confidence).\n\n"
|
| 110 |
-
|
| 111 |
-
# If we have a LIME explanation, show the top 5 features
|
| 112 |
-
if lime_explanation is not None:
|
| 113 |
-
summary += "**Key factors:**\n"
|
| 114 |
-
# Get the feature importance list from LIME
|
| 115 |
-
# Each item is (feature_name, weight) - positive = spam, negative = ham
|
| 116 |
-
feature_list = lime_explanation.as_list()
|
| 117 |
-
top_features = feature_list[:5] # take top 5
|
| 118 |
-
|
| 119 |
-
for feature_rule, weight in top_features:
|
| 120 |
-
if weight > 0:
|
| 121 |
-
direction = "pushes toward spam"
|
| 122 |
-
else:
|
| 123 |
-
direction = "pushes toward ham"
|
| 124 |
-
summary += f"- **{feature_rule}** {direction}\n"
|
| 125 |
-
else:
|
| 126 |
-
# Fallback: show notable metadata values when LIME isn't available
|
| 127 |
-
summary += "**Notable features:**\n"
|
| 128 |
-
meta_raw = compute_metadata_features([email_text])
|
| 129 |
-
meta_values = meta_raw[0]
|
| 130 |
-
|
| 131 |
-
for i, feat_name in enumerate(META_FEATURE_NAMES):
|
| 132 |
-
value = meta_values[i]
|
| 133 |
-
# Only show features with non-zero values (they're more interesting)
|
| 134 |
-
if value != 0:
|
| 135 |
-
description = FEATURE_DESCRIPTIONS.get(feat_name, feat_name)
|
| 136 |
-
summary += f"- **{feat_name}** = {value:.2f} ({description})\n"
|
| 137 |
-
|
| 138 |
-
return summary
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
# 5. generate_lime_plot Function
|
| 143 |
-
# ---------------------------------------------------------------------------
|
| 144 |
-
# Creates a LIME explanation plot for a single email.
|
| 145 |
-
# LIME shows which features pushed the prediction toward spam vs ham.
|
| 146 |
-
|
| 147 |
-
def generate_lime_plot(combined_features):
|
| 148 |
-
"""Generate a LIME explanation plot. Returns (figure, explanation) or (None, None)."""
|
| 149 |
if lime_explainer is None:
|
| 150 |
return None, None
|
| 151 |
-
|
| 152 |
-
# Convert sparse matrix to dense array (fine for a single email)
|
| 153 |
instance = combined_features.toarray()[0]
|
| 154 |
-
|
| 155 |
-
# Ask LIME to explain this prediction
|
| 156 |
-
# num_features=10 means show the 10 most important features
|
| 157 |
explanation = lime_explainer.explain_instance(
|
| 158 |
instance,
|
| 159 |
voting_model.predict_proba,
|
| 160 |
num_features=10,
|
| 161 |
)
|
| 162 |
-
|
| 163 |
-
# Get the matplotlib figure from LIME
|
| 164 |
fig = explanation.as_pyplot_figure()
|
| 165 |
fig.tight_layout()
|
| 166 |
-
|
| 167 |
return fig, explanation
|
| 168 |
|
| 169 |
-
|
| 170 |
# ---------------------------------------------------------------------------
|
| 171 |
-
#
|
| 172 |
# ---------------------------------------------------------------------------
|
| 173 |
-
# Creates a SHAP explanation plot using ONLY the 24 metadata features.
|
| 174 |
-
# We skip the TF-IDF features because SHAP KernelExplainer is slow on
|
| 175 |
-
# thousands of features - metadata-only is fast and still informative.
|
| 176 |
|
| 177 |
-
def
|
| 178 |
-
"""Generate
|
| 179 |
if training_sample is None or voting_model is None:
|
| 180 |
-
return None
|
| 181 |
|
| 182 |
-
|
| 183 |
-
num_meta = len(META_FEATURE_NAMES) # 24
|
| 184 |
background_meta = training_sample[:50, -num_meta:]
|
| 185 |
-
|
| 186 |
-
# Compute metadata features for this email
|
| 187 |
meta_raw = compute_metadata_features([email_text])
|
| 188 |
meta_scaled = meta_scaler.transform(meta_raw)
|
| 189 |
-
|
| 190 |
-
# We need a wrapper function that takes ONLY metadata features
|
| 191 |
-
# and pads zeros for the TF-IDF columns so the model can still predict
|
| 192 |
num_tfidf = training_sample.shape[1] - num_meta
|
| 193 |
|
| 194 |
def predict_with_meta_only(meta_features):
|
| 195 |
-
"""Pad zeros for TF-IDF columns, then predict with the full model."""
|
| 196 |
n_samples = meta_features.shape[0]
|
| 197 |
-
# Create a zero matrix for the TF-IDF part
|
| 198 |
tfidf_zeros = csr_matrix((n_samples, num_tfidf))
|
| 199 |
-
# Combine: TF-IDF zeros + metadata features
|
| 200 |
combined = hstack([tfidf_zeros, csr_matrix(meta_features)])
|
| 201 |
return voting_model.predict_proba(combined)
|
| 202 |
|
| 203 |
-
# Create the SHAP explainer with our wrapper function
|
| 204 |
explainer = shap.KernelExplainer(predict_with_meta_only, background_meta)
|
| 205 |
-
|
| 206 |
-
# Compute SHAP values for this email's metadata
|
| 207 |
shap_values = explainer.shap_values(meta_scaled, nsamples=100)
|
| 208 |
|
| 209 |
-
# SHAP returns different formats depending on the version:
|
| 210 |
-
# - Older versions: a list of arrays [ham_values, spam_values]
|
| 211 |
-
# - Newer versions: a single 2D or 3D array
|
| 212 |
-
# We want a flat 1D array of shape (24,) for the spam class
|
| 213 |
if isinstance(shap_values, list):
|
| 214 |
-
# List format: index 1 is the spam class
|
| 215 |
sv = np.array(shap_values[1]).flatten()
|
| 216 |
else:
|
| 217 |
sv = np.array(shap_values).flatten()
|
| 218 |
-
|
| 219 |
-
# If flatten gave us more values than features (e.g. both classes),
|
| 220 |
-
# take just the last 24 (spam class values)
|
| 221 |
if len(sv) > num_meta:
|
| 222 |
sv = sv[-num_meta:]
|
| 223 |
|
| 224 |
-
|
| 225 |
-
fig, ax = plt.subplots(figsize=(8, 6))
|
| 226 |
|
| 227 |
-
# Sort features by absolute SHAP value (most important on top)
|
| 228 |
sorted_indices = np.argsort(np.abs(sv))
|
| 229 |
sorted_names = [META_FEATURE_NAMES[idx] for idx in sorted_indices.tolist()]
|
| 230 |
sorted_values = sv[sorted_indices]
|
| 231 |
|
| 232 |
-
|
| 233 |
-
colors = []
|
| 234 |
-
for val in sorted_values:
|
| 235 |
-
if val > 0:
|
| 236 |
-
colors.append('#d62728') # red for spam
|
| 237 |
-
else:
|
| 238 |
-
colors.append('#1f77b4') # blue for ham
|
| 239 |
-
|
| 240 |
ax.barh(sorted_names, sorted_values, color=colors)
|
| 241 |
ax.set_xlabel('SHAP Value (impact on spam probability)')
|
| 242 |
ax.set_title('SHAP Feature Importance (Metadata Features)')
|
| 243 |
ax.axvline(x=0, color='black', linewidth=0.5)
|
| 244 |
fig.tight_layout()
|
| 245 |
|
| 246 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
|
|
|
| 248 |
|
| 249 |
# ---------------------------------------------------------------------------
|
| 250 |
-
# 7.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
# ---------------------------------------------------------------------------
|
| 252 |
-
# These give users something to try right away without typing their own email.
|
| 253 |
|
| 254 |
EXAMPLE_EMAILS = [
|
| 255 |
["Subject: URGENT - You Have Won $5,000,000!!!\n\nDear Friend,\n\nCONGRATULATIONS!!! You have been selected as the winner of our international lottery program!!!\nTo claim your $5,000,000 USD prize, click the link below IMMEDIATELY and provide your bank details.\n\nACT NOW - This offer expires in 24 hours!!!\n\nClick here: http://totally-legit-prize.com/claim\nSend $500 processing fee to unlock your winnings.\n\nBest regards,\nDr. Prince Mohammed"],
|
| 256 |
["Subject: Team sync Thursday 2pm\n\nHi everyone,\n\nJust a reminder that we have our weekly team sync this Thursday at 2pm in Conference Room B.\n\nAgenda:\n- Sprint review\n- Q2 planning discussion\n- New hire onboarding update\n\nPlease come prepared with your status updates.\n\nThanks,\nSarah"],
|
| 257 |
["Subject: Your account has been compromised!\n\nDear Customer,\n\nWe detected suspicious activity on your account. Click here immediately to verify your identity: http://secure-bank-login.com/verify\n\nIf you do not verify within 24 hours, your account will be permanently locked.\n\nSecurity Team"],
|
| 258 |
["Subject: Thanksgiving dinner plans\n\nHi everyone!\n\nI wanted to start planning for Thanksgiving dinner. I'm thinking we could do it at my place this year. What does everyone think about 4pm?\n\nLet me know if you have any dietary restrictions or if you want to bring a dish.\n\nLove,\nMom"],
|
|
|
|
| 259 |
]
|
| 260 |
|
| 261 |
-
|
| 262 |
# ---------------------------------------------------------------------------
|
| 263 |
-
#
|
| 264 |
# ---------------------------------------------------------------------------
|
| 265 |
-
# This is the function that Gradio calls when the user clicks "Classify".
|
| 266 |
-
# It handles file uploads, input validation, and orchestrates everything.
|
| 267 |
|
| 268 |
-
def classify_and_explain(email_text, uploaded_file):
|
| 269 |
-
"""Main function called by Gradio. Returns
|
| 270 |
|
| 271 |
-
# If user uploaded a file, read its contents
|
| 272 |
if uploaded_file is not None:
|
| 273 |
try:
|
| 274 |
file_content = Path(uploaded_file).read_text(encoding='utf-8')
|
| 275 |
email_text = file_content
|
| 276 |
except Exception:
|
| 277 |
-
|
|
|
|
| 278 |
|
| 279 |
-
# Check for empty input
|
| 280 |
if email_text is None or email_text.strip() == '':
|
| 281 |
-
|
|
|
|
| 282 |
|
| 283 |
-
# Check that models are loaded
|
| 284 |
if voting_model is None:
|
| 285 |
-
|
|
|
|
| 286 |
|
| 287 |
-
|
| 288 |
-
label, confidence, combined_features = classify_email(email_text)
|
| 289 |
|
| 290 |
-
|
| 291 |
-
lime_fig, lime_explanation = generate_lime_plot(combined_features)
|
| 292 |
|
| 293 |
-
# Step 3: Generate the SHAP plot (wrapped in try/except so SHAP errors
|
| 294 |
-
# don't prevent the Result and LIME tabs from showing)
|
| 295 |
try:
|
| 296 |
-
shap_fig =
|
| 297 |
except Exception as e:
|
| 298 |
-
print(
|
| 299 |
-
shap_fig = None
|
| 300 |
|
| 301 |
-
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
-
|
|
|
|
|
|
|
| 305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
# ---------------------------------------------------------------------------
|
| 308 |
-
#
|
| 309 |
# ---------------------------------------------------------------------------
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
with gr.Blocks(title="Spam Email Classifier with XAI") as demo:
|
| 314 |
gr.Markdown("# Spam Email Classifier with XAI Explanations")
|
| 315 |
-
gr.Markdown("
|
| 316 |
-
"
|
|
|
|
|
|
|
| 317 |
|
| 318 |
with gr.Row():
|
| 319 |
-
# Left column: inputs
|
| 320 |
with gr.Column(scale=1):
|
| 321 |
email_input = gr.Textbox(
|
| 322 |
label="Email Text",
|
|
@@ -328,6 +458,12 @@ with gr.Blocks(title="Spam Email Classifier with XAI") as demo:
|
|
| 328 |
label="Or upload a .txt file",
|
| 329 |
file_types=['.txt'],
|
| 330 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
classify_btn = gr.Button("Classify", variant="primary")
|
| 332 |
gr.Examples(
|
| 333 |
examples=EXAMPLE_EMAILS,
|
|
@@ -336,26 +472,61 @@ with gr.Blocks(title="Spam Email Classifier with XAI") as demo:
|
|
| 336 |
cache_examples=False,
|
| 337 |
)
|
| 338 |
|
| 339 |
-
# Right column: outputs in tabs
|
| 340 |
with gr.Column(scale=1):
|
| 341 |
with gr.Tabs():
|
| 342 |
with gr.Tab("Result"):
|
| 343 |
result_output = gr.Markdown(label="Classification Result")
|
| 344 |
with gr.Tab("LIME"):
|
|
|
|
|
|
|
| 345 |
lime_output = gr.Plot(label="LIME Explanation")
|
| 346 |
with gr.Tab("SHAP"):
|
|
|
|
|
|
|
| 347 |
shap_output = gr.Plot(label="SHAP Explanation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
-
# Wire up the button click to our main function
|
| 350 |
classify_btn.click(
|
| 351 |
fn=classify_and_explain,
|
| 352 |
-
inputs=[email_input, file_input],
|
| 353 |
-
outputs=[result_output, lime_output, shap_output
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
)
|
| 356 |
|
| 357 |
# ---------------------------------------------------------------------------
|
| 358 |
-
#
|
| 359 |
# ---------------------------------------------------------------------------
|
| 360 |
|
| 361 |
if __name__ == '__main__':
|
|
|
|
| 1 |
# Gradio web app for the Spam Email Classifier with XAI explanations
|
| 2 |
+
# University course project — Explainable AI for spam detection
|
| 3 |
+
# Features: LIME, SHAP, ELI5, side-by-side comparison, plain English summary,
|
| 4 |
+
# user feedback logging, and batch retrain support.
|
| 5 |
|
| 6 |
+
import csv
|
| 7 |
+
import os
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import eli5
|
| 12 |
import gradio as gr
|
| 13 |
import lime
|
| 14 |
import lime.lime_tabular
|
| 15 |
+
import matplotlib
|
| 16 |
+
matplotlib.use('Agg')
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import numpy as np
|
| 19 |
import shap
|
| 20 |
+
import joblib
|
| 21 |
from scipy.sparse import hstack, csr_matrix
|
| 22 |
+
|
| 23 |
from utils import (preprocess_text, compute_metadata_features,
|
| 24 |
META_FEATURE_NAMES, FEATURE_DESCRIPTIONS)
|
| 25 |
|
| 26 |
# ---------------------------------------------------------------------------
|
| 27 |
# 1. Model Loading
|
| 28 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
| 29 |
|
| 30 |
models_dir = Path(__file__).parent / 'models'
|
| 31 |
+
feedback_dir = Path(__file__).parent / 'feedback'
|
| 32 |
+
feedback_dir.mkdir(exist_ok=True)
|
| 33 |
+
FEEDBACK_CSV = feedback_dir / 'feedback_log.csv'
|
| 34 |
|
| 35 |
try:
|
| 36 |
voting_model = joblib.load(models_dir / 'voting_model.joblib')
|
|
|
|
| 39 |
feature_names = joblib.load(models_dir / 'feature_names.joblib')
|
| 40 |
optimal_threshold = joblib.load(models_dir / 'optimal_threshold.joblib')
|
| 41 |
training_sample = joblib.load(models_dir / 'training_sample.joblib')
|
| 42 |
+
raw_rf = voting_model.named_estimators_['rf']
|
| 43 |
+
print(f"All models loaded. Threshold = {optimal_threshold:.4f}")
|
| 44 |
except FileNotFoundError as e:
|
| 45 |
print(f"Model file not found: {e}")
|
| 46 |
voting_model = None
|
|
|
|
| 49 |
feature_names = None
|
| 50 |
optimal_threshold = None
|
| 51 |
training_sample = None
|
| 52 |
+
raw_rf = None
|
| 53 |
|
| 54 |
# ---------------------------------------------------------------------------
|
| 55 |
# 2. LIME Explainer Setup
|
| 56 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
| 57 |
|
| 58 |
lime_explainer = None
|
| 59 |
if training_sample is not None and feature_names is not None:
|
|
|
|
| 66 |
print("LIME explainer ready.")
|
| 67 |
|
| 68 |
# ---------------------------------------------------------------------------
|
| 69 |
+
# 3. classify_email
|
| 70 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
def classify_email(email_text, threshold):
|
| 73 |
+
"""Classify a single email. Returns (label, confidence, spam_proba, combined_features)."""
|
|
|
|
| 74 |
cleaned_text = preprocess_text(email_text)
|
|
|
|
|
|
|
| 75 |
tfidf_features = tfidf_vectorizer.transform([cleaned_text])
|
|
|
|
|
|
|
|
|
|
| 76 |
meta_raw = compute_metadata_features([email_text])
|
|
|
|
|
|
|
| 77 |
meta_scaled = meta_scaler.transform(meta_raw)
|
|
|
|
|
|
|
| 78 |
combined = hstack([tfidf_features, csr_matrix(meta_scaled)])
|
|
|
|
|
|
|
| 79 |
spam_proba = voting_model.predict_proba(combined)[0][1]
|
| 80 |
|
| 81 |
+
if spam_proba >= threshold:
|
|
|
|
| 82 |
label = "SPAM"
|
| 83 |
confidence = spam_proba
|
| 84 |
else:
|
| 85 |
label = "HAM (Not Spam)"
|
| 86 |
confidence = 1.0 - spam_proba
|
| 87 |
|
| 88 |
+
return label, confidence, spam_proba, combined
|
|
|
|
| 89 |
|
| 90 |
# ---------------------------------------------------------------------------
|
| 91 |
+
# 4. LIME explanation
|
| 92 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
+
def generate_lime_explanation(combined_features):
|
| 95 |
+
"""Generate LIME explanation. Returns (figure, explanation) or (None, None)."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
if lime_explainer is None:
|
| 97 |
return None, None
|
|
|
|
|
|
|
| 98 |
instance = combined_features.toarray()[0]
|
|
|
|
|
|
|
|
|
|
| 99 |
explanation = lime_explainer.explain_instance(
|
| 100 |
instance,
|
| 101 |
voting_model.predict_proba,
|
| 102 |
num_features=10,
|
| 103 |
)
|
|
|
|
|
|
|
| 104 |
fig = explanation.as_pyplot_figure()
|
| 105 |
fig.tight_layout()
|
|
|
|
| 106 |
return fig, explanation
|
| 107 |
|
|
|
|
| 108 |
# ---------------------------------------------------------------------------
|
| 109 |
+
# 5. SHAP explanation (metadata features only — fast)
|
| 110 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
def generate_shap_explanation(email_text):
|
| 113 |
+
"""Generate SHAP bar chart for metadata features. Returns (figure, shap_values, top_indices) or (None, None, None)."""
|
| 114 |
if training_sample is None or voting_model is None:
|
| 115 |
+
return None, None, None
|
| 116 |
|
| 117 |
+
num_meta = len(META_FEATURE_NAMES)
|
|
|
|
| 118 |
background_meta = training_sample[:50, -num_meta:]
|
|
|
|
|
|
|
| 119 |
meta_raw = compute_metadata_features([email_text])
|
| 120 |
meta_scaled = meta_scaler.transform(meta_raw)
|
|
|
|
|
|
|
|
|
|
| 121 |
num_tfidf = training_sample.shape[1] - num_meta
|
| 122 |
|
| 123 |
def predict_with_meta_only(meta_features):
|
|
|
|
| 124 |
n_samples = meta_features.shape[0]
|
|
|
|
| 125 |
tfidf_zeros = csr_matrix((n_samples, num_tfidf))
|
|
|
|
| 126 |
combined = hstack([tfidf_zeros, csr_matrix(meta_features)])
|
| 127 |
return voting_model.predict_proba(combined)
|
| 128 |
|
|
|
|
| 129 |
explainer = shap.KernelExplainer(predict_with_meta_only, background_meta)
|
|
|
|
|
|
|
| 130 |
shap_values = explainer.shap_values(meta_scaled, nsamples=100)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if isinstance(shap_values, list):
|
|
|
|
| 133 |
sv = np.array(shap_values[1]).flatten()
|
| 134 |
else:
|
| 135 |
sv = np.array(shap_values).flatten()
|
|
|
|
|
|
|
|
|
|
| 136 |
if len(sv) > num_meta:
|
| 137 |
sv = sv[-num_meta:]
|
| 138 |
|
| 139 |
+
top_idx = np.argsort(np.abs(sv))[::-1][:10]
|
|
|
|
| 140 |
|
|
|
|
| 141 |
sorted_indices = np.argsort(np.abs(sv))
|
| 142 |
sorted_names = [META_FEATURE_NAMES[idx] for idx in sorted_indices.tolist()]
|
| 143 |
sorted_values = sv[sorted_indices]
|
| 144 |
|
| 145 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 146 |
+
colors = ['#d62728' if val > 0 else '#1f77b4' for val in sorted_values]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
ax.barh(sorted_names, sorted_values, color=colors)
|
| 148 |
ax.set_xlabel('SHAP Value (impact on spam probability)')
|
| 149 |
ax.set_title('SHAP Feature Importance (Metadata Features)')
|
| 150 |
ax.axvline(x=0, color='black', linewidth=0.5)
|
| 151 |
fig.tight_layout()
|
| 152 |
|
| 153 |
+
return fig, sv, top_idx
|
| 154 |
+
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
# 6. ELI5 explanation
|
| 157 |
+
# ---------------------------------------------------------------------------
|
| 158 |
+
|
| 159 |
+
def generate_eli5_explanation(combined_features):
|
| 160 |
+
"""Generate ELI5 HTML and top feature names. Returns (html_string, feature_names_list) or (None, None)."""
|
| 161 |
+
if raw_rf is None or feature_names is None:
|
| 162 |
+
return None, None
|
| 163 |
+
|
| 164 |
+
instance = combined_features.toarray()[0]
|
| 165 |
+
|
| 166 |
+
eli5_exp = eli5.explain_prediction(raw_rf, instance, feature_names=feature_names, top=10)
|
| 167 |
+
html = eli5.format_as_html(eli5_exp)
|
| 168 |
+
|
| 169 |
+
eli5_top5 = eli5.explain_prediction(raw_rf, instance, feature_names=feature_names, top=5)
|
| 170 |
+
top_names = []
|
| 171 |
+
if hasattr(eli5_top5, 'targets') and eli5_top5.targets:
|
| 172 |
+
for fw in eli5_top5.targets[0].feature_weights.pos[:5]:
|
| 173 |
+
top_names.append(fw.feature)
|
| 174 |
+
for fw in eli5_top5.targets[0].feature_weights.neg[:5]:
|
| 175 |
+
top_names.append(fw.feature)
|
| 176 |
|
| 177 |
+
return html, top_names
|
| 178 |
|
| 179 |
# ---------------------------------------------------------------------------
|
| 180 |
+
# 7. Plain English summary (replaces Ollama LLM)
|
| 181 |
+
# ---------------------------------------------------------------------------
|
| 182 |
+
|
| 183 |
+
def generate_plain_summary(label, confidence, spam_proba, lime_explanation,
|
| 184 |
+
shap_sv, shap_top_idx):
|
| 185 |
+
"""Build a rule-based plain English summary from XAI results."""
|
| 186 |
+
summary = f"### Classification: **{label}** ({confidence:.0%} confidence)\n\n"
|
| 187 |
+
|
| 188 |
+
if lime_explanation is not None:
|
| 189 |
+
feature_list = lime_explanation.as_list()
|
| 190 |
+
summary += "**Key words driving this decision (LIME):**\n"
|
| 191 |
+
for feat_rule, weight in feature_list[:3]:
|
| 192 |
+
direction = "pushes toward spam" if weight > 0 else "pushes toward ham"
|
| 193 |
+
summary += f"- **{feat_rule}** — {direction}\n"
|
| 194 |
+
summary += "\n"
|
| 195 |
+
|
| 196 |
+
if shap_sv is not None and shap_top_idx is not None:
|
| 197 |
+
summary += "**Important email characteristics (SHAP):**\n"
|
| 198 |
+
for i in shap_top_idx[:2]:
|
| 199 |
+
feat_name = META_FEATURE_NAMES[i]
|
| 200 |
+
description = FEATURE_DESCRIPTIONS.get(feat_name, feat_name)
|
| 201 |
+
direction = "spam signal" if shap_sv[i] > 0 else "ham signal"
|
| 202 |
+
summary += f"- **{feat_name}** ({description}) — {direction}\n"
|
| 203 |
+
summary += "\n"
|
| 204 |
+
|
| 205 |
+
if lime_explanation is not None and shap_top_idx is not None:
|
| 206 |
+
lime_top = set(f[0] for f in lime_explanation.as_list()[:10])
|
| 207 |
+
shap_top = set(META_FEATURE_NAMES[i] for i in shap_top_idx[:10])
|
| 208 |
+
overlap = lime_top & shap_top
|
| 209 |
+
if overlap:
|
| 210 |
+
summary += f"**Method agreement:** LIME and SHAP both flag: {', '.join(sorted(overlap))}\n\n"
|
| 211 |
+
|
| 212 |
+
if "SPAM" in label:
|
| 213 |
+
if confidence > 0.9:
|
| 214 |
+
summary += "The model is highly confident this email contains patterns commonly seen in spam or phishing attempts."
|
| 215 |
+
elif confidence > 0.7:
|
| 216 |
+
summary += "The model found several spam-like patterns in this email."
|
| 217 |
+
else:
|
| 218 |
+
summary += "The model leans toward spam, but the evidence is not overwhelming. Use your judgment."
|
| 219 |
+
else:
|
| 220 |
+
if confidence > 0.9:
|
| 221 |
+
summary += "The model is highly confident this is a legitimate email."
|
| 222 |
+
elif confidence > 0.7:
|
| 223 |
+
summary += "The model found this email to be mostly consistent with legitimate messages."
|
| 224 |
+
else:
|
| 225 |
+
summary += "The model leans toward legitimate, but there are some spam-like features. Review carefully."
|
| 226 |
+
|
| 227 |
+
return summary
|
| 228 |
+
|
| 229 |
+
# ---------------------------------------------------------------------------
|
| 230 |
+
# 8. Side-by-side comparison
|
| 231 |
+
# ---------------------------------------------------------------------------
|
| 232 |
+
|
| 233 |
+
def generate_comparison(lime_explanation, shap_sv, shap_top_idx, eli5_names):
|
| 234 |
+
"""Build a markdown comparison of top features from each XAI method."""
|
| 235 |
+
md = "### Side-by-Side: Top Features by Method\n\n"
|
| 236 |
+
md += "| Rank | LIME | SHAP (metadata) | ELI5 |\n"
|
| 237 |
+
md += "|------|------|-----------------|------|\n"
|
| 238 |
+
|
| 239 |
+
lime_top5 = []
|
| 240 |
+
if lime_explanation is not None:
|
| 241 |
+
for feat, w in lime_explanation.as_list()[:5]:
|
| 242 |
+
direction = "spam" if w > 0 else "ham"
|
| 243 |
+
lime_top5.append(f"{feat} ({direction}, {w:+.3f})")
|
| 244 |
+
|
| 245 |
+
shap_top5 = []
|
| 246 |
+
if shap_sv is not None and shap_top_idx is not None:
|
| 247 |
+
for i in shap_top_idx[:5]:
|
| 248 |
+
direction = "spam" if shap_sv[i] > 0 else "ham"
|
| 249 |
+
shap_top5.append(f"{META_FEATURE_NAMES[i]} ({direction}, {shap_sv[i]:+.3f})")
|
| 250 |
+
|
| 251 |
+
eli5_top5 = (eli5_names or [])[:5]
|
| 252 |
+
|
| 253 |
+
for rank in range(5):
|
| 254 |
+
lime_cell = lime_top5[rank] if rank < len(lime_top5) else "—"
|
| 255 |
+
shap_cell = shap_top5[rank] if rank < len(shap_top5) else "—"
|
| 256 |
+
eli5_cell = eli5_top5[rank] if rank < len(eli5_top5) else "—"
|
| 257 |
+
md += f"| {rank+1} | {lime_cell} | {shap_cell} | {eli5_cell} |\n"
|
| 258 |
+
|
| 259 |
+
if lime_explanation is not None and shap_top_idx is not None:
|
| 260 |
+
lime_set = set(f[0] for f in lime_explanation.as_list()[:10])
|
| 261 |
+
shap_set = set(META_FEATURE_NAMES[i] for i in shap_top_idx[:10])
|
| 262 |
+
overlap = lime_set & shap_set
|
| 263 |
+
md += f"\n**LIME-SHAP agreement** (top 10): **{len(overlap)}** shared features"
|
| 264 |
+
if overlap:
|
| 265 |
+
md += f"\nShared: {', '.join(sorted(overlap))}"
|
| 266 |
+
|
| 267 |
+
md += "\n\n*Note: LIME covers all features (words + metadata), SHAP covers only the 24 metadata features, "
|
| 268 |
+
md += "ELI5 uses the Random Forest sub-estimator's internal weights.*"
|
| 269 |
+
|
| 270 |
+
return md
|
| 271 |
+
|
| 272 |
+
# ---------------------------------------------------------------------------
|
| 273 |
+
# 9. Feedback logging
|
| 274 |
+
# ---------------------------------------------------------------------------
|
| 275 |
+
|
| 276 |
+
def log_feedback(email_text, predicted_label, predicted_confidence, threshold,
|
| 277 |
+
feedback_type, correct_label=None):
|
| 278 |
+
"""Append one feedback row to the CSV log."""
|
| 279 |
+
write_header = not FEEDBACK_CSV.exists()
|
| 280 |
+
with open(FEEDBACK_CSV, 'a', newline='', encoding='utf-8') as f:
|
| 281 |
+
writer = csv.writer(f)
|
| 282 |
+
if write_header:
|
| 283 |
+
writer.writerow(['timestamp', 'email_text', 'predicted_label',
|
| 284 |
+
'predicted_confidence', 'feedback', 'correct_label',
|
| 285 |
+
'threshold_used'])
|
| 286 |
+
writer.writerow([
|
| 287 |
+
datetime.now().isoformat(),
|
| 288 |
+
email_text[:500],
|
| 289 |
+
predicted_label,
|
| 290 |
+
f"{predicted_confidence:.4f}",
|
| 291 |
+
feedback_type,
|
| 292 |
+
correct_label or '',
|
| 293 |
+
f"{threshold:.4f}",
|
| 294 |
+
])
|
| 295 |
+
return count_corrections()
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def count_corrections():
|
| 299 |
+
"""Count the number of 'wrong' entries in the feedback log."""
|
| 300 |
+
if not FEEDBACK_CSV.exists():
|
| 301 |
+
return 0
|
| 302 |
+
count = 0
|
| 303 |
+
with open(FEEDBACK_CSV, 'r', encoding='utf-8') as f:
|
| 304 |
+
reader = csv.DictReader(f)
|
| 305 |
+
for row in reader:
|
| 306 |
+
if row.get('feedback') == 'wrong':
|
| 307 |
+
count += 1
|
| 308 |
+
return count
|
| 309 |
+
|
| 310 |
+
# ---------------------------------------------------------------------------
|
| 311 |
+
# 10. Example Emails
|
| 312 |
# ---------------------------------------------------------------------------
|
|
|
|
| 313 |
|
| 314 |
EXAMPLE_EMAILS = [
|
| 315 |
["Subject: URGENT - You Have Won $5,000,000!!!\n\nDear Friend,\n\nCONGRATULATIONS!!! You have been selected as the winner of our international lottery program!!!\nTo claim your $5,000,000 USD prize, click the link below IMMEDIATELY and provide your bank details.\n\nACT NOW - This offer expires in 24 hours!!!\n\nClick here: http://totally-legit-prize.com/claim\nSend $500 processing fee to unlock your winnings.\n\nBest regards,\nDr. Prince Mohammed"],
|
| 316 |
["Subject: Team sync Thursday 2pm\n\nHi everyone,\n\nJust a reminder that we have our weekly team sync this Thursday at 2pm in Conference Room B.\n\nAgenda:\n- Sprint review\n- Q2 planning discussion\n- New hire onboarding update\n\nPlease come prepared with your status updates.\n\nThanks,\nSarah"],
|
| 317 |
["Subject: Your account has been compromised!\n\nDear Customer,\n\nWe detected suspicious activity on your account. Click here immediately to verify your identity: http://secure-bank-login.com/verify\n\nIf you do not verify within 24 hours, your account will be permanently locked.\n\nSecurity Team"],
|
| 318 |
["Subject: Thanksgiving dinner plans\n\nHi everyone!\n\nI wanted to start planning for Thanksgiving dinner. I'm thinking we could do it at my place this year. What does everyone think about 4pm?\n\nLet me know if you have any dietary restrictions or if you want to bring a dish.\n\nLove,\nMom"],
|
| 319 |
+
["Subject: Best prices on V1AGRA and C1ALIS!!!\n\n$$$ SAVE BIG $$$\nBuy now and get 80% OFF!!!\nNo prescription needed! Free shipping!\nOrder at http://cheap-pharma-deals.com\n\nLIMITED TIME OFFER - ACT NOW!"],
|
| 320 |
]
|
| 321 |
|
|
|
|
| 322 |
# ---------------------------------------------------------------------------
|
| 323 |
+
# 11. Main orchestration function
|
| 324 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
| 325 |
|
| 326 |
+
def classify_and_explain(email_text, uploaded_file, threshold):
|
| 327 |
+
"""Main function called by Gradio. Returns all outputs for all tabs + feedback state."""
|
| 328 |
|
|
|
|
| 329 |
if uploaded_file is not None:
|
| 330 |
try:
|
| 331 |
file_content = Path(uploaded_file).read_text(encoding='utf-8')
|
| 332 |
email_text = file_content
|
| 333 |
except Exception:
|
| 334 |
+
empty = ("Could not read file.", None, None, "Error reading file.", "", "", "")
|
| 335 |
+
return empty
|
| 336 |
|
|
|
|
| 337 |
if email_text is None or email_text.strip() == '':
|
| 338 |
+
empty = ("Please enter email text or upload a file.", None, None, "", "", "", "")
|
| 339 |
+
return empty
|
| 340 |
|
|
|
|
| 341 |
if voting_model is None:
|
| 342 |
+
empty = ("Models not found. Run `python3 train.py` first.", None, None, "", "", "", "")
|
| 343 |
+
return empty
|
| 344 |
|
| 345 |
+
label, confidence, spam_proba, combined = classify_email(email_text, threshold)
|
|
|
|
| 346 |
|
| 347 |
+
lime_fig, lime_exp = generate_lime_explanation(combined)
|
|
|
|
| 348 |
|
|
|
|
|
|
|
| 349 |
try:
|
| 350 |
+
shap_fig, shap_sv, shap_top_idx = generate_shap_explanation(email_text)
|
| 351 |
except Exception as e:
|
| 352 |
+
print(f"SHAP error: {e}")
|
| 353 |
+
shap_fig, shap_sv, shap_top_idx = None, None, None
|
| 354 |
|
| 355 |
+
try:
|
| 356 |
+
eli5_html, eli5_names = generate_eli5_explanation(combined)
|
| 357 |
+
except Exception as e:
|
| 358 |
+
print(f"ELI5 error: {e}")
|
| 359 |
+
eli5_html, eli5_names = None, None
|
| 360 |
+
|
| 361 |
+
result_md = f"## {'SPAM' if 'SPAM' in label else 'HAM (Not Spam)'}\n\n"
|
| 362 |
+
result_md += f"**Confidence:** {confidence:.1%}\n\n"
|
| 363 |
+
result_md += f"**Threshold:** {threshold:.0%}\n\n"
|
| 364 |
+
result_md += f"**Spam probability:** {spam_proba:.1%}\n\n"
|
| 365 |
+
if lime_exp is not None:
|
| 366 |
+
result_md += "**Key factors:**\n"
|
| 367 |
+
for feat_rule, weight in lime_exp.as_list()[:5]:
|
| 368 |
+
direction = "pushes toward spam" if weight > 0 else "pushes toward ham"
|
| 369 |
+
result_md += f"- **{feat_rule}** {direction}\n"
|
| 370 |
+
|
| 371 |
+
comparison_md = generate_comparison(lime_exp, shap_sv, shap_top_idx, eli5_names)
|
| 372 |
+
summary_md = generate_plain_summary(label, confidence, spam_proba, lime_exp, shap_sv, shap_top_idx)
|
| 373 |
+
eli5_display = eli5_html or "<p>ELI5 explanation not available.</p>"
|
| 374 |
+
|
| 375 |
+
return (result_md, lime_fig, shap_fig, eli5_display, comparison_md, summary_md,
|
| 376 |
+
f"{label}|||{confidence:.4f}|||{threshold:.4f}|||{email_text[:500]}")
|
| 377 |
|
| 378 |
+
# ---------------------------------------------------------------------------
|
| 379 |
+
# 12. Feedback handlers
|
| 380 |
+
# ---------------------------------------------------------------------------
|
| 381 |
|
| 382 |
+
def handle_correct(hidden_state):
|
| 383 |
+
"""Log positive feedback."""
|
| 384 |
+
if not hidden_state:
|
| 385 |
+
return "No classification to give feedback on."
|
| 386 |
+
parts = hidden_state.split('|||')
|
| 387 |
+
if len(parts) < 4:
|
| 388 |
+
return "No classification to give feedback on."
|
| 389 |
+
label, conf, thresh, email = parts[0], float(parts[1]), float(parts[2]), parts[3]
|
| 390 |
+
corrections = log_feedback(email, label, float(conf), float(thresh), 'correct')
|
| 391 |
+
return f"Thanks for the feedback! ({corrections} corrections collected so far)"
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def handle_wrong(hidden_state, correct_label):
|
| 395 |
+
"""Log negative feedback with the user's correction."""
|
| 396 |
+
if not hidden_state:
|
| 397 |
+
return "No classification to give feedback on."
|
| 398 |
+
parts = hidden_state.split('|||')
|
| 399 |
+
if len(parts) < 4:
|
| 400 |
+
return "No classification to give feedback on."
|
| 401 |
+
label, conf, thresh, email = parts[0], float(parts[1]), float(parts[2]), parts[3]
|
| 402 |
+
corrections = log_feedback(email, label, float(conf), float(thresh), 'wrong', correct_label)
|
| 403 |
+
return f"Correction logged! ({corrections} corrections collected so far)"
|
| 404 |
|
| 405 |
# ---------------------------------------------------------------------------
|
| 406 |
+
# 13. Gradio Blocks UI
|
| 407 |
# ---------------------------------------------------------------------------
|
| 408 |
+
|
| 409 |
+
HOW_IT_WORKS_MD = """
|
| 410 |
+
## How This App Works
|
| 411 |
+
|
| 412 |
+
### What is spam classification?
|
| 413 |
+
Spam classification automatically identifies unwanted or malicious emails (spam) vs. legitimate messages (ham). This helps protect users from phishing scams, fraudulent offers, and unwanted advertising.
|
| 414 |
+
|
| 415 |
+
### The Model
|
| 416 |
+
This app uses a **Voting Ensemble** — three different machine learning models that each "vote" on whether an email is spam:
|
| 417 |
+
- **Random Forest** — builds many decision trees and takes the majority vote
|
| 418 |
+
- **Logistic Regression** — finds a mathematical boundary between spam and ham
|
| 419 |
+
- **Support Vector Machine (SVM)** — finds the widest possible margin between classes
|
| 420 |
+
|
| 421 |
+
By combining all three, the ensemble is more accurate than any single model alone.
|
| 422 |
+
|
| 423 |
+
### Feature Extraction
|
| 424 |
+
The model looks at two types of features:
|
| 425 |
+
- **TF-IDF (Term Frequency-Inverse Document Frequency)** — measures how important each word is. Common spam words like "prize" or "click" get high scores.
|
| 426 |
+
- **24 Metadata Features** — structural patterns like exclamation mark density, dollar sign count, ALL CAPS ratio, URL count, and more.
|
| 427 |
+
|
| 428 |
+
### Explainable AI (XAI) Methods
|
| 429 |
+
This app doesn't just classify — it explains **why**:
|
| 430 |
+
|
| 431 |
+
- **LIME** — Removes words one at a time and watches how the prediction changes. Shows which words matter most.
|
| 432 |
+
- **SHAP** — Uses game theory to calculate each feature's "fair share" of the prediction. Based on Nobel Prize-winning mathematics.
|
| 433 |
+
- **ELI5** — Looks directly at the model's internal weights to show which features it relies on most.
|
| 434 |
+
|
| 435 |
+
### Feedback & Retraining
|
| 436 |
+
When you click "Correct" or "Wrong", your feedback is saved. After enough corrections accumulate, the model can be retrained with the new examples to improve over time. This is called **human-in-the-loop machine learning**.
|
| 437 |
+
|
| 438 |
+
### Disclaimer
|
| 439 |
+
This model was created as a university course project. It is intended for **educational and research purposes only** and should not be used as a sole spam filter in production. Always use established email security tools for real-world spam filtering.
|
| 440 |
+
"""
|
| 441 |
|
| 442 |
with gr.Blocks(title="Spam Email Classifier with XAI") as demo:
|
| 443 |
gr.Markdown("# Spam Email Classifier with XAI Explanations")
|
| 444 |
+
gr.Markdown("Classify emails and understand **why** using LIME, SHAP, ELI5, "
|
| 445 |
+
"and plain English summaries. Created as a university course project.")
|
| 446 |
+
|
| 447 |
+
hidden_state = gr.State("")
|
| 448 |
|
| 449 |
with gr.Row():
|
|
|
|
| 450 |
with gr.Column(scale=1):
|
| 451 |
email_input = gr.Textbox(
|
| 452 |
label="Email Text",
|
|
|
|
| 458 |
label="Or upload a .txt file",
|
| 459 |
file_types=['.txt'],
|
| 460 |
)
|
| 461 |
+
threshold_slider = gr.Slider(
|
| 462 |
+
minimum=0.0, maximum=1.0, step=0.05,
|
| 463 |
+
value=optimal_threshold if optimal_threshold else 0.5,
|
| 464 |
+
label="Classification Threshold",
|
| 465 |
+
info="Emails with spam probability above this are classified as spam.",
|
| 466 |
+
)
|
| 467 |
classify_btn = gr.Button("Classify", variant="primary")
|
| 468 |
gr.Examples(
|
| 469 |
examples=EXAMPLE_EMAILS,
|
|
|
|
| 472 |
cache_examples=False,
|
| 473 |
)
|
| 474 |
|
|
|
|
| 475 |
with gr.Column(scale=1):
|
| 476 |
with gr.Tabs():
|
| 477 |
with gr.Tab("Result"):
|
| 478 |
result_output = gr.Markdown(label="Classification Result")
|
| 479 |
with gr.Tab("LIME"):
|
| 480 |
+
gr.Markdown("*LIME perturbs the input and fits a local model "
|
| 481 |
+
"to see which features matter most.*")
|
| 482 |
lime_output = gr.Plot(label="LIME Explanation")
|
| 483 |
with gr.Tab("SHAP"):
|
| 484 |
+
gr.Markdown("*SHAP uses game theory to assign each feature "
|
| 485 |
+
"a contribution value.*")
|
| 486 |
shap_output = gr.Plot(label="SHAP Explanation")
|
| 487 |
+
with gr.Tab("ELI5"):
|
| 488 |
+
gr.Markdown("*ELI5 shows feature weights directly from the "
|
| 489 |
+
"model's internals.*")
|
| 490 |
+
eli5_output = gr.HTML(label="ELI5 Explanation")
|
| 491 |
+
with gr.Tab("Compare"):
|
| 492 |
+
compare_output = gr.Markdown(label="Method Comparison")
|
| 493 |
+
with gr.Tab("Summary"):
|
| 494 |
+
summary_output = gr.Markdown(label="Plain English Summary")
|
| 495 |
+
with gr.Tab("How It Works"):
|
| 496 |
+
gr.Markdown(HOW_IT_WORKS_MD)
|
| 497 |
+
|
| 498 |
+
gr.Markdown("---")
|
| 499 |
+
gr.Markdown("**Was this classification correct?**")
|
| 500 |
+
feedback_msg = gr.Markdown("")
|
| 501 |
+
with gr.Row():
|
| 502 |
+
correct_btn = gr.Button("Correct", variant="secondary")
|
| 503 |
+
wrong_btn = gr.Button("Wrong", variant="stop")
|
| 504 |
+
correction_dropdown = gr.Dropdown(
|
| 505 |
+
choices=["Spam", "Ham"],
|
| 506 |
+
label="What should it be?",
|
| 507 |
+
visible=True,
|
| 508 |
+
)
|
| 509 |
|
|
|
|
| 510 |
classify_btn.click(
|
| 511 |
fn=classify_and_explain,
|
| 512 |
+
inputs=[email_input, file_input, threshold_slider],
|
| 513 |
+
outputs=[result_output, lime_output, shap_output, eli5_output,
|
| 514 |
+
compare_output, summary_output, hidden_state],
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
correct_btn.click(
|
| 518 |
+
fn=handle_correct,
|
| 519 |
+
inputs=[hidden_state],
|
| 520 |
+
outputs=[feedback_msg],
|
| 521 |
+
)
|
| 522 |
+
wrong_btn.click(
|
| 523 |
+
fn=handle_wrong,
|
| 524 |
+
inputs=[hidden_state, correction_dropdown],
|
| 525 |
+
outputs=[feedback_msg],
|
| 526 |
)
|
| 527 |
|
| 528 |
# ---------------------------------------------------------------------------
|
| 529 |
+
# 14. Launch
|
| 530 |
# ---------------------------------------------------------------------------
|
| 531 |
|
| 532 |
if __name__ == '__main__':
|
docs/superpowers/plans/2026-03-28-gradio-xai-merge.md
ADDED
|
@@ -0,0 +1,1231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gradio XAI Merge Implementation Plan
|
| 2 |
+
|
| 3 |
+
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
| 4 |
+
|
| 5 |
+
**Goal:** Merge ELI5, comparison, plain English summary, How It Works, and feedback/retrain features into the existing Gradio spam classifier Space.
|
| 6 |
+
|
| 7 |
+
**Architecture:** Extend the existing `app.py` (Gradio Blocks layout) by adding 4 new tabs, a feedback row, and a retrain script. The VotingClassifier's RF sub-estimator is extracted at startup for ELI5 and TreeExplainer. Feedback logs to CSV; retrain reads it offline.
|
| 8 |
+
|
| 9 |
+
**Tech Stack:** Python, Gradio 4.44+, scikit-learn, LIME, SHAP, ELI5, matplotlib, NLTK
|
| 10 |
+
|
| 11 |
+
**Spec:** `docs/superpowers/specs/2026-03-28-gradio-xai-merge-design.md`
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## File Structure
|
| 16 |
+
|
| 17 |
+
| File | Action | Responsibility |
|
| 18 |
+
|------|--------|---------------|
|
| 19 |
+
| `requirements.txt` | Modify | Add eli5, bump gradio |
|
| 20 |
+
| `app.py` | Major rewrite | All 7 tabs, feedback UI, threshold slider, classification orchestration |
|
| 21 |
+
| `retrain.py` | Create | Batch retrain script that reads feedback CSV and augments training data |
|
| 22 |
+
| `feedback/.gitkeep` | Create | Empty directory for feedback log accumulation |
|
| 23 |
+
| `README.md` | Modify | Updated frontmatter + feature description |
|
| 24 |
+
| `utils.py` | No changes | Shared preprocessing (stays as-is) |
|
| 25 |
+
| `test_utils.py` | No changes | Existing tests (stays as-is) |
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
### Task 1: Update dependencies
|
| 30 |
+
|
| 31 |
+
**Files:**
|
| 32 |
+
- Modify: `requirements.txt`
|
| 33 |
+
|
| 34 |
+
- [ ] **Step 1: Update requirements.txt**
|
| 35 |
+
|
| 36 |
+
Replace the full contents of `requirements.txt` with:
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
numpy>=1.24.0
|
| 40 |
+
pandas>=2.0.0
|
| 41 |
+
matplotlib>=3.7.0
|
| 42 |
+
scikit-learn>=1.3.0
|
| 43 |
+
scipy>=1.11.0
|
| 44 |
+
nltk>=3.8.0
|
| 45 |
+
lime>=0.2.0
|
| 46 |
+
shap>=0.44.0
|
| 47 |
+
eli5>=0.13.0
|
| 48 |
+
gradio>=4.44.0
|
| 49 |
+
joblib>=1.3.0
|
| 50 |
+
tqdm>=4.65.0
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
- [ ] **Step 2: Create feedback directory**
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
mkdir -p feedback
|
| 57 |
+
touch feedback/.gitkeep
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
- [ ] **Step 3: Install new dependency locally**
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
cd spam-classifier-gradio
|
| 64 |
+
source venv/bin/activate
|
| 65 |
+
pip install eli5>=0.13.0 "gradio>=4.44.0"
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
- [ ] **Step 4: Verify eli5 imports**
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
python3 -c "import eli5; print('eli5', eli5.__version__)"
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Expected: prints eli5 version without error.
|
| 75 |
+
|
| 76 |
+
- [ ] **Step 5: Commit**
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
git add requirements.txt feedback/.gitkeep
|
| 80 |
+
git commit -m "Add eli5 dependency, bump gradio, create feedback directory"
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
### Task 2: Rewrite app.py — model loading and classification core
|
| 86 |
+
|
| 87 |
+
This task replaces the existing `app.py` with the new version. We build it in stages. This task covers the imports, model loading, and the `classify_email` function (unchanged logic, but we also extract the RF sub-estimator for ELI5).
|
| 88 |
+
|
| 89 |
+
**Files:**
|
| 90 |
+
- Modify: `app.py` (rewrite from line 1)
|
| 91 |
+
|
| 92 |
+
- [ ] **Step 1: Write the new imports and model loading section**
|
| 93 |
+
|
| 94 |
+
Replace everything in `app.py` from line 1 through the end of the model loading section (lines 1-59) with:
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
# Gradio web app for the Spam Email Classifier with XAI explanations
|
| 98 |
+
# University course project — Explainable AI for spam detection
|
| 99 |
+
# Features: LIME, SHAP, ELI5, side-by-side comparison, plain English summary,
|
| 100 |
+
# user feedback logging, and batch retrain support.
|
| 101 |
+
|
| 102 |
+
import csv
|
| 103 |
+
import os
|
| 104 |
+
from datetime import datetime
|
| 105 |
+
from pathlib import Path
|
| 106 |
+
|
| 107 |
+
import eli5
|
| 108 |
+
import gradio as gr
|
| 109 |
+
import lime
|
| 110 |
+
import lime.lime_tabular
|
| 111 |
+
import matplotlib
|
| 112 |
+
matplotlib.use('Agg') # non-interactive backend so plots work on servers
|
| 113 |
+
import matplotlib.pyplot as plt
|
| 114 |
+
import numpy as np
|
| 115 |
+
import shap
|
| 116 |
+
import joblib
|
| 117 |
+
from scipy.sparse import hstack, csr_matrix
|
| 118 |
+
|
| 119 |
+
from utils import (preprocess_text, compute_metadata_features,
|
| 120 |
+
META_FEATURE_NAMES, FEATURE_DESCRIPTIONS)
|
| 121 |
+
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
# 1. Model Loading
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
models_dir = Path(__file__).parent / 'models'
|
| 127 |
+
feedback_dir = Path(__file__).parent / 'feedback'
|
| 128 |
+
feedback_dir.mkdir(exist_ok=True)
|
| 129 |
+
FEEDBACK_CSV = feedback_dir / 'feedback_log.csv'
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
voting_model = joblib.load(models_dir / 'voting_model.joblib')
|
| 133 |
+
tfidf_vectorizer = joblib.load(models_dir / 'tfidf_vectorizer.joblib')
|
| 134 |
+
meta_scaler = joblib.load(models_dir / 'meta_scaler.joblib')
|
| 135 |
+
feature_names = joblib.load(models_dir / 'feature_names.joblib')
|
| 136 |
+
optimal_threshold = joblib.load(models_dir / 'optimal_threshold.joblib')
|
| 137 |
+
training_sample = joblib.load(models_dir / 'training_sample.joblib')
|
| 138 |
+
# Extract the Random Forest from the voting ensemble for ELI5 and SHAP TreeExplainer
|
| 139 |
+
raw_rf = voting_model.named_estimators_['rf']
|
| 140 |
+
print(f"All models loaded. Threshold = {optimal_threshold:.4f}")
|
| 141 |
+
except FileNotFoundError as e:
|
| 142 |
+
print(f"Model file not found: {e}")
|
| 143 |
+
voting_model = None
|
| 144 |
+
tfidf_vectorizer = None
|
| 145 |
+
meta_scaler = None
|
| 146 |
+
feature_names = None
|
| 147 |
+
optimal_threshold = None
|
| 148 |
+
training_sample = None
|
| 149 |
+
raw_rf = None
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
- [ ] **Step 2: Write the LIME explainer setup and classify_email function**
|
| 153 |
+
|
| 154 |
+
Append after the model loading section:
|
| 155 |
+
|
| 156 |
+
```python
|
| 157 |
+
# ---------------------------------------------------------------------------
|
| 158 |
+
# 2. LIME Explainer Setup
|
| 159 |
+
# ---------------------------------------------------------------------------
|
| 160 |
+
|
| 161 |
+
lime_explainer = None
|
| 162 |
+
if training_sample is not None and feature_names is not None:
|
| 163 |
+
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
|
| 164 |
+
training_data=training_sample,
|
| 165 |
+
feature_names=feature_names,
|
| 166 |
+
class_names=['Ham', 'Spam'],
|
| 167 |
+
mode='classification',
|
| 168 |
+
)
|
| 169 |
+
print("LIME explainer ready.")
|
| 170 |
+
|
| 171 |
+
# ---------------------------------------------------------------------------
|
| 172 |
+
# 3. classify_email — core prediction logic
|
| 173 |
+
# ---------------------------------------------------------------------------
|
| 174 |
+
|
| 175 |
+
def classify_email(email_text, threshold):
|
| 176 |
+
"""Classify a single email. Returns (label, confidence, spam_proba, combined_features)."""
|
| 177 |
+
cleaned_text = preprocess_text(email_text)
|
| 178 |
+
tfidf_features = tfidf_vectorizer.transform([cleaned_text])
|
| 179 |
+
meta_raw = compute_metadata_features([email_text])
|
| 180 |
+
meta_scaled = meta_scaler.transform(meta_raw)
|
| 181 |
+
combined = hstack([tfidf_features, csr_matrix(meta_scaled)])
|
| 182 |
+
spam_proba = voting_model.predict_proba(combined)[0][1]
|
| 183 |
+
|
| 184 |
+
if spam_proba >= threshold:
|
| 185 |
+
label = "SPAM"
|
| 186 |
+
confidence = spam_proba
|
| 187 |
+
else:
|
| 188 |
+
label = "HAM (Not Spam)"
|
| 189 |
+
confidence = 1.0 - spam_proba
|
| 190 |
+
|
| 191 |
+
return label, confidence, spam_proba, combined
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
- [ ] **Step 3: Verify the module loads without errors**
|
| 195 |
+
|
| 196 |
+
```bash
|
| 197 |
+
python3 -c "
|
| 198 |
+
import app
|
| 199 |
+
print('voting_model:', type(app.voting_model))
|
| 200 |
+
print('raw_rf:', type(app.raw_rf))
|
| 201 |
+
print('classify_email:', callable(app.classify_email))
|
| 202 |
+
"
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
Expected: prints types without errors. `raw_rf` should be `RandomForestClassifier`.
|
| 206 |
+
|
| 207 |
+
- [ ] **Step 4: Commit**
|
| 208 |
+
|
| 209 |
+
```bash
|
| 210 |
+
git add app.py
|
| 211 |
+
git commit -m "Rewrite app.py core: imports, model loading with RF extraction, classify function"
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
---
|
| 215 |
+
|
| 216 |
+
### Task 3: Add XAI explanation functions (LIME, SHAP, ELI5)
|
| 217 |
+
|
| 218 |
+
**Files:**
|
| 219 |
+
- Modify: `app.py` (append after classify_email)
|
| 220 |
+
|
| 221 |
+
- [ ] **Step 1: Add LIME plot function**
|
| 222 |
+
|
| 223 |
+
Append to `app.py`:
|
| 224 |
+
|
| 225 |
+
```python
|
| 226 |
+
# ---------------------------------------------------------------------------
|
| 227 |
+
# 4. LIME explanation
|
| 228 |
+
# ---------------------------------------------------------------------------
|
| 229 |
+
|
| 230 |
+
def generate_lime_explanation(combined_features):
|
| 231 |
+
"""Generate LIME explanation. Returns (figure, explanation) or (None, None)."""
|
| 232 |
+
if lime_explainer is None:
|
| 233 |
+
return None, None
|
| 234 |
+
instance = combined_features.toarray()[0]
|
| 235 |
+
explanation = lime_explainer.explain_instance(
|
| 236 |
+
instance,
|
| 237 |
+
voting_model.predict_proba,
|
| 238 |
+
num_features=10,
|
| 239 |
+
)
|
| 240 |
+
fig = explanation.as_pyplot_figure()
|
| 241 |
+
fig.tight_layout()
|
| 242 |
+
return fig, explanation
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
- [ ] **Step 2: Add SHAP plot function**
|
| 246 |
+
|
| 247 |
+
Append to `app.py`:
|
| 248 |
+
|
| 249 |
+
```python
|
| 250 |
+
# ---------------------------------------------------------------------------
|
| 251 |
+
# 5. SHAP explanation (metadata features only — fast)
|
| 252 |
+
# ---------------------------------------------------------------------------
|
| 253 |
+
|
| 254 |
+
def generate_shap_explanation(email_text):
|
| 255 |
+
"""Generate SHAP bar chart for metadata features. Returns (figure, shap_values, top_indices) or (None, None, None)."""
|
| 256 |
+
if training_sample is None or voting_model is None:
|
| 257 |
+
return None, None, None
|
| 258 |
+
|
| 259 |
+
num_meta = len(META_FEATURE_NAMES) # 24
|
| 260 |
+
background_meta = training_sample[:50, -num_meta:]
|
| 261 |
+
|
| 262 |
+
meta_raw = compute_metadata_features([email_text])
|
| 263 |
+
meta_scaled = meta_scaler.transform(meta_raw)
|
| 264 |
+
|
| 265 |
+
num_tfidf = training_sample.shape[1] - num_meta
|
| 266 |
+
|
| 267 |
+
def predict_with_meta_only(meta_features):
|
| 268 |
+
n_samples = meta_features.shape[0]
|
| 269 |
+
tfidf_zeros = csr_matrix((n_samples, num_tfidf))
|
| 270 |
+
combined = hstack([tfidf_zeros, csr_matrix(meta_features)])
|
| 271 |
+
return voting_model.predict_proba(combined)
|
| 272 |
+
|
| 273 |
+
explainer = shap.KernelExplainer(predict_with_meta_only, background_meta)
|
| 274 |
+
shap_values = explainer.shap_values(meta_scaled, nsamples=100)
|
| 275 |
+
|
| 276 |
+
if isinstance(shap_values, list):
|
| 277 |
+
sv = np.array(shap_values[1]).flatten()
|
| 278 |
+
else:
|
| 279 |
+
sv = np.array(shap_values).flatten()
|
| 280 |
+
|
| 281 |
+
if len(sv) > num_meta:
|
| 282 |
+
sv = sv[-num_meta:]
|
| 283 |
+
|
| 284 |
+
# Top indices by absolute value
|
| 285 |
+
top_idx = np.argsort(np.abs(sv))[::-1][:10]
|
| 286 |
+
|
| 287 |
+
# Build bar chart
|
| 288 |
+
sorted_indices = np.argsort(np.abs(sv))
|
| 289 |
+
sorted_names = [META_FEATURE_NAMES[idx] for idx in sorted_indices.tolist()]
|
| 290 |
+
sorted_values = sv[sorted_indices]
|
| 291 |
+
|
| 292 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 293 |
+
colors = ['#d62728' if val > 0 else '#1f77b4' for val in sorted_values]
|
| 294 |
+
ax.barh(sorted_names, sorted_values, color=colors)
|
| 295 |
+
ax.set_xlabel('SHAP Value (impact on spam probability)')
|
| 296 |
+
ax.set_title('SHAP Feature Importance (Metadata Features)')
|
| 297 |
+
ax.axvline(x=0, color='black', linewidth=0.5)
|
| 298 |
+
fig.tight_layout()
|
| 299 |
+
|
| 300 |
+
return fig, sv, top_idx
|
| 301 |
+
```
|
| 302 |
+
|
| 303 |
+
- [ ] **Step 3: Add ELI5 explanation function**
|
| 304 |
+
|
| 305 |
+
Append to `app.py`:
|
| 306 |
+
|
| 307 |
+
```python
|
| 308 |
+
# ---------------------------------------------------------------------------
|
| 309 |
+
# 6. ELI5 explanation
|
| 310 |
+
# ---------------------------------------------------------------------------
|
| 311 |
+
|
| 312 |
+
def generate_eli5_explanation(combined_features):
|
| 313 |
+
"""Generate ELI5 HTML and top feature names. Returns (html_string, feature_names_list) or (None, None)."""
|
| 314 |
+
if raw_rf is None or feature_names is None:
|
| 315 |
+
return None, None
|
| 316 |
+
|
| 317 |
+
instance = combined_features.toarray()[0]
|
| 318 |
+
|
| 319 |
+
# Full HTML rendering (10 features)
|
| 320 |
+
eli5_exp = eli5.explain_prediction(raw_rf, instance, feature_names=feature_names, top=10)
|
| 321 |
+
html = eli5.format_as_html(eli5_exp)
|
| 322 |
+
|
| 323 |
+
# Extract top 5 feature names for the Compare tab
|
| 324 |
+
eli5_top5 = eli5.explain_prediction(raw_rf, instance, feature_names=feature_names, top=5)
|
| 325 |
+
top_names = []
|
| 326 |
+
if hasattr(eli5_top5, 'targets') and eli5_top5.targets:
|
| 327 |
+
for fw in eli5_top5.targets[0].feature_weights.pos[:5]:
|
| 328 |
+
top_names.append(fw.feature)
|
| 329 |
+
for fw in eli5_top5.targets[0].feature_weights.neg[:5]:
|
| 330 |
+
top_names.append(fw.feature)
|
| 331 |
+
|
| 332 |
+
return html, top_names
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
- [ ] **Step 4: Commit**
|
| 336 |
+
|
| 337 |
+
```bash
|
| 338 |
+
git add app.py
|
| 339 |
+
git commit -m "Add LIME, SHAP, and ELI5 explanation functions"
|
| 340 |
+
```
|
| 341 |
+
|
| 342 |
+
---
|
| 343 |
+
|
| 344 |
+
### Task 4: Add summary, comparison, and feedback helper functions
|
| 345 |
+
|
| 346 |
+
**Files:**
|
| 347 |
+
- Modify: `app.py` (append after ELI5 function)
|
| 348 |
+
|
| 349 |
+
- [ ] **Step 1: Add the rule-based plain English summary function**
|
| 350 |
+
|
| 351 |
+
Append to `app.py`:
|
| 352 |
+
|
| 353 |
+
```python
|
| 354 |
+
# ---------------------------------------------------------------------------
|
| 355 |
+
# 7. Plain English summary (replaces Ollama LLM)
|
| 356 |
+
# ---------------------------------------------------------------------------
|
| 357 |
+
|
| 358 |
+
def generate_plain_summary(label, confidence, spam_proba, lime_explanation,
|
| 359 |
+
shap_sv, shap_top_idx):
|
| 360 |
+
"""Build a rule-based plain English summary from XAI results."""
|
| 361 |
+
summary = f"### Classification: **{label}** ({confidence:.0%} confidence)\n\n"
|
| 362 |
+
|
| 363 |
+
# LIME top features
|
| 364 |
+
if lime_explanation is not None:
|
| 365 |
+
feature_list = lime_explanation.as_list()
|
| 366 |
+
summary += "**Key words driving this decision (LIME):**\n"
|
| 367 |
+
for feat_rule, weight in feature_list[:3]:
|
| 368 |
+
direction = "pushes toward spam" if weight > 0 else "pushes toward ham"
|
| 369 |
+
summary += f"- **{feat_rule}** — {direction}\n"
|
| 370 |
+
summary += "\n"
|
| 371 |
+
|
| 372 |
+
# SHAP top metadata features
|
| 373 |
+
if shap_sv is not None and shap_top_idx is not None:
|
| 374 |
+
summary += "**Important email characteristics (SHAP):**\n"
|
| 375 |
+
for i in shap_top_idx[:2]:
|
| 376 |
+
feat_name = META_FEATURE_NAMES[i]
|
| 377 |
+
description = FEATURE_DESCRIPTIONS.get(feat_name, feat_name)
|
| 378 |
+
direction = "spam signal" if shap_sv[i] > 0 else "ham signal"
|
| 379 |
+
summary += f"- **{feat_name}** ({description}) — {direction}\n"
|
| 380 |
+
summary += "\n"
|
| 381 |
+
|
| 382 |
+
# Agreement note
|
| 383 |
+
if lime_explanation is not None and shap_top_idx is not None:
|
| 384 |
+
lime_top = set(f[0] for f in lime_explanation.as_list()[:10])
|
| 385 |
+
shap_top = set(META_FEATURE_NAMES[i] for i in shap_top_idx[:10])
|
| 386 |
+
overlap = lime_top & shap_top
|
| 387 |
+
if overlap:
|
| 388 |
+
summary += f"**Method agreement:** LIME and SHAP both flag: {', '.join(sorted(overlap))}\n\n"
|
| 389 |
+
|
| 390 |
+
# Closing sentence
|
| 391 |
+
if "SPAM" in label:
|
| 392 |
+
if confidence > 0.9:
|
| 393 |
+
summary += "The model is highly confident this email contains patterns commonly seen in spam or phishing attempts."
|
| 394 |
+
elif confidence > 0.7:
|
| 395 |
+
summary += "The model found several spam-like patterns in this email."
|
| 396 |
+
else:
|
| 397 |
+
summary += "The model leans toward spam, but the evidence is not overwhelming. Use your judgment."
|
| 398 |
+
else:
|
| 399 |
+
if confidence > 0.9:
|
| 400 |
+
summary += "The model is highly confident this is a legitimate email."
|
| 401 |
+
elif confidence > 0.7:
|
| 402 |
+
summary += "The model found this email to be mostly consistent with legitimate messages."
|
| 403 |
+
else:
|
| 404 |
+
summary += "The model leans toward legitimate, but there are some spam-like features. Review carefully."
|
| 405 |
+
|
| 406 |
+
return summary
|
| 407 |
+
```
|
| 408 |
+
|
| 409 |
+
- [ ] **Step 2: Add the comparison markdown generator**
|
| 410 |
+
|
| 411 |
+
Append to `app.py`:
|
| 412 |
+
|
| 413 |
+
```python
|
| 414 |
+
# ---------------------------------------------------------------------------
|
| 415 |
+
# 8. Side-by-side comparison
|
| 416 |
+
# ---------------------------------------------------------------------------
|
| 417 |
+
|
| 418 |
+
def generate_comparison(lime_explanation, shap_sv, shap_top_idx, eli5_names):
|
| 419 |
+
"""Build a markdown comparison of top features from each XAI method."""
|
| 420 |
+
md = "### Side-by-Side: Top Features by Method\n\n"
|
| 421 |
+
md += "| Rank | LIME | SHAP (metadata) | ELI5 |\n"
|
| 422 |
+
md += "|------|------|-----------------|------|\n"
|
| 423 |
+
|
| 424 |
+
lime_top5 = []
|
| 425 |
+
if lime_explanation is not None:
|
| 426 |
+
for feat, w in lime_explanation.as_list()[:5]:
|
| 427 |
+
direction = "spam" if w > 0 else "ham"
|
| 428 |
+
lime_top5.append(f"{feat} ({direction}, {w:+.3f})")
|
| 429 |
+
|
| 430 |
+
shap_top5 = []
|
| 431 |
+
if shap_sv is not None and shap_top_idx is not None:
|
| 432 |
+
for i in shap_top_idx[:5]:
|
| 433 |
+
direction = "spam" if shap_sv[i] > 0 else "ham"
|
| 434 |
+
shap_top5.append(f"{META_FEATURE_NAMES[i]} ({direction}, {shap_sv[i]:+.3f})")
|
| 435 |
+
|
| 436 |
+
eli5_top5 = (eli5_names or [])[:5]
|
| 437 |
+
|
| 438 |
+
for rank in range(5):
|
| 439 |
+
lime_cell = lime_top5[rank] if rank < len(lime_top5) else "—"
|
| 440 |
+
shap_cell = shap_top5[rank] if rank < len(shap_top5) else "—"
|
| 441 |
+
eli5_cell = eli5_top5[rank] if rank < len(eli5_top5) else "—"
|
| 442 |
+
md += f"| {rank+1} | {lime_cell} | {shap_cell} | {eli5_cell} |\n"
|
| 443 |
+
|
| 444 |
+
# Agreement analysis
|
| 445 |
+
if lime_explanation is not None and shap_top_idx is not None:
|
| 446 |
+
lime_set = set(f[0] for f in lime_explanation.as_list()[:10])
|
| 447 |
+
shap_set = set(META_FEATURE_NAMES[i] for i in shap_top_idx[:10])
|
| 448 |
+
overlap = lime_set & shap_set
|
| 449 |
+
md += f"\n**LIME-SHAP agreement** (top 10): **{len(overlap)}** shared features"
|
| 450 |
+
if overlap:
|
| 451 |
+
md += f"\nShared: {', '.join(sorted(overlap))}"
|
| 452 |
+
|
| 453 |
+
md += "\n\n*Note: LIME covers all features (words + metadata), SHAP covers only the 24 metadata features, "
|
| 454 |
+
md += "ELI5 uses the Random Forest sub-estimator's internal weights.*"
|
| 455 |
+
|
| 456 |
+
return md
|
| 457 |
+
```
|
| 458 |
+
|
| 459 |
+
- [ ] **Step 3: Add feedback logging function**
|
| 460 |
+
|
| 461 |
+
Append to `app.py`:
|
| 462 |
+
|
| 463 |
+
```python
|
| 464 |
+
# ---------------------------------------------------------------------------
|
| 465 |
+
# 9. Feedback logging
|
| 466 |
+
# ---------------------------------------------------------------------------
|
| 467 |
+
|
| 468 |
+
def log_feedback(email_text, predicted_label, predicted_confidence, threshold,
|
| 469 |
+
feedback_type, correct_label=None):
|
| 470 |
+
"""Append one feedback row to the CSV log."""
|
| 471 |
+
# Write header if file doesn't exist yet
|
| 472 |
+
write_header = not FEEDBACK_CSV.exists()
|
| 473 |
+
|
| 474 |
+
with open(FEEDBACK_CSV, 'a', newline='', encoding='utf-8') as f:
|
| 475 |
+
writer = csv.writer(f)
|
| 476 |
+
if write_header:
|
| 477 |
+
writer.writerow(['timestamp', 'email_text', 'predicted_label',
|
| 478 |
+
'predicted_confidence', 'feedback', 'correct_label',
|
| 479 |
+
'threshold_used'])
|
| 480 |
+
writer.writerow([
|
| 481 |
+
datetime.now().isoformat(),
|
| 482 |
+
email_text[:500], # truncate to 500 chars
|
| 483 |
+
predicted_label,
|
| 484 |
+
f"{predicted_confidence:.4f}",
|
| 485 |
+
feedback_type,
|
| 486 |
+
correct_label or '',
|
| 487 |
+
f"{threshold:.4f}",
|
| 488 |
+
])
|
| 489 |
+
|
| 490 |
+
return count_corrections()
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def count_corrections():
|
| 494 |
+
"""Count the number of 'wrong' entries in the feedback log."""
|
| 495 |
+
if not FEEDBACK_CSV.exists():
|
| 496 |
+
return 0
|
| 497 |
+
count = 0
|
| 498 |
+
with open(FEEDBACK_CSV, 'r', encoding='utf-8') as f:
|
| 499 |
+
reader = csv.DictReader(f)
|
| 500 |
+
for row in reader:
|
| 501 |
+
if row.get('feedback') == 'wrong':
|
| 502 |
+
count += 1
|
| 503 |
+
return count
|
| 504 |
+
```
|
| 505 |
+
|
| 506 |
+
- [ ] **Step 4: Commit**
|
| 507 |
+
|
| 508 |
+
```bash
|
| 509 |
+
git add app.py
|
| 510 |
+
git commit -m "Add summary, comparison, and feedback helper functions"
|
| 511 |
+
```
|
| 512 |
+
|
| 513 |
+
---
|
| 514 |
+
|
| 515 |
+
### Task 5: Build the Gradio UI (Blocks layout with 7 tabs + feedback)
|
| 516 |
+
|
| 517 |
+
**Files:**
|
| 518 |
+
- Modify: `app.py` (append the Gradio Blocks UI after all helper functions)
|
| 519 |
+
|
| 520 |
+
- [ ] **Step 1: Add example emails and the main orchestration function**
|
| 521 |
+
|
| 522 |
+
Append to `app.py`:
|
| 523 |
+
|
| 524 |
+
```python
|
| 525 |
+
# ---------------------------------------------------------------------------
|
| 526 |
+
# 10. Example Emails
|
| 527 |
+
# ---------------------------------------------------------------------------
|
| 528 |
+
|
| 529 |
+
EXAMPLE_EMAILS = [
|
| 530 |
+
["Subject: URGENT - You Have Won $5,000,000!!!\n\nDear Friend,\n\nCONGRATULATIONS!!! You have been selected as the winner of our international lottery program!!!\nTo claim your $5,000,000 USD prize, click the link below IMMEDIATELY and provide your bank details.\n\nACT NOW - This offer expires in 24 hours!!!\n\nClick here: http://totally-legit-prize.com/claim\nSend $500 processing fee to unlock your winnings.\n\nBest regards,\nDr. Prince Mohammed"],
|
| 531 |
+
["Subject: Team sync Thursday 2pm\n\nHi everyone,\n\nJust a reminder that we have our weekly team sync this Thursday at 2pm in Conference Room B.\n\nAgenda:\n- Sprint review\n- Q2 planning discussion\n- New hire onboarding update\n\nPlease come prepared with your status updates.\n\nThanks,\nSarah"],
|
| 532 |
+
["Subject: Your account has been compromised!\n\nDear Customer,\n\nWe detected suspicious activity on your account. Click here immediately to verify your identity: http://secure-bank-login.com/verify\n\nIf you do not verify within 24 hours, your account will be permanently locked.\n\nSecurity Team"],
|
| 533 |
+
["Subject: Thanksgiving dinner plans\n\nHi everyone!\n\nI wanted to start planning for Thanksgiving dinner. I'm thinking we could do it at my place this year. What does everyone think about 4pm?\n\nLet me know if you have any dietary restrictions or if you want to bring a dish.\n\nLove,\nMom"],
|
| 534 |
+
["Subject: Best prices on V1AGRA and C1ALIS!!!\n\n$$$ SAVE BIG $$$\nBuy now and get 80% OFF!!!\nNo prescription needed! Free shipping!\nOrder at http://cheap-pharma-deals.com\n\nLIMITED TIME OFFER - ACT NOW!"],
|
| 535 |
+
]
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
# ---------------------------------------------------------------------------
|
| 539 |
+
# 11. Main orchestration function
|
| 540 |
+
# ---------------------------------------------------------------------------
|
| 541 |
+
|
| 542 |
+
def classify_and_explain(email_text, uploaded_file, threshold):
|
| 543 |
+
"""Main function called by Gradio. Returns all outputs for all 7 tabs + feedback state."""
|
| 544 |
+
|
| 545 |
+
# Handle file upload
|
| 546 |
+
if uploaded_file is not None:
|
| 547 |
+
try:
|
| 548 |
+
file_content = Path(uploaded_file).read_text(encoding='utf-8')
|
| 549 |
+
email_text = file_content
|
| 550 |
+
except Exception:
|
| 551 |
+
empty = ("Could not read file.", None, None, "Error reading file.", "", "", "")
|
| 552 |
+
return empty
|
| 553 |
+
|
| 554 |
+
if email_text is None or email_text.strip() == '':
|
| 555 |
+
empty = ("Please enter email text or upload a file.", None, None, "", "", "", "")
|
| 556 |
+
return empty
|
| 557 |
+
|
| 558 |
+
if voting_model is None:
|
| 559 |
+
empty = ("Models not found. Run `python3 train.py` first.", None, None, "", "", "", "")
|
| 560 |
+
return empty
|
| 561 |
+
|
| 562 |
+
# Classify
|
| 563 |
+
label, confidence, spam_proba, combined = classify_email(email_text, threshold)
|
| 564 |
+
|
| 565 |
+
# LIME
|
| 566 |
+
lime_fig, lime_exp = generate_lime_explanation(combined)
|
| 567 |
+
|
| 568 |
+
# SHAP
|
| 569 |
+
try:
|
| 570 |
+
shap_fig, shap_sv, shap_top_idx = generate_shap_explanation(email_text)
|
| 571 |
+
except Exception as e:
|
| 572 |
+
print(f"SHAP error: {e}")
|
| 573 |
+
shap_fig, shap_sv, shap_top_idx = None, None, None
|
| 574 |
+
|
| 575 |
+
# ELI5
|
| 576 |
+
try:
|
| 577 |
+
eli5_html, eli5_names = generate_eli5_explanation(combined)
|
| 578 |
+
except Exception as e:
|
| 579 |
+
print(f"ELI5 error: {e}")
|
| 580 |
+
eli5_html, eli5_names = None, None
|
| 581 |
+
|
| 582 |
+
# Result summary (enhanced)
|
| 583 |
+
result_md = f"## {'SPAM' if 'SPAM' in label else 'HAM (Not Spam)'}\n\n"
|
| 584 |
+
result_md += f"**Confidence:** {confidence:.1%}\n\n"
|
| 585 |
+
result_md += f"**Threshold:** {threshold:.0%}\n\n"
|
| 586 |
+
result_md += f"**Spam probability:** {spam_proba:.1%}\n\n"
|
| 587 |
+
if lime_exp is not None:
|
| 588 |
+
result_md += "**Key factors:**\n"
|
| 589 |
+
for feat_rule, weight in lime_exp.as_list()[:5]:
|
| 590 |
+
direction = "pushes toward spam" if weight > 0 else "pushes toward ham"
|
| 591 |
+
result_md += f"- **{feat_rule}** {direction}\n"
|
| 592 |
+
|
| 593 |
+
# Comparison
|
| 594 |
+
comparison_md = generate_comparison(lime_exp, shap_sv, shap_top_idx, eli5_names)
|
| 595 |
+
|
| 596 |
+
# Plain English summary
|
| 597 |
+
summary_md = generate_plain_summary(label, confidence, spam_proba,
|
| 598 |
+
lime_exp, shap_sv, shap_top_idx)
|
| 599 |
+
|
| 600 |
+
# ELI5 HTML (wrap for display)
|
| 601 |
+
eli5_display = eli5_html or "<p>ELI5 explanation not available.</p>"
|
| 602 |
+
|
| 603 |
+
return (result_md, lime_fig, shap_fig, eli5_display, comparison_md, summary_md,
|
| 604 |
+
f"{label}|||{confidence:.4f}|||{threshold:.4f}|||{email_text[:500]}")
|
| 605 |
+
```
|
| 606 |
+
|
| 607 |
+
- [ ] **Step 2: Add feedback handler functions**
|
| 608 |
+
|
| 609 |
+
Append to `app.py`:
|
| 610 |
+
|
| 611 |
+
```python
|
| 612 |
+
# ---------------------------------------------------------------------------
|
| 613 |
+
# 12. Feedback handlers
|
| 614 |
+
# ---------------------------------------------------------------------------
|
| 615 |
+
|
| 616 |
+
def handle_correct(hidden_state):
|
| 617 |
+
"""Log positive feedback."""
|
| 618 |
+
if not hidden_state:
|
| 619 |
+
return "No classification to give feedback on."
|
| 620 |
+
parts = hidden_state.split('|||')
|
| 621 |
+
if len(parts) < 4:
|
| 622 |
+
return "No classification to give feedback on."
|
| 623 |
+
label, conf, thresh, email = parts[0], float(parts[1]), float(parts[2]), parts[3]
|
| 624 |
+
corrections = log_feedback(email, label, float(conf), float(thresh), 'correct')
|
| 625 |
+
return f"Thanks for the feedback! ({corrections} corrections collected so far)"
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def handle_wrong(hidden_state, correct_label):
|
| 629 |
+
"""Log negative feedback with the user's correction."""
|
| 630 |
+
if not hidden_state:
|
| 631 |
+
return "No classification to give feedback on."
|
| 632 |
+
parts = hidden_state.split('|||')
|
| 633 |
+
if len(parts) < 4:
|
| 634 |
+
return "No classification to give feedback on."
|
| 635 |
+
label, conf, thresh, email = parts[0], float(parts[1]), float(parts[2]), parts[3]
|
| 636 |
+
corrections = log_feedback(email, label, float(conf), float(thresh), 'wrong', correct_label)
|
| 637 |
+
return f"Correction logged! ({corrections} corrections collected so far)"
|
| 638 |
+
```
|
| 639 |
+
|
| 640 |
+
- [ ] **Step 3: Add the Gradio Blocks layout**
|
| 641 |
+
|
| 642 |
+
Append to `app.py`:
|
| 643 |
+
|
| 644 |
+
```python
|
| 645 |
+
# ---------------------------------------------------------------------------
|
| 646 |
+
# 13. Gradio Blocks UI
|
| 647 |
+
# ---------------------------------------------------------------------------
|
| 648 |
+
|
| 649 |
+
HOW_IT_WORKS_MD = """
|
| 650 |
+
## How This App Works
|
| 651 |
+
|
| 652 |
+
### What is spam classification?
|
| 653 |
+
Spam classification automatically identifies unwanted or malicious emails (spam) vs. legitimate messages (ham). This helps protect users from phishing scams, fraudulent offers, and unwanted advertising.
|
| 654 |
+
|
| 655 |
+
### The Model
|
| 656 |
+
This app uses a **Voting Ensemble** — three different machine learning models that each "vote" on whether an email is spam:
|
| 657 |
+
- **Random Forest** — builds many decision trees and takes the majority vote
|
| 658 |
+
- **Logistic Regression** — finds a mathematical boundary between spam and ham
|
| 659 |
+
- **Support Vector Machine (SVM)** — finds the widest possible margin between classes
|
| 660 |
+
|
| 661 |
+
By combining all three, the ensemble is more accurate than any single model alone.
|
| 662 |
+
|
| 663 |
+
### Feature Extraction
|
| 664 |
+
The model looks at two types of features:
|
| 665 |
+
- **TF-IDF (Term Frequency-Inverse Document Frequency)** — measures how important each word is. Common spam words like "prize" or "click" get high scores.
|
| 666 |
+
- **24 Metadata Features** — structural patterns like exclamation mark density, dollar sign count, ALL CAPS ratio, URL count, and more.
|
| 667 |
+
|
| 668 |
+
### Explainable AI (XAI) Methods
|
| 669 |
+
This app doesn't just classify — it explains **why**:
|
| 670 |
+
|
| 671 |
+
- **LIME** — Removes words one at a time and watches how the prediction changes. Shows which words matter most.
|
| 672 |
+
- **SHAP** — Uses game theory to calculate each feature's "fair share" of the prediction. Based on Nobel Prize-winning mathematics.
|
| 673 |
+
- **ELI5** — Looks directly at the model's internal weights to show which features it relies on most.
|
| 674 |
+
|
| 675 |
+
### Feedback & Retraining
|
| 676 |
+
When you click "Correct" or "Wrong", your feedback is saved. After enough corrections accumulate, the model can be retrained with the new examples to improve over time. This is called **human-in-the-loop machine learning**.
|
| 677 |
+
|
| 678 |
+
### Disclaimer
|
| 679 |
+
This model was created as a university course project. It is intended for **educational and research purposes only** and should not be used as a sole spam filter in production. Always use established email security tools for real-world spam filtering.
|
| 680 |
+
"""
|
| 681 |
+
|
| 682 |
+
with gr.Blocks(title="Spam Email Classifier with XAI") as demo:
|
| 683 |
+
gr.Markdown("# Spam Email Classifier with XAI Explanations")
|
| 684 |
+
gr.Markdown("Classify emails and understand **why** using LIME, SHAP, ELI5, "
|
| 685 |
+
"and plain English summaries. Created as a university course project.")
|
| 686 |
+
|
| 687 |
+
# Hidden state for feedback (stores last classification info)
|
| 688 |
+
hidden_state = gr.State("")
|
| 689 |
+
|
| 690 |
+
with gr.Row():
|
| 691 |
+
# Left column: inputs
|
| 692 |
+
with gr.Column(scale=1):
|
| 693 |
+
email_input = gr.Textbox(
|
| 694 |
+
label="Email Text",
|
| 695 |
+
placeholder="Paste your email here...",
|
| 696 |
+
lines=12,
|
| 697 |
+
autoscroll=False,
|
| 698 |
+
)
|
| 699 |
+
file_input = gr.File(
|
| 700 |
+
label="Or upload a .txt file",
|
| 701 |
+
file_types=['.txt'],
|
| 702 |
+
)
|
| 703 |
+
threshold_slider = gr.Slider(
|
| 704 |
+
minimum=0.0, maximum=1.0, step=0.05,
|
| 705 |
+
value=optimal_threshold if optimal_threshold else 0.5,
|
| 706 |
+
label="Classification Threshold",
|
| 707 |
+
info="Emails with spam probability above this are classified as spam.",
|
| 708 |
+
)
|
| 709 |
+
classify_btn = gr.Button("Classify", variant="primary")
|
| 710 |
+
gr.Examples(
|
| 711 |
+
examples=EXAMPLE_EMAILS,
|
| 712 |
+
inputs=[email_input],
|
| 713 |
+
label="Try an example email",
|
| 714 |
+
cache_examples=False,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
# Right column: output tabs
|
| 718 |
+
with gr.Column(scale=1):
|
| 719 |
+
with gr.Tabs():
|
| 720 |
+
with gr.Tab("Result"):
|
| 721 |
+
result_output = gr.Markdown(label="Classification Result")
|
| 722 |
+
with gr.Tab("LIME"):
|
| 723 |
+
gr.Markdown("*LIME perturbs the input and fits a local model "
|
| 724 |
+
"to see which features matter most.*")
|
| 725 |
+
lime_output = gr.Plot(label="LIME Explanation")
|
| 726 |
+
with gr.Tab("SHAP"):
|
| 727 |
+
gr.Markdown("*SHAP uses game theory to assign each feature "
|
| 728 |
+
"a contribution value.*")
|
| 729 |
+
shap_output = gr.Plot(label="SHAP Explanation")
|
| 730 |
+
with gr.Tab("ELI5"):
|
| 731 |
+
gr.Markdown("*ELI5 shows feature weights directly from the "
|
| 732 |
+
"model's internals.*")
|
| 733 |
+
eli5_output = gr.HTML(label="ELI5 Explanation")
|
| 734 |
+
with gr.Tab("Compare"):
|
| 735 |
+
compare_output = gr.Markdown(label="Method Comparison")
|
| 736 |
+
with gr.Tab("Summary"):
|
| 737 |
+
summary_output = gr.Markdown(label="Plain English Summary")
|
| 738 |
+
with gr.Tab("How It Works"):
|
| 739 |
+
gr.Markdown(HOW_IT_WORKS_MD)
|
| 740 |
+
|
| 741 |
+
# Feedback row
|
| 742 |
+
gr.Markdown("---")
|
| 743 |
+
gr.Markdown("**Was this classification correct?**")
|
| 744 |
+
feedback_msg = gr.Markdown("")
|
| 745 |
+
with gr.Row():
|
| 746 |
+
correct_btn = gr.Button("Correct", variant="secondary")
|
| 747 |
+
wrong_btn = gr.Button("Wrong", variant="stop")
|
| 748 |
+
correction_dropdown = gr.Dropdown(
|
| 749 |
+
choices=["Spam", "Ham"],
|
| 750 |
+
label="What should it be?",
|
| 751 |
+
visible=True,
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
# Wire up classify button
|
| 755 |
+
classify_btn.click(
|
| 756 |
+
fn=classify_and_explain,
|
| 757 |
+
inputs=[email_input, file_input, threshold_slider],
|
| 758 |
+
outputs=[result_output, lime_output, shap_output, eli5_output,
|
| 759 |
+
compare_output, summary_output, hidden_state],
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
# Wire up feedback buttons
|
| 763 |
+
correct_btn.click(
|
| 764 |
+
fn=handle_correct,
|
| 765 |
+
inputs=[hidden_state],
|
| 766 |
+
outputs=[feedback_msg],
|
| 767 |
+
)
|
| 768 |
+
wrong_btn.click(
|
| 769 |
+
fn=handle_wrong,
|
| 770 |
+
inputs=[hidden_state, correction_dropdown],
|
| 771 |
+
outputs=[feedback_msg],
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
# ---------------------------------------------------------------------------
|
| 776 |
+
# 14. Launch
|
| 777 |
+
# ---------------------------------------------------------------------------
|
| 778 |
+
|
| 779 |
+
if __name__ == '__main__':
|
| 780 |
+
demo.launch()
|
| 781 |
+
```
|
| 782 |
+
|
| 783 |
+
- [ ] **Step 4: Test the app launches locally**
|
| 784 |
+
|
| 785 |
+
```bash
|
| 786 |
+
cd spam-classifier-gradio
|
| 787 |
+
source venv/bin/activate
|
| 788 |
+
python3 app.py
|
| 789 |
+
```
|
| 790 |
+
|
| 791 |
+
Open `http://127.0.0.1:7860` in a browser. Verify:
|
| 792 |
+
1. All 7 tabs are visible
|
| 793 |
+
2. Pasting an example email and clicking Classify populates all tabs
|
| 794 |
+
3. LIME and SHAP plots render
|
| 795 |
+
4. ELI5 shows an HTML table
|
| 796 |
+
5. Compare tab shows the side-by-side table
|
| 797 |
+
6. Summary tab shows plain English text
|
| 798 |
+
7. How It Works tab shows static content
|
| 799 |
+
8. Feedback buttons log to `feedback/feedback_log.csv`
|
| 800 |
+
|
| 801 |
+
Stop the server with Ctrl+C.
|
| 802 |
+
|
| 803 |
+
- [ ] **Step 5: Commit**
|
| 804 |
+
|
| 805 |
+
```bash
|
| 806 |
+
git add app.py
|
| 807 |
+
git commit -m "Build complete Gradio UI with 7 tabs, feedback, and threshold slider"
|
| 808 |
+
```
|
| 809 |
+
|
| 810 |
+
---
|
| 811 |
+
|
| 812 |
+
### Task 6: Create retrain.py with feedback CSV support
|
| 813 |
+
|
| 814 |
+
**Files:**
|
| 815 |
+
- Create: `retrain.py`
|
| 816 |
+
|
| 817 |
+
- [ ] **Step 1: Write retrain.py**
|
| 818 |
+
|
| 819 |
+
Create `retrain.py` in the project root:
|
| 820 |
+
|
| 821 |
+
```python
|
| 822 |
+
# Batch retrain script for the spam classifier
|
| 823 |
+
# Reads feedback corrections from feedback/feedback_log.csv,
|
| 824 |
+
# merges them into the original training data, and retrains
|
| 825 |
+
# the VotingClassifier ensemble.
|
| 826 |
+
#
|
| 827 |
+
# Usage:
|
| 828 |
+
# python3 retrain.py # retrain with feedback
|
| 829 |
+
# python3 retrain.py --no-feedback # retrain without feedback (original data only)
|
| 830 |
+
|
| 831 |
+
import sys
|
| 832 |
+
import csv
|
| 833 |
+
import warnings
|
| 834 |
+
from pathlib import Path
|
| 835 |
+
|
| 836 |
+
import numpy as np
|
| 837 |
+
import pandas as pd
|
| 838 |
+
from sklearn.model_selection import train_test_split
|
| 839 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 840 |
+
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
|
| 841 |
+
from sklearn.linear_model import LogisticRegression
|
| 842 |
+
from sklearn.svm import LinearSVC
|
| 843 |
+
from sklearn.calibration import CalibratedClassifierCV
|
| 844 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 845 |
+
from sklearn.metrics import classification_report, precision_recall_curve
|
| 846 |
+
from scipy.sparse import hstack, csr_matrix
|
| 847 |
+
import joblib
|
| 848 |
+
|
| 849 |
+
from utils import preprocess_text, compute_metadata_features, META_FEATURE_NAMES
|
| 850 |
+
|
| 851 |
+
warnings.filterwarnings('ignore', category=FutureWarning)
|
| 852 |
+
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
| 853 |
+
|
| 854 |
+
project_dir = Path(__file__).parent
|
| 855 |
+
data_dir = project_dir / 'data'
|
| 856 |
+
models_dir = project_dir / 'models'
|
| 857 |
+
feedback_csv = project_dir / 'feedback' / 'feedback_log.csv'
|
| 858 |
+
random_state = 42
|
| 859 |
+
KAGGLE_CAP = 100_000
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
def load_feedback_corrections():
|
| 863 |
+
"""Read feedback CSV and return a DataFrame of corrections."""
|
| 864 |
+
if not feedback_csv.exists():
|
| 865 |
+
print("No feedback file found.")
|
| 866 |
+
return pd.DataFrame(columns=['text', 'label'])
|
| 867 |
+
|
| 868 |
+
corrections = []
|
| 869 |
+
with open(feedback_csv, 'r', encoding='utf-8') as f:
|
| 870 |
+
reader = csv.DictReader(f)
|
| 871 |
+
for row in reader:
|
| 872 |
+
if row.get('feedback') == 'wrong' and row.get('correct_label'):
|
| 873 |
+
label = 1 if row['correct_label'].lower() == 'spam' else 0
|
| 874 |
+
corrections.append({
|
| 875 |
+
'text': row['email_text'],
|
| 876 |
+
'label': label,
|
| 877 |
+
})
|
| 878 |
+
|
| 879 |
+
df = pd.DataFrame(corrections)
|
| 880 |
+
print(f"Found {len(df)} corrections in feedback log.")
|
| 881 |
+
return df
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
def main():
|
| 885 |
+
use_feedback = '--no-feedback' not in sys.argv
|
| 886 |
+
|
| 887 |
+
# --- Load original training data ---
|
| 888 |
+
print("Loading training data...")
|
| 889 |
+
kaggle_path = data_dir / 'spam_Emails_data.csv'
|
| 890 |
+
github_dir = data_dir / 'email-dataset-main' / 'email-dataset-main'
|
| 891 |
+
|
| 892 |
+
frames = []
|
| 893 |
+
|
| 894 |
+
if kaggle_path.exists():
|
| 895 |
+
kaggle_df = pd.read_csv(kaggle_path)
|
| 896 |
+
# Standardize column names
|
| 897 |
+
if 'label' in kaggle_df.columns and 'text' in kaggle_df.columns:
|
| 898 |
+
kaggle_df['label'] = kaggle_df['label'].map({'spam': 1, 'ham': 0})
|
| 899 |
+
kaggle_df = kaggle_df.dropna(subset=['label', 'text'])
|
| 900 |
+
elif 'v1' in kaggle_df.columns and 'v2' in kaggle_df.columns:
|
| 901 |
+
kaggle_df = kaggle_df.rename(columns={'v1': 'label_str', 'v2': 'text'})
|
| 902 |
+
kaggle_df['label'] = kaggle_df['label_str'].map({'spam': 1, 'ham': 0})
|
| 903 |
+
kaggle_df = kaggle_df[['text', 'label']].dropna()
|
| 904 |
+
|
| 905 |
+
if len(kaggle_df) > KAGGLE_CAP:
|
| 906 |
+
kaggle_df = kaggle_df.groupby('label', group_keys=False).apply(
|
| 907 |
+
lambda x: x.sample(min(len(x), KAGGLE_CAP // 2), random_state=random_state)
|
| 908 |
+
)
|
| 909 |
+
frames.append(kaggle_df[['text', 'label']])
|
| 910 |
+
print(f" Kaggle: {len(kaggle_df)} emails")
|
| 911 |
+
|
| 912 |
+
if github_dir.exists():
|
| 913 |
+
for label_dir in github_dir.iterdir():
|
| 914 |
+
if label_dir.is_dir() and label_dir.name in ('spam', 'ham'):
|
| 915 |
+
lbl = 1 if label_dir.name == 'spam' else 0
|
| 916 |
+
for f in label_dir.iterdir():
|
| 917 |
+
if f.is_file():
|
| 918 |
+
try:
|
| 919 |
+
text = f.read_text(encoding='utf-8', errors='ignore')
|
| 920 |
+
frames.append(pd.DataFrame([{'text': text, 'label': lbl}]))
|
| 921 |
+
except Exception:
|
| 922 |
+
pass
|
| 923 |
+
|
| 924 |
+
if not frames:
|
| 925 |
+
print("ERROR: No training data found in data/ directory.")
|
| 926 |
+
sys.exit(1)
|
| 927 |
+
|
| 928 |
+
df = pd.concat(frames, ignore_index=True)
|
| 929 |
+
print(f" Total original: {len(df)} emails")
|
| 930 |
+
|
| 931 |
+
# --- Merge feedback corrections ---
|
| 932 |
+
if use_feedback:
|
| 933 |
+
feedback_df = load_feedback_corrections()
|
| 934 |
+
if len(feedback_df) > 0:
|
| 935 |
+
df = pd.concat([df, feedback_df], ignore_index=True)
|
| 936 |
+
print(f" After feedback merge: {len(df)} emails")
|
| 937 |
+
|
| 938 |
+
# --- Preprocess ---
|
| 939 |
+
print("Preprocessing...")
|
| 940 |
+
df['clean'] = df['text'].apply(preprocess_text)
|
| 941 |
+
df = df[df['clean'].str.len() > 0]
|
| 942 |
+
|
| 943 |
+
X_text = df['clean'].values
|
| 944 |
+
y = df['label'].values
|
| 945 |
+
|
| 946 |
+
# --- Split ---
|
| 947 |
+
X_train_text, X_test_text, y_train, y_test = train_test_split(
|
| 948 |
+
X_text, y, test_size=0.2, random_state=random_state, stratify=y
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
+
# --- TF-IDF ---
|
| 952 |
+
print("Fitting TF-IDF...")
|
| 953 |
+
tfidf = TfidfVectorizer(max_features=3000, ngram_range=(1, 3),
|
| 954 |
+
min_df=2, max_df=0.95)
|
| 955 |
+
X_train_tfidf = tfidf.fit_transform(X_train_text)
|
| 956 |
+
X_test_tfidf = tfidf.transform(X_test_text)
|
| 957 |
+
|
| 958 |
+
# --- Metadata features ---
|
| 959 |
+
print("Computing metadata features...")
|
| 960 |
+
train_idx = df.index[df['clean'].isin(X_train_text)]
|
| 961 |
+
test_idx = df.index[df['clean'].isin(X_test_text)]
|
| 962 |
+
|
| 963 |
+
X_train_meta = compute_metadata_features(df.loc[train_idx, 'text'].tolist()[:len(X_train_text)])
|
| 964 |
+
X_test_meta = compute_metadata_features(df.loc[test_idx, 'text'].tolist()[:len(X_test_text)])
|
| 965 |
+
|
| 966 |
+
scaler = MinMaxScaler()
|
| 967 |
+
X_train_meta_scaled = scaler.fit_transform(X_train_meta)
|
| 968 |
+
X_test_meta_scaled = scaler.transform(X_test_meta)
|
| 969 |
+
|
| 970 |
+
# --- Combine ---
|
| 971 |
+
X_train = hstack([X_train_tfidf, csr_matrix(X_train_meta_scaled)])
|
| 972 |
+
X_test = hstack([X_test_tfidf, csr_matrix(X_test_meta_scaled)])
|
| 973 |
+
|
| 974 |
+
feature_names_list = tfidf.get_feature_names_out().tolist() + META_FEATURE_NAMES
|
| 975 |
+
|
| 976 |
+
# --- Train ensemble ---
|
| 977 |
+
print("Training VotingClassifier ensemble...")
|
| 978 |
+
ensemble = VotingClassifier(
|
| 979 |
+
estimators=[
|
| 980 |
+
('rf', RandomForestClassifier(
|
| 981 |
+
n_estimators=200, n_jobs=-1,
|
| 982 |
+
class_weight='balanced', random_state=random_state)),
|
| 983 |
+
('lr', LogisticRegression(
|
| 984 |
+
max_iter=1000, class_weight='balanced', random_state=random_state)),
|
| 985 |
+
('svm', CalibratedClassifierCV(
|
| 986 |
+
LinearSVC(class_weight='balanced', max_iter=2000,
|
| 987 |
+
random_state=random_state))),
|
| 988 |
+
],
|
| 989 |
+
voting='soft',
|
| 990 |
+
)
|
| 991 |
+
ensemble.fit(X_train, y_train)
|
| 992 |
+
|
| 993 |
+
# --- Evaluate ---
|
| 994 |
+
y_pred = ensemble.predict(X_test)
|
| 995 |
+
print("\nClassification Report:")
|
| 996 |
+
print(classification_report(y_test, y_pred, target_names=['Ham', 'Spam']))
|
| 997 |
+
|
| 998 |
+
# --- Optimal threshold ---
|
| 999 |
+
y_scores = ensemble.predict_proba(X_test)[:, 1]
|
| 1000 |
+
precisions, recalls, thresholds = precision_recall_curve(y_test, y_scores)
|
| 1001 |
+
f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
|
| 1002 |
+
best_idx = np.argmax(f1_scores)
|
| 1003 |
+
optimal_threshold = float(thresholds[best_idx])
|
| 1004 |
+
print(f"Optimal threshold: {optimal_threshold:.4f}")
|
| 1005 |
+
|
| 1006 |
+
# --- Save ---
|
| 1007 |
+
models_dir.mkdir(exist_ok=True)
|
| 1008 |
+
joblib.dump(ensemble, models_dir / 'voting_model.joblib')
|
| 1009 |
+
joblib.dump(tfidf, models_dir / 'tfidf_vectorizer.joblib')
|
| 1010 |
+
joblib.dump(scaler, models_dir / 'meta_scaler.joblib')
|
| 1011 |
+
joblib.dump(feature_names_list, models_dir / 'feature_names.joblib')
|
| 1012 |
+
joblib.dump(optimal_threshold, models_dir / 'optimal_threshold.joblib')
|
| 1013 |
+
|
| 1014 |
+
# Training sample for LIME/SHAP background
|
| 1015 |
+
sample_size = min(200, X_train.shape[0])
|
| 1016 |
+
sample_idx = np.random.RandomState(random_state).choice(
|
| 1017 |
+
X_train.shape[0], sample_size, replace=False)
|
| 1018 |
+
training_sample = X_train[sample_idx].toarray()
|
| 1019 |
+
joblib.dump(training_sample, models_dir / 'training_sample.joblib')
|
| 1020 |
+
|
| 1021 |
+
print(f"\nAll models saved to {models_dir}/")
|
| 1022 |
+
if use_feedback:
|
| 1023 |
+
corrections = load_feedback_corrections()
|
| 1024 |
+
print(f"Feedback corrections incorporated: {len(corrections)}")
|
| 1025 |
+
|
| 1026 |
+
|
| 1027 |
+
if __name__ == '__main__':
|
| 1028 |
+
main()
|
| 1029 |
+
```
|
| 1030 |
+
|
| 1031 |
+
- [ ] **Step 2: Verify retrain.py runs with --no-feedback flag**
|
| 1032 |
+
|
| 1033 |
+
```bash
|
| 1034 |
+
python3 retrain.py --no-feedback
|
| 1035 |
+
```
|
| 1036 |
+
|
| 1037 |
+
Expected: Trains successfully, prints classification report, saves models.
|
| 1038 |
+
|
| 1039 |
+
Note: This requires the `data/` directory to exist with training CSV. If data is symlinked, make sure the symlink works.
|
| 1040 |
+
|
| 1041 |
+
- [ ] **Step 3: Commit**
|
| 1042 |
+
|
| 1043 |
+
```bash
|
| 1044 |
+
git add retrain.py
|
| 1045 |
+
git commit -m "Add retrain.py with feedback CSV ingestion for batch retraining"
|
| 1046 |
+
```
|
| 1047 |
+
|
| 1048 |
+
---
|
| 1049 |
+
|
| 1050 |
+
### Task 7: Update README.md
|
| 1051 |
+
|
| 1052 |
+
**Files:**
|
| 1053 |
+
- Modify: `README.md`
|
| 1054 |
+
|
| 1055 |
+
- [ ] **Step 1: Update README.md frontmatter and description**
|
| 1056 |
+
|
| 1057 |
+
Replace the entire contents of `README.md`:
|
| 1058 |
+
|
| 1059 |
+
```markdown
|
| 1060 |
+
---
|
| 1061 |
+
title: Spam Email Classifier with XAI
|
| 1062 |
+
emoji: 📧
|
| 1063 |
+
colorFrom: blue
|
| 1064 |
+
colorTo: red
|
| 1065 |
+
sdk: gradio
|
| 1066 |
+
sdk_version: "4.44.0"
|
| 1067 |
+
app_file: app.py
|
| 1068 |
+
pinned: false
|
| 1069 |
+
license: mit
|
| 1070 |
+
tags:
|
| 1071 |
+
- spam-detection
|
| 1072 |
+
- xai
|
| 1073 |
+
- lime
|
| 1074 |
+
- shap
|
| 1075 |
+
- eli5
|
| 1076 |
+
- scikit-learn
|
| 1077 |
+
- nlp
|
| 1078 |
+
- explainable-ai
|
| 1079 |
+
---
|
| 1080 |
+
|
| 1081 |
+
# Spam Email Classifier with XAI Explanations
|
| 1082 |
+
|
| 1083 |
+
> **Disclaimer:** This model was created as a university course project. It is intended for **educational and research purposes only** and should not be used as a sole spam/phishing filter in production. Classification accuracy may vary. Always use established email security tools for real-world spam filtering.
|
| 1084 |
+
|
| 1085 |
+
A Gradio web app that classifies emails as spam or ham and provides explainable AI (XAI) insights using three different methods.
|
| 1086 |
+
|
| 1087 |
+
## Features
|
| 1088 |
+
|
| 1089 |
+
- Paste any email and get an instant spam/ham prediction
|
| 1090 |
+
- **LIME** explanations — which words pushed the decision
|
| 1091 |
+
- **SHAP** feature importance — game-theoretic attribution
|
| 1092 |
+
- **ELI5** — model internal feature weights
|
| 1093 |
+
- **Side-by-side comparison** of all three XAI methods
|
| 1094 |
+
- **Plain English summary** of why the model made its decision
|
| 1095 |
+
- **User feedback** — thumbs up/down to log corrections for batch retraining
|
| 1096 |
+
- Adjustable classification threshold
|
| 1097 |
+
|
| 1098 |
+
## How to Run Locally
|
| 1099 |
+
|
| 1100 |
+
```bash
|
| 1101 |
+
pip install -r requirements.txt
|
| 1102 |
+
python train.py # train the models (first time only)
|
| 1103 |
+
python app.py # launch the Gradio app
|
| 1104 |
+
```
|
| 1105 |
+
|
| 1106 |
+
## Retraining with Feedback
|
| 1107 |
+
|
| 1108 |
+
```bash
|
| 1109 |
+
python retrain.py # retrain with accumulated feedback corrections
|
| 1110 |
+
python retrain.py --no-feedback # retrain with original data only
|
| 1111 |
+
```
|
| 1112 |
+
|
| 1113 |
+
## Model
|
| 1114 |
+
|
| 1115 |
+
Voting ensemble (Random Forest + Logistic Regression + SVM) trained on SpamAssassin + Enron email datasets using TF-IDF + 24 metadata features.
|
| 1116 |
+
|
| 1117 |
+
## Tech Stack
|
| 1118 |
+
|
| 1119 |
+
- scikit-learn (ensemble classifier)
|
| 1120 |
+
- LIME + SHAP + ELI5 (explainability)
|
| 1121 |
+
- Gradio (web interface)
|
| 1122 |
+
- NLTK (text preprocessing)
|
| 1123 |
+
```
|
| 1124 |
+
|
| 1125 |
+
- [ ] **Step 2: Commit**
|
| 1126 |
+
|
| 1127 |
+
```bash
|
| 1128 |
+
git add README.md
|
| 1129 |
+
git commit -m "Update README with merged XAI features and feedback system"
|
| 1130 |
+
```
|
| 1131 |
+
|
| 1132 |
+
---
|
| 1133 |
+
|
| 1134 |
+
### Task 8: Update retrain.command wrapper
|
| 1135 |
+
|
| 1136 |
+
**Files:**
|
| 1137 |
+
- Modify: `retrain.command`
|
| 1138 |
+
|
| 1139 |
+
- [ ] **Step 1: Update the retrain shell wrapper**
|
| 1140 |
+
|
| 1141 |
+
Replace `retrain.command` contents:
|
| 1142 |
+
|
| 1143 |
+
```bash
|
| 1144 |
+
#!/bin/bash
|
| 1145 |
+
cd "$(dirname "$0")"
|
| 1146 |
+
source venv/bin/activate
|
| 1147 |
+
echo "============================================================"
|
| 1148 |
+
echo " Spam Classifier — Retrain sklearn ensemble"
|
| 1149 |
+
echo " Models: Random Forest + Logistic Regression + SVM"
|
| 1150 |
+
echo " Includes: Feedback corrections from feedback/feedback_log.csv"
|
| 1151 |
+
echo "============================================================"
|
| 1152 |
+
echo ""
|
| 1153 |
+
echo "Options:"
|
| 1154 |
+
echo " 1) Retrain WITH feedback corrections (default)"
|
| 1155 |
+
echo " 2) Retrain WITHOUT feedback (original data only)"
|
| 1156 |
+
echo ""
|
| 1157 |
+
read -p "Choose [1/2]: " choice
|
| 1158 |
+
echo ""
|
| 1159 |
+
|
| 1160 |
+
if [ "$choice" = "2" ]; then
|
| 1161 |
+
python3 retrain.py --no-feedback
|
| 1162 |
+
else
|
| 1163 |
+
python3 retrain.py
|
| 1164 |
+
fi
|
| 1165 |
+
|
| 1166 |
+
echo ""
|
| 1167 |
+
echo "Done! Models saved to models/"
|
| 1168 |
+
echo ""
|
| 1169 |
+
echo "Press any key to close..."
|
| 1170 |
+
read -n 1
|
| 1171 |
+
```
|
| 1172 |
+
|
| 1173 |
+
- [ ] **Step 2: Commit**
|
| 1174 |
+
|
| 1175 |
+
```bash
|
| 1176 |
+
git add retrain.command
|
| 1177 |
+
git commit -m "Update retrain.command to support feedback-augmented retraining"
|
| 1178 |
+
```
|
| 1179 |
+
|
| 1180 |
+
---
|
| 1181 |
+
|
| 1182 |
+
### Task 9: End-to-end verification
|
| 1183 |
+
|
| 1184 |
+
- [ ] **Step 1: Run existing tests to make sure utils.py still works**
|
| 1185 |
+
|
| 1186 |
+
```bash
|
| 1187 |
+
cd spam-classifier-gradio
|
| 1188 |
+
source venv/bin/activate
|
| 1189 |
+
python3 -m pytest test_utils.py -v
|
| 1190 |
+
```
|
| 1191 |
+
|
| 1192 |
+
Expected: All tests pass. We did not modify `utils.py` or `test_utils.py`.
|
| 1193 |
+
|
| 1194 |
+
- [ ] **Step 2: Launch the app and test all features**
|
| 1195 |
+
|
| 1196 |
+
```bash
|
| 1197 |
+
python3 app.py
|
| 1198 |
+
```
|
| 1199 |
+
|
| 1200 |
+
Open `http://127.0.0.1:7860`. Test:
|
| 1201 |
+
|
| 1202 |
+
1. Paste a spam example → verify Result tab shows SPAM with confidence
|
| 1203 |
+
2. Check LIME tab → bar chart renders
|
| 1204 |
+
3. Check SHAP tab → bar chart renders
|
| 1205 |
+
4. Check ELI5 tab → HTML table renders
|
| 1206 |
+
5. Check Compare tab → side-by-side table with agreement count
|
| 1207 |
+
6. Check Summary tab → plain English explanation
|
| 1208 |
+
7. Check How It Works tab → static educational content
|
| 1209 |
+
8. Click "Correct" → feedback message appears
|
| 1210 |
+
9. Click "Wrong" with "Ham" selected → correction logged
|
| 1211 |
+
10. Verify `feedback/feedback_log.csv` exists and has entries
|
| 1212 |
+
11. Adjust threshold slider → re-classify → label may change
|
| 1213 |
+
|
| 1214 |
+
Stop the server.
|
| 1215 |
+
|
| 1216 |
+
- [ ] **Step 3: Verify feedback CSV format**
|
| 1217 |
+
|
| 1218 |
+
```bash
|
| 1219 |
+
cat feedback/feedback_log.csv
|
| 1220 |
+
```
|
| 1221 |
+
|
| 1222 |
+
Expected: CSV with headers `timestamp,email_text,predicted_label,predicted_confidence,feedback,correct_label,threshold_used` and at least 2 rows from the tests above.
|
| 1223 |
+
|
| 1224 |
+
- [ ] **Step 4: Final commit if any fixes were needed**
|
| 1225 |
+
|
| 1226 |
+
```bash
|
| 1227 |
+
git add -A
|
| 1228 |
+
git commit -m "Fix any issues found during end-to-end testing"
|
| 1229 |
+
```
|
| 1230 |
+
|
| 1231 |
+
Only run this if Step 2 or 3 revealed issues that required fixes.
|
docs/superpowers/specs/2026-03-28-gradio-xai-merge-design.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Design: Merge XAI Features into Gradio Spam Classifier
|
| 2 |
+
|
| 3 |
+
**Date:** 2026-03-28
|
| 4 |
+
**Status:** Approved
|
| 5 |
+
**Project:** spam-classifier-gradio (Hugging Face Space)
|
| 6 |
+
|
| 7 |
+
## Goal
|
| 8 |
+
|
| 9 |
+
Merge the best features from the Streamlit-based `spam-xai-project` into the existing `spam-classifier-gradio` Gradio app to create a single, unified spam classifier Space with full XAI explanations and a user feedback loop for batch retraining.
|
| 10 |
+
|
| 11 |
+
## Decisions Made
|
| 12 |
+
|
| 13 |
+
- **Primary model:** Voting Ensemble (RF + LR + SVM) from `spam-classifier-gradio`
|
| 14 |
+
- **Drop Ollama/LLM features:** No Ollama on HF Spaces — replace with rule-based plain English summary
|
| 15 |
+
- **Drop OCR:** No image upload — text paste only
|
| 16 |
+
- **Drop Liquid AI Space:** Too large for free tier — remove from Spaces, keep as model repo
|
| 17 |
+
- **Retire XAI Streamlit Space:** Single merged Gradio space replaces both
|
| 18 |
+
- **Feedback:** Thumbs up/down logging to CSV + batch retrain script (not live retraining)
|
| 19 |
+
- **Course context:** Generic "university course project" — no specific class number or semester
|
| 20 |
+
|
| 21 |
+
## App Layout
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
┌─────────────────────────────────────────────────────┐
|
| 25 |
+
│ # Spam Email Classifier with XAI Explanations │
|
| 26 |
+
│ Subtitle + educational disclaimer │
|
| 27 |
+
├──────────────────────┬──────────────────────────────┤
|
| 28 |
+
│ Email Text Input │ Tabs: │
|
| 29 |
+
│ (textarea, 12 rows) │ [Result] [LIME] [SHAP] │
|
| 30 |
+
│ │ [ELI5] [Compare] [Summary] │
|
| 31 |
+
│ Upload .txt file │ [How It Works] │
|
| 32 |
+
│ │ │
|
| 33 |
+
│ Threshold slider │ (content changes per tab) │
|
| 34 |
+
│ [0.0 ──●── 1.0] │ │
|
| 35 |
+
│ │ ┌──────────────────────┐ │
|
| 36 |
+
│ [Classify] button │ │ 👍 👎 Feedback │ │
|
| 37 |
+
│ │ │ (after classify) │ │
|
| 38 |
+
│ Example emails │ └──────────────────────┘ │
|
| 39 |
+
└──────────────────────┴──────────────────────────────┘
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## 7 Tabs
|
| 43 |
+
|
| 44 |
+
### 1. Result
|
| 45 |
+
- Spam/ham label with confidence percentage
|
| 46 |
+
- Threshold used for the classification
|
| 47 |
+
- Key factors list (top 5 LIME features with direction)
|
| 48 |
+
- Existing `generate_summary()` logic enhanced
|
| 49 |
+
|
| 50 |
+
### 2. LIME
|
| 51 |
+
- Horizontal bar chart (existing)
|
| 52 |
+
- Caption: "LIME perturbs the input and fits a local model to see which features matter most"
|
| 53 |
+
- 10 features shown, colored red (spam) / blue (ham)
|
| 54 |
+
|
| 55 |
+
### 3. SHAP
|
| 56 |
+
- Metadata feature bar chart (existing)
|
| 57 |
+
- Caption: "SHAP uses game theory to assign each feature a contribution value"
|
| 58 |
+
- Uses KernelExplainer on the 24 metadata features
|
| 59 |
+
|
| 60 |
+
### 4. ELI5 (new)
|
| 61 |
+
- HTML rendering of ELI5 feature weights via `eli5.format_as_html()`
|
| 62 |
+
- Uses the RF sub-estimator extracted from VotingClassifier: `voting_model.named_estimators_['rf']`
|
| 63 |
+
- Caption: "ELI5 shows feature weights directly from the model's internals"
|
| 64 |
+
- Displayed in an HTML component within Gradio
|
| 65 |
+
|
| 66 |
+
### 5. Compare (new)
|
| 67 |
+
- Three columns showing top-5 features from LIME, SHAP, and ELI5
|
| 68 |
+
- Each feature shows its direction (spam/ham) and weight
|
| 69 |
+
- Note: LIME covers all features (TF-IDF + metadata), SHAP covers only the 24 metadata features, ELI5 covers all features via the RF sub-estimator. The comparison highlights where methods agree despite different feature scopes.
|
| 70 |
+
- Feature agreement analysis: count of shared features between LIME and SHAP top-10
|
| 71 |
+
- Lists the shared feature names
|
| 72 |
+
|
| 73 |
+
### 6. Summary (new — replaces Ollama AI Explanation)
|
| 74 |
+
- Rule-based plain English explanation built from XAI results
|
| 75 |
+
- Template structure:
|
| 76 |
+
1. Classification statement with confidence
|
| 77 |
+
2. Top 3 LIME features with direction
|
| 78 |
+
3. Top 2 SHAP metadata features with human-readable descriptions
|
| 79 |
+
4. LIME-SHAP agreement note
|
| 80 |
+
5. Closing sentence based on spam/ham + confidence level
|
| 81 |
+
- No LLM needed — string formatting from already-computed XAI data
|
| 82 |
+
|
| 83 |
+
### 7. How It Works (new)
|
| 84 |
+
- Static Markdown content explaining:
|
| 85 |
+
- What spam classification is and why it matters
|
| 86 |
+
- The model: Voting Ensemble (Random Forest + Logistic Regression + SVM)
|
| 87 |
+
- Feature extraction: TF-IDF (word importance) + 24 metadata features
|
| 88 |
+
- What each XAI method does in beginner-friendly language:
|
| 89 |
+
- LIME: "Tests what happens when words are removed"
|
| 90 |
+
- SHAP: "Calculates each feature's fair contribution using game theory"
|
| 91 |
+
- ELI5: "Shows the model's internal feature weights directly"
|
| 92 |
+
- What the feedback buttons do and how batch retraining works
|
| 93 |
+
- Educational disclaimer
|
| 94 |
+
|
| 95 |
+
## Feedback System
|
| 96 |
+
|
| 97 |
+
### User-facing UI
|
| 98 |
+
- Appears below tabs after classification
|
| 99 |
+
- Two buttons: "Correct" / "Wrong"
|
| 100 |
+
- When "👎 Wrong" is clicked, a dropdown appears to select correct label ("Spam" / "Ham")
|
| 101 |
+
- Counter shows: "X corrections collected"
|
| 102 |
+
|
| 103 |
+
### Feedback storage
|
| 104 |
+
- Logs to `feedback/feedback_log.csv`
|
| 105 |
+
- Columns: `timestamp`, `email_text` (truncated 500 chars), `predicted_label`, `predicted_confidence`, `feedback` (correct/wrong), `correct_label` (if wrong), `threshold_used`
|
| 106 |
+
- Directory created at app startup if missing
|
| 107 |
+
- CSV appended to, never overwritten
|
| 108 |
+
|
| 109 |
+
### Batch retrain
|
| 110 |
+
- `retrain.py` script updated to:
|
| 111 |
+
1. Read `feedback/feedback_log.csv`
|
| 112 |
+
2. Filter for "wrong" entries
|
| 113 |
+
3. Convert corrections to training examples (email text + correct label)
|
| 114 |
+
4. Append to existing training data
|
| 115 |
+
5. Retrain VotingClassifier with augmented dataset
|
| 116 |
+
6. Save new model files to `models/`
|
| 117 |
+
- Retrain runs locally, not live on the Space
|
| 118 |
+
- Workflow: pull CSV from Space -> retrain locally -> push updated models back
|
| 119 |
+
|
| 120 |
+
## Files Changed
|
| 121 |
+
|
| 122 |
+
| File | Action | Details |
|
| 123 |
+
|------|--------|---------|
|
| 124 |
+
| `app.py` | Major update | Add ELI5, Compare, Summary, How It Works tabs + feedback UI + threshold slider |
|
| 125 |
+
| `utils.py` | No changes | Shared preprocessing and feature engineering stays as-is |
|
| 126 |
+
| `requirements.txt` | Update | Add `eli5>=0.13.0`, bump `gradio>=4.44.0` |
|
| 127 |
+
| `retrain.py` | Update | Add feedback CSV ingestion for batch retraining |
|
| 128 |
+
| `README.md` | Update | New frontmatter (tags, sdk_version) + updated feature description |
|
| 129 |
+
| `feedback/.gitkeep` | New | Empty directory for feedback log accumulation |
|
| 130 |
+
|
| 131 |
+
## Dependencies
|
| 132 |
+
|
| 133 |
+
### Current (no changes)
|
| 134 |
+
- numpy, pandas, matplotlib, scikit-learn, scipy, nltk, lime, shap, gradio, joblib, tqdm
|
| 135 |
+
|
| 136 |
+
### New
|
| 137 |
+
- `eli5>=0.13.0` — feature weight explanations from model internals
|
| 138 |
+
|
| 139 |
+
### Removed (vs Streamlit app)
|
| 140 |
+
- streamlit, pytesseract, PIL, requests, wordcloud, seaborn — none of these needed
|
| 141 |
+
|
| 142 |
+
## ELI5 Integration Detail
|
| 143 |
+
|
| 144 |
+
The VotingClassifier contains a RandomForestClassifier as `voting_model.named_estimators_['rf']`. ELI5's `explain_prediction()` works with sklearn tree-based models directly. We extract the RF sub-estimator at startup alongside the existing model loading, and pass it to ELI5 and SHAP TreeExplainer where needed.
|
| 145 |
+
|
| 146 |
+
```python
|
| 147 |
+
# Extract RF from voting ensemble for ELI5 and TreeExplainer
|
| 148 |
+
raw_rf = voting_model.named_estimators_['rf']
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
## Deployment
|
| 152 |
+
|
| 153 |
+
- Single Hugging Face Space: `VoltageVagabond/spam-classifier-gradio`
|
| 154 |
+
- SDK: Gradio (free CPU tier — sklearn models are lightweight)
|
| 155 |
+
- The `spam-xai-project` Space can be paused/deleted
|
| 156 |
+
- The `spam-classifier-liquid-space` Space should be deleted (model stays as repo)
|
feedback/.gitkeep
ADDED
|
File without changes
|
feedback/feedback_log.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
timestamp,email_text,predicted_label,predicted_confidence,feedback,correct_label,threshold_used
|
| 2 |
+
2026-03-28T23:22:40.208203,test email,SPAM,0.9500,correct,,0.5000
|
requirements.txt
CHANGED
|
@@ -6,6 +6,7 @@ scipy>=1.11.0
|
|
| 6 |
nltk>=3.8.0
|
| 7 |
lime>=0.2.0
|
| 8 |
shap>=0.44.0
|
| 9 |
-
|
|
|
|
| 10 |
joblib>=1.3.0
|
| 11 |
tqdm>=4.65.0
|
|
|
|
| 6 |
nltk>=3.8.0
|
| 7 |
lime>=0.2.0
|
| 8 |
shap>=0.44.0
|
| 9 |
+
eli5>=0.13.0
|
| 10 |
+
gradio>=4.44.0
|
| 11 |
joblib>=1.3.0
|
| 12 |
tqdm>=4.65.0
|
retrain.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Batch retrain script for the spam classifier
|
| 2 |
+
# Reads feedback corrections from feedback/feedback_log.csv,
|
| 3 |
+
# merges them into the original training data, and retrains
|
| 4 |
+
# the VotingClassifier ensemble.
|
| 5 |
+
#
|
| 6 |
+
# Usage:
|
| 7 |
+
# python3 retrain.py # retrain with feedback
|
| 8 |
+
# python3 retrain.py --no-feedback # retrain with original data only
|
| 9 |
+
|
| 10 |
+
import sys
|
| 11 |
+
import csv
|
| 12 |
+
import warnings
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from sklearn.model_selection import train_test_split
|
| 18 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 19 |
+
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
|
| 20 |
+
from sklearn.linear_model import LogisticRegression
|
| 21 |
+
from sklearn.svm import LinearSVC
|
| 22 |
+
from sklearn.calibration import CalibratedClassifierCV
|
| 23 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 24 |
+
from sklearn.metrics import classification_report, precision_recall_curve
|
| 25 |
+
from scipy.sparse import hstack, csr_matrix
|
| 26 |
+
import joblib
|
| 27 |
+
|
| 28 |
+
from utils import preprocess_text, compute_metadata_features, META_FEATURE_NAMES
|
| 29 |
+
|
| 30 |
+
warnings.filterwarnings('ignore', category=FutureWarning)
|
| 31 |
+
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
| 32 |
+
|
| 33 |
+
project_dir = Path(__file__).parent
|
| 34 |
+
data_dir = project_dir / 'data'
|
| 35 |
+
models_dir = project_dir / 'models'
|
| 36 |
+
feedback_csv = project_dir / 'feedback' / 'feedback_log.csv'
|
| 37 |
+
random_state = 42
|
| 38 |
+
KAGGLE_CAP = 100_000
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_feedback_corrections():
|
| 42 |
+
"""Read feedback CSV and return a DataFrame of corrections."""
|
| 43 |
+
if not feedback_csv.exists():
|
| 44 |
+
print("No feedback file found.")
|
| 45 |
+
return pd.DataFrame(columns=['text', 'label'])
|
| 46 |
+
|
| 47 |
+
corrections = []
|
| 48 |
+
with open(feedback_csv, 'r', encoding='utf-8') as f:
|
| 49 |
+
reader = csv.DictReader(f)
|
| 50 |
+
for row in reader:
|
| 51 |
+
if row.get('feedback') == 'wrong' and row.get('correct_label'):
|
| 52 |
+
label = 1 if row['correct_label'].lower() == 'spam' else 0
|
| 53 |
+
corrections.append({
|
| 54 |
+
'text': row['email_text'],
|
| 55 |
+
'label': label,
|
| 56 |
+
})
|
| 57 |
+
|
| 58 |
+
df = pd.DataFrame(corrections)
|
| 59 |
+
print(f"Found {len(df)} corrections in feedback log.")
|
| 60 |
+
return df
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
use_feedback = '--no-feedback' not in sys.argv
|
| 65 |
+
|
| 66 |
+
print("Loading training data...")
|
| 67 |
+
kaggle_path = data_dir / 'spam_Emails_data.csv'
|
| 68 |
+
github_dir = data_dir / 'email-dataset-main' / 'email-dataset-main'
|
| 69 |
+
|
| 70 |
+
frames = []
|
| 71 |
+
|
| 72 |
+
if kaggle_path.exists():
|
| 73 |
+
kaggle_df = pd.read_csv(kaggle_path)
|
| 74 |
+
if 'label' in kaggle_df.columns and 'text' in kaggle_df.columns:
|
| 75 |
+
kaggle_df['label'] = kaggle_df['label'].map({'spam': 1, 'ham': 0})
|
| 76 |
+
kaggle_df = kaggle_df.dropna(subset=['label', 'text'])
|
| 77 |
+
elif 'v1' in kaggle_df.columns and 'v2' in kaggle_df.columns:
|
| 78 |
+
kaggle_df = kaggle_df.rename(columns={'v1': 'label_str', 'v2': 'text'})
|
| 79 |
+
kaggle_df['label'] = kaggle_df['label_str'].map({'spam': 1, 'ham': 0})
|
| 80 |
+
kaggle_df = kaggle_df[['text', 'label']].dropna()
|
| 81 |
+
|
| 82 |
+
if len(kaggle_df) > KAGGLE_CAP:
|
| 83 |
+
kaggle_df = kaggle_df.groupby('label', group_keys=False).apply(
|
| 84 |
+
lambda x: x.sample(min(len(x), KAGGLE_CAP // 2), random_state=random_state)
|
| 85 |
+
)
|
| 86 |
+
frames.append(kaggle_df[['text', 'label']])
|
| 87 |
+
print(f" Kaggle: {len(kaggle_df)} emails")
|
| 88 |
+
|
| 89 |
+
if github_dir.exists():
|
| 90 |
+
for label_dir in github_dir.iterdir():
|
| 91 |
+
if label_dir.is_dir() and label_dir.name in ('spam', 'ham'):
|
| 92 |
+
lbl = 1 if label_dir.name == 'spam' else 0
|
| 93 |
+
for f in label_dir.iterdir():
|
| 94 |
+
if f.is_file():
|
| 95 |
+
try:
|
| 96 |
+
text = f.read_text(encoding='utf-8', errors='ignore')
|
| 97 |
+
frames.append(pd.DataFrame([{'text': text, 'label': lbl}]))
|
| 98 |
+
except Exception:
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
if not frames:
|
| 102 |
+
print("ERROR: No training data found in data/ directory.")
|
| 103 |
+
sys.exit(1)
|
| 104 |
+
|
| 105 |
+
df = pd.concat(frames, ignore_index=True)
|
| 106 |
+
print(f" Total original: {len(df)} emails")
|
| 107 |
+
|
| 108 |
+
if use_feedback:
|
| 109 |
+
feedback_df = load_feedback_corrections()
|
| 110 |
+
if len(feedback_df) > 0:
|
| 111 |
+
df = pd.concat([df, feedback_df], ignore_index=True)
|
| 112 |
+
print(f" After feedback merge: {len(df)} emails")
|
| 113 |
+
|
| 114 |
+
print("Preprocessing...")
|
| 115 |
+
df['clean'] = df['text'].apply(preprocess_text)
|
| 116 |
+
df = df[df['clean'].str.len() > 0]
|
| 117 |
+
|
| 118 |
+
X_text = df['clean'].values
|
| 119 |
+
y = df['label'].values
|
| 120 |
+
|
| 121 |
+
X_train_text, X_test_text, y_train, y_test = train_test_split(
|
| 122 |
+
X_text, y, test_size=0.2, random_state=random_state, stratify=y
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
print("Fitting TF-IDF...")
|
| 126 |
+
tfidf = TfidfVectorizer(max_features=3000, ngram_range=(1, 3),
|
| 127 |
+
min_df=2, max_df=0.95)
|
| 128 |
+
X_train_tfidf = tfidf.fit_transform(X_train_text)
|
| 129 |
+
X_test_tfidf = tfidf.transform(X_test_text)
|
| 130 |
+
|
| 131 |
+
print("Computing metadata features...")
|
| 132 |
+
# We need to get the original text (not cleaned) for metadata features
|
| 133 |
+
# Use index alignment with the split
|
| 134 |
+
train_orig = df.loc[df['clean'].isin(X_train_text), 'text'].values[:len(X_train_text)]
|
| 135 |
+
test_orig = df.loc[df['clean'].isin(X_test_text), 'text'].values[:len(X_test_text)]
|
| 136 |
+
|
| 137 |
+
X_train_meta = compute_metadata_features(train_orig.tolist())
|
| 138 |
+
X_test_meta = compute_metadata_features(test_orig.tolist())
|
| 139 |
+
|
| 140 |
+
scaler = MinMaxScaler()
|
| 141 |
+
X_train_meta_scaled = scaler.fit_transform(X_train_meta)
|
| 142 |
+
X_test_meta_scaled = scaler.transform(X_test_meta)
|
| 143 |
+
|
| 144 |
+
X_train = hstack([X_train_tfidf, csr_matrix(X_train_meta_scaled)])
|
| 145 |
+
X_test = hstack([X_test_tfidf, csr_matrix(X_test_meta_scaled)])
|
| 146 |
+
|
| 147 |
+
feature_names_list = tfidf.get_feature_names_out().tolist() + META_FEATURE_NAMES
|
| 148 |
+
|
| 149 |
+
print("Training VotingClassifier ensemble...")
|
| 150 |
+
ensemble = VotingClassifier(
|
| 151 |
+
estimators=[
|
| 152 |
+
('rf', RandomForestClassifier(
|
| 153 |
+
n_estimators=200, n_jobs=-1,
|
| 154 |
+
class_weight='balanced', random_state=random_state)),
|
| 155 |
+
('lr', LogisticRegression(
|
| 156 |
+
max_iter=1000, class_weight='balanced', random_state=random_state)),
|
| 157 |
+
('svm', CalibratedClassifierCV(
|
| 158 |
+
LinearSVC(class_weight='balanced', max_iter=2000,
|
| 159 |
+
random_state=random_state))),
|
| 160 |
+
],
|
| 161 |
+
voting='soft',
|
| 162 |
+
)
|
| 163 |
+
ensemble.fit(X_train, y_train)
|
| 164 |
+
|
| 165 |
+
y_pred = ensemble.predict(X_test)
|
| 166 |
+
print("\nClassification Report:")
|
| 167 |
+
print(classification_report(y_test, y_pred, target_names=['Ham', 'Spam']))
|
| 168 |
+
|
| 169 |
+
y_scores = ensemble.predict_proba(X_test)[:, 1]
|
| 170 |
+
precisions, recalls, thresholds = precision_recall_curve(y_test, y_scores)
|
| 171 |
+
f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
|
| 172 |
+
best_idx = np.argmax(f1_scores)
|
| 173 |
+
optimal_threshold = float(thresholds[best_idx])
|
| 174 |
+
print(f"Optimal threshold: {optimal_threshold:.4f}")
|
| 175 |
+
|
| 176 |
+
models_dir.mkdir(exist_ok=True)
|
| 177 |
+
joblib.dump(ensemble, models_dir / 'voting_model.joblib')
|
| 178 |
+
joblib.dump(tfidf, models_dir / 'tfidf_vectorizer.joblib')
|
| 179 |
+
joblib.dump(scaler, models_dir / 'meta_scaler.joblib')
|
| 180 |
+
joblib.dump(feature_names_list, models_dir / 'feature_names.joblib')
|
| 181 |
+
joblib.dump(optimal_threshold, models_dir / 'optimal_threshold.joblib')
|
| 182 |
+
|
| 183 |
+
sample_size = min(200, X_train.shape[0])
|
| 184 |
+
sample_idx = np.random.RandomState(random_state).choice(
|
| 185 |
+
X_train.shape[0], sample_size, replace=False)
|
| 186 |
+
training_sample = X_train[sample_idx].toarray()
|
| 187 |
+
joblib.dump(training_sample, models_dir / 'training_sample.joblib')
|
| 188 |
+
|
| 189 |
+
print(f"\nAll models saved to {models_dir}/")
|
| 190 |
+
if use_feedback:
|
| 191 |
+
corrections = load_feedback_corrections()
|
| 192 |
+
print(f"Feedback corrections incorporated: {len(corrections)}")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == '__main__':
|
| 196 |
+
main()
|