nonzeroexit commited on
Commit
133e26c
·
verified ·
1 Parent(s): 5d02ae6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -65
app.py CHANGED
@@ -1,19 +1,5 @@
1
- import joblib
2
- _scaler = joblib.load("Comb1_aac_ctd_RFE_selected_features_scaler.joblib")
3
- print("SCALER n_features_in_:", getattr(_scaler, "n_features_in_", "N/A"), flush=True)
4
- _names = getattr(_scaler, "feature_names_in_", None)
5
- if _names is not None:
6
- print("SCALER FEATURE NAMES (%d):" % len(_names), flush=True)
7
- print(list(_names), flush=True)
8
- else:
9
- print("SCALER has NO feature_names_in_ (fit on numpy array)", flush=True)
10
-
11
- from tensorflow.keras.models import load_model
12
- _m = load_model("Comb1_aac_ctd_RFE_selected_features_model.keras")
13
- print("MODEL input_shape:", _m.input_shape, "output_shape:", _m.output_shape, flush=True)
14
-
15
  import os
16
- # Quiet TensorFlow logs and disable oneDNN nondeterminism notice
17
  os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
18
  os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
19
 
@@ -25,22 +11,18 @@ from propy import AAComposition, CTD
25
  from math import expm1
26
 
27
  # ---------------------------------------------------------------------------
28
- # LAZY LOADING
29
- # On the free 16GB Space, loading TensorFlow + PyTorch + ProtBert all at once
30
- # at import time causes an out-of-memory crash. We therefore load each heavy
31
- # component only when it is first needed, and cache it after that.
32
  # ---------------------------------------------------------------------------
33
-
34
- _amp_model = None # Keras AMP classifier
35
- _amp_scaler = None # joblib scaler for AMP features
36
  _protbert_tokenizer = None
37
  _protbert_model = None
38
- _torch = None # torch module, imported lazily
39
  _device = None
40
 
41
 
42
  def get_amp_model():
43
- """Load the Keras AMP classifier + scaler on first use."""
44
  global _amp_model, _amp_scaler
45
  if _amp_model is None:
46
  from tensorflow.keras.models import load_model
@@ -50,7 +32,6 @@ def get_amp_model():
50
 
51
 
52
  def get_protbert():
53
- """Load ProtBert tokenizer + model on first use (only needed for MIC)."""
54
  global _protbert_tokenizer, _protbert_model, _torch, _device
55
  if _protbert_model is None:
56
  import torch
@@ -66,8 +47,9 @@ def get_protbert():
66
 
67
 
68
  # ---------------------------------------------------------------------------
69
- # Selected features (AAC + CTD, RFE-selected). 'Activity' is the target label
70
- # and is intentionally excluded from the input features.
 
71
  # ---------------------------------------------------------------------------
72
  selected_features = [
73
  "_PolarizabilityC1", "_PolarizabilityC2", "_PolarizabilityC3",
@@ -77,7 +59,6 @@ selected_features = [
77
  "_PolarityC1", "_PolarityC2", "_PolarityC3",
78
  "_NormalizedVDWVC1", "_NormalizedVDWVC2", "_NormalizedVDWVC3",
79
  "_HydrophobicityC1", "_HydrophobicityC2", "_HydrophobicityC3",
80
-
81
  "_PolarizabilityT12", "_PolarizabilityT13", "_PolarizabilityT23",
82
  "_SolventAccessibilityT12", "_SolventAccessibilityT13", "_SolventAccessibilityT23",
83
  "_SecondaryStrT12", "_SecondaryStrT13", "_SecondaryStrT23",
@@ -85,14 +66,12 @@ selected_features = [
85
  "_PolarityT12", "_PolarityT13", "_PolarityT23",
86
  "_NormalizedVDWVT12", "_NormalizedVDWVT13", "_NormalizedVDWVT23",
87
  "_HydrophobicityT12", "_HydrophobicityT13", "_HydrophobicityT23",
88
-
89
  "_PolarizabilityD1001", "_PolarizabilityD1025", "_PolarizabilityD1050",
90
  "_PolarizabilityD1075", "_PolarizabilityD1100",
91
  "_PolarizabilityD2001", "_PolarizabilityD2025", "_PolarizabilityD2050",
92
  "_PolarizabilityD2075", "_PolarizabilityD2100",
93
  "_PolarizabilityD3001", "_PolarizabilityD3025", "_PolarizabilityD3050",
94
  "_PolarizabilityD3075", "_PolarizabilityD3100",
95
-
96
  "_SolventAccessibilityD1001", "_SolventAccessibilityD1025",
97
  "_SolventAccessibilityD1050", "_SolventAccessibilityD1075",
98
  "_SolventAccessibilityD1100",
@@ -102,28 +81,24 @@ selected_features = [
102
  "_SolventAccessibilityD3001", "_SolventAccessibilityD3025",
103
  "_SolventAccessibilityD3050", "_SolventAccessibilityD3075",
104
  "_SolventAccessibilityD3100",
105
-
106
  "_SecondaryStrD1001", "_SecondaryStrD1025", "_SecondaryStrD1050",
107
  "_SecondaryStrD1075", "_SecondaryStrD1100",
108
  "_SecondaryStrD2001", "_SecondaryStrD2025", "_SecondaryStrD2050",
109
  "_SecondaryStrD2075", "_SecondaryStrD2100",
110
  "_SecondaryStrD3001", "_SecondaryStrD3025", "_SecondaryStrD3050",
111
  "_SecondaryStrD3075", "_SecondaryStrD3100",
112
-
113
  "_ChargeD1001", "_ChargeD1025", "_ChargeD1050",
114
  "_ChargeD1075", "_ChargeD1100",
115
  "_ChargeD2001", "_ChargeD2025", "_ChargeD2050",
116
  "_ChargeD2075",
117
  "_ChargeD3001", "_ChargeD3025", "_ChargeD3050",
118
  "_ChargeD3075", "_ChargeD3100",
119
-
120
  "_PolarityD1001", "_PolarityD1025", "_PolarityD1050",
121
  "_PolarityD1075", "_PolarityD1100",
122
  "_PolarityD2001", "_PolarityD2025", "_PolarityD2050",
123
  "_PolarityD2075", "_PolarityD2100",
124
  "_PolarityD3001", "_PolarityD3025", "_PolarityD3050",
125
  "_PolarityD3075", "_PolarityD3100",
126
-
127
  "_NormalizedVDWVD1001", "_NormalizedVDWVD1025",
128
  "_NormalizedVDWVD1050", "_NormalizedVDWVD1075",
129
  "_NormalizedVDWVD1100",
@@ -133,7 +108,6 @@ selected_features = [
133
  "_NormalizedVDWVD3001", "_NormalizedVDWVD3025",
134
  "_NormalizedVDWVD3050", "_NormalizedVDWVD3075",
135
  "_NormalizedVDWVD3100",
136
-
137
  "_HydrophobicityD1001", "_HydrophobicityD1025",
138
  "_HydrophobicityD1050", "_HydrophobicityD1075",
139
  "_HydrophobicityD1100",
@@ -143,10 +117,8 @@ selected_features = [
143
  "_HydrophobicityD3001", "_HydrophobicityD3025",
144
  "_HydrophobicityD3050", "_HydrophobicityD3075",
145
  "_HydrophobicityD3100",
146
-
147
  "A", "R", "N", "D", "C", "E", "Q", "G", "H", "I",
148
  "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V",
149
-
150
  "AR", "AD", "AQ", "AG", "AL", "AK", "AF", "AP", "AT", "AV",
151
  "RA", "RC", "RE", "RG", "RI", "RL", "RS", "RT", "RV",
152
  "NR", "NC", "NG", "NI", "NP", "NS", "NY", "NV",
@@ -171,21 +143,21 @@ selected_features = [
171
  "VA", "VR", "VD", "VC", "VE", "VG", "VI", "VL", "VK",
172
  "VS", "VT", "VY", "VV"
173
  ]
 
174
 
175
 
176
  def keras_predict_proba(X):
177
- """Return probabilities as [P(Non-AMP), P(AMP)] for LIME."""
178
  amp_model, _ = get_amp_model()
179
  preds = amp_model.predict(X, verbose=0)
180
  if preds.ndim == 1 or preds.shape[1] == 1:
181
  preds = preds.reshape(-1, 1)
182
- # Assuming sigmoid output = P(AMP); adjust if your model is reversed.
183
- return np.hstack([1 - preds, preds])
184
  return preds
185
 
186
 
187
  def extract_features(sequence):
188
- """Compute AAC (420) + CTD features, scale, and select RFE features."""
189
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
190
  if len(sequence) < 10:
191
  return "Error: Sequence too short."
@@ -193,33 +165,31 @@ def extract_features(sequence):
193
  try:
194
  _, amp_scaler = get_amp_model()
195
 
196
- # AAC: 20 single AAs + 400 dipeptides = 420 features
197
- dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
198
- filtered_aac = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
199
-
200
- # CTD: Composition, Transition, Distribution
201
  ctd_features = CTD.CalculateCTD(sequence)
 
202
 
203
- all_features_dict = {}
204
- all_features_dict.update(ctd_features)
205
- all_features_dict.update(filtered_aac)
 
206
 
207
- feature_df_all = pd.DataFrame([all_features_dict])
208
- normalized_array = amp_scaler.transform(feature_df_all.values)
209
- normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
 
210
 
211
- if not set(selected_features).issubset(normalized_df.columns):
212
- missing = set(selected_features) - set(normalized_df.columns)
213
- return f"Error: Missing features: {list(missing)[:5]}..."
214
 
215
- selected_df = normalized_df[selected_features].fillna(0)
216
- return selected_df.values.astype(np.float32)
217
  except Exception as e:
218
  return f"Error in feature extraction: {str(e)}"
219
 
220
 
221
  def predictmic(sequence):
222
- """Predict MIC values using ProtBert embeddings + per-bacterium models."""
223
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
224
  if len(sequence) < 10:
225
  return {"Error": "Sequence too short or invalid."}
@@ -266,21 +236,20 @@ def full_prediction(sequence):
266
  amp_model, _ = get_amp_model()
267
  raw_pred = amp_model.predict(features, verbose=0)
268
 
269
- # Handle sigmoid (1 output) vs softmax (>=2 outputs)
270
  if raw_pred.ndim == 1 or raw_pred.shape[1] == 1:
271
- prob_amp = float(raw_pred.flatten()[0]) # assume output = P(AMP)
272
  if prob_amp >= 0.5:
273
- prediction = 1 # AMP
274
  confidence = round(prob_amp * 100, 2)
275
  else:
276
- prediction = 0 # Non-AMP
277
  confidence = round((1 - prob_amp) * 100, 2)
278
  else:
279
  class_idx = int(np.argmax(raw_pred[0]))
280
  prediction = class_idx
281
  confidence = round(float(raw_pred[0][class_idx]) * 100, 2)
282
 
283
- # Label convention: 1 = AMP, 0 = Non-AMP (swap if your model is opposite)
284
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 1 else "Non-AMP"
285
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
286
 
@@ -292,7 +261,6 @@ def full_prediction(sequence):
292
  else:
293
  result += "\nMIC prediction skipped for Non-AMP sequences.\n"
294
 
295
- # LIME explanation (lazy import keeps startup light)
296
  try:
297
  from lime.lime_tabular import LimeTabularExplainer
298
  sample_data = np.random.rand(100, len(selected_features))
@@ -316,7 +284,6 @@ def full_prediction(sequence):
316
  return result
317
 
318
 
319
- # Gradio UI
320
  iface = gr.Interface(
321
  fn=full_prediction,
322
  inputs=gr.Textbox(label="Enter Protein Sequence"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ # Quiet TensorFlow logs (must be set before importing tensorflow)
3
  os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
4
  os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
5
 
 
11
  from math import expm1
12
 
13
  # ---------------------------------------------------------------------------
14
+ # LAZY LOADING — keeps the free 16GB Space from OOM-ing at startup.
15
+ # Heavy libs (TF, torch, ProtBert) load only when first needed.
 
 
16
  # ---------------------------------------------------------------------------
17
+ _amp_model = None
18
+ _amp_scaler = None
 
19
  _protbert_tokenizer = None
20
  _protbert_model = None
21
+ _torch = None
22
  _device = None
23
 
24
 
25
  def get_amp_model():
 
26
  global _amp_model, _amp_scaler
27
  if _amp_model is None:
28
  from tensorflow.keras.models import load_model
 
32
 
33
 
34
  def get_protbert():
 
35
  global _protbert_tokenizer, _protbert_model, _torch, _device
36
  if _protbert_model is None:
37
  import torch
 
47
 
48
 
49
  # ---------------------------------------------------------------------------
50
+ # The EXACT 343 features the scaler was fit on, IN THE EXACT TRAINING ORDER.
51
+ # The scaler was fit on a numpy array (no stored names), so order is critical:
52
+ # we must select these columns in this order BEFORE calling scaler.transform().
53
  # ---------------------------------------------------------------------------
54
  selected_features = [
55
  "_PolarizabilityC1", "_PolarizabilityC2", "_PolarizabilityC3",
 
59
  "_PolarityC1", "_PolarityC2", "_PolarityC3",
60
  "_NormalizedVDWVC1", "_NormalizedVDWVC2", "_NormalizedVDWVC3",
61
  "_HydrophobicityC1", "_HydrophobicityC2", "_HydrophobicityC3",
 
62
  "_PolarizabilityT12", "_PolarizabilityT13", "_PolarizabilityT23",
63
  "_SolventAccessibilityT12", "_SolventAccessibilityT13", "_SolventAccessibilityT23",
64
  "_SecondaryStrT12", "_SecondaryStrT13", "_SecondaryStrT23",
 
66
  "_PolarityT12", "_PolarityT13", "_PolarityT23",
67
  "_NormalizedVDWVT12", "_NormalizedVDWVT13", "_NormalizedVDWVT23",
68
  "_HydrophobicityT12", "_HydrophobicityT13", "_HydrophobicityT23",
 
69
  "_PolarizabilityD1001", "_PolarizabilityD1025", "_PolarizabilityD1050",
70
  "_PolarizabilityD1075", "_PolarizabilityD1100",
71
  "_PolarizabilityD2001", "_PolarizabilityD2025", "_PolarizabilityD2050",
72
  "_PolarizabilityD2075", "_PolarizabilityD2100",
73
  "_PolarizabilityD3001", "_PolarizabilityD3025", "_PolarizabilityD3050",
74
  "_PolarizabilityD3075", "_PolarizabilityD3100",
 
75
  "_SolventAccessibilityD1001", "_SolventAccessibilityD1025",
76
  "_SolventAccessibilityD1050", "_SolventAccessibilityD1075",
77
  "_SolventAccessibilityD1100",
 
81
  "_SolventAccessibilityD3001", "_SolventAccessibilityD3025",
82
  "_SolventAccessibilityD3050", "_SolventAccessibilityD3075",
83
  "_SolventAccessibilityD3100",
 
84
  "_SecondaryStrD1001", "_SecondaryStrD1025", "_SecondaryStrD1050",
85
  "_SecondaryStrD1075", "_SecondaryStrD1100",
86
  "_SecondaryStrD2001", "_SecondaryStrD2025", "_SecondaryStrD2050",
87
  "_SecondaryStrD2075", "_SecondaryStrD2100",
88
  "_SecondaryStrD3001", "_SecondaryStrD3025", "_SecondaryStrD3050",
89
  "_SecondaryStrD3075", "_SecondaryStrD3100",
 
90
  "_ChargeD1001", "_ChargeD1025", "_ChargeD1050",
91
  "_ChargeD1075", "_ChargeD1100",
92
  "_ChargeD2001", "_ChargeD2025", "_ChargeD2050",
93
  "_ChargeD2075",
94
  "_ChargeD3001", "_ChargeD3025", "_ChargeD3050",
95
  "_ChargeD3075", "_ChargeD3100",
 
96
  "_PolarityD1001", "_PolarityD1025", "_PolarityD1050",
97
  "_PolarityD1075", "_PolarityD1100",
98
  "_PolarityD2001", "_PolarityD2025", "_PolarityD2050",
99
  "_PolarityD2075", "_PolarityD2100",
100
  "_PolarityD3001", "_PolarityD3025", "_PolarityD3050",
101
  "_PolarityD3075", "_PolarityD3100",
 
102
  "_NormalizedVDWVD1001", "_NormalizedVDWVD1025",
103
  "_NormalizedVDWVD1050", "_NormalizedVDWVD1075",
104
  "_NormalizedVDWVD1100",
 
108
  "_NormalizedVDWVD3001", "_NormalizedVDWVD3025",
109
  "_NormalizedVDWVD3050", "_NormalizedVDWVD3075",
110
  "_NormalizedVDWVD3100",
 
111
  "_HydrophobicityD1001", "_HydrophobicityD1025",
112
  "_HydrophobicityD1050", "_HydrophobicityD1075",
113
  "_HydrophobicityD1100",
 
117
  "_HydrophobicityD3001", "_HydrophobicityD3025",
118
  "_HydrophobicityD3050", "_HydrophobicityD3075",
119
  "_HydrophobicityD3100",
 
120
  "A", "R", "N", "D", "C", "E", "Q", "G", "H", "I",
121
  "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V",
 
122
  "AR", "AD", "AQ", "AG", "AL", "AK", "AF", "AP", "AT", "AV",
123
  "RA", "RC", "RE", "RG", "RI", "RL", "RS", "RT", "RV",
124
  "NR", "NC", "NG", "NI", "NP", "NS", "NY", "NV",
 
143
  "VA", "VR", "VD", "VC", "VE", "VG", "VI", "VL", "VK",
144
  "VS", "VT", "VY", "VV"
145
  ]
146
+ assert len(selected_features) == 343, f"Expected 343 features, got {len(selected_features)}"
147
 
148
 
149
  def keras_predict_proba(X):
150
+ """Return probabilities as [P(Non-AMP), P(AMP)] for LIME (X already scaled)."""
151
  amp_model, _ = get_amp_model()
152
  preds = amp_model.predict(X, verbose=0)
153
  if preds.ndim == 1 or preds.shape[1] == 1:
154
  preds = preds.reshape(-1, 1)
155
+ return np.hstack([1 - preds, preds]) # sigmoid output assumed = P(AMP)
 
156
  return preds
157
 
158
 
159
  def extract_features(sequence):
160
+ """Compute CTD + AAC, select the 343 training columns IN ORDER, then scale."""
161
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
162
  if len(sequence) < 10:
163
  return "Error: Sequence too short."
 
165
  try:
166
  _, amp_scaler = get_amp_model()
167
 
168
+ # Compute full feature pool
 
 
 
 
169
  ctd_features = CTD.CalculateCTD(sequence)
170
+ aac = AAComposition.CalculateAADipeptideComposition(sequence)
171
 
172
+ # Merge everything into one lookup dict
173
+ pool = {}
174
+ pool.update(ctd_features)
175
+ pool.update(aac)
176
 
177
+ # Verify all needed features are present
178
+ missing = [f for f in selected_features if f not in pool]
179
+ if missing:
180
+ return f"Error: Missing features from propy: {missing[:5]}..."
181
 
182
+ # Build the 343-wide row IN THE EXACT TRAINING ORDER, THEN scale.
183
+ ordered_values = [pool[f] for f in selected_features]
184
+ feature_row = np.array(ordered_values, dtype=np.float64).reshape(1, -1)
185
 
186
+ scaled = amp_scaler.transform(feature_row) # scaler expects exactly 343 cols
187
+ return scaled.astype(np.float32)
188
  except Exception as e:
189
  return f"Error in feature extraction: {str(e)}"
190
 
191
 
192
  def predictmic(sequence):
 
193
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
194
  if len(sequence) < 10:
195
  return {"Error": "Sequence too short or invalid."}
 
236
  amp_model, _ = get_amp_model()
237
  raw_pred = amp_model.predict(features, verbose=0)
238
 
 
239
  if raw_pred.ndim == 1 or raw_pred.shape[1] == 1:
240
+ prob_amp = float(raw_pred.flatten()[0]) # sigmoid output assumed = P(AMP)
241
  if prob_amp >= 0.5:
242
+ prediction = 1
243
  confidence = round(prob_amp * 100, 2)
244
  else:
245
+ prediction = 0
246
  confidence = round((1 - prob_amp) * 100, 2)
247
  else:
248
  class_idx = int(np.argmax(raw_pred[0]))
249
  prediction = class_idx
250
  confidence = round(float(raw_pred[0][class_idx]) * 100, 2)
251
 
252
+ # Label convention: 1 = AMP, 0 = Non-AMP (swap if your model is reversed)
253
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 1 else "Non-AMP"
254
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
255
 
 
261
  else:
262
  result += "\nMIC prediction skipped for Non-AMP sequences.\n"
263
 
 
264
  try:
265
  from lime.lime_tabular import LimeTabularExplainer
266
  sample_data = np.random.rand(100, len(selected_features))
 
284
  return result
285
 
286
 
 
287
  iface = gr.Interface(
288
  fn=full_prediction,
289
  inputs=gr.Textbox(label="Enter Protein Sequence"),