AMP-Classifier2 / app.py
nonzeroexit's picture
Update app.py
f8615d0 verified
Raw
History Blame
13.2 kB
import os
# --- Prevent SIGSEGV (exit 139) from TensorFlow + PyTorch native lib clashes ---
# TF and torch each bundle their own OpenMP/MKL; loaded together they can collide
# and crash at the C level. These settings make them coexist and reduce memory.
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
# Quiet TensorFlow logs (must be set before importing tensorflow)
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
import gradio as gr
import joblib
import numpy as np
import pandas as pd
from propy import AAComposition, CTD
from math import expm1
# ---------------------------------------------------------------------------
# LAZY LOADING — keeps the free 16GB Space from OOM-ing at startup.
# Heavy libs (TF, torch, ProtBert) load only when first needed.
# ---------------------------------------------------------------------------
_amp_model = None
_amp_scaler = None
_protbert_tokenizer = None
_protbert_model = None
_torch = None
_device = None
def get_amp_model():
global _amp_model, _amp_scaler
if _amp_model is None:
from tensorflow.keras.models import load_model
_amp_model = load_model("Comb1_aac_ctd_RFE_selected_features_model.keras")
_amp_scaler = joblib.load("Comb1_aac_ctd_RFE_selected_features_scaler.joblib")
return _amp_model, _amp_scaler
def get_protbert():
global _protbert_tokenizer, _protbert_model, _torch, _device
if _protbert_model is None:
import torch
from transformers import BertTokenizer, BertModel
try:
torch.set_num_threads(1) # reduce native threading conflicts with TF
except Exception:
pass
_torch = torch
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_protbert_tokenizer = BertTokenizer.from_pretrained(
"Rostlab/prot_bert", do_lower_case=False
)
_protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
_protbert_model = _protbert_model.to(_device).eval()
return _protbert_tokenizer, _protbert_model, _torch, _device
# ---------------------------------------------------------------------------
# The EXACT 343 features the scaler was fit on, IN THE EXACT TRAINING ORDER.
# The scaler was fit on a numpy array (no stored names), so order is critical:
# we must select these columns in this order BEFORE calling scaler.transform().
# ---------------------------------------------------------------------------
selected_features = [
"_PolarizabilityC1", "_PolarizabilityC2", "_PolarizabilityC3",
"_SolventAccessibilityC1", "_SolventAccessibilityC2", "_SolventAccessibilityC3",
"_SecondaryStrC1", "_SecondaryStrC2", "_SecondaryStrC3",
"_ChargeC1", "_ChargeC2", "_ChargeC3",
"_PolarityC1", "_PolarityC2", "_PolarityC3",
"_NormalizedVDWVC1", "_NormalizedVDWVC2", "_NormalizedVDWVC3",
"_HydrophobicityC1", "_HydrophobicityC2", "_HydrophobicityC3",
"_PolarizabilityT12", "_PolarizabilityT13", "_PolarizabilityT23",
"_SolventAccessibilityT12", "_SolventAccessibilityT13", "_SolventAccessibilityT23",
"_SecondaryStrT12", "_SecondaryStrT13", "_SecondaryStrT23",
"_ChargeT12", "_ChargeT13", "_ChargeT23",
"_PolarityT12", "_PolarityT13", "_PolarityT23",
"_NormalizedVDWVT12", "_NormalizedVDWVT13", "_NormalizedVDWVT23",
"_HydrophobicityT12", "_HydrophobicityT13", "_HydrophobicityT23",
"_PolarizabilityD1001", "_PolarizabilityD1025", "_PolarizabilityD1050",
"_PolarizabilityD1075", "_PolarizabilityD1100",
"_PolarizabilityD2001", "_PolarizabilityD2025", "_PolarizabilityD2050",
"_PolarizabilityD2075", "_PolarizabilityD2100",
"_PolarizabilityD3001", "_PolarizabilityD3025", "_PolarizabilityD3050",
"_PolarizabilityD3075", "_PolarizabilityD3100",
"_SolventAccessibilityD1001", "_SolventAccessibilityD1025",
"_SolventAccessibilityD1050", "_SolventAccessibilityD1075",
"_SolventAccessibilityD1100",
"_SolventAccessibilityD2001", "_SolventAccessibilityD2025",
"_SolventAccessibilityD2050", "_SolventAccessibilityD2075",
"_SolventAccessibilityD2100",
"_SolventAccessibilityD3001", "_SolventAccessibilityD3025",
"_SolventAccessibilityD3050", "_SolventAccessibilityD3075",
"_SolventAccessibilityD3100",
"_SecondaryStrD1001", "_SecondaryStrD1025", "_SecondaryStrD1050",
"_SecondaryStrD1075", "_SecondaryStrD1100",
"_SecondaryStrD2001", "_SecondaryStrD2025", "_SecondaryStrD2050",
"_SecondaryStrD2075", "_SecondaryStrD2100",
"_SecondaryStrD3001", "_SecondaryStrD3025", "_SecondaryStrD3050",
"_SecondaryStrD3075", "_SecondaryStrD3100",
"_ChargeD1001", "_ChargeD1025", "_ChargeD1050",
"_ChargeD1075", "_ChargeD1100",
"_ChargeD2001", "_ChargeD2025", "_ChargeD2050",
"_ChargeD2075",
"_ChargeD3001", "_ChargeD3025", "_ChargeD3050",
"_ChargeD3075", "_ChargeD3100",
"_PolarityD1001", "_PolarityD1025", "_PolarityD1050",
"_PolarityD1075", "_PolarityD1100",
"_PolarityD2001", "_PolarityD2025", "_PolarityD2050",
"_PolarityD2075", "_PolarityD2100",
"_PolarityD3001", "_PolarityD3025", "_PolarityD3050",
"_PolarityD3075", "_PolarityD3100",
"_NormalizedVDWVD1001", "_NormalizedVDWVD1025",
"_NormalizedVDWVD1050", "_NormalizedVDWVD1075",
"_NormalizedVDWVD1100",
"_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
"_NormalizedVDWVD2050", "_NormalizedVDWVD2075",
"_NormalizedVDWVD2100",
"_NormalizedVDWVD3001", "_NormalizedVDWVD3025",
"_NormalizedVDWVD3050", "_NormalizedVDWVD3075",
"_NormalizedVDWVD3100",
"_HydrophobicityD1001", "_HydrophobicityD1025",
"_HydrophobicityD1050", "_HydrophobicityD1075",
"_HydrophobicityD1100",
"_HydrophobicityD2001", "_HydrophobicityD2025",
"_HydrophobicityD2050", "_HydrophobicityD2075",
"_HydrophobicityD2100",
"_HydrophobicityD3001", "_HydrophobicityD3025",
"_HydrophobicityD3050", "_HydrophobicityD3075",
"_HydrophobicityD3100",
"A", "R", "N", "D", "C", "E", "Q", "G", "H", "I",
"L", "K", "M", "F", "P", "S", "T", "W", "Y", "V",
"AR", "AD", "AQ", "AG", "AL", "AK", "AF", "AP", "AT", "AV",
"RA", "RC", "RE", "RG", "RI", "RL", "RS", "RT", "RV",
"NR", "NC", "NG", "NI", "NP", "NS", "NY", "NV",
"DR", "DN", "DC", "DE", "DG", "DF", "DS", "DT", "DY",
"CR", "CN", "CD", "CC", "CI", "CL", "CK", "CT", "CY", "CV",
"EA", "ER", "ED", "EC", "EE", "EG", "EI", "EL", "EK",
"EF", "EP", "ET", "EV",
"QN", "QF", "QV",
"GA", "GR", "GC", "GE", "GG", "GI", "GL", "GK", "GF", "GP", "GY",
"HA", "HP", "HT",
"IA", "IR", "ID", "II", "IL", "IF", "IP", "IS", "IV",
"LA", "LR", "LD", "LC", "LG", "LI", "LK", "LM", "LF",
"LS", "LT", "LY", "LV",
"KA", "KN", "KC", "KG", "KI", "KL", "KK", "KP", "KY",
"MA", "MD", "ME", "MI", "MK", "MF", "MP", "MS", "MV",
"FR", "FE", "FQ", "FG", "FL", "FF", "FS", "FT", "FY", "FV",
"PA", "PR", "PC", "PE", "PL", "PK", "PP", "PS", "PV",
"SA", "SR", "SD", "SC", "SG", "SH", "SI", "SL", "SP", "ST", "SY",
"TA", "TR", "TC", "TE", "TQ", "TG", "TI", "TL", "TP", "TS", "TV",
"WA",
"YN", "YD", "YC", "YQ", "YG", "YP",
"VA", "VR", "VD", "VC", "VE", "VG", "VI", "VL", "VK",
"VS", "VT", "VY", "VV"
]
assert len(selected_features) == 343, f"Expected 343 features, got {len(selected_features)}"
def keras_predict_proba(X):
"""Return probabilities as [P(Non-AMP), P(AMP)] for LIME (X already scaled)."""
amp_model, _ = get_amp_model()
preds = amp_model.predict(X, verbose=0)
if preds.ndim == 1 or preds.shape[1] == 1:
preds = preds.reshape(-1, 1)
return np.hstack([1 - preds, preds]) # sigmoid output assumed = P(AMP)
return preds
def extract_features(sequence):
"""Compute CTD + AAC, select the 343 training columns IN ORDER, then scale."""
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return "Error: Sequence too short."
try:
_, amp_scaler = get_amp_model()
# Compute full feature pool
ctd_features = CTD.CalculateCTD(sequence)
aac = AAComposition.CalculateAADipeptideComposition(sequence)
# Merge everything into one lookup dict
pool = {}
pool.update(ctd_features)
pool.update(aac)
# Verify all needed features are present
missing = [f for f in selected_features if f not in pool]
if missing:
return f"Error: Missing features from propy: {missing[:5]}..."
# Build the 343-wide row IN THE EXACT TRAINING ORDER, THEN scale.
ordered_values = [pool[f] for f in selected_features]
feature_row = np.array(ordered_values, dtype=np.float64).reshape(1, -1)
scaled = amp_scaler.transform(feature_row) # scaler expects exactly 343 cols
return scaled.astype(np.float32)
except Exception as e:
return f"Error in feature extraction: {str(e)}"
def predictmic(sequence):
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return {"Error": "Sequence too short or invalid."}
tokenizer, protbert_model, torch, device = get_protbert()
seq_spaced = ' '.join(list(sequence))
tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length',
truncation=True, max_length=512)
tokens = {k: v.to(device) for k, v in tokens.items()}
with torch.no_grad():
outputs = protbert_model(**tokens)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
bacteria_config = {
"E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
"S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
"P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
"K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
}
mic_results = {}
for bacterium, cfg in bacteria_config.items():
try:
mic_scaler = joblib.load(cfg["scaler"])
scaled = mic_scaler.transform(embedding)
transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
mic_model = joblib.load(cfg["model"])
mic_log = mic_model.predict(transformed)[0]
mic = round(expm1(mic_log), 3)
mic_results[bacterium] = mic
except Exception as e:
mic_results[bacterium] = f"Error: {str(e)}"
return mic_results
def full_prediction(sequence):
features = extract_features(sequence)
if isinstance(features, str):
return features
amp_model, _ = get_amp_model()
raw_pred = amp_model.predict(features, verbose=0)
if raw_pred.ndim == 1 or raw_pred.shape[1] == 1:
prob_amp = float(raw_pred.flatten()[0]) # sigmoid output assumed = P(AMP)
if prob_amp >= 0.5:
prediction = 1
confidence = round(prob_amp * 100, 2)
else:
prediction = 0
confidence = round((1 - prob_amp) * 100, 2)
else:
class_idx = int(np.argmax(raw_pred[0]))
prediction = class_idx
confidence = round(float(raw_pred[0][class_idx]) * 100, 2)
# Label convention: 1 = AMP, 0 = Non-AMP (swap if your model is reversed)
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 1 else "Non-AMP"
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
if prediction == 1:
mic_values = predictmic(sequence)
result += "\nPredicted MIC Values (μM):\n"
for org, mic in mic_values.items():
result += f"- {org}: {mic}\n"
else:
result += "\nMIC prediction skipped for Non-AMP sequences.\n"
try:
from lime.lime_tabular import LimeTabularExplainer
sample_data = np.random.rand(100, len(selected_features))
explainer = LimeTabularExplainer(
training_data=sample_data,
feature_names=selected_features,
class_names=["Non-AMP", "AMP"],
mode="classification"
)
explanation = explainer.explain_instance(
data_row=features[0],
predict_fn=keras_predict_proba,
num_features=10
)
result += "\nTop Features Influencing Prediction:\n"
for feat, weight in explanation.as_list():
result += f"- {feat}: {round(weight, 4)}\n"
except Exception as e:
result += f"\nLIME explanation failed: {str(e)}\n"
return result
iface = gr.Interface(
fn=full_prediction,
inputs=gr.Textbox(label="Enter Protein Sequence"),
outputs=gr.Textbox(label="Results"),
title="AMP & MIC Predictor + LIME Explanation",
description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
)
iface.launch()