# Gradio web app for the Spam Email Classifier with XAI explanations # University course project — Explainable AI for spam detection # Features: LIME, SHAP, ELI5, side-by-side comparison, plain English summary, # user feedback logging, and batch retrain support. import csv import os from datetime import datetime from pathlib import Path import nltk nltk.download('stopwords', quiet=True) import eli5 import gradio as gr import lime import lime.lime_tabular import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import shap import joblib from scipy.sparse import hstack, csr_matrix from utils import (preprocess_text, compute_metadata_features, META_FEATURE_NAMES, FEATURE_DESCRIPTIONS) # --------------------------------------------------------------------------- # 1. Model Loading # --------------------------------------------------------------------------- models_dir = Path(__file__).parent / 'models' feedback_dir = Path(__file__).parent / 'feedback' feedback_dir.mkdir(exist_ok=True) FEEDBACK_CSV = feedback_dir / 'feedback_log.csv' try: voting_model = joblib.load(models_dir / 'voting_model.joblib') tfidf_vectorizer = joblib.load(models_dir / 'tfidf_vectorizer.joblib') meta_scaler = joblib.load(models_dir / 'meta_scaler.joblib') feature_names = joblib.load(models_dir / 'feature_names.joblib') optimal_threshold = joblib.load(models_dir / 'optimal_threshold.joblib') training_sample = joblib.load(models_dir / 'training_sample.joblib') raw_rf = voting_model.named_estimators_['rf'] print(f"All models loaded. Threshold = {optimal_threshold:.4f}") except FileNotFoundError as e: print(f"Model file not found: {e}") voting_model = None tfidf_vectorizer = None meta_scaler = None feature_names = None optimal_threshold = None training_sample = None raw_rf = None # --------------------------------------------------------------------------- # 2. LIME Explainer Setup # --------------------------------------------------------------------------- lime_explainer = None if training_sample is not None and feature_names is not None: lime_explainer = lime.lime_tabular.LimeTabularExplainer( training_data=training_sample, feature_names=feature_names, class_names=['Ham', 'Spam'], mode='classification', ) print("LIME explainer ready.") # --------------------------------------------------------------------------- # 3. classify_email # --------------------------------------------------------------------------- def classify_email(email_text, threshold): """Classify a single email. Returns (label, confidence, spam_proba, combined_features).""" cleaned_text = preprocess_text(email_text) tfidf_features = tfidf_vectorizer.transform([cleaned_text]) meta_raw = compute_metadata_features([email_text]) meta_scaled = meta_scaler.transform(meta_raw) combined = hstack([tfidf_features, csr_matrix(meta_scaled)]) spam_proba = voting_model.predict_proba(combined)[0][1] if spam_proba >= threshold: label = "SPAM" confidence = spam_proba else: label = "HAM (Not Spam)" confidence = 1.0 - spam_proba return label, confidence, spam_proba, combined # --------------------------------------------------------------------------- # 4. LIME explanation # --------------------------------------------------------------------------- def generate_lime_explanation(combined_features): """Generate LIME explanation. Returns (figure, explanation) or (None, None).""" if lime_explainer is None: return None, None instance = combined_features.toarray()[0] explanation = lime_explainer.explain_instance( instance, voting_model.predict_proba, num_features=10, ) fig = explanation.as_pyplot_figure() fig.tight_layout() return fig, explanation # --------------------------------------------------------------------------- # 5. SHAP explanation (metadata features only — fast) # --------------------------------------------------------------------------- def generate_shap_explanation(email_text): """Generate SHAP bar chart for metadata features. Returns (figure, shap_values, top_indices) or (None, None, None).""" if training_sample is None or voting_model is None: return None, None, None num_meta = len(META_FEATURE_NAMES) background_meta = training_sample[:50, -num_meta:] meta_raw = compute_metadata_features([email_text]) meta_scaled = meta_scaler.transform(meta_raw) num_tfidf = training_sample.shape[1] - num_meta def predict_with_meta_only(meta_features): n_samples = meta_features.shape[0] tfidf_zeros = csr_matrix((n_samples, num_tfidf)) combined = hstack([tfidf_zeros, csr_matrix(meta_features)]) return voting_model.predict_proba(combined) explainer = shap.KernelExplainer(predict_with_meta_only, background_meta) shap_values = explainer.shap_values(meta_scaled, nsamples=100) if isinstance(shap_values, list): sv = np.array(shap_values[1]).flatten() else: sv = np.array(shap_values).flatten() if len(sv) > num_meta: sv = sv[-num_meta:] top_idx = np.argsort(np.abs(sv))[::-1][:10] sorted_indices = np.argsort(np.abs(sv)) sorted_names = [META_FEATURE_NAMES[idx] for idx in sorted_indices.tolist()] sorted_values = sv[sorted_indices] fig, ax = plt.subplots(figsize=(8, 6)) colors = ['#d62728' if val > 0 else '#1f77b4' for val in sorted_values] ax.barh(sorted_names, sorted_values, color=colors) ax.set_xlabel('SHAP Value (impact on spam probability)') ax.set_title('SHAP Feature Importance (Metadata Features)') ax.axvline(x=0, color='black', linewidth=0.5) fig.tight_layout() return fig, sv, top_idx # --------------------------------------------------------------------------- # 6. ELI5 explanation # --------------------------------------------------------------------------- def generate_eli5_explanation(combined_features): """Generate ELI5 HTML and top feature names. Returns (html_string, feature_names_list) or (None, None).""" if raw_rf is None or feature_names is None: return None, None instance = combined_features.toarray()[0] eli5_exp = eli5.explain_prediction(raw_rf, instance, feature_names=feature_names, top=10) html = eli5.format_as_html(eli5_exp) eli5_top5 = eli5.explain_prediction(raw_rf, instance, feature_names=feature_names, top=5) top_names = [] if hasattr(eli5_top5, 'targets') and eli5_top5.targets: for fw in eli5_top5.targets[0].feature_weights.pos[:5]: top_names.append(fw.feature) for fw in eli5_top5.targets[0].feature_weights.neg[:5]: top_names.append(fw.feature) return html, top_names # --------------------------------------------------------------------------- # 7. Plain English summary (replaces Ollama LLM) # --------------------------------------------------------------------------- def generate_plain_summary(label, confidence, spam_proba, lime_explanation, shap_sv, shap_top_idx): """Build a rule-based plain English summary from XAI results.""" summary = f"### Classification: **{label}** ({confidence:.0%} confidence)\n\n" if lime_explanation is not None: feature_list = lime_explanation.as_list() summary += "**Key words driving this decision (LIME):**\n" for feat_rule, weight in feature_list[:3]: direction = "pushes toward spam" if weight > 0 else "pushes toward ham" summary += f"- **{feat_rule}** — {direction}\n" summary += "\n" if shap_sv is not None and shap_top_idx is not None: summary += "**Important email characteristics (SHAP):**\n" for i in shap_top_idx[:2]: feat_name = META_FEATURE_NAMES[i] description = FEATURE_DESCRIPTIONS.get(feat_name, feat_name) direction = "spam signal" if shap_sv[i] > 0 else "ham signal" summary += f"- **{feat_name}** ({description}) — {direction}\n" summary += "\n" if lime_explanation is not None and shap_top_idx is not None: lime_top = set(f[0] for f in lime_explanation.as_list()[:10]) shap_top = set(META_FEATURE_NAMES[i] for i in shap_top_idx[:10]) overlap = lime_top & shap_top if overlap: summary += f"**Method agreement:** LIME and SHAP both flag: {', '.join(sorted(overlap))}\n\n" if "SPAM" in label: if confidence > 0.9: summary += "The model is highly confident this email contains patterns commonly seen in spam or phishing attempts." elif confidence > 0.7: summary += "The model found several spam-like patterns in this email." else: summary += "The model leans toward spam, but the evidence is not overwhelming. Use your judgment." else: if confidence > 0.9: summary += "The model is highly confident this is a legitimate email." elif confidence > 0.7: summary += "The model found this email to be mostly consistent with legitimate messages." else: summary += "The model leans toward legitimate, but there are some spam-like features. Review carefully." return summary # --------------------------------------------------------------------------- # 8. Side-by-side comparison # --------------------------------------------------------------------------- def generate_comparison(lime_explanation, shap_sv, shap_top_idx, eli5_names): """Build a markdown comparison of top features from each XAI method.""" md = "### Side-by-Side: Top Features by Method\n\n" md += "| Rank | LIME | SHAP (metadata) | ELI5 |\n" md += "|------|------|-----------------|------|\n" lime_top5 = [] if lime_explanation is not None: for feat, w in lime_explanation.as_list()[:5]: direction = "spam" if w > 0 else "ham" lime_top5.append(f"{feat} ({direction}, {w:+.3f})") shap_top5 = [] if shap_sv is not None and shap_top_idx is not None: for i in shap_top_idx[:5]: direction = "spam" if shap_sv[i] > 0 else "ham" shap_top5.append(f"{META_FEATURE_NAMES[i]} ({direction}, {shap_sv[i]:+.3f})") eli5_top5 = (eli5_names or [])[:5] for rank in range(5): lime_cell = lime_top5[rank] if rank < len(lime_top5) else "—" shap_cell = shap_top5[rank] if rank < len(shap_top5) else "—" eli5_cell = eli5_top5[rank] if rank < len(eli5_top5) else "—" md += f"| {rank+1} | {lime_cell} | {shap_cell} | {eli5_cell} |\n" if lime_explanation is not None and shap_top_idx is not None: lime_set = set(f[0] for f in lime_explanation.as_list()[:10]) shap_set = set(META_FEATURE_NAMES[i] for i in shap_top_idx[:10]) overlap = lime_set & shap_set md += f"\n**LIME-SHAP agreement** (top 10): **{len(overlap)}** shared features" if overlap: md += f"\nShared: {', '.join(sorted(overlap))}" md += "\n\n*Note: LIME covers all features (words + metadata), SHAP covers only the 24 metadata features, " md += "ELI5 uses the Random Forest sub-estimator's internal weights.*" return md # --------------------------------------------------------------------------- # 9. Feedback logging # --------------------------------------------------------------------------- def log_feedback(email_text, predicted_label, predicted_confidence, threshold, feedback_type, correct_label=None): """Append one feedback row to the CSV log.""" write_header = not FEEDBACK_CSV.exists() with open(FEEDBACK_CSV, 'a', newline='', encoding='utf-8') as f: writer = csv.writer(f) if write_header: writer.writerow(['timestamp', 'email_text', 'predicted_label', 'predicted_confidence', 'feedback', 'correct_label', 'threshold_used']) writer.writerow([ datetime.now().isoformat(), email_text[:500], predicted_label, f"{predicted_confidence:.4f}", feedback_type, correct_label or '', f"{threshold:.4f}", ]) return count_corrections() def count_corrections(): """Count the number of 'wrong' entries in the feedback log.""" if not FEEDBACK_CSV.exists(): return 0 count = 0 with open(FEEDBACK_CSV, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: if row.get('feedback') == 'wrong': count += 1 return count # --------------------------------------------------------------------------- # 10. Example Emails # --------------------------------------------------------------------------- EXAMPLE_EMAILS = [ ["Subject: URGENT - You Have Won $5,000,000!!!\n\nDear Friend,\n\nCONGRATULATIONS!!! You have been selected as the winner of our international lottery program!!!\nTo claim your $5,000,000 USD prize, click the link below IMMEDIATELY and provide your bank details.\n\nACT NOW - This offer expires in 24 hours!!!\n\nClick here: http://totally-legit-prize.com/claim\nSend $500 processing fee to unlock your winnings.\n\nBest regards,\nDr. Prince Mohammed"], ["Subject: Team sync Thursday 2pm\n\nHi everyone,\n\nJust a reminder that we have our weekly team sync this Thursday at 2pm in Conference Room B.\n\nAgenda:\n- Sprint review\n- Q2 planning discussion\n- New hire onboarding update\n\nPlease come prepared with your status updates.\n\nThanks,\nSarah"], ["Subject: Your account has been compromised!\n\nDear Customer,\n\nWe detected suspicious activity on your account. Click here immediately to verify your identity: http://secure-bank-login.com/verify\n\nIf you do not verify within 24 hours, your account will be permanently locked.\n\nSecurity Team"], ["Subject: Thanksgiving dinner plans\n\nHi everyone!\n\nI wanted to start planning for Thanksgiving dinner. I'm thinking we could do it at my place this year. What does everyone think about 4pm?\n\nLet me know if you have any dietary restrictions or if you want to bring a dish.\n\nLove,\nMom"], ["Subject: Best prices on V1AGRA and C1ALIS!!!\n\n$$$ SAVE BIG $$$\nBuy now and get 80% OFF!!!\nNo prescription needed! Free shipping!\nOrder at http://cheap-pharma-deals.com\n\nLIMITED TIME OFFER - ACT NOW!"], ] # --------------------------------------------------------------------------- # 11. Main orchestration function # --------------------------------------------------------------------------- def classify_and_explain(email_text, uploaded_file, threshold): """Main function called by Gradio. Returns all outputs for all tabs + feedback state.""" if uploaded_file is not None: try: file_content = Path(uploaded_file).read_text(encoding='utf-8') email_text = file_content except Exception: empty = ("Could not read file.", None, None, "Error reading file.", "", "", "") return empty if email_text is None or email_text.strip() == '': empty = ("Please enter email text or upload a file.", None, None, "", "", "", "") return empty if voting_model is None: empty = ("Models not found. Run `python3 train.py` first.", None, None, "", "", "", "") return empty label, confidence, spam_proba, combined = classify_email(email_text, threshold) lime_fig, lime_exp = generate_lime_explanation(combined) try: shap_fig, shap_sv, shap_top_idx = generate_shap_explanation(email_text) except Exception as e: print(f"SHAP error: {e}") shap_fig, shap_sv, shap_top_idx = None, None, None try: eli5_html, eli5_names = generate_eli5_explanation(combined) except Exception as e: print(f"ELI5 error: {e}") eli5_html, eli5_names = None, None result_md = f"## {'SPAM' if 'SPAM' in label else 'HAM (Not Spam)'}\n\n" result_md += f"**Confidence:** {confidence:.1%}\n\n" result_md += f"**Threshold:** {threshold:.0%}\n\n" result_md += f"**Spam probability:** {spam_proba:.1%}\n\n" if lime_exp is not None: result_md += "**Key factors:**\n" for feat_rule, weight in lime_exp.as_list()[:5]: direction = "pushes toward spam" if weight > 0 else "pushes toward ham" result_md += f"- **{feat_rule}** {direction}\n" comparison_md = generate_comparison(lime_exp, shap_sv, shap_top_idx, eli5_names) summary_md = generate_plain_summary(label, confidence, spam_proba, lime_exp, shap_sv, shap_top_idx) eli5_display = eli5_html or "
ELI5 explanation not available.
" return (result_md, lime_fig, shap_fig, eli5_display, comparison_md, summary_md, f"{label}|||{confidence:.4f}|||{threshold:.4f}|||{email_text[:500]}") # --------------------------------------------------------------------------- # 12. Feedback handlers # --------------------------------------------------------------------------- def handle_correct(hidden_state): """Log positive feedback.""" if not hidden_state: return "No classification to give feedback on." parts = hidden_state.split('|||') if len(parts) < 4: return "No classification to give feedback on." label, conf, thresh, email = parts[0], float(parts[1]), float(parts[2]), parts[3] corrections = log_feedback(email, label, float(conf), float(thresh), 'correct') return f"Thanks for the feedback! ({corrections} corrections collected so far)" def handle_wrong(hidden_state, correct_label): """Log negative feedback with the user's correction.""" if not hidden_state: return "No classification to give feedback on." parts = hidden_state.split('|||') if len(parts) < 4: return "No classification to give feedback on." label, conf, thresh, email = parts[0], float(parts[1]), float(parts[2]), parts[3] corrections = log_feedback(email, label, float(conf), float(thresh), 'wrong', correct_label) return f"Correction logged! ({corrections} corrections collected so far)" # --------------------------------------------------------------------------- # 13. Gradio Blocks UI # --------------------------------------------------------------------------- HOW_IT_WORKS_MD = """ ## How This App Works ### What is spam classification? Spam classification automatically identifies unwanted or malicious emails (spam) vs. legitimate messages (ham). This helps protect users from phishing scams, fraudulent offers, and unwanted advertising. ### The Model This app uses a **Voting Ensemble** — three different machine learning models that each "vote" on whether an email is spam: - **Random Forest** — builds many decision trees and takes the majority vote - **Logistic Regression** — finds a mathematical boundary between spam and ham - **Support Vector Machine (SVM)** — finds the widest possible margin between classes By combining all three, the ensemble is more accurate than any single model alone. ### Feature Extraction The model looks at two types of features: - **TF-IDF (Term Frequency-Inverse Document Frequency)** — measures how important each word is. Common spam words like "prize" or "click" get high scores. - **24 Metadata Features** — structural patterns like exclamation mark density, dollar sign count, ALL CAPS ratio, URL count, and more. ### Explainable AI (XAI) Methods This app doesn't just classify — it explains **why**: - **LIME** — Removes words one at a time and watches how the prediction changes. Shows which words matter most. - **SHAP** — Uses game theory to calculate each feature's "fair share" of the prediction. Based on Nobel Prize-winning mathematics. - **ELI5** — Looks directly at the model's internal weights to show which features it relies on most. ### Feedback & Retraining When you click "Correct" or "Wrong", your feedback is saved. After enough corrections accumulate, the model can be retrained with the new examples to improve over time. This is called **human-in-the-loop machine learning**. ### Disclaimer This model was created as a university course project. It is intended for **educational and research purposes only** and should not be used as a sole spam filter in production. Always use established email security tools for real-world spam filtering. """ theme = gr.themes.Soft( primary_hue="blue", secondary_hue="red", neutral_hue="slate", font=gr.themes.GoogleFont("Inter"), font_mono=gr.themes.GoogleFont("IBM Plex Mono"), ) custom_css = """ /* ── Container ── */ .gradio-container { max-width: 1180px !important; margin: 0 auto !important; padding: 1.5rem 2rem !important; } /* ── Top bar ── */ .topbar { background: linear-gradient(135deg, #f8fafc 0%, #eef2ff 100%); border: 1px solid #e2e8f0; border-radius: 14px; padding: 1.4rem 1.8rem 1.2rem; margin-bottom: 1.2rem; box-shadow: 0 1px 3px rgba(0,0,0,0.06); text-align: center; } .topbar-title { font-size: 22px; font-weight: 700; color: #1e293b; margin: 0 0 0.3rem; } .topbar-subtitle { font-size: 13px; color: #64748b; margin: 0 0 0.7rem; } .topbar-badges { display: flex; justify-content: center; gap: 0.5rem; flex-wrap: wrap; } .topbar-badge { display: inline-block; background: #e0e7ff; color: #3730a3; font-size: 11.5px; font-weight: 600; padding: 0.25rem 0.7rem; border-radius: 999px; letter-spacing: 0.02em; } /* ── Input panel (left column) ── */ .input-panel { background: linear-gradient(180deg, #ffffff 0%, #f8fafc 100%); border: 1px solid #e2e8f0; border-radius: 14px; padding: 1.2rem; box-shadow: 0 1px 3px rgba(0,0,0,0.04); } /* ── Output panel (right column) ── */ .output-panel { background: linear-gradient(180deg, #ffffff 0%, #f8fafc 100%); border: 1px solid #e2e8f0; border-radius: 14px; padding: 1.2rem; box-shadow: 0 1px 3px rgba(0,0,0,0.04); } .output-panel .plot-container { max-height: 420px; overflow-y: auto; } /* ── Feedback card ── */ .feedback-card { background: linear-gradient(135deg, #f8fafc 0%, #f1f5f9 100%); border: 1px solid #e2e8f0; border-radius: 14px; padding: 1rem 1.4rem; margin-top: 1rem; box-shadow: 0 1px 3px rgba(0,0,0,0.04); } /* ── Classify button ── */ .classify-btn button { border-radius: 10px !important; } /* ── Responsive ── */ @media (max-width: 980px) { .gradio-container { padding: 1rem !important; } .topbar { padding: 1rem 1.2rem; } .input-panel, .output-panel { min-width: 0 !important; } } """ TOPBAR_HTML = """ """ with gr.Blocks(title="Spam Email Classifier with XAI", theme=theme, css=custom_css) as demo: gr.HTML(TOPBAR_HTML) hidden_state = gr.State("") with gr.Row(equal_height=False): with gr.Column(scale=2, min_width=360, elem_classes="input-panel"): email_input = gr.Textbox( label="Email Text", placeholder="Paste your email here...", lines=8, autoscroll=False, ) file_input = gr.File( label="Or upload a .txt file", file_types=['.txt'], ) threshold_slider = gr.Slider( minimum=0.0, maximum=1.0, step=0.05, value=optimal_threshold if optimal_threshold else 0.5, label="Classification Threshold", info="Emails with spam probability above this are classified as spam.", ) classify_btn = gr.Button("Classify", variant="primary", size="lg", elem_classes="classify-btn") with gr.Accordion("Example Emails", open=False): gr.Examples( examples=EXAMPLE_EMAILS, inputs=[email_input], label="Click to load an example", cache_examples=False, ) with gr.Column(scale=3, min_width=480, elem_classes="output-panel"): with gr.Tabs(): with gr.Tab("Result"): result_output = gr.Markdown(label="Classification Result") with gr.Tab("LIME"): gr.Markdown("*LIME perturbs the input and fits a local model " "to see which features matter most.*") lime_output = gr.Plot(label="LIME Explanation") with gr.Tab("SHAP"): gr.Markdown("*SHAP uses game theory to assign each feature " "a contribution value.*") shap_output = gr.Plot(label="SHAP Explanation") with gr.Tab("ELI5"): gr.Markdown("*ELI5 shows feature weights directly from the " "model's internals.*") eli5_output = gr.HTML(label="ELI5 Explanation") with gr.Tab("Compare"): compare_output = gr.Markdown(label="Method Comparison") with gr.Tab("Summary"): summary_output = gr.Markdown(label="Plain English Summary") with gr.Tab("How It Works"): gr.Markdown(HOW_IT_WORKS_MD) with gr.Group(elem_classes="feedback-card"): with gr.Row(): feedback_msg = gr.Markdown("**Was this classification correct?**") correct_btn = gr.Button("Correct", variant="secondary", scale=0, min_width=100) wrong_btn = gr.Button("Wrong", variant="stop", scale=0, min_width=100) correction_dropdown = gr.Dropdown( choices=["Spam", "Ham"], label="Correct label", scale=0, min_width=120, ) classify_btn.click( fn=classify_and_explain, inputs=[email_input, file_input, threshold_slider], outputs=[result_output, lime_output, shap_output, eli5_output, compare_output, summary_output, hidden_state], ) correct_btn.click( fn=handle_correct, inputs=[hidden_state], outputs=[feedback_msg], ) wrong_btn.click( fn=handle_wrong, inputs=[hidden_state, correction_dropdown], outputs=[feedback_msg], ) # --------------------------------------------------------------------------- # 14. Launch # --------------------------------------------------------------------------- if __name__ == '__main__': demo.queue().launch(server_name="0.0.0.0", server_port=7860)