AMP-Classifier / app.py
nonzeroexit's picture
Update app.py
bd01e5d verified
Raw
History Blame
10.3 kB
import os
# Native-lib hygiene (prevents TF/PyTorch SIGSEGV when both load; harmless for RF)
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import sys
import json
import subprocess
import joblib
import numpy as np
import pandas as pd
from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
from lime.lime_tabular import LimeTabularExplainer
import gradio as gr
# ---------------------------------------------------------------------------
# Load Random Forest AMP classifier + MinMax scaler (original files)
# ---------------------------------------------------------------------------
model = joblib.load("RF.joblib")
scaler = joblib.load("norm (4).joblib")
# ---------------------------------------------------------------------------
# Original 138 RFE-selected features (CTD + AAC + Autocorrelation + APAAC)
# ---------------------------------------------------------------------------
selected_features = [
"_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
"_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
"_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
"_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
"_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050",
"_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
"_NormalizedVDWVD2050", "_NormalizedVDWVD3001", "_HydrophobicityD1001", "_HydrophobicityD2001",
"_HydrophobicityD3001", "_HydrophobicityD3025", "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
"AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL", "HC", "IA", "IL", "IV", "LA", "LC", "LE",
"LI", "LT", "LV", "KC", "MA", "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
"MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
"GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
"GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
"GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
"GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29", "GearyAuto_AvFlexibility30",
"GearyAuto_Polarizability22", "GearyAuto_Polarizability24", "GearyAuto_Polarizability25",
"GearyAuto_Polarizability27", "GearyAuto_Polarizability28", "GearyAuto_Polarizability29",
"GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24", "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30",
"GearyAuto_ResidueASA21", "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
"GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24", "GearyAuto_ResidueVol25",
"GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28", "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30",
"GearyAuto_Steric18", "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
"GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
"GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28", "GearyAuto_Mutability29",
"GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
"APAAC15", "APAAC18", "APAAC19", "APAAC24"
]
assert len(selected_features) == 138, f"Expected 138 features, got {len(selected_features)}"
# ---------------------------------------------------------------------------
# LIME explainer
# Built ONCE at startup so explanations are reproducible across requests.
# The training-data argument controls how LIME perturbs features around the
# input. After MinMax scaling each feature lives in [0,1], so we use a small
# uniform sample with a FIXED seed — that gives stable, repeatable weights.
# (If you have a saved sample of real normalized training rows, swap it in
# here and explanations will reflect the true feature distribution.)
# ---------------------------------------------------------------------------
_rng = np.random.default_rng(seed=42)
_lime_background = _rng.uniform(low=0.0, high=1.0, size=(500, len(selected_features)))
explainer = LimeTabularExplainer(
training_data=_lime_background,
feature_names=selected_features,
class_names=["AMP", "Non-AMP"],
mode="classification",
discretize_continuous=True,
random_state=42, # stable explanations
)
# ---------------------------------------------------------------------------
# Feature extraction — produces the full propy feature pool, scales it with
# the saved MinMax scaler, then selects the 138 features the RF was trained on.
# ---------------------------------------------------------------------------
def extract_features(sequence):
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return "Error: Sequence too short."
try:
# Original full pool: CTD + AAC(first 420) + Autocorrelation + PseudoAAC
dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
ctd_features = CTD.CalculateCTD(sequence)
auto_features = Autocorrelation.CalculateAutoTotal(sequence)
pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
all_features_dict = {}
all_features_dict.update(ctd_features)
all_features_dict.update(filtered_dipeptide_features)
all_features_dict.update(auto_features)
all_features_dict.update(pseudo_features)
feature_df_all = pd.DataFrame([all_features_dict])
normalized_array = scaler.transform(feature_df_all.values)
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
if not set(selected_features).issubset(normalized_df.columns):
missing = set(selected_features) - set(normalized_df.columns)
return f"Error: Missing features: {list(missing)[:5]}..."
selected_df = normalized_df[selected_features].fillna(0)
return selected_df.values
except Exception as e:
return f"Error in feature extraction: {str(e)}"
# ---------------------------------------------------------------------------
# MIC prediction — runs in a SEPARATE process (mic_worker.py).
# This isolates PyTorch/ProtBert from the main process and prevents the
# native-library crash (exit 139) plus the OOM spike on the free tier.
# ---------------------------------------------------------------------------
def predictmic(sequence):
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return {"Error": "Sequence too short or invalid."}
try:
proc = subprocess.run(
[sys.executable, "mic_worker.py", sequence],
capture_output=True, text=True, timeout=900
)
except subprocess.TimeoutExpired:
return {"Error": "MIC prediction timed out (ProtBert may still be downloading; try again shortly)."}
except Exception as e:
return {"Error": f"Failed to start MIC worker: {str(e)}"}
if proc.returncode != 0:
tail = (proc.stderr or "").strip().splitlines()[-3:]
return {"Error": f"MIC worker exited with code {proc.returncode}. {' '.join(tail)}"}
out_lines = [ln for ln in (proc.stdout or "").splitlines() if ln.strip()]
if not out_lines:
return {"Error": "MIC worker produced no output."}
try:
return json.loads(out_lines[-1])
except Exception:
return {"Error": f"Could not parse MIC worker output: {out_lines[-1][:200]}"}
# ---------------------------------------------------------------------------
# Main prediction pipeline
# ---------------------------------------------------------------------------
def full_prediction(sequence):
features = extract_features(sequence)
if isinstance(features, str):
return features
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
try:
class_index = list(model.classes_).index(prediction)
confidence = round(probabilities[class_index] * 100, 2)
except Exception:
confidence = "Unknown"
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
# ---- LIME first (per your spec: LIME before SHAP in the report) ----
try:
explanation = explainer.explain_instance(
data_row=features[0], # <-- explicitly the single input sequence
predict_fn=model.predict_proba,
num_features=10,
num_samples=2000, # perturbations around this single input
)
result += "\nTop Features Influencing Prediction (LIME):\n"
for feat, weight in explanation.as_list():
result += f"- {feat}: {round(weight, 4)}\n"
except Exception as e:
result += f"\nLIME explanation failed: {str(e)}\n"
# ---- MIC (only for AMPs) ----
if prediction == 0:
mic_values = predictmic(sequence)
result += "\nPredicted MIC Values (μM):\n"
for org, mic in mic_values.items():
result += f"- {org}: {mic}\n"
else:
result += "\nMIC prediction skipped for Non-AMP sequences.\n"
return result
# Gradio UI
iface = gr.Interface(
fn=full_prediction,
inputs=gr.Textbox(label="Enter Protein Sequence"),
outputs=gr.Textbox(label="Results"),
title="AMP & MIC Predictor + LIME Explanation",
description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
)
iface.launch()