File size: 13,016 Bytes
98f9e87
f8615d0
 
 
 
 
 
 
133e26c
2739a59
 
 
85c36de
942bf87
51a3749
ea9a1bf
b0a80bd
12675f2
 
 
febb4a6
2739a59
133e26c
12675f2
 
 
2739a59
133e26c
 
2739a59
 
 
 
 
 
 
 
 
942bf87
f0f9b27
2739a59
133e26c
 
 
2739a59
b0a80bd
5d02ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0a80bd
133e26c
b0a80bd
2739a59
b0a80bd
133e26c
2739a59
 
b0a80bd
 
133e26c
b0a80bd
b206439
f776418
44f5cf9
133e26c
63d3a19
 
 
 
 
2739a59
 
133e26c
63d3a19
133e26c
63d3a19
133e26c
 
 
 
44f5cf9
133e26c
 
 
 
63d3a19
133e26c
 
 
63d3a19
133e26c
 
63d3a19
 
 
2739a59
44f5cf9
12675f2
 
 
 
 
 
63d3a19
 
 
 
12675f2
 
 
 
 
 
 
 
 
 
 
 
44f5cf9
12675f2
 
 
 
63d3a19
12675f2
 
 
 
63d3a19
12675f2
 
 
 
63d3a19
2739a59
44f5cf9
769f73d
63d3a19
 
769f73d
63d3a19
769f73d
63d3a19
2739a59
769f73d
2739a59
769f73d
b0a80bd
 
133e26c
b0a80bd
133e26c
b0a80bd
 
133e26c
b0a80bd
 
 
 
 
b206439
133e26c
b0a80bd
63d3a19
 
b0a80bd
769f73d
63d3a19
769f73d
63d3a19
 
 
 
 
 
b0a80bd
2739a59
 
 
 
 
 
 
 
b0a80bd
 
 
 
 
 
 
 
 
 
63d3a19
 
44f5cf9
2739a59
44f5cf9
63d3a19
 
 
 
 
44f5cf9
68ded6f
2739a59
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import os
# --- Prevent SIGSEGV (exit 139) from TensorFlow + PyTorch native lib clashes ---
# TF and torch each bundle their own OpenMP/MKL; loaded together they can collide
# and crash at the C level. These settings make them coexist and reduce memory.
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
# Quiet TensorFlow logs (must be set before importing tensorflow)
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")

import gradio as gr
import joblib
import numpy as np
import pandas as pd
from propy import AAComposition, CTD
import sys
import json
import subprocess

# ---------------------------------------------------------------------------
# LAZY LOADING — keeps the free 16GB Space from OOM-ing at startup.
# Only the TensorFlow AMP model is loaded in THIS process. ProtBert/PyTorch
# run in a SEPARATE process (mic_worker.py) to avoid a native-library clash
# between TensorFlow and PyTorch that caused SIGSEGV (exit 139).
# ---------------------------------------------------------------------------
_amp_model = None
_amp_scaler = None


def get_amp_model():
    global _amp_model, _amp_scaler
    if _amp_model is None:
        from tensorflow.keras.models import load_model
        _amp_model = load_model("Comb1_aac_ctd_RFE_selected_features_model.keras")
        _amp_scaler = joblib.load("Comb1_aac_ctd_RFE_selected_features_scaler.joblib")
    return _amp_model, _amp_scaler


# ---------------------------------------------------------------------------
# The EXACT 343 features the scaler was fit on, IN THE EXACT TRAINING ORDER.
# The scaler was fit on a numpy array (no stored names), so order is critical:
# we must select these columns in this order BEFORE calling scaler.transform().
# ---------------------------------------------------------------------------
selected_features = [
    "_PolarizabilityC1", "_PolarizabilityC2", "_PolarizabilityC3",
    "_SolventAccessibilityC1", "_SolventAccessibilityC2", "_SolventAccessibilityC3",
    "_SecondaryStrC1", "_SecondaryStrC2", "_SecondaryStrC3",
    "_ChargeC1", "_ChargeC2", "_ChargeC3",
    "_PolarityC1", "_PolarityC2", "_PolarityC3",
    "_NormalizedVDWVC1", "_NormalizedVDWVC2", "_NormalizedVDWVC3",
    "_HydrophobicityC1", "_HydrophobicityC2", "_HydrophobicityC3",
    "_PolarizabilityT12", "_PolarizabilityT13", "_PolarizabilityT23",
    "_SolventAccessibilityT12", "_SolventAccessibilityT13", "_SolventAccessibilityT23",
    "_SecondaryStrT12", "_SecondaryStrT13", "_SecondaryStrT23",
    "_ChargeT12", "_ChargeT13", "_ChargeT23",
    "_PolarityT12", "_PolarityT13", "_PolarityT23",
    "_NormalizedVDWVT12", "_NormalizedVDWVT13", "_NormalizedVDWVT23",
    "_HydrophobicityT12", "_HydrophobicityT13", "_HydrophobicityT23",
    "_PolarizabilityD1001", "_PolarizabilityD1025", "_PolarizabilityD1050",
    "_PolarizabilityD1075", "_PolarizabilityD1100",
    "_PolarizabilityD2001", "_PolarizabilityD2025", "_PolarizabilityD2050",
    "_PolarizabilityD2075", "_PolarizabilityD2100",
    "_PolarizabilityD3001", "_PolarizabilityD3025", "_PolarizabilityD3050",
    "_PolarizabilityD3075", "_PolarizabilityD3100",
    "_SolventAccessibilityD1001", "_SolventAccessibilityD1025",
    "_SolventAccessibilityD1050", "_SolventAccessibilityD1075",
    "_SolventAccessibilityD1100",
    "_SolventAccessibilityD2001", "_SolventAccessibilityD2025",
    "_SolventAccessibilityD2050", "_SolventAccessibilityD2075",
    "_SolventAccessibilityD2100",
    "_SolventAccessibilityD3001", "_SolventAccessibilityD3025",
    "_SolventAccessibilityD3050", "_SolventAccessibilityD3075",
    "_SolventAccessibilityD3100",
    "_SecondaryStrD1001", "_SecondaryStrD1025", "_SecondaryStrD1050",
    "_SecondaryStrD1075", "_SecondaryStrD1100",
    "_SecondaryStrD2001", "_SecondaryStrD2025", "_SecondaryStrD2050",
    "_SecondaryStrD2075", "_SecondaryStrD2100",
    "_SecondaryStrD3001", "_SecondaryStrD3025", "_SecondaryStrD3050",
    "_SecondaryStrD3075", "_SecondaryStrD3100",
    "_ChargeD1001", "_ChargeD1025", "_ChargeD1050",
    "_ChargeD1075", "_ChargeD1100",
    "_ChargeD2001", "_ChargeD2025", "_ChargeD2050",
    "_ChargeD2075",
    "_ChargeD3001", "_ChargeD3025", "_ChargeD3050",
    "_ChargeD3075", "_ChargeD3100",
    "_PolarityD1001", "_PolarityD1025", "_PolarityD1050",
    "_PolarityD1075", "_PolarityD1100",
    "_PolarityD2001", "_PolarityD2025", "_PolarityD2050",
    "_PolarityD2075", "_PolarityD2100",
    "_PolarityD3001", "_PolarityD3025", "_PolarityD3050",
    "_PolarityD3075", "_PolarityD3100",
    "_NormalizedVDWVD1001", "_NormalizedVDWVD1025",
    "_NormalizedVDWVD1050", "_NormalizedVDWVD1075",
    "_NormalizedVDWVD1100",
    "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
    "_NormalizedVDWVD2050", "_NormalizedVDWVD2075",
    "_NormalizedVDWVD2100",
    "_NormalizedVDWVD3001", "_NormalizedVDWVD3025",
    "_NormalizedVDWVD3050", "_NormalizedVDWVD3075",
    "_NormalizedVDWVD3100",
    "_HydrophobicityD1001", "_HydrophobicityD1025",
    "_HydrophobicityD1050", "_HydrophobicityD1075",
    "_HydrophobicityD1100",
    "_HydrophobicityD2001", "_HydrophobicityD2025",
    "_HydrophobicityD2050", "_HydrophobicityD2075",
    "_HydrophobicityD2100",
    "_HydrophobicityD3001", "_HydrophobicityD3025",
    "_HydrophobicityD3050", "_HydrophobicityD3075",
    "_HydrophobicityD3100",
    "A", "R", "N", "D", "C", "E", "Q", "G", "H", "I",
    "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V",
    "AR", "AD", "AQ", "AG", "AL", "AK", "AF", "AP", "AT", "AV",
    "RA", "RC", "RE", "RG", "RI", "RL", "RS", "RT", "RV",
    "NR", "NC", "NG", "NI", "NP", "NS", "NY", "NV",
    "DR", "DN", "DC", "DE", "DG", "DF", "DS", "DT", "DY",
    "CR", "CN", "CD", "CC", "CI", "CL", "CK", "CT", "CY", "CV",
    "EA", "ER", "ED", "EC", "EE", "EG", "EI", "EL", "EK",
    "EF", "EP", "ET", "EV",
    "QN", "QF", "QV",
    "GA", "GR", "GC", "GE", "GG", "GI", "GL", "GK", "GF", "GP", "GY",
    "HA", "HP", "HT",
    "IA", "IR", "ID", "II", "IL", "IF", "IP", "IS", "IV",
    "LA", "LR", "LD", "LC", "LG", "LI", "LK", "LM", "LF",
    "LS", "LT", "LY", "LV",
    "KA", "KN", "KC", "KG", "KI", "KL", "KK", "KP", "KY",
    "MA", "MD", "ME", "MI", "MK", "MF", "MP", "MS", "MV",
    "FR", "FE", "FQ", "FG", "FL", "FF", "FS", "FT", "FY", "FV",
    "PA", "PR", "PC", "PE", "PL", "PK", "PP", "PS", "PV",
    "SA", "SR", "SD", "SC", "SG", "SH", "SI", "SL", "SP", "ST", "SY",
    "TA", "TR", "TC", "TE", "TQ", "TG", "TI", "TL", "TP", "TS", "TV",
    "WA",
    "YN", "YD", "YC", "YQ", "YG", "YP",
    "VA", "VR", "VD", "VC", "VE", "VG", "VI", "VL", "VK",
    "VS", "VT", "VY", "VV"
]
assert len(selected_features) == 343, f"Expected 343 features, got {len(selected_features)}"


def keras_predict_proba(X):
    """Return probabilities as [P(Non-AMP), P(AMP)] for LIME (X already scaled)."""
    amp_model, _ = get_amp_model()
    preds = amp_model.predict(X, verbose=0)
    if preds.ndim == 1 or preds.shape[1] == 1:
        preds = preds.reshape(-1, 1)
        return np.hstack([1 - preds, preds])  # sigmoid output assumed = P(AMP)
    return preds


def extract_features(sequence):
    """Compute CTD + AAC, select the 343 training columns IN ORDER, then scale."""
    sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
    if len(sequence) < 10:
        return "Error: Sequence too short."

    try:
        _, amp_scaler = get_amp_model()

        # Compute full feature pool
        ctd_features = CTD.CalculateCTD(sequence)
        aac = AAComposition.CalculateAADipeptideComposition(sequence)

        # Merge everything into one lookup dict
        pool = {}
        pool.update(ctd_features)
        pool.update(aac)

        # Verify all needed features are present
        missing = [f for f in selected_features if f not in pool]
        if missing:
            return f"Error: Missing features from propy: {missing[:5]}..."

        # Build the 343-wide row IN THE EXACT TRAINING ORDER, THEN scale.
        ordered_values = [pool[f] for f in selected_features]
        feature_row = np.array(ordered_values, dtype=np.float64).reshape(1, -1)

        scaled = amp_scaler.transform(feature_row)  # scaler expects exactly 343 cols
        return scaled.astype(np.float32)
    except Exception as e:
        return f"Error in feature extraction: {str(e)}"


def predictmic(sequence):
    """Run MIC prediction in a SEPARATE process (mic_worker.py).

    This isolates PyTorch/ProtBert from TensorFlow, preventing the native-library
    crash (exit 139) and keeping peak memory low. The worker prints a JSON dict on
    its last stdout line; we parse and return it.
    """
    sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
    if len(sequence) < 10:
        return {"Error": "Sequence too short or invalid."}

    try:
        # First run downloads ProtBert (~1.6GB), so allow a generous timeout.
        proc = subprocess.run(
            [sys.executable, "mic_worker.py", sequence],
            capture_output=True,
            text=True,
            timeout=900  # 15 minutes; mostly for the one-time model download
        )
    except subprocess.TimeoutExpired:
        return {"Error": "MIC prediction timed out (model download may still be in progress; try again shortly)."}
    except Exception as e:
        return {"Error": f"Failed to start MIC worker: {str(e)}"}

    if proc.returncode != 0:
        # Worker crashed; surface stderr tail for debugging
        tail = (proc.stderr or "").strip().splitlines()[-3:]
        return {"Error": f"MIC worker exited with code {proc.returncode}. {' '.join(tail)}"}

    # Parse the last non-empty stdout line as JSON
    out_lines = [ln for ln in (proc.stdout or "").splitlines() if ln.strip()]
    if not out_lines:
        return {"Error": "MIC worker produced no output."}

    try:
        return json.loads(out_lines[-1])
    except Exception:
        return {"Error": f"Could not parse MIC worker output: {out_lines[-1][:200]}"}


def full_prediction(sequence):
    print("[CHECKPOINT] full_prediction called", flush=True)
    features = extract_features(sequence)
    if isinstance(features, str):
        print("[CHECKPOINT] extract_features returned error:", features, flush=True)
        return features
    print("[CHECKPOINT] features extracted OK, shape:", features.shape, flush=True)

    amp_model, _ = get_amp_model()
    print("[CHECKPOINT] AMP model loaded, running predict...", flush=True)
    raw_pred = amp_model.predict(features, verbose=0)
    print("[CHECKPOINT] AMP predict done:", raw_pred, flush=True)

    if raw_pred.ndim == 1 or raw_pred.shape[1] == 1:
        prob_amp = float(raw_pred.flatten()[0])  # sigmoid output assumed = P(AMP)
        if prob_amp >= 0.5:
            prediction = 1
            confidence = round(prob_amp * 100, 2)
        else:
            prediction = 0
            confidence = round((1 - prob_amp) * 100, 2)
    else:
        class_idx = int(np.argmax(raw_pred[0]))
        prediction = class_idx
        confidence = round(float(raw_pred[0][class_idx]) * 100, 2)

    # Label convention: 1 = AMP, 0 = Non-AMP (swap if your model is reversed)
    amp_result = "Antimicrobial Peptide (AMP)" if prediction == 1 else "Non-AMP"
    result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"

    if prediction == 1:
        print("[CHECKPOINT] AMP detected, starting MIC (ProtBert)...", flush=True)
        mic_values = predictmic(sequence)
        print("[CHECKPOINT] MIC done:", mic_values, flush=True)
        result += "\nPredicted MIC Values (μM):\n"
        for org, mic in mic_values.items():
            result += f"- {org}: {mic}\n"
    else:
        result += "\nMIC prediction skipped for Non-AMP sequences.\n"

    try:
        from lime.lime_tabular import LimeTabularExplainer
        sample_data = np.random.rand(100, len(selected_features))
        explainer = LimeTabularExplainer(
            training_data=sample_data,
            feature_names=selected_features,
            class_names=["Non-AMP", "AMP"],
            mode="classification"
        )
        explanation = explainer.explain_instance(
            data_row=features[0],
            predict_fn=keras_predict_proba,
            num_features=10
        )
        result += "\nTop Features Influencing Prediction:\n"
        for feat, weight in explanation.as_list():
            result += f"- {feat}: {round(weight, 4)}\n"
    except Exception as e:
        result += f"\nLIME explanation failed: {str(e)}\n"

    return result


iface = gr.Interface(
    fn=full_prediction,
    inputs=gr.Textbox(label="Enter Protein Sequence"),
    outputs=gr.Textbox(label="Results"),
    title="AMP & MIC Predictor + LIME Explanation",
    description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
)

iface.launch()