Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,31 +1,25 @@
|
|
| 1 |
-
import
|
| 2 |
-
# Native-lib hygiene (prevents TF/PyTorch SIGSEGV when both load; harmless for RF)
|
| 3 |
-
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
|
| 4 |
-
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
| 5 |
-
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
| 6 |
-
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
|
| 7 |
-
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 8 |
-
|
| 9 |
-
import sys
|
| 10 |
-
import json
|
| 11 |
-
import subprocess
|
| 12 |
import joblib
|
| 13 |
import numpy as np
|
| 14 |
import pandas as pd
|
| 15 |
from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
|
|
|
|
|
|
|
|
|
|
| 16 |
from lime.lime_tabular import LimeTabularExplainer
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
# ---------------------------------------------------------------------------
|
| 21 |
-
# Load Random Forest AMP classifier + MinMax scaler (original files)
|
| 22 |
-
# ---------------------------------------------------------------------------
|
| 23 |
model = joblib.load("RF.joblib")
|
| 24 |
scaler = joblib.load("norm (4).joblib")
|
| 25 |
|
| 26 |
-
#
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
|
| 30 |
"_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
|
| 31 |
"_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
|
|
@@ -53,49 +47,22 @@ selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondarySt
|
|
| 53 |
"GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
|
| 54 |
"APAAC15", "APAAC18", "APAAC19", "APAAC24"]
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
# ---------------------------------------------------------------------------
|
| 59 |
-
# LIME explainer
|
| 60 |
-
# Built ONCE at startup so explanations are reproducible across requests.
|
| 61 |
-
# Prefers a real normalized training sample (lime_background.joblib). Falls
|
| 62 |
-
# back to seeded uniform noise if that file isn't present (still stable, but
|
| 63 |
-
# less faithful to the true feature distribution).
|
| 64 |
-
# ---------------------------------------------------------------------------
|
| 65 |
-
try:
|
| 66 |
-
_lime_background = joblib.load("lime_background.joblib")
|
| 67 |
-
if _lime_background.shape[1] != len(selected_features):
|
| 68 |
-
raise ValueError(
|
| 69 |
-
f"lime_background.joblib has {_lime_background.shape[1]} cols, "
|
| 70 |
-
f"expected {len(selected_features)}"
|
| 71 |
-
)
|
| 72 |
-
print(f"[LIME] Using real training sample: {_lime_background.shape}", flush=True)
|
| 73 |
-
except Exception as e:
|
| 74 |
-
print(f"[LIME] No usable lime_background.joblib ({e}); falling back to uniform noise.", flush=True)
|
| 75 |
-
_rng = np.random.default_rng(seed=42)
|
| 76 |
-
_lime_background = _rng.uniform(low=0.0, high=1.0, size=(500, len(selected_features)))
|
| 77 |
-
|
| 78 |
explainer = LimeTabularExplainer(
|
| 79 |
-
training_data=
|
| 80 |
feature_names=selected_features,
|
| 81 |
class_names=["AMP", "Non-AMP"],
|
| 82 |
-
mode="classification"
|
| 83 |
-
discretize_continuous=True,
|
| 84 |
-
random_state=42, # stable explanations
|
| 85 |
)
|
| 86 |
|
| 87 |
-
|
| 88 |
-
# ---------------------------------------------------------------------------
|
| 89 |
-
# Feature extraction — produces the full propy feature pool, scales it with
|
| 90 |
-
# the saved MinMax scaler, then selects the 138 features the RF was trained on.
|
| 91 |
-
# ---------------------------------------------------------------------------
|
| 92 |
def extract_features(sequence):
|
| 93 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
| 94 |
if len(sequence) < 10:
|
| 95 |
return "Error: Sequence too short."
|
| 96 |
|
| 97 |
try:
|
| 98 |
-
# Original full pool: CTD + AAC(first 420) + Autocorrelation + PseudoAAC
|
| 99 |
dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
|
| 100 |
filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
|
| 101 |
ctd_features = CTD.CalculateCTD(sequence)
|
|
@@ -113,51 +80,50 @@ def extract_features(sequence):
|
|
| 113 |
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
|
| 114 |
|
| 115 |
if not set(selected_features).issubset(normalized_df.columns):
|
| 116 |
-
|
| 117 |
-
return f"Error: Missing features: {list(missing)[:5]}..."
|
| 118 |
|
| 119 |
selected_df = normalized_df[selected_features].fillna(0)
|
| 120 |
return selected_df.values
|
| 121 |
except Exception as e:
|
| 122 |
return f"Error in feature extraction: {str(e)}"
|
| 123 |
|
| 124 |
-
|
| 125 |
-
# ---------------------------------------------------------------------------
|
| 126 |
-
# MIC prediction — runs in a SEPARATE process (mic_worker.py).
|
| 127 |
-
# This isolates PyTorch/ProtBert from the main process and prevents the
|
| 128 |
-
# native-library crash (exit 139) plus the OOM spike on the free tier.
|
| 129 |
-
# ---------------------------------------------------------------------------
|
| 130 |
def predictmic(sequence):
|
| 131 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
| 132 |
if len(sequence) < 10:
|
| 133 |
return {"Error": "Sequence too short or invalid."}
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
def full_prediction(sequence):
|
| 162 |
features = extract_features(sequence)
|
| 163 |
if isinstance(features, str):
|
|
@@ -175,21 +141,6 @@ def full_prediction(sequence):
|
|
| 175 |
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
|
| 176 |
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
|
| 177 |
|
| 178 |
-
# ---- LIME first (per your spec: LIME before SHAP in the report) ----
|
| 179 |
-
try:
|
| 180 |
-
explanation = explainer.explain_instance(
|
| 181 |
-
data_row=features[0], # <-- explicitly the single input sequence
|
| 182 |
-
predict_fn=model.predict_proba,
|
| 183 |
-
num_features=10,
|
| 184 |
-
num_samples=2000, # perturbations around this single input
|
| 185 |
-
)
|
| 186 |
-
result += "\nTop Features Influencing Prediction (LIME):\n"
|
| 187 |
-
for feat, weight in explanation.as_list():
|
| 188 |
-
result += f"- {feat}: {round(weight, 4)}\n"
|
| 189 |
-
except Exception as e:
|
| 190 |
-
result += f"\nLIME explanation failed: {str(e)}\n"
|
| 191 |
-
|
| 192 |
-
# ---- MIC (only for AMPs) ----
|
| 193 |
if prediction == 0:
|
| 194 |
mic_values = predictmic(sequence)
|
| 195 |
result += "\nPredicted MIC Values (μM):\n"
|
|
@@ -198,8 +149,17 @@ def full_prediction(sequence):
|
|
| 198 |
else:
|
| 199 |
result += "\nMIC prediction skipped for Non-AMP sequences.\n"
|
| 200 |
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
# Gradio UI
|
| 205 |
iface = gr.Interface(
|
|
@@ -210,4 +170,4 @@ iface = gr.Interface(
|
|
| 210 |
description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
|
| 211 |
)
|
| 212 |
|
| 213 |
-
iface.launch()
|
|
|
|
| 1 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import joblib
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
| 5 |
from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
|
| 6 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import BertTokenizer, BertModel
|
| 9 |
from lime.lime_tabular import LimeTabularExplainer
|
| 10 |
+
from math import expm1
|
| 11 |
|
| 12 |
+
# Load AMP Classifier and Scaler
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
model = joblib.load("RF.joblib")
|
| 14 |
scaler = joblib.load("norm (4).joblib")
|
| 15 |
|
| 16 |
+
# Load ProtBert
|
| 17 |
+
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
|
| 18 |
+
protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
|
| 19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
protbert_model = protbert_model.to(device).eval()
|
| 21 |
+
|
| 22 |
+
# Define selected features (put your complete list here)
|
| 23 |
selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
|
| 24 |
"_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
|
| 25 |
"_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
|
|
|
|
| 47 |
"GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
|
| 48 |
"APAAC15", "APAAC18", "APAAC19", "APAAC24"]
|
| 49 |
|
| 50 |
+
# Dummy data for LIME
|
| 51 |
+
sample_data = np.random.rand(100, len(selected_features))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
explainer = LimeTabularExplainer(
|
| 53 |
+
training_data=sample_data,
|
| 54 |
feature_names=selected_features,
|
| 55 |
class_names=["AMP", "Non-AMP"],
|
| 56 |
+
mode="classification"
|
|
|
|
|
|
|
| 57 |
)
|
| 58 |
|
| 59 |
+
# Feature extraction function
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
def extract_features(sequence):
|
| 61 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
| 62 |
if len(sequence) < 10:
|
| 63 |
return "Error: Sequence too short."
|
| 64 |
|
| 65 |
try:
|
|
|
|
| 66 |
dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
|
| 67 |
filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
|
| 68 |
ctd_features = CTD.CalculateCTD(sequence)
|
|
|
|
| 80 |
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
|
| 81 |
|
| 82 |
if not set(selected_features).issubset(normalized_df.columns):
|
| 83 |
+
return "Error: Some selected features are missing."
|
|
|
|
| 84 |
|
| 85 |
selected_df = normalized_df[selected_features].fillna(0)
|
| 86 |
return selected_df.values
|
| 87 |
except Exception as e:
|
| 88 |
return f"Error in feature extraction: {str(e)}"
|
| 89 |
|
| 90 |
+
# MIC prediction function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def predictmic(sequence):
|
| 92 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
| 93 |
if len(sequence) < 10:
|
| 94 |
return {"Error": "Sequence too short or invalid."}
|
| 95 |
|
| 96 |
+
seq_spaced = ' '.join(list(sequence))
|
| 97 |
+
tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
|
| 98 |
+
tokens = {k: v.to(device) for k, v in tokens.items()}
|
| 99 |
+
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
outputs = protbert_model(**tokens)
|
| 102 |
+
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
|
| 103 |
+
|
| 104 |
+
bacteria_config = {
|
| 105 |
+
"E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
|
| 106 |
+
"S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
|
| 107 |
+
"P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
|
| 108 |
+
"K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
mic_results = {}
|
| 112 |
+
for bacterium, cfg in bacteria_config.items():
|
| 113 |
+
try:
|
| 114 |
+
scaler = joblib.load(cfg["scaler"])
|
| 115 |
+
scaled = scaler.transform(embedding)
|
| 116 |
+
transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
|
| 117 |
+
model = joblib.load(cfg["model"])
|
| 118 |
+
mic_log = model.predict(transformed)[0]
|
| 119 |
+
mic = round(expm1(mic_log), 3)
|
| 120 |
+
mic_results[bacterium] = mic
|
| 121 |
+
except Exception as e:
|
| 122 |
+
mic_results[bacterium] = f"Error: {str(e)}"
|
| 123 |
+
|
| 124 |
+
return mic_results
|
| 125 |
+
|
| 126 |
+
# Main prediction function
|
| 127 |
def full_prediction(sequence):
|
| 128 |
features = extract_features(sequence)
|
| 129 |
if isinstance(features, str):
|
|
|
|
| 141 |
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
|
| 142 |
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
if prediction == 0:
|
| 145 |
mic_values = predictmic(sequence)
|
| 146 |
result += "\nPredicted MIC Values (μM):\n"
|
|
|
|
| 149 |
else:
|
| 150 |
result += "\nMIC prediction skipped for Non-AMP sequences.\n"
|
| 151 |
|
| 152 |
+
explanation = explainer.explain_instance(
|
| 153 |
+
data_row=features[0],
|
| 154 |
+
predict_fn=model.predict_proba,
|
| 155 |
+
num_features=10
|
| 156 |
+
)
|
| 157 |
|
| 158 |
+
result += "\nTop Features Influencing Prediction:\n"
|
| 159 |
+
for feat, weight in explanation.as_list():
|
| 160 |
+
result += f"- {feat}: {round(weight, 4)}\n"
|
| 161 |
+
|
| 162 |
+
return result
|
| 163 |
|
| 164 |
# Gradio UI
|
| 165 |
iface = gr.Interface(
|
|
|
|
| 170 |
description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
|
| 171 |
)
|
| 172 |
|
| 173 |
+
iface.launch(share=True)
|