Spaces:
Running
Running
File size: 10,552 Bytes
98f9e87 bd01e5d f8615d0 bd01e5d 2739a59 bd01e5d 942bf87 51a3749 ea9a1bf 4222f98 1dcb272 bd01e5d febb4a6 2739a59 bd01e5d 2739a59 bd01e5d f0f9b27 2739a59 bd01e5d 2739a59 8a9cc7c bd01e5d b0a80bd 1dcb272 bd01e5d 6a89413 1dcb272 6a89413 bd01e5d 1dcb272 bd01e5d 1dcb272 2739a59 bd01e5d 44f5cf9 63d3a19 bd01e5d 1dcb272 bd01e5d 1dcb272 4222f98 bd01e5d 4222f98 03f381c 4222f98 bd01e5d 03f381c bd01e5d 63d3a19 bd01e5d 63d3a19 1dcb272 bd01e5d 44f5cf9 63d3a19 12675f2 bd01e5d 12675f2 bd01e5d 12675f2 44f5cf9 12675f2 63d3a19 12675f2 63d3a19 2739a59 bd01e5d 44f5cf9 63d3a19 1dcb272 bd01e5d 1dcb272 bd01e5d b206439 1dcb272 bd01e5d 63d3a19 bd01e5d 5745f40 63d3a19 44f5cf9 2739a59 bd01e5d 44f5cf9 63d3a19 44f5cf9 68ded6f 2739a59 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | import os
# Native-lib hygiene (prevents TF/PyTorch SIGSEGV when both load; harmless for RF)
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import sys
import json
import subprocess
import joblib
import numpy as np
import pandas as pd
from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
from lime.lime_tabular import LimeTabularExplainer
import gradio as gr
# ---------------------------------------------------------------------------
# Load Random Forest AMP classifier + MinMax scaler (original files)
# ---------------------------------------------------------------------------
model = joblib.load("RF.joblib")
scaler = joblib.load("norm (4).joblib")
# ---------------------------------------------------------------------------
# Original 138 RFE-selected features (CTD + AAC + Autocorrelation + APAAC)
# ---------------------------------------------------------------------------
selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
"_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
"_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
"_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
"_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050",
"_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
"_NormalizedVDWVD2050", "_NormalizedVDWVD3001", "_HydrophobicityD1001", "_HydrophobicityD2001",
"_HydrophobicityD3001", "_HydrophobicityD3025", "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
"AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL", "HC", "IA", "IL", "IV", "LA", "LC", "LE",
"LI", "LT", "LV", "KC", "MA", "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
"MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
"GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
"GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
"GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
"GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29", "GearyAuto_AvFlexibility30",
"GearyAuto_Polarizability22", "GearyAuto_Polarizability24", "GearyAuto_Polarizability25",
"GearyAuto_Polarizability27", "GearyAuto_Polarizability28", "GearyAuto_Polarizability29",
"GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24", "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30",
"GearyAuto_ResidueASA21", "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
"GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24", "GearyAuto_ResidueVol25",
"GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28", "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30",
"GearyAuto_Steric18", "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
"GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
"GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28", "GearyAuto_Mutability29",
"GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
"APAAC15", "APAAC18", "APAAC19", "APAAC24"]
assert len(selected_features) == 138, f"Expected 138 features, got {len(selected_features)}"
# ---------------------------------------------------------------------------
# LIME explainer
# Built ONCE at startup so explanations are reproducible across requests.
# Prefers a real normalized training sample (lime_background.joblib). Falls
# back to seeded uniform noise if that file isn't present (still stable, but
# less faithful to the true feature distribution).
# ---------------------------------------------------------------------------
try:
_lime_background = joblib.load("lime_background.joblib")
if _lime_background.shape[1] != len(selected_features):
raise ValueError(
f"lime_background.joblib has {_lime_background.shape[1]} cols, "
f"expected {len(selected_features)}"
)
print(f"[LIME] Using real training sample: {_lime_background.shape}", flush=True)
except Exception as e:
print(f"[LIME] No usable lime_background.joblib ({e}); falling back to uniform noise.", flush=True)
_rng = np.random.default_rng(seed=42)
_lime_background = _rng.uniform(low=0.0, high=1.0, size=(500, len(selected_features)))
explainer = LimeTabularExplainer(
training_data=_lime_background,
feature_names=selected_features,
class_names=["AMP", "Non-AMP"],
mode="classification",
discretize_continuous=True,
random_state=42, # stable explanations
)
# ---------------------------------------------------------------------------
# Feature extraction — produces the full propy feature pool, scales it with
# the saved MinMax scaler, then selects the 138 features the RF was trained on.
# ---------------------------------------------------------------------------
def extract_features(sequence):
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return "Error: Sequence too short."
try:
# Original full pool: CTD + AAC(first 420) + Autocorrelation + PseudoAAC
dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
ctd_features = CTD.CalculateCTD(sequence)
auto_features = Autocorrelation.CalculateAutoTotal(sequence)
pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
all_features_dict = {}
all_features_dict.update(ctd_features)
all_features_dict.update(filtered_dipeptide_features)
all_features_dict.update(auto_features)
all_features_dict.update(pseudo_features)
feature_df_all = pd.DataFrame([all_features_dict])
normalized_array = scaler.transform(feature_df_all.values)
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
if not set(selected_features).issubset(normalized_df.columns):
missing = set(selected_features) - set(normalized_df.columns)
return f"Error: Missing features: {list(missing)[:5]}..."
selected_df = normalized_df[selected_features].fillna(0)
return selected_df.values
except Exception as e:
return f"Error in feature extraction: {str(e)}"
# ---------------------------------------------------------------------------
# MIC prediction — runs in a SEPARATE process (mic_worker.py).
# This isolates PyTorch/ProtBert from the main process and prevents the
# native-library crash (exit 139) plus the OOM spike on the free tier.
# ---------------------------------------------------------------------------
def predictmic(sequence):
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return {"Error": "Sequence too short or invalid."}
try:
proc = subprocess.run(
[sys.executable, "mic_worker.py", sequence],
capture_output=True, text=True, timeout=900
)
except subprocess.TimeoutExpired:
return {"Error": "MIC prediction timed out (ProtBert may still be downloading; try again shortly)."}
except Exception as e:
return {"Error": f"Failed to start MIC worker: {str(e)}"}
if proc.returncode != 0:
tail = (proc.stderr or "").strip().splitlines()[-3:]
return {"Error": f"MIC worker exited with code {proc.returncode}. {' '.join(tail)}"}
out_lines = [ln for ln in (proc.stdout or "").splitlines() if ln.strip()]
if not out_lines:
return {"Error": "MIC worker produced no output."}
try:
return json.loads(out_lines[-1])
except Exception:
return {"Error": f"Could not parse MIC worker output: {out_lines[-1][:200]}"}
# ---------------------------------------------------------------------------
# Main prediction pipeline
# ---------------------------------------------------------------------------
def full_prediction(sequence):
features = extract_features(sequence)
if isinstance(features, str):
return features
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
try:
class_index = list(model.classes_).index(prediction)
confidence = round(probabilities[class_index] * 100, 2)
except Exception:
confidence = "Unknown"
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
# ---- LIME first (per your spec: LIME before SHAP in the report) ----
try:
explanation = explainer.explain_instance(
data_row=features[0], # <-- explicitly the single input sequence
predict_fn=model.predict_proba,
num_features=10,
num_samples=2000, # perturbations around this single input
)
result += "\nTop Features Influencing Prediction (LIME):\n"
for feat, weight in explanation.as_list():
result += f"- {feat}: {round(weight, 4)}\n"
except Exception as e:
result += f"\nLIME explanation failed: {str(e)}\n"
# ---- MIC (only for AMPs) ----
if prediction == 0:
mic_values = predictmic(sequence)
result += "\nPredicted MIC Values (μM):\n"
for org, mic in mic_values.items():
result += f"- {org}: {mic}\n"
else:
result += "\nMIC prediction skipped for Non-AMP sequences.\n"
return result
# Gradio UI
iface = gr.Interface(
fn=full_prediction,
inputs=gr.Textbox(label="Enter Protein Sequence"),
outputs=gr.Textbox(label="Results"),
title="AMP & MIC Predictor + LIME Explanation",
description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
)
iface.launch() |