File size: 10,994 Bytes
f01ec67
 
 
 
 
 
af005a8
f01ec67
 
 
 
 
 
 
 
 
 
 
 
af005a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f01ec67
 
 
af005a8
f01ec67
 
 
 
 
af005a8
f01ec67
 
af005a8
f01ec67
 
af005a8
28fc2e0
 
 
 
 
 
 
af005a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f01ec67
af005a8
f01ec67
af005a8
f01ec67
 
af005a8
 
 
 
 
 
 
 
 
 
 
f01ec67
 
 
af005a8
f01ec67
 
 
 
 
 
 
 
 
 
af005a8
f01ec67
af005a8
 
f01ec67
 
 
 
 
af005a8
 
 
 
 
f01ec67
 
 
 
 
1f3192e
f01ec67
 
 
 
 
 
 
 
af005a8
 
f01ec67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af005a8
 
 
 
 
f01ec67
 
 
 
af005a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f01ec67
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from contextlib import asynccontextmanager
from datetime import datetime, timezone
import logging
import uvicorn

from backend.core.config import settings, startup_validation
from backend.core.logging_config import setup_logging
from backend.core.exceptions import MedSightException
from backend.core.middleware import (
    RequestIDMiddleware, SecurityHeadersMiddleware, AccessLogMiddleware, rate_limit_handler, limiter
)
from backend.db.session import init_db, engine
from slowapi.errors import RateLimitExceeded

# ═══ Router Imports ═══
from backend.api.v1.routers.auth import router as auth_router
from backend.api.v1.routers.users import router as users_router

# ML-dependent routers: import gracefully (they need torch, transformers, etc.)
_ml_routers_loaded = True
try:
    from backend.api.v1.routers.analyze import router as analyze_router
    from backend.api.v1.routers.chat import router as chat_router
    from backend.api.v1.routers.report import router as report_router
except (ImportError, OSError) as e:
    _ml_routers_loaded = False
    _ml_import_error = str(e)

# ML system imports (optional β€” graceful degradation)
_ml_system_loaded = True
try:
    from backend.ml.registry import ModelRegistry
    from backend.orchestration.queue import task_queue
    from backend.orchestration.scheduler import start_scheduler, stop_scheduler
except (ImportError, OSError) as e:
    _ml_system_loaded = False
    _ml_system_error = str(e)

logger = logging.getLogger(__name__)


@asynccontextmanager
async def lifespan(app: FastAPI):
    # ── STARTUP ──────────────────────────────────────────
    setup_logging()
    logger.info("πŸš€ MedSight AI starting up...")

    startup_validation()
    logger.info("βœ… Configuration validated")

    await init_db()
    logger.info("βœ… Database initialized")

    # Auto-download models from HuggingFace if not present locally
    try:
        from backend.ml.vision.hf_download import ensure_models
        ensure_models()
    except Exception as e:
        logger.warning(f"⚠️ HuggingFace model download skipped: {e}")

    # ML model registry - Load in background to prevent blocking the API
    if _ml_system_loaded:
        try:
            registry = ModelRegistry()
            app.state.model_registry = registry
            if _ml_system_loaded:
                task_queue._pipeline.registry = registry
            
            # Spawn background task for heavy loading
            async def load_models():
                if not settings.LOAD_ML_MODELS:
                    logger.info("⏩ Skipping ML models loading (LOAD_ML_MODELS=false)")
                    return
                try:
                    await registry.startup_load()
                    logger.info("βœ… ML models background loading complete")
                except Exception as e:
                    logger.error(f"❌ ML background loading failed: {e}")
            
            import asyncio
            asyncio.create_task(load_models())
            logger.info("πŸš€ ML models loading initiated in background")
        except Exception as e:
            logger.warning(f"⚠️ ML registry initialization failed: {e}")
            app.state.model_registry = None
    else:
        logger.warning(f"⚠️ ML system not available: {_ml_system_error}")
        app.state.model_registry = None

    # Task queue
    if _ml_system_loaded:
        try:
            await task_queue.start()
            app.state.task_queue = task_queue
            logger.info("βœ… Task queue started")
        except Exception as e:
            logger.warning(f"⚠️ Task queue failed to start: {e}")
    else:
        logger.warning("⚠️ Task queue not available (ML dependencies missing)")

    if _ml_system_loaded:
        try:
            start_scheduler()
            logger.info("βœ… Scheduler started")
        except Exception as e:
            logger.warning(f"⚠️ Scheduler failed to start: {e}")

    app.state.start_time = datetime.now(timezone.utc)
    app.state.is_ready = True
    logger.info("🟒 MedSight AI ready to serve requests")

    yield  # ── APP RUNS HERE ──

    # ── SHUTDOWN ─────────────────────────────────────────
    logger.info("⏳ MedSight AI shutting down...")

    if _ml_system_loaded and hasattr(app.state, "task_queue"):
        try:
            await task_queue.stop()
        except Exception:
            pass
        try:
            stop_scheduler()
        except Exception:
            pass

    await engine.dispose()
    logger.info("βœ… Graceful shutdown complete")


app = FastAPI(
    title="MedSight AI",
    description="Multimodal Medical Diagnostic Platform powered by deep learning",
    version=settings.VERSION,
    docs_url="/docs" if not settings.is_production else None,
    redoc_url="/redoc" if not settings.is_production else None,
    lifespan=lifespan,
    openapi_tags=[
        {"name": "Health", "description": "System health checks"},
        {"name": "Authentication", "description": "Login, registration, and tokens"},
        {"name": "Users", "description": "User profiles and session history"},
        {"name": "Analysis", "description": "Multimodal analysis endpoints"},
        {"name": "Chat", "description": "RAG-powered medical Q&A"},
        {"name": "Reports", "description": "PDF report generation"},
    ]
)

app.state.limiter = limiter

# ═══ MIDDLEWARE LAYER ═══
# SessionMiddleware is required by authlib for OAuth state management
from starlette.middleware.sessions import SessionMiddleware
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY)
app.add_middleware(TrustedHostMiddleware, allowed_hosts=settings.TRUSTED_HOSTS)
app.add_middleware(RequestIDMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(
    CORSMiddleware,
    allow_origins=settings.ALLOWED_ORIGINS,
    allow_origin_regex=settings.ALLOWED_ORIGIN_REGEX,
    allow_credentials=True,
    allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
    allow_headers=["Content-Type", "Authorization", "X-API-Key", "X-Request-ID", "Accept"],
    expose_headers=["X-Request-ID", "X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset"],
    max_age=86400,
)
app.add_middleware(AccessLogMiddleware)


# ═══ EXCEPTION HANDLERS ═══
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
    errors = [{"field": ".".join(map(str, err["loc"])), "message": err["msg"]} for err in exc.errors()]
    return JSONResponse(
        status_code=422,
        content={"error_code": "VALIDATION_ERROR", "details": errors}
    )

@app.exception_handler(MedSightException)
async def medsight_exception_handler(request: Request, exc: MedSightException):
    return JSONResponse(
        status_code=exc.status_code,
        content={"error_code": exc.error_code, "message": exc.message, "request_id": getattr(request.state, "request_id", None)}
    )

@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
    logger.error(f"Unhandled Exception: {exc}", exc_info=True)
    return JSONResponse(
        status_code=500,
        content={"error_code": "INTERNAL_ERROR", "message": "An unexpected error occurred", "request_id": getattr(request.state, "request_id", None)}
    )

app.add_exception_handler(RateLimitExceeded, rate_limit_handler)


# ═══════════════════════════════════════════════════════════════════════
# ROUTE REGISTRATION
# ═══════════════════════════════════════════════════════════════════════

# Root
@app.get("/")
async def root():
    return {"name": "MedSight AI", "version": settings.VERSION, "status": "operational", "docs": "/docs"}

# Health (always inline β€” no heavy deps)
@app.get("/api/v1/health", tags=["Health"])
async def health_check():
    """Health check endpoint for frontend status monitoring."""
    uptime = None
    if hasattr(app.state, "start_time"):
        uptime = (datetime.now(timezone.utc) - app.state.start_time).total_seconds()
    return {
        "status": "healthy",
        "version": settings.VERSION,
        "environment": settings.ENVIRONMENT,
        "uptime_seconds": uptime,
        "ml_available": _ml_system_loaded and getattr(app.state, "model_registry", None) is not None,
    }

# Auth (real router β€” no ML deps)
app.include_router(auth_router, prefix="/api/v1/auth", tags=["Authentication"])

# Users (real router β€” no ML deps)
app.include_router(users_router, prefix="/api/v1/users", tags=["Users"])

# ML-dependent routers: real if loaded, stubs if not
if _ml_routers_loaded:
    app.include_router(analyze_router, prefix="/api/v1/analyze", tags=["Analysis"])
    app.include_router(chat_router, prefix="/api/v1/chat", tags=["Chat"])
    app.include_router(report_router, prefix="/api/v1/report", tags=["Reports"])
    logger.info("βœ… All ML routers registered")
else:
    # Fallback stubs for ML endpoints when dependencies are missing
    @app.post("/api/v1/analyze", tags=["Analysis"])
    async def analyze_fallback():
        return JSONResponse(status_code=503, content={
            "error_code": "ML_UNAVAILABLE",
            "message": f"ML pipeline not available: {_ml_import_error}"
        })

    @app.get("/api/v1/analyze/status/{task_id}", tags=["Analysis"])
    async def analyze_status_fallback(task_id: str):
        return JSONResponse(status_code=503, content={"error_code": "ML_UNAVAILABLE"})

    @app.get("/api/v1/analyze/result/{session_id}", tags=["Analysis"])
    async def analyze_result_fallback(session_id: str):
        return JSONResponse(status_code=503, content={"error_code": "ML_UNAVAILABLE"})

    @app.post("/api/v1/chat", tags=["Chat"])
    async def chat_fallback():
        return JSONResponse(status_code=503, content={"error_code": "ML_UNAVAILABLE"})

    @app.get("/api/v1/report/{session_id}", tags=["Reports"])
    async def report_fallback(session_id: str):
        return JSONResponse(status_code=503, content={"error_code": "ML_UNAVAILABLE"})


if __name__ == "__main__":
    uvicorn.run(
        "backend.main:app",
        host="0.0.0.0",
        port=8000,
        reload=settings.DEBUG,
        log_config=None,
        access_log=False
    )