Spaces:
Sleeping
Sleeping
| import os | |
| # --- Prevent SIGSEGV (exit 139) from TensorFlow + PyTorch native lib clashes --- | |
| # TF and torch each bundle their own OpenMP/MKL; loaded together they can collide | |
| # and crash at the C level. These settings make them coexist and reduce memory. | |
| os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE") | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| os.environ.setdefault("MKL_NUM_THREADS", "1") | |
| os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") | |
| # Quiet TensorFlow logs (must be set before importing tensorflow) | |
| os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") | |
| os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0") | |
| import gradio as gr | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| from propy import AAComposition, CTD | |
| from math import expm1 | |
| # --------------------------------------------------------------------------- | |
| # LAZY LOADING — keeps the free 16GB Space from OOM-ing at startup. | |
| # Heavy libs (TF, torch, ProtBert) load only when first needed. | |
| # --------------------------------------------------------------------------- | |
| _amp_model = None | |
| _amp_scaler = None | |
| _protbert_tokenizer = None | |
| _protbert_model = None | |
| _torch = None | |
| _device = None | |
| def get_amp_model(): | |
| global _amp_model, _amp_scaler | |
| if _amp_model is None: | |
| from tensorflow.keras.models import load_model | |
| _amp_model = load_model("Comb1_aac_ctd_RFE_selected_features_model.keras") | |
| _amp_scaler = joblib.load("Comb1_aac_ctd_RFE_selected_features_scaler.joblib") | |
| return _amp_model, _amp_scaler | |
| def get_protbert(): | |
| global _protbert_tokenizer, _protbert_model, _torch, _device | |
| if _protbert_model is None: | |
| import torch | |
| from transformers import BertTokenizer, BertModel | |
| try: | |
| torch.set_num_threads(1) # reduce native threading conflicts with TF | |
| except Exception: | |
| pass | |
| _torch = torch | |
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _protbert_tokenizer = BertTokenizer.from_pretrained( | |
| "Rostlab/prot_bert", do_lower_case=False | |
| ) | |
| _protbert_model = BertModel.from_pretrained("Rostlab/prot_bert") | |
| _protbert_model = _protbert_model.to(_device).eval() | |
| return _protbert_tokenizer, _protbert_model, _torch, _device | |
| # --------------------------------------------------------------------------- | |
| # The EXACT 343 features the scaler was fit on, IN THE EXACT TRAINING ORDER. | |
| # The scaler was fit on a numpy array (no stored names), so order is critical: | |
| # we must select these columns in this order BEFORE calling scaler.transform(). | |
| # --------------------------------------------------------------------------- | |
| selected_features = [ | |
| "_PolarizabilityC1", "_PolarizabilityC2", "_PolarizabilityC3", | |
| "_SolventAccessibilityC1", "_SolventAccessibilityC2", "_SolventAccessibilityC3", | |
| "_SecondaryStrC1", "_SecondaryStrC2", "_SecondaryStrC3", | |
| "_ChargeC1", "_ChargeC2", "_ChargeC3", | |
| "_PolarityC1", "_PolarityC2", "_PolarityC3", | |
| "_NormalizedVDWVC1", "_NormalizedVDWVC2", "_NormalizedVDWVC3", | |
| "_HydrophobicityC1", "_HydrophobicityC2", "_HydrophobicityC3", | |
| "_PolarizabilityT12", "_PolarizabilityT13", "_PolarizabilityT23", | |
| "_SolventAccessibilityT12", "_SolventAccessibilityT13", "_SolventAccessibilityT23", | |
| "_SecondaryStrT12", "_SecondaryStrT13", "_SecondaryStrT23", | |
| "_ChargeT12", "_ChargeT13", "_ChargeT23", | |
| "_PolarityT12", "_PolarityT13", "_PolarityT23", | |
| "_NormalizedVDWVT12", "_NormalizedVDWVT13", "_NormalizedVDWVT23", | |
| "_HydrophobicityT12", "_HydrophobicityT13", "_HydrophobicityT23", | |
| "_PolarizabilityD1001", "_PolarizabilityD1025", "_PolarizabilityD1050", | |
| "_PolarizabilityD1075", "_PolarizabilityD1100", | |
| "_PolarizabilityD2001", "_PolarizabilityD2025", "_PolarizabilityD2050", | |
| "_PolarizabilityD2075", "_PolarizabilityD2100", | |
| "_PolarizabilityD3001", "_PolarizabilityD3025", "_PolarizabilityD3050", | |
| "_PolarizabilityD3075", "_PolarizabilityD3100", | |
| "_SolventAccessibilityD1001", "_SolventAccessibilityD1025", | |
| "_SolventAccessibilityD1050", "_SolventAccessibilityD1075", | |
| "_SolventAccessibilityD1100", | |
| "_SolventAccessibilityD2001", "_SolventAccessibilityD2025", | |
| "_SolventAccessibilityD2050", "_SolventAccessibilityD2075", | |
| "_SolventAccessibilityD2100", | |
| "_SolventAccessibilityD3001", "_SolventAccessibilityD3025", | |
| "_SolventAccessibilityD3050", "_SolventAccessibilityD3075", | |
| "_SolventAccessibilityD3100", | |
| "_SecondaryStrD1001", "_SecondaryStrD1025", "_SecondaryStrD1050", | |
| "_SecondaryStrD1075", "_SecondaryStrD1100", | |
| "_SecondaryStrD2001", "_SecondaryStrD2025", "_SecondaryStrD2050", | |
| "_SecondaryStrD2075", "_SecondaryStrD2100", | |
| "_SecondaryStrD3001", "_SecondaryStrD3025", "_SecondaryStrD3050", | |
| "_SecondaryStrD3075", "_SecondaryStrD3100", | |
| "_ChargeD1001", "_ChargeD1025", "_ChargeD1050", | |
| "_ChargeD1075", "_ChargeD1100", | |
| "_ChargeD2001", "_ChargeD2025", "_ChargeD2050", | |
| "_ChargeD2075", | |
| "_ChargeD3001", "_ChargeD3025", "_ChargeD3050", | |
| "_ChargeD3075", "_ChargeD3100", | |
| "_PolarityD1001", "_PolarityD1025", "_PolarityD1050", | |
| "_PolarityD1075", "_PolarityD1100", | |
| "_PolarityD2001", "_PolarityD2025", "_PolarityD2050", | |
| "_PolarityD2075", "_PolarityD2100", | |
| "_PolarityD3001", "_PolarityD3025", "_PolarityD3050", | |
| "_PolarityD3075", "_PolarityD3100", | |
| "_NormalizedVDWVD1001", "_NormalizedVDWVD1025", | |
| "_NormalizedVDWVD1050", "_NormalizedVDWVD1075", | |
| "_NormalizedVDWVD1100", | |
| "_NormalizedVDWVD2001", "_NormalizedVDWVD2025", | |
| "_NormalizedVDWVD2050", "_NormalizedVDWVD2075", | |
| "_NormalizedVDWVD2100", | |
| "_NormalizedVDWVD3001", "_NormalizedVDWVD3025", | |
| "_NormalizedVDWVD3050", "_NormalizedVDWVD3075", | |
| "_NormalizedVDWVD3100", | |
| "_HydrophobicityD1001", "_HydrophobicityD1025", | |
| "_HydrophobicityD1050", "_HydrophobicityD1075", | |
| "_HydrophobicityD1100", | |
| "_HydrophobicityD2001", "_HydrophobicityD2025", | |
| "_HydrophobicityD2050", "_HydrophobicityD2075", | |
| "_HydrophobicityD2100", | |
| "_HydrophobicityD3001", "_HydrophobicityD3025", | |
| "_HydrophobicityD3050", "_HydrophobicityD3075", | |
| "_HydrophobicityD3100", | |
| "A", "R", "N", "D", "C", "E", "Q", "G", "H", "I", | |
| "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", | |
| "AR", "AD", "AQ", "AG", "AL", "AK", "AF", "AP", "AT", "AV", | |
| "RA", "RC", "RE", "RG", "RI", "RL", "RS", "RT", "RV", | |
| "NR", "NC", "NG", "NI", "NP", "NS", "NY", "NV", | |
| "DR", "DN", "DC", "DE", "DG", "DF", "DS", "DT", "DY", | |
| "CR", "CN", "CD", "CC", "CI", "CL", "CK", "CT", "CY", "CV", | |
| "EA", "ER", "ED", "EC", "EE", "EG", "EI", "EL", "EK", | |
| "EF", "EP", "ET", "EV", | |
| "QN", "QF", "QV", | |
| "GA", "GR", "GC", "GE", "GG", "GI", "GL", "GK", "GF", "GP", "GY", | |
| "HA", "HP", "HT", | |
| "IA", "IR", "ID", "II", "IL", "IF", "IP", "IS", "IV", | |
| "LA", "LR", "LD", "LC", "LG", "LI", "LK", "LM", "LF", | |
| "LS", "LT", "LY", "LV", | |
| "KA", "KN", "KC", "KG", "KI", "KL", "KK", "KP", "KY", | |
| "MA", "MD", "ME", "MI", "MK", "MF", "MP", "MS", "MV", | |
| "FR", "FE", "FQ", "FG", "FL", "FF", "FS", "FT", "FY", "FV", | |
| "PA", "PR", "PC", "PE", "PL", "PK", "PP", "PS", "PV", | |
| "SA", "SR", "SD", "SC", "SG", "SH", "SI", "SL", "SP", "ST", "SY", | |
| "TA", "TR", "TC", "TE", "TQ", "TG", "TI", "TL", "TP", "TS", "TV", | |
| "WA", | |
| "YN", "YD", "YC", "YQ", "YG", "YP", | |
| "VA", "VR", "VD", "VC", "VE", "VG", "VI", "VL", "VK", | |
| "VS", "VT", "VY", "VV" | |
| ] | |
| assert len(selected_features) == 343, f"Expected 343 features, got {len(selected_features)}" | |
| def keras_predict_proba(X): | |
| """Return probabilities as [P(Non-AMP), P(AMP)] for LIME (X already scaled).""" | |
| amp_model, _ = get_amp_model() | |
| preds = amp_model.predict(X, verbose=0) | |
| if preds.ndim == 1 or preds.shape[1] == 1: | |
| preds = preds.reshape(-1, 1) | |
| return np.hstack([1 - preds, preds]) # sigmoid output assumed = P(AMP) | |
| return preds | |
| def extract_features(sequence): | |
| """Compute CTD + AAC, select the 343 training columns IN ORDER, then scale.""" | |
| sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"]) | |
| if len(sequence) < 10: | |
| return "Error: Sequence too short." | |
| try: | |
| _, amp_scaler = get_amp_model() | |
| # Compute full feature pool | |
| ctd_features = CTD.CalculateCTD(sequence) | |
| aac = AAComposition.CalculateAADipeptideComposition(sequence) | |
| # Merge everything into one lookup dict | |
| pool = {} | |
| pool.update(ctd_features) | |
| pool.update(aac) | |
| # Verify all needed features are present | |
| missing = [f for f in selected_features if f not in pool] | |
| if missing: | |
| return f"Error: Missing features from propy: {missing[:5]}..." | |
| # Build the 343-wide row IN THE EXACT TRAINING ORDER, THEN scale. | |
| ordered_values = [pool[f] for f in selected_features] | |
| feature_row = np.array(ordered_values, dtype=np.float64).reshape(1, -1) | |
| scaled = amp_scaler.transform(feature_row) # scaler expects exactly 343 cols | |
| return scaled.astype(np.float32) | |
| except Exception as e: | |
| return f"Error in feature extraction: {str(e)}" | |
| def predictmic(sequence): | |
| sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"]) | |
| if len(sequence) < 10: | |
| return {"Error": "Sequence too short or invalid."} | |
| tokenizer, protbert_model, torch, device = get_protbert() | |
| seq_spaced = ' '.join(list(sequence)) | |
| tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', | |
| truncation=True, max_length=512) | |
| tokens = {k: v.to(device) for k, v in tokens.items()} | |
| with torch.no_grad(): | |
| outputs = protbert_model(**tokens) | |
| embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1) | |
| bacteria_config = { | |
| "E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None}, | |
| "S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None}, | |
| "P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None}, | |
| "K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"} | |
| } | |
| mic_results = {} | |
| for bacterium, cfg in bacteria_config.items(): | |
| try: | |
| mic_scaler = joblib.load(cfg["scaler"]) | |
| scaled = mic_scaler.transform(embedding) | |
| transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled | |
| mic_model = joblib.load(cfg["model"]) | |
| mic_log = mic_model.predict(transformed)[0] | |
| mic = round(expm1(mic_log), 3) | |
| mic_results[bacterium] = mic | |
| except Exception as e: | |
| mic_results[bacterium] = f"Error: {str(e)}" | |
| return mic_results | |
| def full_prediction(sequence): | |
| features = extract_features(sequence) | |
| if isinstance(features, str): | |
| return features | |
| amp_model, _ = get_amp_model() | |
| raw_pred = amp_model.predict(features, verbose=0) | |
| if raw_pred.ndim == 1 or raw_pred.shape[1] == 1: | |
| prob_amp = float(raw_pred.flatten()[0]) # sigmoid output assumed = P(AMP) | |
| if prob_amp >= 0.5: | |
| prediction = 1 | |
| confidence = round(prob_amp * 100, 2) | |
| else: | |
| prediction = 0 | |
| confidence = round((1 - prob_amp) * 100, 2) | |
| else: | |
| class_idx = int(np.argmax(raw_pred[0])) | |
| prediction = class_idx | |
| confidence = round(float(raw_pred[0][class_idx]) * 100, 2) | |
| # Label convention: 1 = AMP, 0 = Non-AMP (swap if your model is reversed) | |
| amp_result = "Antimicrobial Peptide (AMP)" if prediction == 1 else "Non-AMP" | |
| result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n" | |
| if prediction == 1: | |
| mic_values = predictmic(sequence) | |
| result += "\nPredicted MIC Values (μM):\n" | |
| for org, mic in mic_values.items(): | |
| result += f"- {org}: {mic}\n" | |
| else: | |
| result += "\nMIC prediction skipped for Non-AMP sequences.\n" | |
| try: | |
| from lime.lime_tabular import LimeTabularExplainer | |
| sample_data = np.random.rand(100, len(selected_features)) | |
| explainer = LimeTabularExplainer( | |
| training_data=sample_data, | |
| feature_names=selected_features, | |
| class_names=["Non-AMP", "AMP"], | |
| mode="classification" | |
| ) | |
| explanation = explainer.explain_instance( | |
| data_row=features[0], | |
| predict_fn=keras_predict_proba, | |
| num_features=10 | |
| ) | |
| result += "\nTop Features Influencing Prediction:\n" | |
| for feat, weight in explanation.as_list(): | |
| result += f"- {feat}: {round(weight, 4)}\n" | |
| except Exception as e: | |
| result += f"\nLIME explanation failed: {str(e)}\n" | |
| return result | |
| iface = gr.Interface( | |
| fn=full_prediction, | |
| inputs=gr.Textbox(label="Enter Protein Sequence"), | |
| outputs=gr.Textbox(label="Results"), | |
| title="AMP & MIC Predictor + LIME Explanation", | |
| description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights." | |
| ) | |
| iface.launch() |