File size: 2,614 Bytes
4466c5e a4c9dd7 6dbd5a0 a4c9dd7 4466c5e 5100cb5 4466c5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | import pathlib
import numpy as np
import torch
import streamlit as st
from torch import nn
# Model Definition
class FastMLP(nn.Module):
def __init__(self, input_dim=1024):
super(FastMLP, self).__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, 1) # Single output for binary classification
)
def forward(self, x):
return self.layers(x)
# Model Loader
@st.cache_resource
def load_model():
# Always resolve relative to the StreamlitApp folder, not the process CWD.
streamlitapp_dir = pathlib.Path(__file__).resolve().parent.parent
repo_root = streamlitapp_dir.parent
candidates = [
repo_root / "MLModels" / "ampMLModel.pt",
repo_root / "models" / "ampMLModel.pt",
streamlitapp_dir / "models" / "ampMLModel.pt",
]
model_path = next((p for p in candidates if p.exists()), candidates[0])
if not model_path.exists():
raise FileNotFoundError(
"Model file 'ampMLModel.pt' not found in any of:\n"
f"- {repo_root / 'MLModels' / 'ampMLModel.pt'}\n"
f"- {repo_root / 'models' / 'ampMLModel.pt'}\n"
f"- {streamlitapp_dir / 'models' / 'ampMLModel.pt'}\n"
)
# Build model and load weights
model = FastMLP(input_dim=1024)
model.load_state_dict(torch.load(str(model_path), map_location="cpu"))
model.eval()
return model
# Sequence Encoder
def encode_sequence(seq, max_len=51):
"""
Converts amino acid sequence to flattened one-hot vector
padded/truncated to match model input_dim (1024)
"""
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
aa_to_idx = {aa: i for i, aa in enumerate(amino_acids)}
one_hot = np.zeros((max_len, len(amino_acids))) # max_len x 20
for i, aa in enumerate(seq[:max_len]):
if aa in aa_to_idx:
one_hot[i, aa_to_idx[aa]] = 1
flat = one_hot.flatten() # length = max_len*20 = 1020
if len(flat) < 1024:
flat = np.pad(flat, (0, 1024 - len(flat)))
return flat
# Prediction Function
def predict_amp(sequence, model):
"""
Takes an amino acid sequence string and the loaded model,
returns ("AMP"/"Non-AMP") and probability
"""
x = torch.tensor(encode_sequence(sequence), dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
logits = model(x)
prob = torch.sigmoid(logits).item()
label = "AMP" if prob >= 0.5 else "Non-AMP"
return label, round(prob, 3) |