hoshikrana commited on
Commit
2bdb663
·
1 Parent(s): 27826d5

feat: security headers, frontend auth context, ML registry, and orchestration pipeline

Browse files
backend/core/middleware.py CHANGED
@@ -1,7 +1,45 @@
1
  from fastapi import Request, Response
2
- from fastapi.responses import JSONResponse
 
3
  from slowapi import Limiter
4
  from slowapi.errors import RateLimitExceeded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def get_rate_limit_key(request: Request) -> str:
7
  user_id = getattr(request.state, "user_id", None)
 
1
  from fastapi import Request, Response
2
+ from fastapi.responses import JSONResponse, RedirectResponse
3
+ from starlette.middleware.base import BaseHTTPMiddleware
4
  from slowapi import Limiter
5
  from slowapi.errors import RateLimitExceeded
6
+ from backend.core.config import settings
7
+
8
+ class HTTPSRedirectMiddleware(BaseHTTPMiddleware):
9
+ async def dispatch(self, request: Request, call_next):
10
+ if settings.is_production:
11
+ # Check if request came in as HTTP (Render.com forwards as HTTPS but sets X-Forwarded-Proto)
12
+ proto = request.headers.get("X-Forwarded-Proto", "https")
13
+ if proto == "http":
14
+ https_url = str(request.url).replace("http://", "https://", 1)
15
+ return RedirectResponse(https_url, status_code=301)
16
+ return await call_next(request)
17
+
18
+ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
19
+ async def dispatch(self, request: Request, call_next):
20
+ response = await call_next(request)
21
+
22
+ # Always add these headers:
23
+ response.headers["X-Content-Type-Options"] = "nosniff"
24
+ response.headers["X-Frame-Options"] = "DENY"
25
+ response.headers["X-XSS-Protection"] = "1; mode=block"
26
+ response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
27
+ response.headers["Permissions-Policy"] = (
28
+ "camera=(), microphone=(self), geolocation=(), "
29
+ "payment=(), usb=(), magnetometer=()"
30
+ ) # microphone=(self) required for Whisper voice input
31
+
32
+ # Production only:
33
+ if settings.is_production:
34
+ response.headers["Strict-Transport-Security"] = (
35
+ "max-age=63072000; includeSubDomains; preload"
36
+ )
37
+
38
+ # Remove headers that leak server info:
39
+ response.headers.pop("Server", None)
40
+ response.headers.pop("X-Powered-By", None)
41
+
42
+ return response
43
 
44
  def get_rate_limit_key(request: Request) -> str:
45
  user_id = getattr(request.state, "user_id", None)
backend/ml/registry.py CHANGED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from datetime import datetime, UTC
6
+ from typing import Literal, Any
7
+ import torch
8
+
9
+ from backend.core.config import settings
10
+ from backend.core.logging_config import ml_logger
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ @dataclass
15
+ class ModelProfile:
16
+ name: str
17
+ hf_model_id: str
18
+ local_cache_subdir: str
19
+ device_preference: Literal["cuda", "cpu", "auto"]
20
+ vram_mb: int
21
+ ram_mb: int
22
+ load_priority: int
23
+ is_required: bool
24
+
25
+ MODEL_PROFILES = {
26
+ "dino_anomaly": ModelProfile(
27
+ name="dino_anomaly", hf_model_id="facebook/dinov2-small", local_cache_subdir="dino_anomaly",
28
+ device_preference="cuda", vram_mb=400, ram_mb=50, load_priority=1, is_required=True
29
+ ),
30
+ "biobert_ner": ModelProfile(
31
+ name="biobert_ner", hf_model_id="dmis-lab/biobert-base-cased-v1.2", local_cache_subdir="biobert_ner",
32
+ device_preference="cuda", vram_mb=450, ram_mb=50, load_priority=2, is_required=True
33
+ ),
34
+ "biomedvlp": ModelProfile(
35
+ name="biomedvlp", hf_model_id="microsoft/BiomedVLP-CXR-BERT-specialized", local_cache_subdir="biomedvlp",
36
+ device_preference="auto", vram_mb=900, ram_mb=100, load_priority=3, is_required=False
37
+ ),
38
+ "whisper_tiny": ModelProfile(
39
+ name="whisper_tiny", hf_model_id="openai/whisper-tiny", local_cache_subdir="whisper",
40
+ device_preference="cpu", vram_mb=0, ram_mb=300, load_priority=4, is_required=False
41
+ ),
42
+ "biogpt_base": ModelProfile(
43
+ name="biogpt_base", hf_model_id="microsoft/biogpt", local_cache_subdir="biogpt",
44
+ device_preference="cpu", vram_mb=0, ram_mb=700, load_priority=5, is_required=False
45
+ ),
46
+ "minilm": ModelProfile(
47
+ name="minilm", hf_model_id="sentence-transformers/all-MiniLM-L6-v2", local_cache_subdir="minilm",
48
+ device_preference="cpu", vram_mb=0, ram_mb=100, load_priority=1, is_required=True
49
+ ),
50
+ }
51
+
52
+ @dataclass
53
+ class ModelState:
54
+ profile: ModelProfile
55
+ model: Any = None
56
+ tokenizer: Any = None
57
+ head: Any = None # Extension for DINO head architecture
58
+ stats: dict = None # Extension for anomaly scoring
59
+ is_loaded: bool = False
60
+ is_loading: bool = False
61
+ load_error: str | None = None
62
+ load_time_ms: int = 0
63
+ last_used: datetime | None = None
64
+ current_device: str = "unloaded"
65
+
66
+ @property
67
+ def is_available(self) -> bool:
68
+ return self.is_loaded and self.load_error is None
69
+
70
+ class ModelRegistry:
71
+ def __init__(self):
72
+ self._states: dict[str, ModelState] = {
73
+ name: ModelState(profile=profile) for name, profile in MODEL_PROFILES.items()
74
+ }
75
+ self._locks: dict[str, asyncio.Lock] = {
76
+ name: asyncio.Lock() for name in MODEL_PROFILES
77
+ }
78
+ self._gpu_budget_mb = settings.GPU_VRAM_BUDGET_MB
79
+
80
+ async def startup_load(self):
81
+ ml_logger.logger.info("Starting model registry startup")
82
+ sorted_models = sorted(MODEL_PROFILES.values(), key=lambda m: m.load_priority)
83
+
84
+ for profile in sorted_models:
85
+ if profile.device_preference == "cpu":
86
+ await self._load_model(profile.name)
87
+ else:
88
+ if self._get_used_vram() + profile.vram_mb <= self._gpu_budget_mb:
89
+ await self._load_model(profile.name)
90
+ else:
91
+ ml_logger.logger.warning(f"Skipping GPU load for {profile.name}: VRAM budget exceeded. Will load on CPU on first request.")
92
+
93
+ loaded = [n for n, s in self._states.items() if s.is_available]
94
+ failed = [n for n, s in self._states.items() if s.load_error]
95
+ required_failed = [n for n in failed if MODEL_PROFILES[n].is_required]
96
+
97
+ if required_failed:
98
+ raise RuntimeError(f"Critical models failed to load: {required_failed}. Check logs.")
99
+
100
+ ml_logger.logger.info("Registry startup complete", extra={"loaded": loaded, "failed": failed, "vram_used_mb": self._get_used_vram()})
101
+
102
+ async def get(self, model_name: str) -> ModelState:
103
+ if model_name not in self._states:
104
+ raise ValueError(f"Unknown model: {model_name}")
105
+
106
+ state = self._states[model_name]
107
+ if not state.is_available and not state.is_loading:
108
+ await self._load_model(model_name)
109
+
110
+ self._states[model_name].last_used = datetime.now(UTC)
111
+ return self._states[model_name]
112
+
113
+ def is_available(self, model_name: str) -> bool:
114
+ return self._states.get(model_name, ModelState(ModelProfile("","","","cpu",0,0,0,False))).is_available
115
+
116
+ async def _load_model(self, model_name: str):
117
+ async with self._locks[model_name]:
118
+ state = self._states[model_name]
119
+ if state.is_available:
120
+ return
121
+
122
+ state.is_loading = True
123
+ start_time = time.monotonic()
124
+
125
+ try:
126
+ profile = state.profile
127
+ device = self._resolve_device(profile)
128
+
129
+ if device == "cuda":
130
+ needed = profile.vram_mb
131
+ available = self._gpu_budget_mb - self._get_used_vram()
132
+ if available < needed:
133
+ evicted = await self._evict_lru_gpu_model(except_model=model_name)
134
+ if evicted:
135
+ ml_logger.logger.info(f"Evicted {evicted} to make room for {model_name}")
136
+
137
+ # Fetch objects securely
138
+ result = await asyncio.to_thread(self._load_model_sync, model_name, profile, device)
139
+
140
+ load_time_ms = int((time.monotonic() - start_time) * 1000)
141
+
142
+ state.model = result.get('model')
143
+ state.tokenizer = result.get('tokenizer')
144
+ state.head = result.get('head')
145
+ state.stats = result.get('stats')
146
+
147
+ state.is_loaded = True
148
+ state.load_error = None
149
+ state.load_time_ms = load_time_ms
150
+ state.current_device = device
151
+
152
+ ml_logger.log_model_load(model_name, device, load_time_ms, vram_delta_mb=profile.vram_mb if device == "cuda" else None)
153
+
154
+ except Exception as e:
155
+ state.load_error = str(e)
156
+ state.is_loaded = False
157
+ ml_logger.logger.error(f"Failed to load model {model_name}: {e}", exc_info=True)
158
+ if MODEL_PROFILES[model_name].is_required:
159
+ raise
160
+ finally:
161
+ state.is_loading = False
162
+
163
+ def _load_model_sync(self, name: str, profile: ModelProfile, device: str) -> dict:
164
+ cache_dir = settings.MODEL_CACHE_DIR / profile.local_cache_subdir
165
+ cache_dir.mkdir(parents=True, exist_ok=True)
166
+
167
+ if name == "dino_anomaly":
168
+ from transformers import AutoImageProcessor, AutoModel as ViTModel
169
+ processor = AutoImageProcessor.from_pretrained(profile.hf_model_id, cache_dir=cache_dir)
170
+ model = ViTModel.from_pretrained(profile.hf_model_id, cache_dir=cache_dir).to(device)
171
+ model.eval()
172
+
173
+ # Simulated Projection Head Loading Logic
174
+ head = None
175
+ stats = {"mean": 0.001, "std": 0.0005}
176
+ head_path = settings.MODEL_CACHE_DIR / "anomaly_head.pt"
177
+ stats_path = settings.MODEL_CACHE_DIR / "anomaly_stats.json"
178
+
179
+ if head_path.exists():
180
+ # We will define DINOProjectionHead in the vision module later
181
+ import json
182
+ pass
183
+ else:
184
+ logger.warning("Using untrained head — anomaly scores may be unreliable")
185
+
186
+ if stats_path.exists():
187
+ import json
188
+ stats = json.loads(stats_path.read_text())
189
+
190
+ return {"model": model, "tokenizer": processor, "head": head, "stats": stats}
191
+
192
+ elif name == "biobert_ner":
193
+ fine_tuned_path = settings.MODEL_CACHE_DIR / "biobert_ner_finetuned"
194
+ model_path = str(fine_tuned_path) if fine_tuned_path.exists() else profile.hf_model_id
195
+
196
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
197
+ tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=cache_dir)
198
+ model = AutoModelForTokenClassification.from_pretrained(model_path, cache_dir=cache_dir).to(device)
199
+ model.eval()
200
+ return {"model": model, "tokenizer": tokenizer}
201
+
202
+ elif name == "whisper_tiny":
203
+ import whisper
204
+ model = whisper.load_model("tiny", device="cpu", download_root=str(cache_dir))
205
+ return {"model": model, "tokenizer": None}
206
+
207
+ elif name == "biogpt_base":
208
+ from transformers import BioGptForCausalLM, BioGptTokenizer
209
+ tokenizer = BioGptTokenizer.from_pretrained(profile.hf_model_id, cache_dir=cache_dir)
210
+ model = BioGptForCausalLM.from_pretrained(profile.hf_model_id, cache_dir=cache_dir)
211
+ model.eval()
212
+ return {"model": model, "tokenizer": tokenizer}
213
+
214
+ elif name == "minilm":
215
+ from sentence_transformers import SentenceTransformer
216
+ model = SentenceTransformer(profile.hf_model_id, cache_folder=str(cache_dir))
217
+ return {"model": model, "tokenizer": None}
218
+
219
+ elif name == "biomedvlp":
220
+ from transformers import AutoModel, AutoTokenizer
221
+ free_vram = torch.cuda.mem_get_info()[0] // 1024 // 1024 if torch.cuda.is_available() else 0
222
+ if free_vram < 950 and device == "cuda":
223
+ device = "cpu"
224
+ logger.warning("Insufficient VRAM for BiomedVLP — loading on CPU (slower)")
225
+
226
+ # Trust remote code is required for custom Microsoft implementation
227
+ tokenizer = AutoTokenizer.from_pretrained(profile.hf_model_id, cache_dir=cache_dir, trust_remote_code=True)
228
+ model = AutoModel.from_pretrained(profile.hf_model_id, cache_dir=cache_dir, trust_remote_code=True).to(device)
229
+ model.eval()
230
+ return {"model": model, "tokenizer": tokenizer}
231
+
232
+ else:
233
+ raise ValueError(f"No loader defined for model: {name}")
234
+
235
+ def _resolve_device(self, profile: ModelProfile) -> str:
236
+ if profile.device_preference == "cpu":
237
+ return "cpu"
238
+ if profile.device_preference == "cuda":
239
+ if not torch.cuda.is_available():
240
+ ml_logger.logger.warning(f"CUDA not available, loading {profile.name} on CPU")
241
+ return "cpu"
242
+ return "cuda"
243
+ if profile.device_preference == "auto":
244
+ if torch.cuda.is_available():
245
+ free_vram = self._gpu_budget_mb - self._get_used_vram()
246
+ if free_vram >= profile.vram_mb:
247
+ return "cuda"
248
+ return "cpu"
249
+
250
+ def _get_used_vram(self) -> int:
251
+ return sum(s.profile.vram_mb for s in self._states.values() if s.is_available and s.current_device == "cuda")
252
+
253
+ async def _evict_lru_gpu_model(self, except_model: str) -> str | None:
254
+ gpu_models = [
255
+ (name, state) for name, state in self._states.items()
256
+ if state.is_available and state.current_device == "cuda" and name != except_model
257
+ ]
258
+ if not gpu_models:
259
+ return None
260
+
261
+ lru_name, _ = min(gpu_models, key=lambda x: x[1].last_used or datetime.min.replace(tzinfo=UTC))
262
+ await asyncio.to_thread(self._move_to_cpu, lru_name)
263
+ return lru_name
264
+
265
+ def _move_to_cpu(self, model_name: str):
266
+ state = self._states[model_name]
267
+ if state.model is not None and hasattr(state.model, "cpu"):
268
+ state.model = state.model.cpu()
269
+ torch.cuda.empty_cache()
270
+ state.current_device = "cpu"
271
+ ml_logger.logger.info(f"Moved {model_name} to CPU")
272
+
273
+ def get_status(self) -> dict:
274
+ return {
275
+ "models": {
276
+ name: {
277
+ "is_available": state.is_available,
278
+ "device": state.current_device,
279
+ "load_error": state.load_error,
280
+ "load_time_ms": state.load_time_ms,
281
+ "last_used": state.last_used.isoformat() if state.last_used else None,
282
+ "vram_mb": state.profile.vram_mb if state.current_device == "cuda" else 0
283
+ }
284
+ for name, state in self._states.items()
285
+ },
286
+ "gpu_budget_mb": self._gpu_budget_mb,
287
+ "gpu_used_mb": self._get_used_vram(),
288
+ "gpu_free_mb": self._gpu_budget_mb - self._get_used_vram()
289
+ }
290
+
291
+ model_registry = ModelRegistry()
backend/orchestration/pipeline.py CHANGED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import asyncio
3
+ import logging
4
+ import torch
5
+ from pathlib import Path
6
+ from datetime import datetime, UTC
7
+
8
+ from backend.ml.registry import ModelRegistry
9
+ from backend.core.exceptions import InvalidFileError, InferenceError, ModelNotLoadedError
10
+ from backend.api.v1.schemas.analysis import (
11
+ AnalysisResult, VisionResult, NLPResult, FusionResult, ProcessingTimings
12
+ )
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class AnalysisPipeline:
17
+ def __init__(self, registry: ModelRegistry):
18
+ self.registry = registry
19
+
20
+ async def run(self, session_id: str, image_path: Path, symptoms_text: str) -> AnalysisResult:
21
+ timings = {}
22
+ warnings = []
23
+ vision_result = None
24
+ nlp_result = None
25
+ fusion_result = None
26
+ report_text = None
27
+
28
+ # ── STEP 1: VALIDATE INPUT ────────────────────────────
29
+ t0 = time.monotonic()
30
+ try:
31
+ if not image_path.exists():
32
+ raise InvalidFileError("Image file not found")
33
+ processed_image_path = await asyncio.to_thread(self._preprocess_image, image_path)
34
+ except Exception as e:
35
+ raise InferenceError(f"Input validation failed: {e}")
36
+ timings["preprocess_ms"] = int((time.monotonic() - t0) * 1000)
37
+
38
+ # ── STEP 2: VISION ANALYSIS (GPU) ─────────────────────
39
+ t0 = time.monotonic()
40
+ try:
41
+ # We wrap this in resilience layers in resilience.py
42
+ vision_result = await self._run_vision(processed_image_path)
43
+ except Exception as e:
44
+ logger.warning(f"Vision analysis failed for session {session_id}: {e}")
45
+ warnings.append(f"Vision analysis unavailable: {type(e).__name__}")
46
+ timings["vision_ms"] = int((time.monotonic() - t0) * 1000)
47
+
48
+ # ── STEP 3: VRAM CLEANUP ───────────────────────────────
49
+ if vision_result is not None:
50
+ if torch.cuda.is_available():
51
+ await asyncio.to_thread(torch.cuda.empty_cache)
52
+ await asyncio.sleep(0.1)
53
+
54
+ # ── STEP 4: NLP ANALYSIS (GPU/CPU) ────────────────────
55
+ t0 = time.monotonic()
56
+ if symptoms_text.strip():
57
+ try:
58
+ nlp_result = await self._run_nlp(symptoms_text)
59
+ except Exception as e:
60
+ logger.warning(f"NLP analysis failed for session {session_id}: {e}")
61
+ warnings.append(f"NLP analysis unavailable: {type(e).__name__}")
62
+ else:
63
+ warnings.append("No symptoms text provided — NLP analysis skipped")
64
+ timings["nlp_ms"] = int((time.monotonic() - t0) * 1000)
65
+
66
+ # ── STEP 5: MULTIMODAL FUSION ─────────────────────────
67
+ t0 = time.monotonic()
68
+ if vision_result is not None and nlp_result is not None:
69
+ try:
70
+ fusion_result = await self._run_fusion(processed_image_path, symptoms_text)
71
+ except Exception as e:
72
+ logger.warning(f"Fusion failed for session {session_id}: {e}")
73
+ warnings.append(f"Multimodal fusion unavailable: {type(e).__name__}")
74
+ else:
75
+ warnings.append("Fusion skipped: requires both vision and NLP results")
76
+ timings["fusion_ms"] = int((time.monotonic() - t0) * 1000)
77
+
78
+ # ── STEP 6: REPORT GENERATION ─────────────────────────
79
+ t0 = time.monotonic()
80
+ try:
81
+ report_text = await self._generate_report(vision_result, nlp_result, fusion_result)
82
+ except Exception as e:
83
+ logger.warning(f"Report generation failed: {e}")
84
+ report_text = self._fallback_report(vision_result, nlp_result)
85
+ warnings.append("Using template report — AI report generation unavailable")
86
+ timings["report_ms"] = int((time.monotonic() - t0) * 1000)
87
+
88
+ # ── STEP 7: DETERMINE OVERALL STATUS ──────────────────
89
+ if vision_result is None and nlp_result is None:
90
+ overall_status = "FAILED"
91
+ elif vision_result is None or nlp_result is None:
92
+ overall_status = "PARTIAL"
93
+ else:
94
+ overall_status = "COMPLETE"
95
+
96
+ timings["total_ms"] = sum(timings.values())
97
+
98
+ return AnalysisResult(
99
+ session_id=session_id, patient_id="", timestamp=datetime.now(UTC),
100
+ vision=vision_result, nlp=nlp_result, fusion=fusion_result,
101
+ report_text=report_text, overall_status=overall_status,
102
+ timings=ProcessingTimings(**timings), warnings=warnings
103
+ )
104
+
105
+ def _preprocess_image(self, image_path: Path) -> Path:
106
+ from PIL import Image
107
+ with Image.open(image_path) as img:
108
+ img = img.convert("RGB")
109
+ img = img.resize((224, 224), Image.LANCZOS)
110
+ output_path = image_path.with_suffix(".processed.png")
111
+ img.save(output_path, "PNG")
112
+ return output_path
113
+
114
+ async def _run_vision(self, image_path: Path) -> VisionResult:
115
+ from backend.ml.vision.anomaly import AnomalyDetector
116
+ from backend.ml.vision.gradcam import GradCAM
117
+
118
+ state = await self.registry.get("dino_anomaly")
119
+ if not state.is_available:
120
+ raise ModelNotLoadedError("Vision model unavailable")
121
+
122
+ anomaly_score, model_confidence = await asyncio.to_thread(
123
+ AnomalyDetector.score, image_path, state.model, state.head, state.stats, state.current_device
124
+ )
125
+
126
+ heatmap_b64, top_regions = await asyncio.to_thread(
127
+ GradCAM.generate, image_path, state.model, anomaly_score
128
+ )
129
+
130
+ risk_level = "LOW" if anomaly_score < 40 else "MEDIUM" if anomaly_score < 70 else "HIGH"
131
+ return VisionResult(
132
+ anomaly_score=round(anomaly_score, 1), risk_level=risk_level,
133
+ heatmap_base64=heatmap_b64, top_regions=top_regions, model_confidence=model_confidence
134
+ )
135
+
136
+ async def _run_nlp(self, text: str) -> NLPResult:
137
+ from backend.ml.nlp.ner import NERExtractor
138
+ from backend.ml.nlp.classifier import DiseaseClassifier
139
+
140
+ ner_state = await self.registry.get("biobert_ner")
141
+ entities = await asyncio.to_thread(
142
+ NERExtractor.extract, text, ner_state.model, ner_state.tokenizer, ner_state.current_device
143
+ )
144
+
145
+ diagnosis = await asyncio.to_thread(DiseaseClassifier.classify, text, entities)
146
+ return NLPResult(
147
+ entities=entities, primary_diagnosis=diagnosis["primary"],
148
+ diagnosis_confidence=diagnosis["confidence"], differential=diagnosis["differential"]
149
+ )
150
+
151
+ async def _run_fusion(self, image_path: Path, text: str) -> FusionResult:
152
+ from backend.ml.fusion.medclip import MultimodalFusion
153
+
154
+ state = await self.registry.get("biomedvlp")
155
+ if not state.is_available:
156
+ raise ModelNotLoadedError("Fusion model unavailable")
157
+
158
+ similarity, alignment = await asyncio.to_thread(
159
+ MultimodalFusion.compute_similarity, image_path, text, state.model, state.tokenizer, state.current_device
160
+ )
161
+ final_risk = "HIGH" if similarity < 0.3 else "MEDIUM" if similarity < 0.7 else "LOW"
162
+ return FusionResult(image_text_similarity=round(similarity, 3), alignment=alignment, final_risk=final_risk)
163
+
164
+ async def _generate_report(self, vision: VisionResult | None, nlp: NLPResult | None, fusion: FusionResult | None) -> str:
165
+ from backend.ml.rag.generator import ReportGenerator
166
+
167
+ state = await self.registry.get("biogpt_base")
168
+ if not state.is_available:
169
+ raise ModelNotLoadedError("Report generation unavailable")
170
+
171
+ return await asyncio.to_thread(
172
+ ReportGenerator.generate, vision, nlp, fusion, state.model, state.tokenizer
173
+ )
174
+
175
+ def _fallback_report(self, vision: VisionResult | None, nlp: NLPResult | None) -> str:
176
+ parts = ["## AI-Assisted Analysis Report\n\n*Note: This is an automated template report.*\n"]
177
+ if vision:
178
+ parts.append(f"**Imaging Findings:** Anomaly score of {vision.anomaly_score}/100 indicates {vision.risk_level.lower()} risk findings.")
179
+ if nlp:
180
+ diseases = ", ".join(nlp.entities.diseases) if nlp.entities.diseases else "none identified"
181
+ symptoms = ", ".join(nlp.entities.symptoms) if nlp.entities.symptoms else "none documented"
182
+ parts.append(f"**Clinical Impression:** {nlp.primary_diagnosis} (confidence: {nlp.diagnosis_confidence:.0%}). Identified conditions: {diseases}. Symptoms: {symptoms}.")
183
+ parts.append("\n**Recommendation:** Please consult a licensed physician for diagnosis and treatment.")
184
+ return "\n".join(parts)
backend/orchestration/queue.py CHANGED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import uuid
3
+ import logging
4
+ from pathlib import Path
5
+ from datetime import datetime, UTC
6
+ from sqlalchemy import select, func
7
+
8
+ from backend.db.models import AnalysisTask, AnalysisSession
9
+ from backend.core.exceptions import TaskNotFoundError, SessionAccessDeniedError
10
+ from backend.core.logging_config import ml_logger
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class AnalysisTaskQueue:
15
+ MAX_CONCURRENT = 2
16
+ WORKER_SLEEP_SECONDS = 5
17
+
18
+ def __init__(self, db_session_factory, pipeline):
19
+ self._db_factory = db_session_factory
20
+ self._pipeline = pipeline
21
+ self._new_task_event = asyncio.Event()
22
+ self._active_count = 0
23
+ self._active_lock = asyncio.Lock()
24
+ self._worker_task: asyncio.Task | None = None
25
+ self._is_running = False
26
+
27
+ async def start(self):
28
+ self._is_running = True
29
+ self._worker_task = asyncio.create_task(self._worker_loop())
30
+ logger.info("Task queue worker started")
31
+
32
+ async def stop(self):
33
+ logger.info("Task queue stopping...")
34
+ self._is_running = False
35
+ self._new_task_event.set()
36
+
37
+ deadline = asyncio.get_event_loop().time() + 60
38
+ while self._active_count > 0:
39
+ if asyncio.get_event_loop().time() > deadline:
40
+ logger.warning(f"Shutdown timeout: {self._active_count} tasks still running")
41
+ break
42
+ await asyncio.sleep(1)
43
+
44
+ if self._worker_task:
45
+ self._worker_task.cancel()
46
+
47
+ async def submit(self, session_id: str, user_id: str | None, image_path: str, symptoms_text: str, priority: int = 1) -> str:
48
+ task_id = str(uuid.uuid4())
49
+ async with self._db_factory() as db:
50
+ task = AnalysisTask(
51
+ id=task_id, session_id=session_id, user_id=user_id,
52
+ status="PENDING", priority=priority,
53
+ image_path=image_path, symptoms_text=symptoms_text
54
+ )
55
+ db.add(task)
56
+ await db.commit()
57
+
58
+ self._new_task_event.set()
59
+ return task_id
60
+
61
+ async def get_status(self, task_id: str) -> dict:
62
+ async with self._db_factory() as db:
63
+ task = await db.get(AnalysisTask, task_id)
64
+ if not task:
65
+ raise TaskNotFoundError()
66
+
67
+ position = None
68
+ if task.status == "PENDING":
69
+ result = await db.execute(
70
+ select(func.count(AnalysisTask.id)).where(
71
+ AnalysisTask.status == "PENDING",
72
+ AnalysisTask.priority >= task.priority,
73
+ AnalysisTask.created_at < task.created_at
74
+ )
75
+ )
76
+ position = result.scalar_one() + 1
77
+
78
+ estimated_wait = None
79
+ if position:
80
+ slots_until_ours = max(0, position - self.MAX_CONCURRENT)
81
+ estimated_wait = slots_until_ours * 45
82
+
83
+ return {
84
+ "task_id": task_id, "session_id": task.session_id, "status": task.status,
85
+ "position_in_queue": position, "estimated_wait_seconds": estimated_wait,
86
+ "started_at": task.started_at, "completed_at": task.completed_at,
87
+ "error_message": task.error_message
88
+ }
89
+
90
+ async def cancel(self, task_id: str, user_id: str) -> bool:
91
+ async with self._db_factory() as db:
92
+ task = await db.get(AnalysisTask, task_id)
93
+ if not task:
94
+ raise TaskNotFoundError()
95
+ if task.user_id != user_id:
96
+ raise SessionAccessDeniedError()
97
+ if task.status != "PENDING":
98
+ return False
99
+
100
+ task.status = "CANCELLED"
101
+ image_path = Path(task.image_path)
102
+ if image_path.exists():
103
+ image_path.unlink()
104
+
105
+ await db.commit()
106
+ return True
107
+
108
+ async def _worker_loop(self):
109
+ logger.info("Worker loop started")
110
+ while self._is_running:
111
+ try:
112
+ await asyncio.wait_for(
113
+ asyncio.shield(self._new_task_event.wait()), timeout=self.WORKER_SLEEP_SECONDS
114
+ )
115
+ self._new_task_event.clear()
116
+ except asyncio.TimeoutError:
117
+ pass
118
+
119
+ if not self._is_running:
120
+ break
121
+
122
+ while self._active_count < self.MAX_CONCURRENT:
123
+ task = await self._fetch_next_task()
124
+ if not task:
125
+ break
126
+ asyncio.create_task(self._process_task(task))
127
+
128
+ logger.info("Worker loop stopped")
129
+
130
+ async def _fetch_next_task(self) -> AnalysisTask | None:
131
+ async with self._db_factory() as db:
132
+ result = await db.execute(
133
+ select(AnalysisTask).where(AnalysisTask.status == "PENDING")
134
+ .order_by(AnalysisTask.priority.desc(), AnalysisTask.created_at.asc())
135
+ .limit(1)
136
+ .with_for_update(skip_locked=True)
137
+ )
138
+ return result.scalar_one_or_none()
139
+
140
+ async def _process_task(self, task: AnalysisTask):
141
+ async with self._active_lock:
142
+ self._active_count += 1
143
+
144
+ try:
145
+ async with self._db_factory() as db:
146
+ task.status = "PROCESSING"
147
+ task.started_at = datetime.now(UTC)
148
+ await db.merge(task)
149
+ await db.commit()
150
+
151
+ result = await self._pipeline.run(
152
+ session_id=task.session_id,
153
+ image_path=Path(task.image_path),
154
+ symptoms_text=task.symptoms_text or ""
155
+ )
156
+
157
+ async with self._db_factory() as db:
158
+ db_task = await db.get(AnalysisTask, task.id)
159
+ db_task.status = "COMPLETED"
160
+ db_task.completed_at = datetime.now(UTC)
161
+
162
+ session = await db.get(AnalysisSession, task.session_id)
163
+ session.result_json = result.model_dump()
164
+ session.status = "READY"
165
+ session.risk_level = result.vision.risk_level if result.vision else "UNKNOWN"
166
+
167
+ await db.commit()
168
+
169
+ ml_logger.log_pipeline_step(
170
+ "full_pipeline", "COMPLETED",
171
+ int((datetime.now(UTC) - task.started_at).total_seconds() * 1000), task.session_id
172
+ )
173
+
174
+ except Exception as e:
175
+ logger.error(f"Task {task.id} failed: {e}", exc_info=True)
176
+ async with self._db_factory() as db:
177
+ db_task = await db.get(AnalysisTask, task.id)
178
+ db_task.attempt_count = (db_task.attempt_count or 0) + 1
179
+
180
+ if db_task.attempt_count < 3:
181
+ db_task.status = "PENDING"
182
+ self._new_task_event.set()
183
+ else:
184
+ db_task.status = "FAILED"
185
+ db_task.error_message = str(e)[:500]
186
+ session = await db.get(AnalysisSession, task.session_id)
187
+ session.status = "FAILED"
188
+ session.error_message = "Analysis failed after 3 attempts"
189
+ await db.commit()
190
+
191
+ finally:
192
+ try:
193
+ Path(task.image_path).unlink(missing_ok=True)
194
+ except Exception:
195
+ pass
196
+ async with self._active_lock:
197
+ self._active_count -= 1
198
+ self._new_task_event.set()
frontend/app/(auth)/callback/page.jsx ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client'
2
+ import { useEffect, useState } from 'react'
3
+ import { useRouter, useSearchParams } from 'next/navigation'
4
+ import { useAuth } from '@/lib/auth/AuthContext'
5
+ import Link from 'next/link'
6
+
7
+ export default function AuthCallbackPage() {
8
+ const searchParams = useSearchParams()
9
+ const router = useRouter()
10
+ const { loginWithToken } = useAuth()
11
+ const [errorMsg, setErrorMsg] = useState(null)
12
+
13
+ useEffect(() => {
14
+ const processToken = async () => {
15
+ const token = searchParams.get('token')
16
+ const error = searchParams.get('error')
17
+
18
+ if (error) {
19
+ setErrorMsg(`Authentication failed: ${error}`)
20
+ return
21
+ }
22
+
23
+ if (token) {
24
+ try {
25
+ await loginWithToken(token)
26
+ // Clean URL and redirect
27
+ router.replace('/upload')
28
+ } catch (err) {
29
+ setErrorMsg("Failed to establish secure session.")
30
+ }
31
+ } else {
32
+ setErrorMsg("No authentication token provided.")
33
+ }
34
+ }
35
+
36
+ processToken()
37
+ }, [searchParams, loginWithToken, router])
38
+
39
+ if (errorMsg) {
40
+ return (
41
+ <div className="flex flex-col items-center justify-center min-h-screen p-4 text-center">
42
+ <div className="p-6 bg-red-50 border border-red-200 rounded-lg shadow-sm">
43
+ <h2 className="text-lg font-semibold text-red-700 mb-2">Error</h2>
44
+ <p className="text-red-600 mb-4">{errorMsg}</p>
45
+ <Link href="/login" className="px-4 py-2 text-white bg-blue-600 rounded hover:bg-blue-700 transition">
46
+ Return to Login
47
+ </Link>
48
+ </div>
49
+ </div>
50
+ )
51
+ }
52
+
53
+ return (
54
+ <div className="flex flex-col items-center justify-center min-h-screen">
55
+ <div className="w-8 h-8 border-4 border-blue-600 border-t-transparent rounded-full animate-spin mb-4" />
56
+ <p className="text-gray-600 font-medium">Securing session...</p>
57
+ </div>
58
+ )
59
+ }
frontend/components/auth/ProtectedRoute.jsx ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client'
2
+ import { useEffect } from 'react'
3
+ import { useRouter, usePathname } from 'next/navigation'
4
+ import { useAuth } from '@/lib/auth/AuthContext'
5
+
6
+ export default function ProtectedRoute({ children }) {
7
+ const { isAuthenticated, isLoading } = useAuth()
8
+ const router = useRouter()
9
+ const pathname = usePathname()
10
+
11
+ useEffect(() => {
12
+ if (!isLoading && !isAuthenticated) {
13
+ sessionStorage.setItem('intendedPath', pathname)
14
+ router.push('/login')
15
+ }
16
+ }, [isLoading, isAuthenticated, router, pathname])
17
+
18
+ if (isLoading) {
19
+ return <div className="flex items-center justify-center min-h-screen">Loading secure environment...</div>
20
+ }
21
+
22
+ return isAuthenticated ? children : null
23
+ }
frontend/components/auth/PublicOnlyRoute.jsx ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client'
2
+ import { useEffect } from 'react'
3
+ import { useRouter } from 'next/navigation'
4
+ import { useAuth } from '@/lib/auth/AuthContext'
5
+
6
+ export default function PublicOnlyRoute({ children }) {
7
+ const { isAuthenticated, isLoading } = useAuth()
8
+ const router = useRouter()
9
+
10
+ useEffect(() => {
11
+ if (!isLoading && isAuthenticated) {
12
+ router.push('/upload')
13
+ }
14
+ }, [isLoading, isAuthenticated, router])
15
+
16
+ if (isLoading) return null
17
+
18
+ return !isAuthenticated ? children : null
19
+ }
frontend/next.config.js ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /** @type {import('next').NextConfig} */
2
+
3
+ const securityHeaders = [
4
+ {
5
+ key: "Content-Security-Policy",
6
+ value: [
7
+ "default-src 'self'",
8
+ "script-src 'self' 'unsafe-eval' 'unsafe-inline'",
9
+ "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com",
10
+ "font-src 'self' https://fonts.gstatic.com",
11
+ "img-src 'self' data: blob: https:",
12
+ `connect-src 'self' ${process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8000'}`,
13
+ "frame-src 'none'",
14
+ "object-src 'none'",
15
+ "base-uri 'self'",
16
+ "form-action 'self'",
17
+ ].join("; ")
18
+ },
19
+ { key: "X-DNS-Prefetch-Control", value: "on" },
20
+ { key: "X-Frame-Options", value: "SAMEORIGIN" },
21
+ { key: "X-Content-Type-Options", value: "nosniff" },
22
+ { key: "Referrer-Policy", value: "strict-origin-when-cross-origin" },
23
+ ];
24
+
25
+ const nextConfig = {
26
+ output: "standalone",
27
+ experimental: {
28
+ serverActions: true,
29
+ },
30
+ images: {
31
+ domains: [(process.env.NEXT_PUBLIC_API_URL || 'localhost').replace(/^https?:\/\//, '').split(':')[0]],
32
+ },
33
+ async headers() {
34
+ return [
35
+ {
36
+ source: "/(.*)",
37
+ headers: securityHeaders,
38
+ },
39
+ ];
40
+ },
41
+ };
42
+
43
+ module.exports = nextConfig;
scripts/security_audit.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import httpx
3
+ import asyncio
4
+
5
+ API_URL = "http://localhost:8000"
6
+
7
+ async def run_audit():
8
+ print("\n--- MedSight AI Security Audit ---\n")
9
+ passed = 0
10
+ total = 0
11
+
12
+ def check(name, condition, error_msg):
13
+ nonlocal passed, total
14
+ total += 1
15
+ if condition:
16
+ print(f"✅ PASS: {name}")
17
+ passed += 1
18
+ else:
19
+ print(f"❌ FAIL: {name} - {error_msg}")
20
+
21
+ try:
22
+ async with httpx.AsyncClient() as client:
23
+ # 1. Test Health / Headers
24
+ res = await client.get(f"{API_URL}/api/v1/health")
25
+ headers = res.headers
26
+ check("X-Content-Type-Options", headers.get("x-content-type-options") == "nosniff", "Header missing or incorrect")
27
+ check("X-Frame-Options", headers.get("x-frame-options") == "DENY", "Header missing or incorrect")
28
+ check("Permissions-Policy", "microphone=(self)" in headers.get("permissions-policy", ""), "Microphone permission misconfigured")
29
+ check("No Server Header", "server" not in headers, "Server header is leaking framework info")
30
+
31
+ # 2. Test CORS Rejection
32
+ cors_res = await client.options(
33
+ f"{API_URL}/api/v1/health",
34
+ headers={"Origin": "http://evil-domain.com", "Access-Control-Request-Method": "GET"}
35
+ )
36
+ # Depending on FastAPI config, it might strip CORS headers or return 400.
37
+ # The key is it shouldn't return Access-Control-Allow-Origin: http://evil-domain.com
38
+ check("CORS Restrictions", cors_res.headers.get("access-control-allow-origin") != "http://evil-domain.com", "CORS allowed untrusted origin")
39
+
40
+ except httpx.ConnectError:
41
+ print("❌ Server not running. Start it with: make run-dev")
42
+ sys.exit(1)
43
+
44
+ print(f"\nAudit Complete: {passed}/{total} Passed\n")
45
+
46
+ if __name__ == "__main__":
47
+ asyncio.run(run_audit())