File size: 3,562 Bytes
e2b9a0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87f0d31
 
 
e2b9a0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87f0d31
e2b9a0c
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
MIC prediction worker — runs in a SEPARATE process from the main app.

Why: the main app loads TensorFlow (for the AMP Keras model). Loading PyTorch +
ProtBert into that same process causes a native-library (OpenMP/MKL) clash and a
SIGSEGV (exit 139), plus a large memory spike. By running ProtBert here in its own
short-lived process, TensorFlow and PyTorch never coexist. This process loads
torch, computes the MIC values, prints them as JSON to stdout, and exits — which
frees all of its memory.

Usage:
    python mic_worker.py "<AMINO_ACID_SEQUENCE>"
Output (stdout, last line):
    {"E.coli": 12.3, "S.aureus": 4.5, ...}   on success
    {"error": "..."}                          on failure
"""
import os
# Keep threading modest to limit memory/CPU on the free tier.
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
# Cache ProtBert in a stable location so it is downloaded only ONCE and reused
# on every subsequent prediction (instead of re-downloading each run).
os.environ.setdefault("HF_HOME", "/app/.cache/huggingface")

import sys
import json
from math import expm1

import joblib
import numpy as np


def main():
    if len(sys.argv) < 2:
        print(json.dumps({"error": "No sequence provided"}))
        return

    sequence = sys.argv[1]
    sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
    if len(sequence) < 10:
        print(json.dumps({"error": "Sequence too short or invalid."}))
        return

    try:
        import torch
        from transformers import BertTokenizer, BertModel
        torch.set_num_threads(1)

        device = torch.device("cpu")  # free tier has no GPU
        tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
        protbert = BertModel.from_pretrained("Rostlab/prot_bert").to(device).eval()

        seq_spaced = ' '.join(list(sequence))
        tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length',
                           truncation=True, max_length=512)
        tokens = {k: v.to(device) for k, v in tokens.items()}

        with torch.no_grad():
            outputs = protbert(**tokens)
            embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
    except Exception as e:
        print(json.dumps({"error": f"ProtBert embedding failed: {str(e)}"}))
        return

    bacteria_config = {
        "E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
        "S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
        "P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
        "K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
    }

    mic_results = {}
    for bacterium, cfg in bacteria_config.items():
        try:
            mic_scaler = joblib.load(cfg["scaler"])
            scaled = mic_scaler.transform(embedding)
            transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
            mic_model = joblib.load(cfg["model"])
            mic_log = mic_model.predict(transformed)[0]
            mic_results[bacterium] = round(expm1(float(mic_log)), 3)
        except Exception as e:
            mic_results[bacterium] = f"Error: {str(e)}"

    # Final line of stdout = JSON result
    print(json.dumps(mic_results))


if __name__ == "__main__":
    main()