nonzeroexit commited on
Commit
0ff9972
·
verified ·
1 Parent(s): 8a9cc7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -102
app.py CHANGED
@@ -1,31 +1,25 @@
1
- import os
2
- # Native-lib hygiene (prevents TF/PyTorch SIGSEGV when both load; harmless for RF)
3
- os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
4
- os.environ.setdefault("OMP_NUM_THREADS", "1")
5
- os.environ.setdefault("MKL_NUM_THREADS", "1")
6
- os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
7
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
8
-
9
- import sys
10
- import json
11
- import subprocess
12
  import joblib
13
  import numpy as np
14
  import pandas as pd
15
  from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
 
 
 
16
  from lime.lime_tabular import LimeTabularExplainer
 
17
 
18
- import gradio as gr
19
-
20
- # ---------------------------------------------------------------------------
21
- # Load Random Forest AMP classifier + MinMax scaler (original files)
22
- # ---------------------------------------------------------------------------
23
  model = joblib.load("RF.joblib")
24
  scaler = joblib.load("norm (4).joblib")
25
 
26
- # ---------------------------------------------------------------------------
27
- # Original 138 RFE-selected features (CTD + AAC + Autocorrelation + APAAC)
28
- # ---------------------------------------------------------------------------
 
 
 
 
29
  selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
30
  "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
31
  "_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
@@ -53,49 +47,22 @@ selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondarySt
53
  "GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
54
  "APAAC15", "APAAC18", "APAAC19", "APAAC24"]
55
 
56
- assert len(selected_features) == 138, f"Expected 138 features, got {len(selected_features)}"
57
-
58
- # ---------------------------------------------------------------------------
59
- # LIME explainer
60
- # Built ONCE at startup so explanations are reproducible across requests.
61
- # Prefers a real normalized training sample (lime_background.joblib). Falls
62
- # back to seeded uniform noise if that file isn't present (still stable, but
63
- # less faithful to the true feature distribution).
64
- # ---------------------------------------------------------------------------
65
- try:
66
- _lime_background = joblib.load("lime_background.joblib")
67
- if _lime_background.shape[1] != len(selected_features):
68
- raise ValueError(
69
- f"lime_background.joblib has {_lime_background.shape[1]} cols, "
70
- f"expected {len(selected_features)}"
71
- )
72
- print(f"[LIME] Using real training sample: {_lime_background.shape}", flush=True)
73
- except Exception as e:
74
- print(f"[LIME] No usable lime_background.joblib ({e}); falling back to uniform noise.", flush=True)
75
- _rng = np.random.default_rng(seed=42)
76
- _lime_background = _rng.uniform(low=0.0, high=1.0, size=(500, len(selected_features)))
77
-
78
  explainer = LimeTabularExplainer(
79
- training_data=_lime_background,
80
  feature_names=selected_features,
81
  class_names=["AMP", "Non-AMP"],
82
- mode="classification",
83
- discretize_continuous=True,
84
- random_state=42, # stable explanations
85
  )
86
 
87
-
88
- # ---------------------------------------------------------------------------
89
- # Feature extraction — produces the full propy feature pool, scales it with
90
- # the saved MinMax scaler, then selects the 138 features the RF was trained on.
91
- # ---------------------------------------------------------------------------
92
  def extract_features(sequence):
93
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
94
  if len(sequence) < 10:
95
  return "Error: Sequence too short."
96
 
97
  try:
98
- # Original full pool: CTD + AAC(first 420) + Autocorrelation + PseudoAAC
99
  dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
100
  filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
101
  ctd_features = CTD.CalculateCTD(sequence)
@@ -113,51 +80,50 @@ def extract_features(sequence):
113
  normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
114
 
115
  if not set(selected_features).issubset(normalized_df.columns):
116
- missing = set(selected_features) - set(normalized_df.columns)
117
- return f"Error: Missing features: {list(missing)[:5]}..."
118
 
119
  selected_df = normalized_df[selected_features].fillna(0)
120
  return selected_df.values
121
  except Exception as e:
122
  return f"Error in feature extraction: {str(e)}"
123
 
124
-
125
- # ---------------------------------------------------------------------------
126
- # MIC prediction — runs in a SEPARATE process (mic_worker.py).
127
- # This isolates PyTorch/ProtBert from the main process and prevents the
128
- # native-library crash (exit 139) plus the OOM spike on the free tier.
129
- # ---------------------------------------------------------------------------
130
  def predictmic(sequence):
131
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
132
  if len(sequence) < 10:
133
  return {"Error": "Sequence too short or invalid."}
134
 
135
- try:
136
- proc = subprocess.run(
137
- [sys.executable, "mic_worker.py", sequence],
138
- capture_output=True, text=True, timeout=900
139
- )
140
- except subprocess.TimeoutExpired:
141
- return {"Error": "MIC prediction timed out (ProtBert may still be downloading; try again shortly)."}
142
- except Exception as e:
143
- return {"Error": f"Failed to start MIC worker: {str(e)}"}
144
-
145
- if proc.returncode != 0:
146
- tail = (proc.stderr or "").strip().splitlines()[-3:]
147
- return {"Error": f"MIC worker exited with code {proc.returncode}. {' '.join(tail)}"}
148
-
149
- out_lines = [ln for ln in (proc.stdout or "").splitlines() if ln.strip()]
150
- if not out_lines:
151
- return {"Error": "MIC worker produced no output."}
152
- try:
153
- return json.loads(out_lines[-1])
154
- except Exception:
155
- return {"Error": f"Could not parse MIC worker output: {out_lines[-1][:200]}"}
156
-
157
-
158
- # ---------------------------------------------------------------------------
159
- # Main prediction pipeline
160
- # ---------------------------------------------------------------------------
 
 
 
 
 
161
  def full_prediction(sequence):
162
  features = extract_features(sequence)
163
  if isinstance(features, str):
@@ -175,21 +141,6 @@ def full_prediction(sequence):
175
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
176
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
177
 
178
- # ---- LIME first (per your spec: LIME before SHAP in the report) ----
179
- try:
180
- explanation = explainer.explain_instance(
181
- data_row=features[0], # <-- explicitly the single input sequence
182
- predict_fn=model.predict_proba,
183
- num_features=10,
184
- num_samples=2000, # perturbations around this single input
185
- )
186
- result += "\nTop Features Influencing Prediction (LIME):\n"
187
- for feat, weight in explanation.as_list():
188
- result += f"- {feat}: {round(weight, 4)}\n"
189
- except Exception as e:
190
- result += f"\nLIME explanation failed: {str(e)}\n"
191
-
192
- # ---- MIC (only for AMPs) ----
193
  if prediction == 0:
194
  mic_values = predictmic(sequence)
195
  result += "\nPredicted MIC Values (μM):\n"
@@ -198,8 +149,17 @@ def full_prediction(sequence):
198
  else:
199
  result += "\nMIC prediction skipped for Non-AMP sequences.\n"
200
 
201
- return result
 
 
 
 
202
 
 
 
 
 
 
203
 
204
  # Gradio UI
205
  iface = gr.Interface(
@@ -210,4 +170,4 @@ iface = gr.Interface(
210
  description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
211
  )
212
 
213
- iface.launch()
 
1
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
2
  import joblib
3
  import numpy as np
4
  import pandas as pd
5
  from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
6
+ from sklearn.preprocessing import MinMaxScaler
7
+ import torch
8
+ from transformers import BertTokenizer, BertModel
9
  from lime.lime_tabular import LimeTabularExplainer
10
+ from math import expm1
11
 
12
+ # Load AMP Classifier and Scaler
 
 
 
 
13
  model = joblib.load("RF.joblib")
14
  scaler = joblib.load("norm (4).joblib")
15
 
16
+ # Load ProtBert
17
+ tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
18
+ protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ protbert_model = protbert_model.to(device).eval()
21
+
22
+ # Define selected features (put your complete list here)
23
  selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
24
  "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
25
  "_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
 
47
  "GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
48
  "APAAC15", "APAAC18", "APAAC19", "APAAC24"]
49
 
50
+ # Dummy data for LIME
51
+ sample_data = np.random.rand(100, len(selected_features))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  explainer = LimeTabularExplainer(
53
+ training_data=sample_data,
54
  feature_names=selected_features,
55
  class_names=["AMP", "Non-AMP"],
56
+ mode="classification"
 
 
57
  )
58
 
59
+ # Feature extraction function
 
 
 
 
60
  def extract_features(sequence):
61
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
62
  if len(sequence) < 10:
63
  return "Error: Sequence too short."
64
 
65
  try:
 
66
  dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
67
  filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
68
  ctd_features = CTD.CalculateCTD(sequence)
 
80
  normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
81
 
82
  if not set(selected_features).issubset(normalized_df.columns):
83
+ return "Error: Some selected features are missing."
 
84
 
85
  selected_df = normalized_df[selected_features].fillna(0)
86
  return selected_df.values
87
  except Exception as e:
88
  return f"Error in feature extraction: {str(e)}"
89
 
90
+ # MIC prediction function
 
 
 
 
 
91
  def predictmic(sequence):
92
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
93
  if len(sequence) < 10:
94
  return {"Error": "Sequence too short or invalid."}
95
 
96
+ seq_spaced = ' '.join(list(sequence))
97
+ tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
98
+ tokens = {k: v.to(device) for k, v in tokens.items()}
99
+
100
+ with torch.no_grad():
101
+ outputs = protbert_model(**tokens)
102
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
103
+
104
+ bacteria_config = {
105
+ "E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
106
+ "S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
107
+ "P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
108
+ "K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
109
+ }
110
+
111
+ mic_results = {}
112
+ for bacterium, cfg in bacteria_config.items():
113
+ try:
114
+ scaler = joblib.load(cfg["scaler"])
115
+ scaled = scaler.transform(embedding)
116
+ transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
117
+ model = joblib.load(cfg["model"])
118
+ mic_log = model.predict(transformed)[0]
119
+ mic = round(expm1(mic_log), 3)
120
+ mic_results[bacterium] = mic
121
+ except Exception as e:
122
+ mic_results[bacterium] = f"Error: {str(e)}"
123
+
124
+ return mic_results
125
+
126
+ # Main prediction function
127
  def full_prediction(sequence):
128
  features = extract_features(sequence)
129
  if isinstance(features, str):
 
141
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
142
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  if prediction == 0:
145
  mic_values = predictmic(sequence)
146
  result += "\nPredicted MIC Values (μM):\n"
 
149
  else:
150
  result += "\nMIC prediction skipped for Non-AMP sequences.\n"
151
 
152
+ explanation = explainer.explain_instance(
153
+ data_row=features[0],
154
+ predict_fn=model.predict_proba,
155
+ num_features=10
156
+ )
157
 
158
+ result += "\nTop Features Influencing Prediction:\n"
159
+ for feat, weight in explanation.as_list():
160
+ result += f"- {feat}: {round(weight, 4)}\n"
161
+
162
+ return result
163
 
164
  # Gradio UI
165
  iface = gr.Interface(
 
170
  description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
171
  )
172
 
173
+ iface.launch(share=True)