File size: 10,552 Bytes
98f9e87
bd01e5d
f8615d0
 
 
 
bd01e5d
2739a59
bd01e5d
 
 
942bf87
51a3749
ea9a1bf
4222f98
1dcb272
bd01e5d
 
febb4a6
2739a59
bd01e5d
2739a59
bd01e5d
 
f0f9b27
2739a59
bd01e5d
2739a59
8a9cc7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd01e5d
b0a80bd
1dcb272
bd01e5d
 
6a89413
 
 
1dcb272
6a89413
 
 
 
 
 
 
 
 
 
 
 
bd01e5d
 
1dcb272
 
bd01e5d
 
 
 
1dcb272
 
2739a59
bd01e5d
 
 
 
44f5cf9
63d3a19
 
 
 
 
bd01e5d
1dcb272
bd01e5d
 
 
 
1dcb272
4222f98
 
bd01e5d
4222f98
 
03f381c
4222f98
bd01e5d
 
03f381c
bd01e5d
 
 
63d3a19
bd01e5d
 
63d3a19
 
 
1dcb272
bd01e5d
 
 
 
 
44f5cf9
63d3a19
 
 
 
12675f2
 
 
bd01e5d
12675f2
 
bd01e5d
12675f2
 
44f5cf9
12675f2
 
 
63d3a19
12675f2
 
 
 
 
 
 
63d3a19
2739a59
bd01e5d
 
 
44f5cf9
63d3a19
 
 
1dcb272
bd01e5d
 
1dcb272
bd01e5d
 
 
 
 
b206439
1dcb272
bd01e5d
63d3a19
bd01e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5745f40
63d3a19
 
 
 
 
 
 
 
44f5cf9
2739a59
bd01e5d
44f5cf9
63d3a19
 
 
 
 
44f5cf9
68ded6f
2739a59
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import os
# Native-lib hygiene (prevents TF/PyTorch SIGSEGV when both load; harmless for RF)
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

import sys
import json
import subprocess
import joblib
import numpy as np
import pandas as pd
from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
from lime.lime_tabular import LimeTabularExplainer

import gradio as gr

# ---------------------------------------------------------------------------
# Load Random Forest AMP classifier + MinMax scaler (original files)
# ---------------------------------------------------------------------------
model = joblib.load("RF.joblib")
scaler = joblib.load("norm (4).joblib")

# ---------------------------------------------------------------------------
# Original 138 RFE-selected features (CTD + AAC + Autocorrelation + APAAC)
# ---------------------------------------------------------------------------
selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
"_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
"_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
"_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
"_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050",
"_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
"_NormalizedVDWVD2050", "_NormalizedVDWVD3001", "_HydrophobicityD1001", "_HydrophobicityD2001",
"_HydrophobicityD3001", "_HydrophobicityD3025", "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
"AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL", "HC", "IA", "IL", "IV", "LA", "LC", "LE",
"LI", "LT", "LV", "KC", "MA", "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
"MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
"GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
"GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
"GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
"GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29", "GearyAuto_AvFlexibility30",
"GearyAuto_Polarizability22", "GearyAuto_Polarizability24", "GearyAuto_Polarizability25",
"GearyAuto_Polarizability27", "GearyAuto_Polarizability28", "GearyAuto_Polarizability29",
"GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24", "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30",
"GearyAuto_ResidueASA21", "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
"GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24", "GearyAuto_ResidueVol25",
"GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28", "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30",
"GearyAuto_Steric18", "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
"GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
"GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28", "GearyAuto_Mutability29",
"GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
"APAAC15", "APAAC18", "APAAC19", "APAAC24"]

assert len(selected_features) == 138, f"Expected 138 features, got {len(selected_features)}"

# ---------------------------------------------------------------------------
# LIME explainer
# Built ONCE at startup so explanations are reproducible across requests.
# Prefers a real normalized training sample (lime_background.joblib). Falls
# back to seeded uniform noise if that file isn't present (still stable, but
# less faithful to the true feature distribution).
# ---------------------------------------------------------------------------
try:
    _lime_background = joblib.load("lime_background.joblib")
    if _lime_background.shape[1] != len(selected_features):
        raise ValueError(
            f"lime_background.joblib has {_lime_background.shape[1]} cols, "
            f"expected {len(selected_features)}"
        )
    print(f"[LIME] Using real training sample: {_lime_background.shape}", flush=True)
except Exception as e:
    print(f"[LIME] No usable lime_background.joblib ({e}); falling back to uniform noise.", flush=True)
    _rng = np.random.default_rng(seed=42)
    _lime_background = _rng.uniform(low=0.0, high=1.0, size=(500, len(selected_features)))

explainer = LimeTabularExplainer(
    training_data=_lime_background,
    feature_names=selected_features,
    class_names=["AMP", "Non-AMP"],
    mode="classification",
    discretize_continuous=True,
    random_state=42,  # stable explanations
)


# ---------------------------------------------------------------------------
# Feature extraction — produces the full propy feature pool, scales it with
# the saved MinMax scaler, then selects the 138 features the RF was trained on.
# ---------------------------------------------------------------------------
def extract_features(sequence):
    sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
    if len(sequence) < 10:
        return "Error: Sequence too short."

    try:
        # Original full pool: CTD + AAC(first 420) + Autocorrelation + PseudoAAC
        dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
        filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
        ctd_features = CTD.CalculateCTD(sequence)
        auto_features = Autocorrelation.CalculateAutoTotal(sequence)
        pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)

        all_features_dict = {}
        all_features_dict.update(ctd_features)
        all_features_dict.update(filtered_dipeptide_features)
        all_features_dict.update(auto_features)
        all_features_dict.update(pseudo_features)

        feature_df_all = pd.DataFrame([all_features_dict])
        normalized_array = scaler.transform(feature_df_all.values)
        normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)

        if not set(selected_features).issubset(normalized_df.columns):
            missing = set(selected_features) - set(normalized_df.columns)
            return f"Error: Missing features: {list(missing)[:5]}..."

        selected_df = normalized_df[selected_features].fillna(0)
        return selected_df.values
    except Exception as e:
        return f"Error in feature extraction: {str(e)}"


# ---------------------------------------------------------------------------
# MIC prediction — runs in a SEPARATE process (mic_worker.py).
# This isolates PyTorch/ProtBert from the main process and prevents the
# native-library crash (exit 139) plus the OOM spike on the free tier.
# ---------------------------------------------------------------------------
def predictmic(sequence):
    sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
    if len(sequence) < 10:
        return {"Error": "Sequence too short or invalid."}

    try:
        proc = subprocess.run(
            [sys.executable, "mic_worker.py", sequence],
            capture_output=True, text=True, timeout=900
        )
    except subprocess.TimeoutExpired:
        return {"Error": "MIC prediction timed out (ProtBert may still be downloading; try again shortly)."}
    except Exception as e:
        return {"Error": f"Failed to start MIC worker: {str(e)}"}

    if proc.returncode != 0:
        tail = (proc.stderr or "").strip().splitlines()[-3:]
        return {"Error": f"MIC worker exited with code {proc.returncode}. {' '.join(tail)}"}

    out_lines = [ln for ln in (proc.stdout or "").splitlines() if ln.strip()]
    if not out_lines:
        return {"Error": "MIC worker produced no output."}
    try:
        return json.loads(out_lines[-1])
    except Exception:
        return {"Error": f"Could not parse MIC worker output: {out_lines[-1][:200]}"}


# ---------------------------------------------------------------------------
# Main prediction pipeline
# ---------------------------------------------------------------------------
def full_prediction(sequence):
    features = extract_features(sequence)
    if isinstance(features, str):
        return features

    prediction = model.predict(features)[0]
    probabilities = model.predict_proba(features)[0]

    try:
        class_index = list(model.classes_).index(prediction)
        confidence = round(probabilities[class_index] * 100, 2)
    except Exception:
        confidence = "Unknown"

    amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
    result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"

    # ---- LIME first (per your spec: LIME before SHAP in the report) ----
    try:
        explanation = explainer.explain_instance(
            data_row=features[0],          # <-- explicitly the single input sequence
            predict_fn=model.predict_proba,
            num_features=10,
            num_samples=2000,              # perturbations around this single input
        )
        result += "\nTop Features Influencing Prediction (LIME):\n"
        for feat, weight in explanation.as_list():
            result += f"- {feat}: {round(weight, 4)}\n"
    except Exception as e:
        result += f"\nLIME explanation failed: {str(e)}\n"

    # ---- MIC (only for AMPs) ----
    if prediction == 0:
        mic_values = predictmic(sequence)
        result += "\nPredicted MIC Values (μM):\n"
        for org, mic in mic_values.items():
            result += f"- {org}: {mic}\n"
    else:
        result += "\nMIC prediction skipped for Non-AMP sequences.\n"

    return result


# Gradio UI
iface = gr.Interface(
    fn=full_prediction,
    inputs=gr.Textbox(label="Enter Protein Sequence"),
    outputs=gr.Textbox(label="Results"),
    title="AMP & MIC Predictor + LIME Explanation",
    description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
)

iface.launch()