nonzeroexit commited on
Commit
12675f2
·
verified ·
1 Parent(s): 769f73d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -57
app.py CHANGED
@@ -15,18 +15,18 @@ import joblib
15
  import numpy as np
16
  import pandas as pd
17
  from propy import AAComposition, CTD
18
- from math import expm1
 
 
19
 
20
  # ---------------------------------------------------------------------------
21
  # LAZY LOADING — keeps the free 16GB Space from OOM-ing at startup.
22
- # Heavy libs (TF, torch, ProtBert) load only when first needed.
 
 
23
  # ---------------------------------------------------------------------------
24
  _amp_model = None
25
  _amp_scaler = None
26
- _protbert_tokenizer = None
27
- _protbert_model = None
28
- _torch = None
29
- _device = None
30
 
31
 
32
  def get_amp_model():
@@ -38,25 +38,6 @@ def get_amp_model():
38
  return _amp_model, _amp_scaler
39
 
40
 
41
- def get_protbert():
42
- global _protbert_tokenizer, _protbert_model, _torch, _device
43
- if _protbert_model is None:
44
- import torch
45
- from transformers import BertTokenizer, BertModel
46
- try:
47
- torch.set_num_threads(1) # reduce native threading conflicts with TF
48
- except Exception:
49
- pass
50
- _torch = torch
51
- _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
- _protbert_tokenizer = BertTokenizer.from_pretrained(
53
- "Rostlab/prot_bert", do_lower_case=False
54
- )
55
- _protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
56
- _protbert_model = _protbert_model.to(_device).eval()
57
- return _protbert_tokenizer, _protbert_model, _torch, _device
58
-
59
-
60
  # ---------------------------------------------------------------------------
61
  # The EXACT 343 features the scaler was fit on, IN THE EXACT TRAINING ORDER.
62
  # The scaler was fit on a numpy array (no stored names), so order is critical:
@@ -201,42 +182,43 @@ def extract_features(sequence):
201
 
202
 
203
  def predictmic(sequence):
 
 
 
 
 
 
204
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
205
  if len(sequence) < 10:
206
  return {"Error": "Sequence too short or invalid."}
207
 
208
- tokenizer, protbert_model, torch, device = get_protbert()
209
-
210
- seq_spaced = ' '.join(list(sequence))
211
- tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length',
212
- truncation=True, max_length=512)
213
- tokens = {k: v.to(device) for k, v in tokens.items()}
214
-
215
- with torch.no_grad():
216
- outputs = protbert_model(**tokens)
217
- embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
218
-
219
- bacteria_config = {
220
- "E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
221
- "S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
222
- "P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
223
- "K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
224
- }
225
-
226
- mic_results = {}
227
- for bacterium, cfg in bacteria_config.items():
228
- try:
229
- mic_scaler = joblib.load(cfg["scaler"])
230
- scaled = mic_scaler.transform(embedding)
231
- transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
232
- mic_model = joblib.load(cfg["model"])
233
- mic_log = mic_model.predict(transformed)[0]
234
- mic = round(expm1(mic_log), 3)
235
- mic_results[bacterium] = mic
236
- except Exception as e:
237
- mic_results[bacterium] = f"Error: {str(e)}"
238
-
239
- return mic_results
240
 
241
 
242
  def full_prediction(sequence):
 
15
  import numpy as np
16
  import pandas as pd
17
  from propy import AAComposition, CTD
18
+ import sys
19
+ import json
20
+ import subprocess
21
 
22
  # ---------------------------------------------------------------------------
23
  # LAZY LOADING — keeps the free 16GB Space from OOM-ing at startup.
24
+ # Only the TensorFlow AMP model is loaded in THIS process. ProtBert/PyTorch
25
+ # run in a SEPARATE process (mic_worker.py) to avoid a native-library clash
26
+ # between TensorFlow and PyTorch that caused SIGSEGV (exit 139).
27
  # ---------------------------------------------------------------------------
28
  _amp_model = None
29
  _amp_scaler = None
 
 
 
 
30
 
31
 
32
  def get_amp_model():
 
38
  return _amp_model, _amp_scaler
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # ---------------------------------------------------------------------------
42
  # The EXACT 343 features the scaler was fit on, IN THE EXACT TRAINING ORDER.
43
  # The scaler was fit on a numpy array (no stored names), so order is critical:
 
182
 
183
 
184
  def predictmic(sequence):
185
+ """Run MIC prediction in a SEPARATE process (mic_worker.py).
186
+
187
+ This isolates PyTorch/ProtBert from TensorFlow, preventing the native-library
188
+ crash (exit 139) and keeping peak memory low. The worker prints a JSON dict on
189
+ its last stdout line; we parse and return it.
190
+ """
191
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
192
  if len(sequence) < 10:
193
  return {"Error": "Sequence too short or invalid."}
194
 
195
+ try:
196
+ # First run downloads ProtBert (~1.6GB), so allow a generous timeout.
197
+ proc = subprocess.run(
198
+ [sys.executable, "mic_worker.py", sequence],
199
+ capture_output=True,
200
+ text=True,
201
+ timeout=900 # 15 minutes; mostly for the one-time model download
202
+ )
203
+ except subprocess.TimeoutExpired:
204
+ return {"Error": "MIC prediction timed out (model download may still be in progress; try again shortly)."}
205
+ except Exception as e:
206
+ return {"Error": f"Failed to start MIC worker: {str(e)}"}
207
+
208
+ if proc.returncode != 0:
209
+ # Worker crashed; surface stderr tail for debugging
210
+ tail = (proc.stderr or "").strip().splitlines()[-3:]
211
+ return {"Error": f"MIC worker exited with code {proc.returncode}. {' '.join(tail)}"}
212
+
213
+ # Parse the last non-empty stdout line as JSON
214
+ out_lines = [ln for ln in (proc.stdout or "").splitlines() if ln.strip()]
215
+ if not out_lines:
216
+ return {"Error": "MIC worker produced no output."}
217
+
218
+ try:
219
+ return json.loads(out_lines[-1])
220
+ except Exception:
221
+ return {"Error": f"Could not parse MIC worker output: {out_lines[-1][:200]}"}
 
 
 
 
 
222
 
223
 
224
  def full_prediction(sequence):