nonzeroexit commited on
Commit
2739a59
·
verified ·
1 Parent(s): 98f9e87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -39
app.py CHANGED
@@ -1,30 +1,60 @@
1
  import os
2
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
3
- os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
 
 
4
  import gradio as gr
5
  import joblib
6
  import numpy as np
7
  import pandas as pd
8
  from propy import AAComposition, CTD
9
- import tensorflow as tf
10
- from tensorflow.keras.models import load_model
11
- import torch
12
- from transformers import BertTokenizer, BertModel
13
- from lime.lime_tabular import LimeTabularExplainer
14
  from math import expm1
15
 
16
- # Load AMP Classifier (Keras) and Scaler
17
- model = load_model("Comb1_aac_ctd_RFE_selected_features_model.keras")
18
- scaler = joblib.load("Comb1_aac_ctd_RFE_selected_features_scaler.joblib")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Load ProtBert (for MIC prediction)
21
- tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
22
- protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- protbert_model = protbert_model.to(device).eval()
25
 
26
- # Define selected features (AAC + CTD, RFE-selected)
27
- # Note: 'Activity' is the target label and is excluded from input features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  selected_features = [
29
  '_PolarizabilityC1', '_PolarizabilityC2', '_PolarizabilityC3',
30
  '_SolventAccessibilityC1', '_SolventAccessibilityC2', '_SolventAccessibilityC3',
@@ -84,32 +114,27 @@ selected_features = [
84
  'VL', 'VK', 'VM', 'VF', 'VP', 'VS', 'VT', 'VW', 'VY', 'VV'
85
  ]
86
 
87
- # Wrapper to make Keras model behave like a sklearn classifier for LIME
88
  def keras_predict_proba(X):
89
- """Return probabilities for both classes as [P(Non-AMP), P(AMP)]."""
90
- preds = model.predict(X, verbose=0)
 
91
  if preds.ndim == 1 or preds.shape[1] == 1:
92
  preds = preds.reshape(-1, 1)
93
  # Assuming sigmoid output = P(AMP); adjust if your model is reversed.
94
  return np.hstack([1 - preds, preds])
95
  return preds
96
 
97
- # Dummy data for LIME
98
- sample_data = np.random.rand(100, len(selected_features))
99
- explainer = LimeTabularExplainer(
100
- training_data=sample_data,
101
- feature_names=selected_features,
102
- class_names=["Non-AMP", "AMP"],
103
- mode="classification"
104
- )
105
 
106
- # Feature extraction function (AAC + CTD only)
107
  def extract_features(sequence):
 
108
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
109
  if len(sequence) < 10:
110
  return "Error: Sequence too short."
111
 
112
  try:
 
 
113
  # AAC: 20 single AAs + 400 dipeptides = 420 features
114
  dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
115
  filtered_aac = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
@@ -122,7 +147,7 @@ def extract_features(sequence):
122
  all_features_dict.update(filtered_aac)
123
 
124
  feature_df_all = pd.DataFrame([all_features_dict])
125
- normalized_array = scaler.transform(feature_df_all.values)
126
  normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
127
 
128
  if not set(selected_features).issubset(normalized_df.columns):
@@ -134,14 +159,18 @@ def extract_features(sequence):
134
  except Exception as e:
135
  return f"Error in feature extraction: {str(e)}"
136
 
137
- # MIC prediction function (unchanged)
138
  def predictmic(sequence):
 
139
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
140
  if len(sequence) < 10:
141
  return {"Error": "Sequence too short or invalid."}
142
 
 
 
143
  seq_spaced = ' '.join(list(sequence))
144
- tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
 
145
  tokens = {k: v.to(device) for k, v in tokens.items()}
146
 
147
  with torch.no_grad():
@@ -170,14 +199,14 @@ def predictmic(sequence):
170
 
171
  return mic_results
172
 
173
- # Main prediction function
174
  def full_prediction(sequence):
175
  features = extract_features(sequence)
176
  if isinstance(features, str):
177
  return features
178
 
179
- # Keras prediction
180
- raw_pred = model.predict(features, verbose=0)
181
 
182
  # Handle sigmoid (1 output) vs softmax (>=2 outputs)
183
  if raw_pred.ndim == 1 or raw_pred.shape[1] == 1:
@@ -193,7 +222,7 @@ def full_prediction(sequence):
193
  prediction = class_idx
194
  confidence = round(float(raw_pred[0][class_idx]) * 100, 2)
195
 
196
- # Label convention: 1 = AMP, 0 = Non-AMP (swap if your model uses the opposite)
197
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 1 else "Non-AMP"
198
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
199
 
@@ -205,14 +234,21 @@ def full_prediction(sequence):
205
  else:
206
  result += "\nMIC prediction skipped for Non-AMP sequences.\n"
207
 
208
- # LIME explanation
209
  try:
 
 
 
 
 
 
 
 
210
  explanation = explainer.explain_instance(
211
  data_row=features[0],
212
  predict_fn=keras_predict_proba,
213
  num_features=10
214
  )
215
-
216
  result += "\nTop Features Influencing Prediction:\n"
217
  for feat, weight in explanation.as_list():
218
  result += f"- {feat}: {round(weight, 4)}\n"
@@ -221,6 +257,7 @@ def full_prediction(sequence):
221
 
222
  return result
223
 
 
224
  # Gradio UI
225
  iface = gr.Interface(
226
  fn=full_prediction,
@@ -230,4 +267,4 @@ iface = gr.Interface(
230
  description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
231
  )
232
 
233
- iface.launch(share=True)
 
1
  import os
2
+ # Quiet TensorFlow logs and disable oneDNN nondeterminism notice
3
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
4
+ os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
5
+
6
  import gradio as gr
7
  import joblib
8
  import numpy as np
9
  import pandas as pd
10
  from propy import AAComposition, CTD
 
 
 
 
 
11
  from math import expm1
12
 
13
+ # ---------------------------------------------------------------------------
14
+ # LAZY LOADING
15
+ # On the free 16GB Space, loading TensorFlow + PyTorch + ProtBert all at once
16
+ # at import time causes an out-of-memory crash. We therefore load each heavy
17
+ # component only when it is first needed, and cache it after that.
18
+ # ---------------------------------------------------------------------------
19
+
20
+ _amp_model = None # Keras AMP classifier
21
+ _amp_scaler = None # joblib scaler for AMP features
22
+ _protbert_tokenizer = None
23
+ _protbert_model = None
24
+ _torch = None # torch module, imported lazily
25
+ _device = None
26
+
27
+
28
+ def get_amp_model():
29
+ """Load the Keras AMP classifier + scaler on first use."""
30
+ global _amp_model, _amp_scaler
31
+ if _amp_model is None:
32
+ from tensorflow.keras.models import load_model
33
+ _amp_model = load_model("Comb1_aac_ctd_RFE_selected_features_model.keras")
34
+ _amp_scaler = joblib.load("Comb1_aac_ctd_RFE_selected_features_scaler.joblib")
35
+ return _amp_model, _amp_scaler
36
 
 
 
 
 
 
37
 
38
+ def get_protbert():
39
+ """Load ProtBert tokenizer + model on first use (only needed for MIC)."""
40
+ global _protbert_tokenizer, _protbert_model, _torch, _device
41
+ if _protbert_model is None:
42
+ import torch
43
+ from transformers import BertTokenizer, BertModel
44
+ _torch = torch
45
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ _protbert_tokenizer = BertTokenizer.from_pretrained(
47
+ "Rostlab/prot_bert", do_lower_case=False
48
+ )
49
+ _protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
50
+ _protbert_model = _protbert_model.to(_device).eval()
51
+ return _protbert_tokenizer, _protbert_model, _torch, _device
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Selected features (AAC + CTD, RFE-selected). 'Activity' is the target label
56
+ # and is intentionally excluded from the input features.
57
+ # ---------------------------------------------------------------------------
58
  selected_features = [
59
  '_PolarizabilityC1', '_PolarizabilityC2', '_PolarizabilityC3',
60
  '_SolventAccessibilityC1', '_SolventAccessibilityC2', '_SolventAccessibilityC3',
 
114
  'VL', 'VK', 'VM', 'VF', 'VP', 'VS', 'VT', 'VW', 'VY', 'VV'
115
  ]
116
 
117
+
118
  def keras_predict_proba(X):
119
+ """Return probabilities as [P(Non-AMP), P(AMP)] for LIME."""
120
+ amp_model, _ = get_amp_model()
121
+ preds = amp_model.predict(X, verbose=0)
122
  if preds.ndim == 1 or preds.shape[1] == 1:
123
  preds = preds.reshape(-1, 1)
124
  # Assuming sigmoid output = P(AMP); adjust if your model is reversed.
125
  return np.hstack([1 - preds, preds])
126
  return preds
127
 
 
 
 
 
 
 
 
 
128
 
 
129
  def extract_features(sequence):
130
+ """Compute AAC (420) + CTD features, scale, and select RFE features."""
131
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
132
  if len(sequence) < 10:
133
  return "Error: Sequence too short."
134
 
135
  try:
136
+ _, amp_scaler = get_amp_model()
137
+
138
  # AAC: 20 single AAs + 400 dipeptides = 420 features
139
  dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
140
  filtered_aac = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
 
147
  all_features_dict.update(filtered_aac)
148
 
149
  feature_df_all = pd.DataFrame([all_features_dict])
150
+ normalized_array = amp_scaler.transform(feature_df_all.values)
151
  normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
152
 
153
  if not set(selected_features).issubset(normalized_df.columns):
 
159
  except Exception as e:
160
  return f"Error in feature extraction: {str(e)}"
161
 
162
+
163
  def predictmic(sequence):
164
+ """Predict MIC values using ProtBert embeddings + per-bacterium models."""
165
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
166
  if len(sequence) < 10:
167
  return {"Error": "Sequence too short or invalid."}
168
 
169
+ tokenizer, protbert_model, torch, device = get_protbert()
170
+
171
  seq_spaced = ' '.join(list(sequence))
172
+ tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length',
173
+ truncation=True, max_length=512)
174
  tokens = {k: v.to(device) for k, v in tokens.items()}
175
 
176
  with torch.no_grad():
 
199
 
200
  return mic_results
201
 
202
+
203
  def full_prediction(sequence):
204
  features = extract_features(sequence)
205
  if isinstance(features, str):
206
  return features
207
 
208
+ amp_model, _ = get_amp_model()
209
+ raw_pred = amp_model.predict(features, verbose=0)
210
 
211
  # Handle sigmoid (1 output) vs softmax (>=2 outputs)
212
  if raw_pred.ndim == 1 or raw_pred.shape[1] == 1:
 
222
  prediction = class_idx
223
  confidence = round(float(raw_pred[0][class_idx]) * 100, 2)
224
 
225
+ # Label convention: 1 = AMP, 0 = Non-AMP (swap if your model is opposite)
226
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 1 else "Non-AMP"
227
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
228
 
 
234
  else:
235
  result += "\nMIC prediction skipped for Non-AMP sequences.\n"
236
 
237
+ # LIME explanation (lazy import keeps startup light)
238
  try:
239
+ from lime.lime_tabular import LimeTabularExplainer
240
+ sample_data = np.random.rand(100, len(selected_features))
241
+ explainer = LimeTabularExplainer(
242
+ training_data=sample_data,
243
+ feature_names=selected_features,
244
+ class_names=["Non-AMP", "AMP"],
245
+ mode="classification"
246
+ )
247
  explanation = explainer.explain_instance(
248
  data_row=features[0],
249
  predict_fn=keras_predict_proba,
250
  num_features=10
251
  )
 
252
  result += "\nTop Features Influencing Prediction:\n"
253
  for feat, weight in explanation.as_list():
254
  result += f"- {feat}: {round(weight, 4)}\n"
 
257
 
258
  return result
259
 
260
+
261
  # Gradio UI
262
  iface = gr.Interface(
263
  fn=full_prediction,
 
267
  description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
268
  )
269
 
270
+ iface.launch()