Spaces:
Running
Running
File size: 3,562 Bytes
e2b9a0c 87f0d31 e2b9a0c 87f0d31 e2b9a0c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """
MIC prediction worker — runs in a SEPARATE process from the main app.
Why: the main app loads TensorFlow (for the AMP Keras model). Loading PyTorch +
ProtBert into that same process causes a native-library (OpenMP/MKL) clash and a
SIGSEGV (exit 139), plus a large memory spike. By running ProtBert here in its own
short-lived process, TensorFlow and PyTorch never coexist. This process loads
torch, computes the MIC values, prints them as JSON to stdout, and exits — which
frees all of its memory.
Usage:
python mic_worker.py "<AMINO_ACID_SEQUENCE>"
Output (stdout, last line):
{"E.coli": 12.3, "S.aureus": 4.5, ...} on success
{"error": "..."} on failure
"""
import os
# Keep threading modest to limit memory/CPU on the free tier.
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
# Cache ProtBert in a stable location so it is downloaded only ONCE and reused
# on every subsequent prediction (instead of re-downloading each run).
os.environ.setdefault("HF_HOME", "/app/.cache/huggingface")
import sys
import json
from math import expm1
import joblib
import numpy as np
def main():
if len(sys.argv) < 2:
print(json.dumps({"error": "No sequence provided"}))
return
sequence = sys.argv[1]
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
print(json.dumps({"error": "Sequence too short or invalid."}))
return
try:
import torch
from transformers import BertTokenizer, BertModel
torch.set_num_threads(1)
device = torch.device("cpu") # free tier has no GPU
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
protbert = BertModel.from_pretrained("Rostlab/prot_bert").to(device).eval()
seq_spaced = ' '.join(list(sequence))
tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length',
truncation=True, max_length=512)
tokens = {k: v.to(device) for k, v in tokens.items()}
with torch.no_grad():
outputs = protbert(**tokens)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
except Exception as e:
print(json.dumps({"error": f"ProtBert embedding failed: {str(e)}"}))
return
bacteria_config = {
"E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
"S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
"P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
"K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
}
mic_results = {}
for bacterium, cfg in bacteria_config.items():
try:
mic_scaler = joblib.load(cfg["scaler"])
scaled = mic_scaler.transform(embedding)
transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
mic_model = joblib.load(cfg["model"])
mic_log = mic_model.predict(transformed)[0]
mic_results[bacterium] = round(expm1(float(mic_log)), 3)
except Exception as e:
mic_results[bacterium] = f"Error: {str(e)}"
# Final line of stdout = JSON result
print(json.dumps(mic_results))
if __name__ == "__main__":
main() |