nonzeroexit commited on
Commit
3deecbd
·
verified ·
1 Parent(s): 5aeb6c5

Update mic_worker.py

Browse files
Files changed (1) hide show
  1. mic_worker.py +0 -90
mic_worker.py CHANGED
@@ -1,90 +0,0 @@
1
- """
2
- MIC prediction worker — runs in a SEPARATE process from the main app.
3
-
4
- Why: the main app loads TensorFlow (for the AMP Keras model). Loading PyTorch +
5
- ProtBert into that same process causes a native-library (OpenMP/MKL) clash and a
6
- SIGSEGV (exit 139), plus a large memory spike. By running ProtBert here in its own
7
- short-lived process, TensorFlow and PyTorch never coexist. This process loads
8
- torch, computes the MIC values, prints them as JSON to stdout, and exits — which
9
- frees all of its memory.
10
-
11
- Usage:
12
- python mic_worker.py "<AMINO_ACID_SEQUENCE>"
13
- Output (stdout, last line):
14
- {"E.coli": 12.3, "S.aureus": 4.5, ...} on success
15
- {"error": "..."} on failure
16
- """
17
- import os
18
- # Keep threading modest to limit memory/CPU on the free tier.
19
- os.environ.setdefault("OMP_NUM_THREADS", "1")
20
- os.environ.setdefault("MKL_NUM_THREADS", "1")
21
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
22
- # Cache ProtBert in a stable location so it is downloaded only ONCE and reused
23
- # on every subsequent prediction (instead of re-downloading each run).
24
- os.environ.setdefault("HF_HOME", "/app/.cache/huggingface")
25
-
26
- import sys
27
- import json
28
- from math import expm1
29
-
30
- import joblib
31
- import numpy as np
32
-
33
-
34
- def main():
35
- if len(sys.argv) < 2:
36
- print(json.dumps({"error": "No sequence provided"}))
37
- return
38
-
39
- sequence = sys.argv[1]
40
- sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
41
- if len(sequence) < 10:
42
- print(json.dumps({"error": "Sequence too short or invalid."}))
43
- return
44
-
45
- try:
46
- import torch
47
- from transformers import BertTokenizer, BertModel
48
- torch.set_num_threads(1)
49
-
50
- device = torch.device("cpu") # free tier has no GPU
51
- tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
52
- protbert = BertModel.from_pretrained("Rostlab/prot_bert").to(device).eval()
53
-
54
- seq_spaced = ' '.join(list(sequence))
55
- tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length',
56
- truncation=True, max_length=512)
57
- tokens = {k: v.to(device) for k, v in tokens.items()}
58
-
59
- with torch.no_grad():
60
- outputs = protbert(**tokens)
61
- embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
62
- except Exception as e:
63
- print(json.dumps({"error": f"ProtBert embedding failed: {str(e)}"}))
64
- return
65
-
66
- bacteria_config = {
67
- "E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
68
- "S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
69
- "P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
70
- "K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
71
- }
72
-
73
- mic_results = {}
74
- for bacterium, cfg in bacteria_config.items():
75
- try:
76
- mic_scaler = joblib.load(cfg["scaler"])
77
- scaled = mic_scaler.transform(embedding)
78
- transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
79
- mic_model = joblib.load(cfg["model"])
80
- mic_log = mic_model.predict(transformed)[0]
81
- mic_results[bacterium] = round(expm1(float(mic_log)), 3)
82
- except Exception as e:
83
- mic_results[bacterium] = f"Error: {str(e)}"
84
-
85
- # Final line of stdout = JSON result
86
- print(json.dumps(mic_results))
87
-
88
-
89
- if __name__ == "__main__":
90
- main()