Spaces:
Running
Running
| """ | |
| MIC prediction worker — runs in a SEPARATE process from the main app. | |
| Why: the main app loads TensorFlow (for the AMP Keras model). Loading PyTorch + | |
| ProtBert into that same process causes a native-library (OpenMP/MKL) clash and a | |
| SIGSEGV (exit 139), plus a large memory spike. By running ProtBert here in its own | |
| short-lived process, TensorFlow and PyTorch never coexist. This process loads | |
| torch, computes the MIC values, prints them as JSON to stdout, and exits — which | |
| frees all of its memory. | |
| Usage: | |
| python mic_worker.py "<AMINO_ACID_SEQUENCE>" | |
| Output (stdout, last line): | |
| {"E.coli": 12.3, "S.aureus": 4.5, ...} on success | |
| {"error": "..."} on failure | |
| """ | |
| import os | |
| # Keep threading modest to limit memory/CPU on the free tier. | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| os.environ.setdefault("MKL_NUM_THREADS", "1") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| # Cache ProtBert in a stable location so it is downloaded only ONCE and reused | |
| # on every subsequent prediction (instead of re-downloading each run). | |
| os.environ.setdefault("HF_HOME", "/app/.cache/huggingface") | |
| import sys | |
| import json | |
| from math import expm1 | |
| import joblib | |
| import numpy as np | |
| def main(): | |
| if len(sys.argv) < 2: | |
| print(json.dumps({"error": "No sequence provided"})) | |
| return | |
| sequence = sys.argv[1] | |
| sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"]) | |
| if len(sequence) < 10: | |
| print(json.dumps({"error": "Sequence too short or invalid."})) | |
| return | |
| try: | |
| import torch | |
| from transformers import BertTokenizer, BertModel | |
| torch.set_num_threads(1) | |
| device = torch.device("cpu") # free tier has no GPU | |
| tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) | |
| protbert = BertModel.from_pretrained("Rostlab/prot_bert").to(device).eval() | |
| seq_spaced = ' '.join(list(sequence)) | |
| tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', | |
| truncation=True, max_length=512) | |
| tokens = {k: v.to(device) for k, v in tokens.items()} | |
| with torch.no_grad(): | |
| outputs = protbert(**tokens) | |
| embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1) | |
| except Exception as e: | |
| print(json.dumps({"error": f"ProtBert embedding failed: {str(e)}"})) | |
| return | |
| bacteria_config = { | |
| "E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None}, | |
| "S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None}, | |
| "P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None}, | |
| "K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"} | |
| } | |
| mic_results = {} | |
| for bacterium, cfg in bacteria_config.items(): | |
| try: | |
| mic_scaler = joblib.load(cfg["scaler"]) | |
| scaled = mic_scaler.transform(embedding) | |
| transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled | |
| mic_model = joblib.load(cfg["model"]) | |
| mic_log = mic_model.predict(transformed)[0] | |
| mic_results[bacterium] = round(expm1(float(mic_log)), 3) | |
| except Exception as e: | |
| mic_results[bacterium] = f"Error: {str(e)}" | |
| # Final line of stdout = JSON result | |
| print(json.dumps(mic_results)) | |
| if __name__ == "__main__": | |
| main() |