import os # --- Prevent SIGSEGV (exit 139) from TensorFlow + PyTorch native lib clashes --- os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE") os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0") import gradio as gr import joblib import numpy as np import pandas as pd from propy import AAComposition, Autocorrelation, CTD, PseudoAAC from lime.lime_tabular import LimeTabularExplainer import sys import json import subprocess # --------------------------------------------------------------------------- # LAZY LOADING — keeps the free 16GB Space from OOM-ing at startup. # Only the TensorFlow AMP model is loaded in THIS process. ProtBert/PyTorch # run in a SEPARATE process (mic_worker.py) to avoid a native-library clash # between TensorFlow and PyTorch that caused SIGSEGV (exit 139). # --------------------------------------------------------------------------- _amp_model = None _amp_scaler = None def get_amp_model(): global _amp_model, _amp_scaler if _amp_model is None: from tensorflow.keras.models import load_model _amp_model = load_model("Comb1_aac_ctd_RFE_selected_features_model (1).keras") _amp_scaler = joblib.load("norm (4).joblib") return _amp_model, _amp_scaler # --------------------------------------------------------------------------- # The EXACT 343 features the model was trained on, IN THE EXACT TRAINING ORDER. # --------------------------------------------------------------------------- selected_features = [ "_PolarizabilityC1", "_PolarizabilityC2", "_PolarizabilityC3", "_SolventAccessibilityC1", "_SolventAccessibilityC2", "_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC2", "_SecondaryStrC3", "_ChargeC1", "_ChargeC2", "_ChargeC3", "_PolarityC1", "_PolarityC2", "_PolarityC3", "_NormalizedVDWVC1", "_NormalizedVDWVC2", "_NormalizedVDWVC3", "_HydrophobicityC1", "_HydrophobicityC2", "_HydrophobicityC3", "_PolarizabilityT12", "_PolarizabilityT13", "_PolarizabilityT23", "_SolventAccessibilityT12", "_SolventAccessibilityT13", "_SolventAccessibilityT23", "_SecondaryStrT12", "_SecondaryStrT13", "_SecondaryStrT23", "_ChargeT12", "_ChargeT13", "_ChargeT23", "_PolarityT12", "_PolarityT13", "_PolarityT23", "_NormalizedVDWVT12", "_NormalizedVDWVT13", "_NormalizedVDWVT23", "_HydrophobicityT12", "_HydrophobicityT13", "_HydrophobicityT23", "_PolarizabilityD1001", "_PolarizabilityD1025", "_PolarizabilityD1050", "_PolarizabilityD1075", "_PolarizabilityD1100", "_PolarizabilityD2001", "_PolarizabilityD2025", "_PolarizabilityD2050", "_PolarizabilityD2075", "_PolarizabilityD2100", "_PolarizabilityD3001", "_PolarizabilityD3025", "_PolarizabilityD3050", "_PolarizabilityD3075", "_PolarizabilityD3100", "_SolventAccessibilityD1001", "_SolventAccessibilityD1025", "_SolventAccessibilityD1050", "_SolventAccessibilityD1075", "_SolventAccessibilityD1100", "_SolventAccessibilityD2001", "_SolventAccessibilityD2025", "_SolventAccessibilityD2050", "_SolventAccessibilityD2075", "_SolventAccessibilityD2100", "_SolventAccessibilityD3001", "_SolventAccessibilityD3025", "_SolventAccessibilityD3050", "_SolventAccessibilityD3075", "_SolventAccessibilityD3100", "_SecondaryStrD1001", "_SecondaryStrD1025", "_SecondaryStrD1050", "_SecondaryStrD1075", "_SecondaryStrD1100", "_SecondaryStrD2001", "_SecondaryStrD2025", "_SecondaryStrD2050", "_SecondaryStrD2075", "_SecondaryStrD2100", "_SecondaryStrD3001", "_SecondaryStrD3025", "_SecondaryStrD3050", "_SecondaryStrD3075", "_SecondaryStrD3100", "_ChargeD1001", "_ChargeD1025", "_ChargeD1050", "_ChargeD1075", "_ChargeD1100", "_ChargeD2001", "_ChargeD2025", "_ChargeD2050", "_ChargeD2075", "_ChargeD3001", "_ChargeD3025", "_ChargeD3050", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1025", "_PolarityD1050", "_PolarityD1075", "_PolarityD1100", "_PolarityD2001", "_PolarityD2025", "_PolarityD2050", "_PolarityD2075", "_PolarityD2100", "_PolarityD3001", "_PolarityD3025", "_PolarityD3050", "_PolarityD3075", "_PolarityD3100", "_NormalizedVDWVD1001", "_NormalizedVDWVD1025", "_NormalizedVDWVD1050", "_NormalizedVDWVD1075", "_NormalizedVDWVD1100", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025", "_NormalizedVDWVD2050", "_NormalizedVDWVD2075", "_NormalizedVDWVD2100", "_NormalizedVDWVD3001", "_NormalizedVDWVD3025", "_NormalizedVDWVD3050", "_NormalizedVDWVD3075", "_NormalizedVDWVD3100", "_HydrophobicityD1001", "_HydrophobicityD1025", "_HydrophobicityD1050", "_HydrophobicityD1075", "_HydrophobicityD1100", "_HydrophobicityD2001", "_HydrophobicityD2025", "_HydrophobicityD2050", "_HydrophobicityD2075", "_HydrophobicityD2100", "_HydrophobicityD3001", "_HydrophobicityD3025", "_HydrophobicityD3050", "_HydrophobicityD3075", "_HydrophobicityD3100", "A", "R", "N", "D", "C", "E", "Q", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", "AR", "AD", "AQ", "AG", "AL", "AK", "AF", "AP", "AT", "AV", "RA", "RC", "RE", "RG", "RI", "RL", "RS", "RT", "RV", "NR", "NC", "NG", "NI", "NP", "NS", "NY", "NV", "DR", "DN", "DC", "DE", "DG", "DF", "DS", "DT", "DY", "CR", "CN", "CD", "CC", "CI", "CL", "CK", "CT", "CY", "CV", "EA", "ER", "ED", "EC", "EE", "EG", "EI", "EL", "EK", "EF", "EP", "ET", "EV", "QN", "QF", "QV", "GA", "GR", "GC", "GE", "GG", "GI", "GL", "GK", "GF", "GP", "GY", "HA", "HP", "HT", "IA", "IR", "ID", "II", "IL", "IF", "IP", "IS", "IV", "LA", "LR", "LD", "LC", "LG", "LI", "LK", "LM", "LF", "LS", "LT", "LY", "LV", "KA", "KN", "KC", "KG", "KI", "KL", "KK", "KP", "KY", "MA", "MD", "ME", "MI", "MK", "MF", "MP", "MS", "MV", "FR", "FE", "FQ", "FG", "FL", "FF", "FS", "FT", "FY", "FV", "PA", "PR", "PC", "PE", "PL", "PK", "PP", "PS", "PV", "SA", "SR", "SD", "SC", "SG", "SH", "SI", "SL", "SP", "ST", "SY", "TA", "TR", "TC", "TE", "TQ", "TG", "TI", "TL", "TP", "TS", "TV", "WA", "YN", "YD", "YC", "YQ", "YG", "YP", "VA", "VR", "VD", "VC", "VE", "VG", "VI", "VL", "VK", "VS", "VT", "VY", "VV" ] assert len(selected_features) == 343, f"Expected 343 features, got {len(selected_features)}" # --------------------------------------------------------------------------- # LIME explainer — built ONCE at startup with uniform [0,1] background data. # Valid because all features are MinMax-scaled to [0,1]. # class_names: index 0 = AMP, index 1 = Non-AMP (matches training: AMP=0, Non-AMP=1) # We always explain label=0 (AMP class) so weights are consistent across all # sequences — positive weight = pushes TOWARD AMP, negative = pushes AWAY. # --------------------------------------------------------------------------- _lime_background = np.random.rand(100, len(selected_features)) _explainer = LimeTabularExplainer( training_data=_lime_background, feature_names=selected_features, class_names=["AMP", "Non-AMP"], # index 0=AMP, index 1=Non-AMP mode="classification" ) def keras_predict_proba(X): """Return [P(AMP), P(Non-AMP)] for LIME. Training labels: AMP=0, Non-AMP=1. Sigmoid output = P(Non-AMP=1), so P(AMP) = 1 - sigmoid. Column order must match class_names: col0=P(AMP), col1=P(Non-AMP). """ amp_model, _ = get_amp_model() preds = amp_model.predict(X, verbose=0) if preds.ndim == 1 or preds.shape[1] == 1: preds = preds.reshape(-1, 1) # preds = P(Non-AMP) return np.hstack([1 - preds, preds]) # [P(AMP), P(Non-AMP)] return preds def extract_features(sequence): """Compute the full 1325-feature pool, scale it, then select the 343 model features.""" sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"]) if len(sequence) < 10: return "Error: Sequence too short." try: _, amp_scaler = get_amp_model() # Replicate the EXACT feature pool the scaler was fit on (1325 features). # Merge order must match training: CTD → dipeptide(420) → autocorr → pseudoAAC ctd_features = CTD.CalculateCTD(sequence) dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence) filtered_dipeptide = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]} auto_features = Autocorrelation.CalculateAutoTotal(sequence) pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9) all_features_dict = {} all_features_dict.update(ctd_features) all_features_dict.update(filtered_dipeptide) all_features_dict.update(auto_features) all_features_dict.update(pseudo_features) # Build full-pool DataFrame (~1325 columns) and scale feature_df_all = pd.DataFrame([all_features_dict]) scaled_array = amp_scaler.transform(feature_df_all.values) scaled_df = pd.DataFrame(scaled_array, columns=feature_df_all.columns) # Verify all 343 selected features are present missing = [f for f in selected_features if f not in scaled_df.columns] if missing: return f"Error: Missing features after scaling: {missing[:5]}..." # Select 343 features in model training order selected_df = scaled_df[selected_features].fillna(0) return selected_df.values.astype(np.float32) except Exception as e: return f"Error in feature extraction: {str(e)}" def predictmic(sequence): """Run MIC prediction in a SEPARATE process (mic_worker.py). Isolates PyTorch/ProtBert from TensorFlow to prevent SIGSEGV (exit 139). The worker prints a JSON dict on its last stdout line. """ sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"]) if len(sequence) < 10: return {"Error": "Sequence too short or invalid."} try: proc = subprocess.run( [sys.executable, "mic_worker.py", sequence], capture_output=True, text=True, timeout=900 ) except subprocess.TimeoutExpired: return {"Error": "MIC prediction timed out (model download may still be in progress; try again shortly)."} except Exception as e: return {"Error": f"Failed to start MIC worker: {str(e)}"} if proc.returncode != 0: tail = (proc.stderr or "").strip().splitlines()[-3:] return {"Error": f"MIC worker exited with code {proc.returncode}. {' '.join(tail)}"} out_lines = [ln for ln in (proc.stdout or "").splitlines() if ln.strip()] if not out_lines: return {"Error": "MIC worker produced no output."} try: return json.loads(out_lines[-1]) except Exception: return {"Error": f"Could not parse MIC worker output: {out_lines[-1][:200]}"} def full_prediction(sequence): print("[CHECKPOINT] full_prediction called", flush=True) features = extract_features(sequence) if isinstance(features, str): print("[CHECKPOINT] extract_features error:", features, flush=True) return features print("[CHECKPOINT] features extracted OK, shape:", features.shape, flush=True) amp_model, _ = get_amp_model() raw_pred = amp_model.predict(features, verbose=0) print("[CHECKPOINT] raw sigmoid output:", raw_pred, flush=True) # sigmoid output = P(Non-AMP) — training labels were AMP=0, Non-AMP=1 prob_non_amp = float(raw_pred.flatten()[0]) prob_amp = 1.0 - prob_non_amp if prob_amp >= 0.5: prediction = 0 # AMP (class 0) confidence = round(prob_amp * 100, 2) else: prediction = 1 # Non-AMP (class 1) confidence = round(prob_non_amp * 100, 2) amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP" result = f"Prediction: {amp_result}\n" result += f"Confidence: {confidence}%\n" if prediction == 0: # AMP → run MIC print("[CHECKPOINT] AMP detected, starting MIC (ProtBert)...", flush=True) mic_values = predictmic(sequence) print("[CHECKPOINT] MIC done:", mic_values, flush=True) result += "\nPredicted MIC Values (μM):\n" for org, mic in mic_values.items(): result += f"- {org}: {mic}\n" else: result += "\nMIC prediction skipped for Non-AMP sequences.\n" # ------------------------------------------------------------------ # LIME — always explains class 0 (AMP) so weights are consistent: # weight > 0 → feature pushes TOWARD AMP classification # weight < 0 → feature pushes AWAY from AMP classification # This is meaningful for both AMPs and Non-AMPs: # AMP sequence → top positive weights explain why it's an AMP # Non-AMP sequence → top negative weights explain why it's NOT an AMP # ------------------------------------------------------------------ try: explanation = _explainer.explain_instance( data_row=features[0], predict_fn=keras_predict_proba, num_features=10, labels=(0,) # always explain AMP class (index 0) ) result += "\nTop Features Influencing AMP Classification:\n" for feat, weight in explanation.as_list(label=0): direction = "↑ AMP" if weight > 0 else "↓ AMP" result += f"- {feat}: {round(weight, 4)} ({direction})\n" except Exception as e: result += f"\nLIME explanation failed: {str(e)}\n" return result iface = gr.Interface( fn=full_prediction, inputs=gr.Textbox(label="Enter Protein Sequence"), outputs=gr.Textbox(label="Results"), title="AMP & MIC Predictor + LIME Explanation", description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights." ) iface.launch()