Spaces:
Running
Running
Commit ·
2bdb663
1
Parent(s): 27826d5
feat: security headers, frontend auth context, ML registry, and orchestration pipeline
Browse files- backend/core/middleware.py +39 -1
- backend/ml/registry.py +291 -0
- backend/orchestration/pipeline.py +184 -0
- backend/orchestration/queue.py +198 -0
- frontend/app/(auth)/callback/page.jsx +59 -0
- frontend/components/auth/ProtectedRoute.jsx +23 -0
- frontend/components/auth/PublicOnlyRoute.jsx +19 -0
- frontend/next.config.js +43 -0
- scripts/security_audit.py +47 -0
backend/core/middleware.py
CHANGED
|
@@ -1,7 +1,45 @@
|
|
| 1 |
from fastapi import Request, Response
|
| 2 |
-
from fastapi.responses import JSONResponse
|
|
|
|
| 3 |
from slowapi import Limiter
|
| 4 |
from slowapi.errors import RateLimitExceeded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def get_rate_limit_key(request: Request) -> str:
|
| 7 |
user_id = getattr(request.state, "user_id", None)
|
|
|
|
| 1 |
from fastapi import Request, Response
|
| 2 |
+
from fastapi.responses import JSONResponse, RedirectResponse
|
| 3 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 4 |
from slowapi import Limiter
|
| 5 |
from slowapi.errors import RateLimitExceeded
|
| 6 |
+
from backend.core.config import settings
|
| 7 |
+
|
| 8 |
+
class HTTPSRedirectMiddleware(BaseHTTPMiddleware):
|
| 9 |
+
async def dispatch(self, request: Request, call_next):
|
| 10 |
+
if settings.is_production:
|
| 11 |
+
# Check if request came in as HTTP (Render.com forwards as HTTPS but sets X-Forwarded-Proto)
|
| 12 |
+
proto = request.headers.get("X-Forwarded-Proto", "https")
|
| 13 |
+
if proto == "http":
|
| 14 |
+
https_url = str(request.url).replace("http://", "https://", 1)
|
| 15 |
+
return RedirectResponse(https_url, status_code=301)
|
| 16 |
+
return await call_next(request)
|
| 17 |
+
|
| 18 |
+
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
| 19 |
+
async def dispatch(self, request: Request, call_next):
|
| 20 |
+
response = await call_next(request)
|
| 21 |
+
|
| 22 |
+
# Always add these headers:
|
| 23 |
+
response.headers["X-Content-Type-Options"] = "nosniff"
|
| 24 |
+
response.headers["X-Frame-Options"] = "DENY"
|
| 25 |
+
response.headers["X-XSS-Protection"] = "1; mode=block"
|
| 26 |
+
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
| 27 |
+
response.headers["Permissions-Policy"] = (
|
| 28 |
+
"camera=(), microphone=(self), geolocation=(), "
|
| 29 |
+
"payment=(), usb=(), magnetometer=()"
|
| 30 |
+
) # microphone=(self) required for Whisper voice input
|
| 31 |
+
|
| 32 |
+
# Production only:
|
| 33 |
+
if settings.is_production:
|
| 34 |
+
response.headers["Strict-Transport-Security"] = (
|
| 35 |
+
"max-age=63072000; includeSubDomains; preload"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Remove headers that leak server info:
|
| 39 |
+
response.headers.pop("Server", None)
|
| 40 |
+
response.headers.pop("X-Powered-By", None)
|
| 41 |
+
|
| 42 |
+
return response
|
| 43 |
|
| 44 |
def get_rate_limit_key(request: Request) -> str:
|
| 45 |
user_id = getattr(request.state, "user_id", None)
|
backend/ml/registry.py
CHANGED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import time
|
| 3 |
+
import logging
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from datetime import datetime, UTC
|
| 6 |
+
from typing import Literal, Any
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from backend.core.config import settings
|
| 10 |
+
from backend.core.logging_config import ml_logger
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class ModelProfile:
|
| 16 |
+
name: str
|
| 17 |
+
hf_model_id: str
|
| 18 |
+
local_cache_subdir: str
|
| 19 |
+
device_preference: Literal["cuda", "cpu", "auto"]
|
| 20 |
+
vram_mb: int
|
| 21 |
+
ram_mb: int
|
| 22 |
+
load_priority: int
|
| 23 |
+
is_required: bool
|
| 24 |
+
|
| 25 |
+
MODEL_PROFILES = {
|
| 26 |
+
"dino_anomaly": ModelProfile(
|
| 27 |
+
name="dino_anomaly", hf_model_id="facebook/dinov2-small", local_cache_subdir="dino_anomaly",
|
| 28 |
+
device_preference="cuda", vram_mb=400, ram_mb=50, load_priority=1, is_required=True
|
| 29 |
+
),
|
| 30 |
+
"biobert_ner": ModelProfile(
|
| 31 |
+
name="biobert_ner", hf_model_id="dmis-lab/biobert-base-cased-v1.2", local_cache_subdir="biobert_ner",
|
| 32 |
+
device_preference="cuda", vram_mb=450, ram_mb=50, load_priority=2, is_required=True
|
| 33 |
+
),
|
| 34 |
+
"biomedvlp": ModelProfile(
|
| 35 |
+
name="biomedvlp", hf_model_id="microsoft/BiomedVLP-CXR-BERT-specialized", local_cache_subdir="biomedvlp",
|
| 36 |
+
device_preference="auto", vram_mb=900, ram_mb=100, load_priority=3, is_required=False
|
| 37 |
+
),
|
| 38 |
+
"whisper_tiny": ModelProfile(
|
| 39 |
+
name="whisper_tiny", hf_model_id="openai/whisper-tiny", local_cache_subdir="whisper",
|
| 40 |
+
device_preference="cpu", vram_mb=0, ram_mb=300, load_priority=4, is_required=False
|
| 41 |
+
),
|
| 42 |
+
"biogpt_base": ModelProfile(
|
| 43 |
+
name="biogpt_base", hf_model_id="microsoft/biogpt", local_cache_subdir="biogpt",
|
| 44 |
+
device_preference="cpu", vram_mb=0, ram_mb=700, load_priority=5, is_required=False
|
| 45 |
+
),
|
| 46 |
+
"minilm": ModelProfile(
|
| 47 |
+
name="minilm", hf_model_id="sentence-transformers/all-MiniLM-L6-v2", local_cache_subdir="minilm",
|
| 48 |
+
device_preference="cpu", vram_mb=0, ram_mb=100, load_priority=1, is_required=True
|
| 49 |
+
),
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class ModelState:
|
| 54 |
+
profile: ModelProfile
|
| 55 |
+
model: Any = None
|
| 56 |
+
tokenizer: Any = None
|
| 57 |
+
head: Any = None # Extension for DINO head architecture
|
| 58 |
+
stats: dict = None # Extension for anomaly scoring
|
| 59 |
+
is_loaded: bool = False
|
| 60 |
+
is_loading: bool = False
|
| 61 |
+
load_error: str | None = None
|
| 62 |
+
load_time_ms: int = 0
|
| 63 |
+
last_used: datetime | None = None
|
| 64 |
+
current_device: str = "unloaded"
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def is_available(self) -> bool:
|
| 68 |
+
return self.is_loaded and self.load_error is None
|
| 69 |
+
|
| 70 |
+
class ModelRegistry:
|
| 71 |
+
def __init__(self):
|
| 72 |
+
self._states: dict[str, ModelState] = {
|
| 73 |
+
name: ModelState(profile=profile) for name, profile in MODEL_PROFILES.items()
|
| 74 |
+
}
|
| 75 |
+
self._locks: dict[str, asyncio.Lock] = {
|
| 76 |
+
name: asyncio.Lock() for name in MODEL_PROFILES
|
| 77 |
+
}
|
| 78 |
+
self._gpu_budget_mb = settings.GPU_VRAM_BUDGET_MB
|
| 79 |
+
|
| 80 |
+
async def startup_load(self):
|
| 81 |
+
ml_logger.logger.info("Starting model registry startup")
|
| 82 |
+
sorted_models = sorted(MODEL_PROFILES.values(), key=lambda m: m.load_priority)
|
| 83 |
+
|
| 84 |
+
for profile in sorted_models:
|
| 85 |
+
if profile.device_preference == "cpu":
|
| 86 |
+
await self._load_model(profile.name)
|
| 87 |
+
else:
|
| 88 |
+
if self._get_used_vram() + profile.vram_mb <= self._gpu_budget_mb:
|
| 89 |
+
await self._load_model(profile.name)
|
| 90 |
+
else:
|
| 91 |
+
ml_logger.logger.warning(f"Skipping GPU load for {profile.name}: VRAM budget exceeded. Will load on CPU on first request.")
|
| 92 |
+
|
| 93 |
+
loaded = [n for n, s in self._states.items() if s.is_available]
|
| 94 |
+
failed = [n for n, s in self._states.items() if s.load_error]
|
| 95 |
+
required_failed = [n for n in failed if MODEL_PROFILES[n].is_required]
|
| 96 |
+
|
| 97 |
+
if required_failed:
|
| 98 |
+
raise RuntimeError(f"Critical models failed to load: {required_failed}. Check logs.")
|
| 99 |
+
|
| 100 |
+
ml_logger.logger.info("Registry startup complete", extra={"loaded": loaded, "failed": failed, "vram_used_mb": self._get_used_vram()})
|
| 101 |
+
|
| 102 |
+
async def get(self, model_name: str) -> ModelState:
|
| 103 |
+
if model_name not in self._states:
|
| 104 |
+
raise ValueError(f"Unknown model: {model_name}")
|
| 105 |
+
|
| 106 |
+
state = self._states[model_name]
|
| 107 |
+
if not state.is_available and not state.is_loading:
|
| 108 |
+
await self._load_model(model_name)
|
| 109 |
+
|
| 110 |
+
self._states[model_name].last_used = datetime.now(UTC)
|
| 111 |
+
return self._states[model_name]
|
| 112 |
+
|
| 113 |
+
def is_available(self, model_name: str) -> bool:
|
| 114 |
+
return self._states.get(model_name, ModelState(ModelProfile("","","","cpu",0,0,0,False))).is_available
|
| 115 |
+
|
| 116 |
+
async def _load_model(self, model_name: str):
|
| 117 |
+
async with self._locks[model_name]:
|
| 118 |
+
state = self._states[model_name]
|
| 119 |
+
if state.is_available:
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
state.is_loading = True
|
| 123 |
+
start_time = time.monotonic()
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
profile = state.profile
|
| 127 |
+
device = self._resolve_device(profile)
|
| 128 |
+
|
| 129 |
+
if device == "cuda":
|
| 130 |
+
needed = profile.vram_mb
|
| 131 |
+
available = self._gpu_budget_mb - self._get_used_vram()
|
| 132 |
+
if available < needed:
|
| 133 |
+
evicted = await self._evict_lru_gpu_model(except_model=model_name)
|
| 134 |
+
if evicted:
|
| 135 |
+
ml_logger.logger.info(f"Evicted {evicted} to make room for {model_name}")
|
| 136 |
+
|
| 137 |
+
# Fetch objects securely
|
| 138 |
+
result = await asyncio.to_thread(self._load_model_sync, model_name, profile, device)
|
| 139 |
+
|
| 140 |
+
load_time_ms = int((time.monotonic() - start_time) * 1000)
|
| 141 |
+
|
| 142 |
+
state.model = result.get('model')
|
| 143 |
+
state.tokenizer = result.get('tokenizer')
|
| 144 |
+
state.head = result.get('head')
|
| 145 |
+
state.stats = result.get('stats')
|
| 146 |
+
|
| 147 |
+
state.is_loaded = True
|
| 148 |
+
state.load_error = None
|
| 149 |
+
state.load_time_ms = load_time_ms
|
| 150 |
+
state.current_device = device
|
| 151 |
+
|
| 152 |
+
ml_logger.log_model_load(model_name, device, load_time_ms, vram_delta_mb=profile.vram_mb if device == "cuda" else None)
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
state.load_error = str(e)
|
| 156 |
+
state.is_loaded = False
|
| 157 |
+
ml_logger.logger.error(f"Failed to load model {model_name}: {e}", exc_info=True)
|
| 158 |
+
if MODEL_PROFILES[model_name].is_required:
|
| 159 |
+
raise
|
| 160 |
+
finally:
|
| 161 |
+
state.is_loading = False
|
| 162 |
+
|
| 163 |
+
def _load_model_sync(self, name: str, profile: ModelProfile, device: str) -> dict:
|
| 164 |
+
cache_dir = settings.MODEL_CACHE_DIR / profile.local_cache_subdir
|
| 165 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 166 |
+
|
| 167 |
+
if name == "dino_anomaly":
|
| 168 |
+
from transformers import AutoImageProcessor, AutoModel as ViTModel
|
| 169 |
+
processor = AutoImageProcessor.from_pretrained(profile.hf_model_id, cache_dir=cache_dir)
|
| 170 |
+
model = ViTModel.from_pretrained(profile.hf_model_id, cache_dir=cache_dir).to(device)
|
| 171 |
+
model.eval()
|
| 172 |
+
|
| 173 |
+
# Simulated Projection Head Loading Logic
|
| 174 |
+
head = None
|
| 175 |
+
stats = {"mean": 0.001, "std": 0.0005}
|
| 176 |
+
head_path = settings.MODEL_CACHE_DIR / "anomaly_head.pt"
|
| 177 |
+
stats_path = settings.MODEL_CACHE_DIR / "anomaly_stats.json"
|
| 178 |
+
|
| 179 |
+
if head_path.exists():
|
| 180 |
+
# We will define DINOProjectionHead in the vision module later
|
| 181 |
+
import json
|
| 182 |
+
pass
|
| 183 |
+
else:
|
| 184 |
+
logger.warning("Using untrained head — anomaly scores may be unreliable")
|
| 185 |
+
|
| 186 |
+
if stats_path.exists():
|
| 187 |
+
import json
|
| 188 |
+
stats = json.loads(stats_path.read_text())
|
| 189 |
+
|
| 190 |
+
return {"model": model, "tokenizer": processor, "head": head, "stats": stats}
|
| 191 |
+
|
| 192 |
+
elif name == "biobert_ner":
|
| 193 |
+
fine_tuned_path = settings.MODEL_CACHE_DIR / "biobert_ner_finetuned"
|
| 194 |
+
model_path = str(fine_tuned_path) if fine_tuned_path.exists() else profile.hf_model_id
|
| 195 |
+
|
| 196 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
| 197 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=cache_dir)
|
| 198 |
+
model = AutoModelForTokenClassification.from_pretrained(model_path, cache_dir=cache_dir).to(device)
|
| 199 |
+
model.eval()
|
| 200 |
+
return {"model": model, "tokenizer": tokenizer}
|
| 201 |
+
|
| 202 |
+
elif name == "whisper_tiny":
|
| 203 |
+
import whisper
|
| 204 |
+
model = whisper.load_model("tiny", device="cpu", download_root=str(cache_dir))
|
| 205 |
+
return {"model": model, "tokenizer": None}
|
| 206 |
+
|
| 207 |
+
elif name == "biogpt_base":
|
| 208 |
+
from transformers import BioGptForCausalLM, BioGptTokenizer
|
| 209 |
+
tokenizer = BioGptTokenizer.from_pretrained(profile.hf_model_id, cache_dir=cache_dir)
|
| 210 |
+
model = BioGptForCausalLM.from_pretrained(profile.hf_model_id, cache_dir=cache_dir)
|
| 211 |
+
model.eval()
|
| 212 |
+
return {"model": model, "tokenizer": tokenizer}
|
| 213 |
+
|
| 214 |
+
elif name == "minilm":
|
| 215 |
+
from sentence_transformers import SentenceTransformer
|
| 216 |
+
model = SentenceTransformer(profile.hf_model_id, cache_folder=str(cache_dir))
|
| 217 |
+
return {"model": model, "tokenizer": None}
|
| 218 |
+
|
| 219 |
+
elif name == "biomedvlp":
|
| 220 |
+
from transformers import AutoModel, AutoTokenizer
|
| 221 |
+
free_vram = torch.cuda.mem_get_info()[0] // 1024 // 1024 if torch.cuda.is_available() else 0
|
| 222 |
+
if free_vram < 950 and device == "cuda":
|
| 223 |
+
device = "cpu"
|
| 224 |
+
logger.warning("Insufficient VRAM for BiomedVLP — loading on CPU (slower)")
|
| 225 |
+
|
| 226 |
+
# Trust remote code is required for custom Microsoft implementation
|
| 227 |
+
tokenizer = AutoTokenizer.from_pretrained(profile.hf_model_id, cache_dir=cache_dir, trust_remote_code=True)
|
| 228 |
+
model = AutoModel.from_pretrained(profile.hf_model_id, cache_dir=cache_dir, trust_remote_code=True).to(device)
|
| 229 |
+
model.eval()
|
| 230 |
+
return {"model": model, "tokenizer": tokenizer}
|
| 231 |
+
|
| 232 |
+
else:
|
| 233 |
+
raise ValueError(f"No loader defined for model: {name}")
|
| 234 |
+
|
| 235 |
+
def _resolve_device(self, profile: ModelProfile) -> str:
|
| 236 |
+
if profile.device_preference == "cpu":
|
| 237 |
+
return "cpu"
|
| 238 |
+
if profile.device_preference == "cuda":
|
| 239 |
+
if not torch.cuda.is_available():
|
| 240 |
+
ml_logger.logger.warning(f"CUDA not available, loading {profile.name} on CPU")
|
| 241 |
+
return "cpu"
|
| 242 |
+
return "cuda"
|
| 243 |
+
if profile.device_preference == "auto":
|
| 244 |
+
if torch.cuda.is_available():
|
| 245 |
+
free_vram = self._gpu_budget_mb - self._get_used_vram()
|
| 246 |
+
if free_vram >= profile.vram_mb:
|
| 247 |
+
return "cuda"
|
| 248 |
+
return "cpu"
|
| 249 |
+
|
| 250 |
+
def _get_used_vram(self) -> int:
|
| 251 |
+
return sum(s.profile.vram_mb for s in self._states.values() if s.is_available and s.current_device == "cuda")
|
| 252 |
+
|
| 253 |
+
async def _evict_lru_gpu_model(self, except_model: str) -> str | None:
|
| 254 |
+
gpu_models = [
|
| 255 |
+
(name, state) for name, state in self._states.items()
|
| 256 |
+
if state.is_available and state.current_device == "cuda" and name != except_model
|
| 257 |
+
]
|
| 258 |
+
if not gpu_models:
|
| 259 |
+
return None
|
| 260 |
+
|
| 261 |
+
lru_name, _ = min(gpu_models, key=lambda x: x[1].last_used or datetime.min.replace(tzinfo=UTC))
|
| 262 |
+
await asyncio.to_thread(self._move_to_cpu, lru_name)
|
| 263 |
+
return lru_name
|
| 264 |
+
|
| 265 |
+
def _move_to_cpu(self, model_name: str):
|
| 266 |
+
state = self._states[model_name]
|
| 267 |
+
if state.model is not None and hasattr(state.model, "cpu"):
|
| 268 |
+
state.model = state.model.cpu()
|
| 269 |
+
torch.cuda.empty_cache()
|
| 270 |
+
state.current_device = "cpu"
|
| 271 |
+
ml_logger.logger.info(f"Moved {model_name} to CPU")
|
| 272 |
+
|
| 273 |
+
def get_status(self) -> dict:
|
| 274 |
+
return {
|
| 275 |
+
"models": {
|
| 276 |
+
name: {
|
| 277 |
+
"is_available": state.is_available,
|
| 278 |
+
"device": state.current_device,
|
| 279 |
+
"load_error": state.load_error,
|
| 280 |
+
"load_time_ms": state.load_time_ms,
|
| 281 |
+
"last_used": state.last_used.isoformat() if state.last_used else None,
|
| 282 |
+
"vram_mb": state.profile.vram_mb if state.current_device == "cuda" else 0
|
| 283 |
+
}
|
| 284 |
+
for name, state in self._states.items()
|
| 285 |
+
},
|
| 286 |
+
"gpu_budget_mb": self._gpu_budget_mb,
|
| 287 |
+
"gpu_used_mb": self._get_used_vram(),
|
| 288 |
+
"gpu_free_mb": self._gpu_budget_mb - self._get_used_vram()
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
model_registry = ModelRegistry()
|
backend/orchestration/pipeline.py
CHANGED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import asyncio
|
| 3 |
+
import logging
|
| 4 |
+
import torch
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from datetime import datetime, UTC
|
| 7 |
+
|
| 8 |
+
from backend.ml.registry import ModelRegistry
|
| 9 |
+
from backend.core.exceptions import InvalidFileError, InferenceError, ModelNotLoadedError
|
| 10 |
+
from backend.api.v1.schemas.analysis import (
|
| 11 |
+
AnalysisResult, VisionResult, NLPResult, FusionResult, ProcessingTimings
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class AnalysisPipeline:
|
| 17 |
+
def __init__(self, registry: ModelRegistry):
|
| 18 |
+
self.registry = registry
|
| 19 |
+
|
| 20 |
+
async def run(self, session_id: str, image_path: Path, symptoms_text: str) -> AnalysisResult:
|
| 21 |
+
timings = {}
|
| 22 |
+
warnings = []
|
| 23 |
+
vision_result = None
|
| 24 |
+
nlp_result = None
|
| 25 |
+
fusion_result = None
|
| 26 |
+
report_text = None
|
| 27 |
+
|
| 28 |
+
# ── STEP 1: VALIDATE INPUT ────────────────────────────
|
| 29 |
+
t0 = time.monotonic()
|
| 30 |
+
try:
|
| 31 |
+
if not image_path.exists():
|
| 32 |
+
raise InvalidFileError("Image file not found")
|
| 33 |
+
processed_image_path = await asyncio.to_thread(self._preprocess_image, image_path)
|
| 34 |
+
except Exception as e:
|
| 35 |
+
raise InferenceError(f"Input validation failed: {e}")
|
| 36 |
+
timings["preprocess_ms"] = int((time.monotonic() - t0) * 1000)
|
| 37 |
+
|
| 38 |
+
# ── STEP 2: VISION ANALYSIS (GPU) ─────────────────────
|
| 39 |
+
t0 = time.monotonic()
|
| 40 |
+
try:
|
| 41 |
+
# We wrap this in resilience layers in resilience.py
|
| 42 |
+
vision_result = await self._run_vision(processed_image_path)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.warning(f"Vision analysis failed for session {session_id}: {e}")
|
| 45 |
+
warnings.append(f"Vision analysis unavailable: {type(e).__name__}")
|
| 46 |
+
timings["vision_ms"] = int((time.monotonic() - t0) * 1000)
|
| 47 |
+
|
| 48 |
+
# ── STEP 3: VRAM CLEANUP ───────────────────────────────
|
| 49 |
+
if vision_result is not None:
|
| 50 |
+
if torch.cuda.is_available():
|
| 51 |
+
await asyncio.to_thread(torch.cuda.empty_cache)
|
| 52 |
+
await asyncio.sleep(0.1)
|
| 53 |
+
|
| 54 |
+
# ── STEP 4: NLP ANALYSIS (GPU/CPU) ────────────────────
|
| 55 |
+
t0 = time.monotonic()
|
| 56 |
+
if symptoms_text.strip():
|
| 57 |
+
try:
|
| 58 |
+
nlp_result = await self._run_nlp(symptoms_text)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.warning(f"NLP analysis failed for session {session_id}: {e}")
|
| 61 |
+
warnings.append(f"NLP analysis unavailable: {type(e).__name__}")
|
| 62 |
+
else:
|
| 63 |
+
warnings.append("No symptoms text provided — NLP analysis skipped")
|
| 64 |
+
timings["nlp_ms"] = int((time.monotonic() - t0) * 1000)
|
| 65 |
+
|
| 66 |
+
# ── STEP 5: MULTIMODAL FUSION ─────────────────────────
|
| 67 |
+
t0 = time.monotonic()
|
| 68 |
+
if vision_result is not None and nlp_result is not None:
|
| 69 |
+
try:
|
| 70 |
+
fusion_result = await self._run_fusion(processed_image_path, symptoms_text)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.warning(f"Fusion failed for session {session_id}: {e}")
|
| 73 |
+
warnings.append(f"Multimodal fusion unavailable: {type(e).__name__}")
|
| 74 |
+
else:
|
| 75 |
+
warnings.append("Fusion skipped: requires both vision and NLP results")
|
| 76 |
+
timings["fusion_ms"] = int((time.monotonic() - t0) * 1000)
|
| 77 |
+
|
| 78 |
+
# ── STEP 6: REPORT GENERATION ─────────────────────────
|
| 79 |
+
t0 = time.monotonic()
|
| 80 |
+
try:
|
| 81 |
+
report_text = await self._generate_report(vision_result, nlp_result, fusion_result)
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.warning(f"Report generation failed: {e}")
|
| 84 |
+
report_text = self._fallback_report(vision_result, nlp_result)
|
| 85 |
+
warnings.append("Using template report — AI report generation unavailable")
|
| 86 |
+
timings["report_ms"] = int((time.monotonic() - t0) * 1000)
|
| 87 |
+
|
| 88 |
+
# ── STEP 7: DETERMINE OVERALL STATUS ──────────────────
|
| 89 |
+
if vision_result is None and nlp_result is None:
|
| 90 |
+
overall_status = "FAILED"
|
| 91 |
+
elif vision_result is None or nlp_result is None:
|
| 92 |
+
overall_status = "PARTIAL"
|
| 93 |
+
else:
|
| 94 |
+
overall_status = "COMPLETE"
|
| 95 |
+
|
| 96 |
+
timings["total_ms"] = sum(timings.values())
|
| 97 |
+
|
| 98 |
+
return AnalysisResult(
|
| 99 |
+
session_id=session_id, patient_id="", timestamp=datetime.now(UTC),
|
| 100 |
+
vision=vision_result, nlp=nlp_result, fusion=fusion_result,
|
| 101 |
+
report_text=report_text, overall_status=overall_status,
|
| 102 |
+
timings=ProcessingTimings(**timings), warnings=warnings
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def _preprocess_image(self, image_path: Path) -> Path:
|
| 106 |
+
from PIL import Image
|
| 107 |
+
with Image.open(image_path) as img:
|
| 108 |
+
img = img.convert("RGB")
|
| 109 |
+
img = img.resize((224, 224), Image.LANCZOS)
|
| 110 |
+
output_path = image_path.with_suffix(".processed.png")
|
| 111 |
+
img.save(output_path, "PNG")
|
| 112 |
+
return output_path
|
| 113 |
+
|
| 114 |
+
async def _run_vision(self, image_path: Path) -> VisionResult:
|
| 115 |
+
from backend.ml.vision.anomaly import AnomalyDetector
|
| 116 |
+
from backend.ml.vision.gradcam import GradCAM
|
| 117 |
+
|
| 118 |
+
state = await self.registry.get("dino_anomaly")
|
| 119 |
+
if not state.is_available:
|
| 120 |
+
raise ModelNotLoadedError("Vision model unavailable")
|
| 121 |
+
|
| 122 |
+
anomaly_score, model_confidence = await asyncio.to_thread(
|
| 123 |
+
AnomalyDetector.score, image_path, state.model, state.head, state.stats, state.current_device
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
heatmap_b64, top_regions = await asyncio.to_thread(
|
| 127 |
+
GradCAM.generate, image_path, state.model, anomaly_score
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
risk_level = "LOW" if anomaly_score < 40 else "MEDIUM" if anomaly_score < 70 else "HIGH"
|
| 131 |
+
return VisionResult(
|
| 132 |
+
anomaly_score=round(anomaly_score, 1), risk_level=risk_level,
|
| 133 |
+
heatmap_base64=heatmap_b64, top_regions=top_regions, model_confidence=model_confidence
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
async def _run_nlp(self, text: str) -> NLPResult:
|
| 137 |
+
from backend.ml.nlp.ner import NERExtractor
|
| 138 |
+
from backend.ml.nlp.classifier import DiseaseClassifier
|
| 139 |
+
|
| 140 |
+
ner_state = await self.registry.get("biobert_ner")
|
| 141 |
+
entities = await asyncio.to_thread(
|
| 142 |
+
NERExtractor.extract, text, ner_state.model, ner_state.tokenizer, ner_state.current_device
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
diagnosis = await asyncio.to_thread(DiseaseClassifier.classify, text, entities)
|
| 146 |
+
return NLPResult(
|
| 147 |
+
entities=entities, primary_diagnosis=diagnosis["primary"],
|
| 148 |
+
diagnosis_confidence=diagnosis["confidence"], differential=diagnosis["differential"]
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
async def _run_fusion(self, image_path: Path, text: str) -> FusionResult:
|
| 152 |
+
from backend.ml.fusion.medclip import MultimodalFusion
|
| 153 |
+
|
| 154 |
+
state = await self.registry.get("biomedvlp")
|
| 155 |
+
if not state.is_available:
|
| 156 |
+
raise ModelNotLoadedError("Fusion model unavailable")
|
| 157 |
+
|
| 158 |
+
similarity, alignment = await asyncio.to_thread(
|
| 159 |
+
MultimodalFusion.compute_similarity, image_path, text, state.model, state.tokenizer, state.current_device
|
| 160 |
+
)
|
| 161 |
+
final_risk = "HIGH" if similarity < 0.3 else "MEDIUM" if similarity < 0.7 else "LOW"
|
| 162 |
+
return FusionResult(image_text_similarity=round(similarity, 3), alignment=alignment, final_risk=final_risk)
|
| 163 |
+
|
| 164 |
+
async def _generate_report(self, vision: VisionResult | None, nlp: NLPResult | None, fusion: FusionResult | None) -> str:
|
| 165 |
+
from backend.ml.rag.generator import ReportGenerator
|
| 166 |
+
|
| 167 |
+
state = await self.registry.get("biogpt_base")
|
| 168 |
+
if not state.is_available:
|
| 169 |
+
raise ModelNotLoadedError("Report generation unavailable")
|
| 170 |
+
|
| 171 |
+
return await asyncio.to_thread(
|
| 172 |
+
ReportGenerator.generate, vision, nlp, fusion, state.model, state.tokenizer
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def _fallback_report(self, vision: VisionResult | None, nlp: NLPResult | None) -> str:
|
| 176 |
+
parts = ["## AI-Assisted Analysis Report\n\n*Note: This is an automated template report.*\n"]
|
| 177 |
+
if vision:
|
| 178 |
+
parts.append(f"**Imaging Findings:** Anomaly score of {vision.anomaly_score}/100 indicates {vision.risk_level.lower()} risk findings.")
|
| 179 |
+
if nlp:
|
| 180 |
+
diseases = ", ".join(nlp.entities.diseases) if nlp.entities.diseases else "none identified"
|
| 181 |
+
symptoms = ", ".join(nlp.entities.symptoms) if nlp.entities.symptoms else "none documented"
|
| 182 |
+
parts.append(f"**Clinical Impression:** {nlp.primary_diagnosis} (confidence: {nlp.diagnosis_confidence:.0%}). Identified conditions: {diseases}. Symptoms: {symptoms}.")
|
| 183 |
+
parts.append("\n**Recommendation:** Please consult a licensed physician for diagnosis and treatment.")
|
| 184 |
+
return "\n".join(parts)
|
backend/orchestration/queue.py
CHANGED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import uuid
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from datetime import datetime, UTC
|
| 6 |
+
from sqlalchemy import select, func
|
| 7 |
+
|
| 8 |
+
from backend.db.models import AnalysisTask, AnalysisSession
|
| 9 |
+
from backend.core.exceptions import TaskNotFoundError, SessionAccessDeniedError
|
| 10 |
+
from backend.core.logging_config import ml_logger
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class AnalysisTaskQueue:
|
| 15 |
+
MAX_CONCURRENT = 2
|
| 16 |
+
WORKER_SLEEP_SECONDS = 5
|
| 17 |
+
|
| 18 |
+
def __init__(self, db_session_factory, pipeline):
|
| 19 |
+
self._db_factory = db_session_factory
|
| 20 |
+
self._pipeline = pipeline
|
| 21 |
+
self._new_task_event = asyncio.Event()
|
| 22 |
+
self._active_count = 0
|
| 23 |
+
self._active_lock = asyncio.Lock()
|
| 24 |
+
self._worker_task: asyncio.Task | None = None
|
| 25 |
+
self._is_running = False
|
| 26 |
+
|
| 27 |
+
async def start(self):
|
| 28 |
+
self._is_running = True
|
| 29 |
+
self._worker_task = asyncio.create_task(self._worker_loop())
|
| 30 |
+
logger.info("Task queue worker started")
|
| 31 |
+
|
| 32 |
+
async def stop(self):
|
| 33 |
+
logger.info("Task queue stopping...")
|
| 34 |
+
self._is_running = False
|
| 35 |
+
self._new_task_event.set()
|
| 36 |
+
|
| 37 |
+
deadline = asyncio.get_event_loop().time() + 60
|
| 38 |
+
while self._active_count > 0:
|
| 39 |
+
if asyncio.get_event_loop().time() > deadline:
|
| 40 |
+
logger.warning(f"Shutdown timeout: {self._active_count} tasks still running")
|
| 41 |
+
break
|
| 42 |
+
await asyncio.sleep(1)
|
| 43 |
+
|
| 44 |
+
if self._worker_task:
|
| 45 |
+
self._worker_task.cancel()
|
| 46 |
+
|
| 47 |
+
async def submit(self, session_id: str, user_id: str | None, image_path: str, symptoms_text: str, priority: int = 1) -> str:
|
| 48 |
+
task_id = str(uuid.uuid4())
|
| 49 |
+
async with self._db_factory() as db:
|
| 50 |
+
task = AnalysisTask(
|
| 51 |
+
id=task_id, session_id=session_id, user_id=user_id,
|
| 52 |
+
status="PENDING", priority=priority,
|
| 53 |
+
image_path=image_path, symptoms_text=symptoms_text
|
| 54 |
+
)
|
| 55 |
+
db.add(task)
|
| 56 |
+
await db.commit()
|
| 57 |
+
|
| 58 |
+
self._new_task_event.set()
|
| 59 |
+
return task_id
|
| 60 |
+
|
| 61 |
+
async def get_status(self, task_id: str) -> dict:
|
| 62 |
+
async with self._db_factory() as db:
|
| 63 |
+
task = await db.get(AnalysisTask, task_id)
|
| 64 |
+
if not task:
|
| 65 |
+
raise TaskNotFoundError()
|
| 66 |
+
|
| 67 |
+
position = None
|
| 68 |
+
if task.status == "PENDING":
|
| 69 |
+
result = await db.execute(
|
| 70 |
+
select(func.count(AnalysisTask.id)).where(
|
| 71 |
+
AnalysisTask.status == "PENDING",
|
| 72 |
+
AnalysisTask.priority >= task.priority,
|
| 73 |
+
AnalysisTask.created_at < task.created_at
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
position = result.scalar_one() + 1
|
| 77 |
+
|
| 78 |
+
estimated_wait = None
|
| 79 |
+
if position:
|
| 80 |
+
slots_until_ours = max(0, position - self.MAX_CONCURRENT)
|
| 81 |
+
estimated_wait = slots_until_ours * 45
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
"task_id": task_id, "session_id": task.session_id, "status": task.status,
|
| 85 |
+
"position_in_queue": position, "estimated_wait_seconds": estimated_wait,
|
| 86 |
+
"started_at": task.started_at, "completed_at": task.completed_at,
|
| 87 |
+
"error_message": task.error_message
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
async def cancel(self, task_id: str, user_id: str) -> bool:
|
| 91 |
+
async with self._db_factory() as db:
|
| 92 |
+
task = await db.get(AnalysisTask, task_id)
|
| 93 |
+
if not task:
|
| 94 |
+
raise TaskNotFoundError()
|
| 95 |
+
if task.user_id != user_id:
|
| 96 |
+
raise SessionAccessDeniedError()
|
| 97 |
+
if task.status != "PENDING":
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
task.status = "CANCELLED"
|
| 101 |
+
image_path = Path(task.image_path)
|
| 102 |
+
if image_path.exists():
|
| 103 |
+
image_path.unlink()
|
| 104 |
+
|
| 105 |
+
await db.commit()
|
| 106 |
+
return True
|
| 107 |
+
|
| 108 |
+
async def _worker_loop(self):
|
| 109 |
+
logger.info("Worker loop started")
|
| 110 |
+
while self._is_running:
|
| 111 |
+
try:
|
| 112 |
+
await asyncio.wait_for(
|
| 113 |
+
asyncio.shield(self._new_task_event.wait()), timeout=self.WORKER_SLEEP_SECONDS
|
| 114 |
+
)
|
| 115 |
+
self._new_task_event.clear()
|
| 116 |
+
except asyncio.TimeoutError:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
if not self._is_running:
|
| 120 |
+
break
|
| 121 |
+
|
| 122 |
+
while self._active_count < self.MAX_CONCURRENT:
|
| 123 |
+
task = await self._fetch_next_task()
|
| 124 |
+
if not task:
|
| 125 |
+
break
|
| 126 |
+
asyncio.create_task(self._process_task(task))
|
| 127 |
+
|
| 128 |
+
logger.info("Worker loop stopped")
|
| 129 |
+
|
| 130 |
+
async def _fetch_next_task(self) -> AnalysisTask | None:
|
| 131 |
+
async with self._db_factory() as db:
|
| 132 |
+
result = await db.execute(
|
| 133 |
+
select(AnalysisTask).where(AnalysisTask.status == "PENDING")
|
| 134 |
+
.order_by(AnalysisTask.priority.desc(), AnalysisTask.created_at.asc())
|
| 135 |
+
.limit(1)
|
| 136 |
+
.with_for_update(skip_locked=True)
|
| 137 |
+
)
|
| 138 |
+
return result.scalar_one_or_none()
|
| 139 |
+
|
| 140 |
+
async def _process_task(self, task: AnalysisTask):
|
| 141 |
+
async with self._active_lock:
|
| 142 |
+
self._active_count += 1
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
async with self._db_factory() as db:
|
| 146 |
+
task.status = "PROCESSING"
|
| 147 |
+
task.started_at = datetime.now(UTC)
|
| 148 |
+
await db.merge(task)
|
| 149 |
+
await db.commit()
|
| 150 |
+
|
| 151 |
+
result = await self._pipeline.run(
|
| 152 |
+
session_id=task.session_id,
|
| 153 |
+
image_path=Path(task.image_path),
|
| 154 |
+
symptoms_text=task.symptoms_text or ""
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
async with self._db_factory() as db:
|
| 158 |
+
db_task = await db.get(AnalysisTask, task.id)
|
| 159 |
+
db_task.status = "COMPLETED"
|
| 160 |
+
db_task.completed_at = datetime.now(UTC)
|
| 161 |
+
|
| 162 |
+
session = await db.get(AnalysisSession, task.session_id)
|
| 163 |
+
session.result_json = result.model_dump()
|
| 164 |
+
session.status = "READY"
|
| 165 |
+
session.risk_level = result.vision.risk_level if result.vision else "UNKNOWN"
|
| 166 |
+
|
| 167 |
+
await db.commit()
|
| 168 |
+
|
| 169 |
+
ml_logger.log_pipeline_step(
|
| 170 |
+
"full_pipeline", "COMPLETED",
|
| 171 |
+
int((datetime.now(UTC) - task.started_at).total_seconds() * 1000), task.session_id
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
except Exception as e:
|
| 175 |
+
logger.error(f"Task {task.id} failed: {e}", exc_info=True)
|
| 176 |
+
async with self._db_factory() as db:
|
| 177 |
+
db_task = await db.get(AnalysisTask, task.id)
|
| 178 |
+
db_task.attempt_count = (db_task.attempt_count or 0) + 1
|
| 179 |
+
|
| 180 |
+
if db_task.attempt_count < 3:
|
| 181 |
+
db_task.status = "PENDING"
|
| 182 |
+
self._new_task_event.set()
|
| 183 |
+
else:
|
| 184 |
+
db_task.status = "FAILED"
|
| 185 |
+
db_task.error_message = str(e)[:500]
|
| 186 |
+
session = await db.get(AnalysisSession, task.session_id)
|
| 187 |
+
session.status = "FAILED"
|
| 188 |
+
session.error_message = "Analysis failed after 3 attempts"
|
| 189 |
+
await db.commit()
|
| 190 |
+
|
| 191 |
+
finally:
|
| 192 |
+
try:
|
| 193 |
+
Path(task.image_path).unlink(missing_ok=True)
|
| 194 |
+
except Exception:
|
| 195 |
+
pass
|
| 196 |
+
async with self._active_lock:
|
| 197 |
+
self._active_count -= 1
|
| 198 |
+
self._new_task_event.set()
|
frontend/app/(auth)/callback/page.jsx
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client'
|
| 2 |
+
import { useEffect, useState } from 'react'
|
| 3 |
+
import { useRouter, useSearchParams } from 'next/navigation'
|
| 4 |
+
import { useAuth } from '@/lib/auth/AuthContext'
|
| 5 |
+
import Link from 'next/link'
|
| 6 |
+
|
| 7 |
+
export default function AuthCallbackPage() {
|
| 8 |
+
const searchParams = useSearchParams()
|
| 9 |
+
const router = useRouter()
|
| 10 |
+
const { loginWithToken } = useAuth()
|
| 11 |
+
const [errorMsg, setErrorMsg] = useState(null)
|
| 12 |
+
|
| 13 |
+
useEffect(() => {
|
| 14 |
+
const processToken = async () => {
|
| 15 |
+
const token = searchParams.get('token')
|
| 16 |
+
const error = searchParams.get('error')
|
| 17 |
+
|
| 18 |
+
if (error) {
|
| 19 |
+
setErrorMsg(`Authentication failed: ${error}`)
|
| 20 |
+
return
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
if (token) {
|
| 24 |
+
try {
|
| 25 |
+
await loginWithToken(token)
|
| 26 |
+
// Clean URL and redirect
|
| 27 |
+
router.replace('/upload')
|
| 28 |
+
} catch (err) {
|
| 29 |
+
setErrorMsg("Failed to establish secure session.")
|
| 30 |
+
}
|
| 31 |
+
} else {
|
| 32 |
+
setErrorMsg("No authentication token provided.")
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
processToken()
|
| 37 |
+
}, [searchParams, loginWithToken, router])
|
| 38 |
+
|
| 39 |
+
if (errorMsg) {
|
| 40 |
+
return (
|
| 41 |
+
<div className="flex flex-col items-center justify-center min-h-screen p-4 text-center">
|
| 42 |
+
<div className="p-6 bg-red-50 border border-red-200 rounded-lg shadow-sm">
|
| 43 |
+
<h2 className="text-lg font-semibold text-red-700 mb-2">Error</h2>
|
| 44 |
+
<p className="text-red-600 mb-4">{errorMsg}</p>
|
| 45 |
+
<Link href="/login" className="px-4 py-2 text-white bg-blue-600 rounded hover:bg-blue-700 transition">
|
| 46 |
+
Return to Login
|
| 47 |
+
</Link>
|
| 48 |
+
</div>
|
| 49 |
+
</div>
|
| 50 |
+
)
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
return (
|
| 54 |
+
<div className="flex flex-col items-center justify-center min-h-screen">
|
| 55 |
+
<div className="w-8 h-8 border-4 border-blue-600 border-t-transparent rounded-full animate-spin mb-4" />
|
| 56 |
+
<p className="text-gray-600 font-medium">Securing session...</p>
|
| 57 |
+
</div>
|
| 58 |
+
)
|
| 59 |
+
}
|
frontend/components/auth/ProtectedRoute.jsx
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client'
|
| 2 |
+
import { useEffect } from 'react'
|
| 3 |
+
import { useRouter, usePathname } from 'next/navigation'
|
| 4 |
+
import { useAuth } from '@/lib/auth/AuthContext'
|
| 5 |
+
|
| 6 |
+
export default function ProtectedRoute({ children }) {
|
| 7 |
+
const { isAuthenticated, isLoading } = useAuth()
|
| 8 |
+
const router = useRouter()
|
| 9 |
+
const pathname = usePathname()
|
| 10 |
+
|
| 11 |
+
useEffect(() => {
|
| 12 |
+
if (!isLoading && !isAuthenticated) {
|
| 13 |
+
sessionStorage.setItem('intendedPath', pathname)
|
| 14 |
+
router.push('/login')
|
| 15 |
+
}
|
| 16 |
+
}, [isLoading, isAuthenticated, router, pathname])
|
| 17 |
+
|
| 18 |
+
if (isLoading) {
|
| 19 |
+
return <div className="flex items-center justify-center min-h-screen">Loading secure environment...</div>
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
return isAuthenticated ? children : null
|
| 23 |
+
}
|
frontend/components/auth/PublicOnlyRoute.jsx
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client'
|
| 2 |
+
import { useEffect } from 'react'
|
| 3 |
+
import { useRouter } from 'next/navigation'
|
| 4 |
+
import { useAuth } from '@/lib/auth/AuthContext'
|
| 5 |
+
|
| 6 |
+
export default function PublicOnlyRoute({ children }) {
|
| 7 |
+
const { isAuthenticated, isLoading } = useAuth()
|
| 8 |
+
const router = useRouter()
|
| 9 |
+
|
| 10 |
+
useEffect(() => {
|
| 11 |
+
if (!isLoading && isAuthenticated) {
|
| 12 |
+
router.push('/upload')
|
| 13 |
+
}
|
| 14 |
+
}, [isLoading, isAuthenticated, router])
|
| 15 |
+
|
| 16 |
+
if (isLoading) return null
|
| 17 |
+
|
| 18 |
+
return !isAuthenticated ? children : null
|
| 19 |
+
}
|
frontend/next.config.js
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/** @type {import('next').NextConfig} */
|
| 2 |
+
|
| 3 |
+
const securityHeaders = [
|
| 4 |
+
{
|
| 5 |
+
key: "Content-Security-Policy",
|
| 6 |
+
value: [
|
| 7 |
+
"default-src 'self'",
|
| 8 |
+
"script-src 'self' 'unsafe-eval' 'unsafe-inline'",
|
| 9 |
+
"style-src 'self' 'unsafe-inline' https://fonts.googleapis.com",
|
| 10 |
+
"font-src 'self' https://fonts.gstatic.com",
|
| 11 |
+
"img-src 'self' data: blob: https:",
|
| 12 |
+
`connect-src 'self' ${process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8000'}`,
|
| 13 |
+
"frame-src 'none'",
|
| 14 |
+
"object-src 'none'",
|
| 15 |
+
"base-uri 'self'",
|
| 16 |
+
"form-action 'self'",
|
| 17 |
+
].join("; ")
|
| 18 |
+
},
|
| 19 |
+
{ key: "X-DNS-Prefetch-Control", value: "on" },
|
| 20 |
+
{ key: "X-Frame-Options", value: "SAMEORIGIN" },
|
| 21 |
+
{ key: "X-Content-Type-Options", value: "nosniff" },
|
| 22 |
+
{ key: "Referrer-Policy", value: "strict-origin-when-cross-origin" },
|
| 23 |
+
];
|
| 24 |
+
|
| 25 |
+
const nextConfig = {
|
| 26 |
+
output: "standalone",
|
| 27 |
+
experimental: {
|
| 28 |
+
serverActions: true,
|
| 29 |
+
},
|
| 30 |
+
images: {
|
| 31 |
+
domains: [(process.env.NEXT_PUBLIC_API_URL || 'localhost').replace(/^https?:\/\//, '').split(':')[0]],
|
| 32 |
+
},
|
| 33 |
+
async headers() {
|
| 34 |
+
return [
|
| 35 |
+
{
|
| 36 |
+
source: "/(.*)",
|
| 37 |
+
headers: securityHeaders,
|
| 38 |
+
},
|
| 39 |
+
];
|
| 40 |
+
},
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
module.exports = nextConfig;
|
scripts/security_audit.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import httpx
|
| 3 |
+
import asyncio
|
| 4 |
+
|
| 5 |
+
API_URL = "http://localhost:8000"
|
| 6 |
+
|
| 7 |
+
async def run_audit():
|
| 8 |
+
print("\n--- MedSight AI Security Audit ---\n")
|
| 9 |
+
passed = 0
|
| 10 |
+
total = 0
|
| 11 |
+
|
| 12 |
+
def check(name, condition, error_msg):
|
| 13 |
+
nonlocal passed, total
|
| 14 |
+
total += 1
|
| 15 |
+
if condition:
|
| 16 |
+
print(f"✅ PASS: {name}")
|
| 17 |
+
passed += 1
|
| 18 |
+
else:
|
| 19 |
+
print(f"❌ FAIL: {name} - {error_msg}")
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
async with httpx.AsyncClient() as client:
|
| 23 |
+
# 1. Test Health / Headers
|
| 24 |
+
res = await client.get(f"{API_URL}/api/v1/health")
|
| 25 |
+
headers = res.headers
|
| 26 |
+
check("X-Content-Type-Options", headers.get("x-content-type-options") == "nosniff", "Header missing or incorrect")
|
| 27 |
+
check("X-Frame-Options", headers.get("x-frame-options") == "DENY", "Header missing or incorrect")
|
| 28 |
+
check("Permissions-Policy", "microphone=(self)" in headers.get("permissions-policy", ""), "Microphone permission misconfigured")
|
| 29 |
+
check("No Server Header", "server" not in headers, "Server header is leaking framework info")
|
| 30 |
+
|
| 31 |
+
# 2. Test CORS Rejection
|
| 32 |
+
cors_res = await client.options(
|
| 33 |
+
f"{API_URL}/api/v1/health",
|
| 34 |
+
headers={"Origin": "http://evil-domain.com", "Access-Control-Request-Method": "GET"}
|
| 35 |
+
)
|
| 36 |
+
# Depending on FastAPI config, it might strip CORS headers or return 400.
|
| 37 |
+
# The key is it shouldn't return Access-Control-Allow-Origin: http://evil-domain.com
|
| 38 |
+
check("CORS Restrictions", cors_res.headers.get("access-control-allow-origin") != "http://evil-domain.com", "CORS allowed untrusted origin")
|
| 39 |
+
|
| 40 |
+
except httpx.ConnectError:
|
| 41 |
+
print("❌ Server not running. Start it with: make run-dev")
|
| 42 |
+
sys.exit(1)
|
| 43 |
+
|
| 44 |
+
print(f"\nAudit Complete: {passed}/{total} Passed\n")
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
asyncio.run(run_audit())
|