| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| import torch |
| import numpy as np |
| import os |
| import time |
| import joblib |
| from pathlib import Path |
| from datetime import datetime, timezone |
| from typing import Optional |
| from contextlib import asynccontextmanager |
| from dotenv import load_dotenv |
| import shutil |
| from huggingface_hub import hf_hub_download |
|
|
| |
| from transformers import BertTokenizer, BertModel |
|
|
| |
| from utils.model_classes import MHSA_GRU |
|
|
| load_dotenv() |
|
|
| |
|
|
| |
| MODEL_REPO = { |
| "repo_id": "camlas/toxicity", |
| "files": { |
| "classifier": "mhsa_gru_classifier.pth", |
| "scaler": "scaler.pkl" |
| } |
| } |
|
|
| |
| TRANSFORMER_CONFIG = { |
| "model_name": "Rostlab/prot_bert", |
| "model_type": "ProtBERT", |
| "tokenizer_class": BertTokenizer, |
| "model_class": BertModel |
| } |
|
|
| CLASSES = ["Non-Toxic", "Toxic"] |
| API_VERSION = "2.0.0-protbert" |
| MODEL_VERSION = "ProtBERT-MHSA-GRU-v1" |
|
|
| |
| models = { |
| "transformer": None, |
| "tokenizer": None, |
| "classifier": None, |
| "scaler": None |
| } |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
|
|
| def ensure_models_directory(): |
| models_dir = "models" |
| Path(models_dir).mkdir(exist_ok=True) |
| return models_dir |
|
|
| def download_model_from_hub(model_key: str) -> Optional[str]: |
| """Download custom trained models (Classifier/Scaler) from Private HF Repo""" |
| try: |
| filename = MODEL_REPO["files"][model_key] |
| repo_id = MODEL_REPO["repo_id"] |
| models_dir = ensure_models_directory() |
| local_path = os.path.join(models_dir, filename) |
|
|
| |
| if os.path.exists(local_path): |
| print(f"✅ Found {model_key} locally: {local_path}") |
| return local_path |
|
|
| print(f"📥 Downloading {model_key} from {repo_id}...") |
| token = os.getenv("HF_TOKEN") |
| |
| if not token: |
| print("⚠️ Warning: HF_TOKEN not found in .env. Private repos will fail.") |
|
|
| temp_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| repo_type="model", |
| token=token |
| ) |
| shutil.copy2(temp_path, local_path) |
| return local_path |
| except Exception as e: |
| print(f"❌ Error downloading {model_key}: {e}") |
| return None |
|
|
| def load_feature_extractor(): |
| """Load the ProtBERT Model from HuggingFace""" |
| print(f"🔄 Loading Transformer: {TRANSFORMER_CONFIG['model_name']}...") |
| try: |
| |
| tokenizer = TRANSFORMER_CONFIG['tokenizer_class'].from_pretrained( |
| TRANSFORMER_CONFIG['model_name'], |
| do_lower_case=False |
| ) |
| model = TRANSFORMER_CONFIG['model_class'].from_pretrained( |
| TRANSFORMER_CONFIG['model_name'] |
| ) |
| model.to(device) |
| model.eval() |
| |
| models["tokenizer"] = tokenizer |
| models["transformer"] = model |
| print("✅ ProtBERT Transformer loaded successfully") |
| return True |
| except Exception as e: |
| print(f"❌ Error loading Transformer: {e}") |
| return False |
|
|
| def load_classifier_and_scaler(): |
| """Load the custom MHSA-GRU classifier and Scaler""" |
| try: |
| |
| scaler_path = download_model_from_hub("scaler") |
| if scaler_path: |
| models["scaler"] = joblib.load(scaler_path) |
| print("✅ Scaler loaded") |
|
|
| |
| clf_path = download_model_from_hub("classifier") |
| if clf_path: |
| |
| input_dim = 1024 |
| |
| print(f"ℹ️ Initializing MHSA_GRU with input_dim={input_dim} (ProtBERT)") |
| |
| classifier = MHSA_GRU( |
| input_dim=input_dim, |
| hidden_dim=256, |
| num_heads=8, |
| num_gru_layers=2, |
| dropout=0.3 |
| ) |
| |
| state_dict = torch.load(clf_path, map_location=device) |
| classifier.load_state_dict(state_dict) |
| classifier.to(device) |
| classifier.eval() |
| models["classifier"] = classifier |
| print("✅ Classifier loaded") |
| |
| return models["scaler"] is not None and models["classifier"] is not None |
| except Exception as e: |
| print(f"❌ Error loading custom models: {e}") |
| return False |
|
|
| def preprocess_sequence(sequence: str): |
| """ |
| Preprocess sequence for ProtBERT. |
| ProtBERT expects spaces between amino acids: 'M K T A Y...' |
| """ |
| |
| sequence = sequence.upper().strip().replace('\n', '').replace('\r', '') |
| |
| |
| spaced_sequence = " ".join(list(sequence)) |
| return spaced_sequence |
|
|
| def extract_features(sequence: str): |
| """Run sequence through ProtBERT to get [CLS] embeddings""" |
| tokenizer = models["tokenizer"] |
| model = models["transformer"] |
| |
| processed_seq = preprocess_sequence(sequence) |
| |
| inputs = tokenizer( |
| [processed_seq], |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=512 |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| |
| |
| features = outputs.last_hidden_state[:, 0, :] |
| |
| return features.cpu().numpy() |
|
|
| |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| print("🚀 Starting Toxicity Detection API (ProtBERT Edition)...") |
| |
| |
| if not os.path.exists("utils/model_classes.py"): |
| print("❌ Error: utils/model_classes.py not found. Please create it.") |
| |
| success_tf = load_feature_extractor() |
| success_custom = load_classifier_and_scaler() |
| |
| if not (success_tf and success_custom): |
| print("⚠️ Warning: Not all models loaded successfully") |
| yield |
| print("🔄 Shutting down API...") |
|
|
| app = FastAPI( |
| title="Peptide Toxicity Detection API", |
| description="API using ProtBERT features + MHSA-GRU classifier", |
| version=API_VERSION, |
| lifespan=lifespan |
| ) |
|
|
| |
|
|
| class SequenceRequest(BaseModel): |
| sequence: str |
| |
| class PredictionResponse(BaseModel): |
| sequence_preview: str |
| is_toxic: bool |
| label: str |
| score: float |
| confidence_level: str |
| model_used: str |
| processing_time_ms: float |
| timestamp: str |
|
|
| |
|
|
| @app.get("/") |
| async def root(): |
| return {"message": "Toxicity Detection API is running. Use /predict to analyze sequences."} |
|
|
| @app.get("/health") |
| async def health_check(): |
| loaded = all(v is not None for v in models.values()) |
| return { |
| "status": "healthy" if loaded else "degraded", |
| "models_loaded": {k: v is not None for k, v in models.items()}, |
| "device": str(device), |
| "model_version": MODEL_VERSION, |
| "feature_extractor": TRANSFORMER_CONFIG["model_name"] |
| } |
|
|
| @app.post("/predict", response_model=PredictionResponse) |
| async def predict(request: SequenceRequest): |
| start_time = time.time() |
| |
| if not all(models.values()): |
| raise HTTPException(status_code=503, detail="Models are not fully initialized.") |
| |
| if not request.sequence: |
| raise HTTPException(status_code=400, detail="Empty sequence provided.") |
|
|
| try: |
| |
| |
| raw_features = extract_features(request.sequence) |
| |
| |
| |
| scaled_features = models["scaler"].transform(raw_features) |
| |
| |
| features_tensor = torch.FloatTensor(scaled_features).to(device) |
| |
| with torch.no_grad(): |
| |
| probability = models["classifier"](features_tensor).item() |
| |
| |
| |
| prediction_class = 1 if probability > 0.5 else 0 |
| predicted_label = CLASSES[prediction_class] |
| |
| |
| confidence_score = abs(probability - 0.5) * 2 |
| confidence_level = "High" if confidence_score > 0.8 else "Medium" if confidence_score > 0.5 else "Low" |
| |
| processing_time = round((time.time() - start_time) * 1000, 2) |
| |
| return PredictionResponse( |
| sequence_preview=request.sequence[:20] + "..." if len(request.sequence) > 20 else request.sequence, |
| is_toxic=(prediction_class == 1), |
| label=predicted_label, |
| score=probability, |
| confidence_level=confidence_level, |
| model_used="ProtBERT + MHSA-GRU", |
| processing_time_ms=processing_time, |
| timestamp=datetime.now(timezone.utc).isoformat() |
| ) |
| |
| except Exception as e: |
| print(f"Error during prediction: {e}") |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |