File size: 2,614 Bytes
4466c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c9dd7
 
 
6dbd5a0
a4c9dd7
 
 
 
4466c5e
 
5100cb5
 
 
 
 
 
4466c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pathlib
import numpy as np
import torch
import streamlit as st
from torch import nn

# Model Definition
class FastMLP(nn.Module):
    def __init__(self, input_dim=1024):
        super(FastMLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # Single output for binary classification
        )

    def forward(self, x):
        return self.layers(x)

# Model Loader
@st.cache_resource
def load_model():
    # Always resolve relative to the StreamlitApp folder, not the process CWD.
    streamlitapp_dir = pathlib.Path(__file__).resolve().parent.parent
    repo_root = streamlitapp_dir.parent

    candidates = [
        repo_root / "MLModels" / "ampMLModel.pt",
        repo_root / "models" / "ampMLModel.pt",
        streamlitapp_dir / "models" / "ampMLModel.pt",
    ]
    model_path = next((p for p in candidates if p.exists()), candidates[0])

    if not model_path.exists():
        raise FileNotFoundError(
            "Model file 'ampMLModel.pt' not found in any of:\n"
            f"- {repo_root / 'MLModels' / 'ampMLModel.pt'}\n"
            f"- {repo_root / 'models' / 'ampMLModel.pt'}\n"
            f"- {streamlitapp_dir / 'models' / 'ampMLModel.pt'}\n"
        )

    # Build model and load weights
    model = FastMLP(input_dim=1024)
    model.load_state_dict(torch.load(str(model_path), map_location="cpu"))
    model.eval()
    return model

# Sequence Encoder
def encode_sequence(seq, max_len=51):
    """
    Converts amino acid sequence to flattened one-hot vector
    padded/truncated to match model input_dim (1024)
    """
    amino_acids = "ACDEFGHIKLMNPQRSTVWY"
    aa_to_idx = {aa: i for i, aa in enumerate(amino_acids)}

    one_hot = np.zeros((max_len, len(amino_acids)))  # max_len x 20
    for i, aa in enumerate(seq[:max_len]):
        if aa in aa_to_idx:
            one_hot[i, aa_to_idx[aa]] = 1

    flat = one_hot.flatten()  # length = max_len*20 = 1020

    if len(flat) < 1024:
        flat = np.pad(flat, (0, 1024 - len(flat)))

    return flat

# Prediction Function
def predict_amp(sequence, model):
    """
    Takes an amino acid sequence string and the loaded model,
    returns ("AMP"/"Non-AMP") and probability
    """
    x = torch.tensor(encode_sequence(sequence), dtype=torch.float32).unsqueeze(0)

    with torch.no_grad():
        logits = model(x)
        prob = torch.sigmoid(logits).item()

    label = "AMP" if prob >= 0.5 else "Non-AMP"
    return label, round(prob, 3)