hoshikrana commited on
Commit
27826d5
·
1 Parent(s): 47dce21

feat: security, JWT, API keys, validation, and auth routes

Browse files
alembic.ini DELETED
@@ -1,41 +0,0 @@
1
- [alembic]
2
- script_location = backend/db/migrations
3
- prepend_sys_path = .
4
- version_path_separator = os # Use os.linesep
5
-
6
- [post_write_hooks]
7
- # Post-write hooks to format auto-generated migrations
8
-
9
- [loggers]
10
- keys = root,sqlalchemy,alembic
11
-
12
- [handlers]
13
- keys = console
14
-
15
- [formatters]
16
- keys = generic
17
-
18
- [logger_root]
19
- level = WARN
20
- handlers = console
21
- qualname =
22
-
23
- [logger_sqlalchemy]
24
- level = WARN
25
- handlers =
26
- qualname = sqlalchemy.engine
27
-
28
- [logger_alembic]
29
- level = INFO
30
- handlers =
31
- qualname = alembic
32
-
33
- [handler_console]
34
- class = StreamHandler
35
- args = (sys.stderr,)
36
- level = NOTSET
37
- formatter = generic
38
-
39
- [formatter_generic]
40
- format = %(levelname)-5.5s [%(name)s] %(message)s
41
- datefmt = %H:%M:%S
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/api/v1/routers/auth.py CHANGED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from fastapi import APIRouter, Request, Depends, BackgroundTasks
3
+ from fastapi.responses import JSONResponse, RedirectResponse
4
+ from fastapi.security import OAuth2PasswordRequestForm
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+ from authlib.integrations.starlette_client import OAuth, OAuthError
7
+
8
+ from backend.db.session import get_db
9
+ from backend.db.models import User
10
+ from backend.db.utils import exists, get_or_404
11
+ from backend.core.middleware import limiter
12
+ from backend.core.security import (
13
+ validate_password_strength, hash_password, verify_password,
14
+ create_access_token, create_refresh_token, set_refresh_cookie, clear_auth_cookies,
15
+ verify_token, get_refresh_token_from_cookie, blacklist_token, generate_verification_token
16
+ )
17
+ from backend.core.brute_force import brute_force_protector
18
+ from backend.core.dependencies import get_client_ip, get_current_user
19
+ from backend.core.exceptions import (
20
+ ValidationError, EmailAlreadyExistsError, AuthenticationError, AccountInactiveError
21
+ )
22
+ from backend.api.v1.schemas.auth import (
23
+ RegisterRequest, TokenResponse, UserResponse, AuthResponse, MessageResponse
24
+ )
25
+ from backend.core.config import settings
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ oauth = OAuth()
30
+ oauth.register(
31
+ name="google",
32
+ server_metadata_url="https://accounts.google.com/.well-known/openid-configuration",
33
+ client_id=settings.GOOGLE_CLIENT_ID,
34
+ client_secret=settings.GOOGLE_CLIENT_SECRET,
35
+ client_kwargs={"scope": "openid email profile"},
36
+ )
37
+
38
+ router = APIRouter()
39
+
40
+ # --- Google OAuth Segment ---
41
+ @router.get("/google/login")
42
+ async def google_login(request: Request):
43
+ """Redirects to Google consent screen."""
44
+ redirect_uri = settings.GOOGLE_REDIRECT_URI
45
+ return await oauth.google.authorize_redirect(request, redirect_uri)
46
+
47
+ @router.get("/google/callback")
48
+ async def google_callback(request: Request, db: AsyncSession = Depends(get_db)):
49
+ """Handles Google OAuth callback."""
50
+ # Step 1: Exchange code for token
51
+ try:
52
+ google_token = await oauth.google.authorize_access_token(request)
53
+ except OAuthError as e:
54
+ raise AuthenticationError(f"Google OAuth failed: {e.description}")
55
+
56
+ # Step 2: Get user info
57
+ user_info = google_token.get("userinfo")
58
+ if not user_info or not user_info.get("email_verified"):
59
+ raise AuthenticationError("Google email not verified")
60
+
61
+ google_id = user_info.get("sub")
62
+ email = user_info.get("email")
63
+ full_name = user_info.get("name", "Google User")
64
+ picture = user_info.get("picture")
65
+
66
+ # Step 3: Find or create user
67
+ user = await User.get_by_google_id(db, google_id)
68
+ if not user:
69
+ user = await User.get_by_email(db, email)
70
+ if user:
71
+ # Link existing account to Google
72
+ user.google_id = google_id
73
+ user.profile_picture_url = picture
74
+ user.is_active = True
75
+ else:
76
+ # Create new user
77
+ user = User(
78
+ email=email, full_name=full_name,
79
+ google_id=google_id, profile_picture_url=picture,
80
+ is_active=True, is_verified=True
81
+ )
82
+ db.add(user)
83
+
84
+ await db.commit()
85
+ await db.refresh(user)
86
+
87
+ # Step 4: Issue tokens
88
+ access_token = create_access_token(user.id)
89
+ refresh_token = create_refresh_token(user.id)
90
+
91
+ # Step 5: Redirect to frontend
92
+ response = RedirectResponse(
93
+ url=f"{settings.FRONTEND_URL}/auth/callback?token={access_token}"
94
+ )
95
+ set_refresh_cookie(response, refresh_token)
96
+ return response
97
+
98
+ # --- Email/Password Segment ---
99
+ async def send_verification_email(email: str, token: str):
100
+ # TODO: Implement actual email sending logic
101
+ logger.info(f"MOCK EMAIL to {email}: Your verification token is {token}")
102
+
103
+ async def update_login_stats(user_id: str, ip: str):
104
+ # TODO: Update last_login_at in DB
105
+ pass
106
+
107
+ @router.post("/register", response_model=MessageResponse)
108
+ @limiter.limit("3/hour")
109
+ async def register(
110
+ request: Request,
111
+ body: RegisterRequest,
112
+ background_tasks: BackgroundTasks,
113
+ db: AsyncSession = Depends(get_db)
114
+ ):
115
+ failures = validate_password_strength(body.password)
116
+ if failures:
117
+ raise ValidationError(f"Password too weak: {'; '.join(failures)}")
118
+
119
+ if await exists(db, User, email=body.email):
120
+ raise EmailAlreadyExistsError()
121
+
122
+ user = User(
123
+ email=body.email,
124
+ full_name=body.full_name,
125
+ hashed_password=hash_password(body.password),
126
+ is_active=False # Requires email verification
127
+ )
128
+ db.add(user)
129
+ await db.commit()
130
+ await db.refresh(user)
131
+
132
+ plain_token, token_hash = generate_verification_token()
133
+ background_tasks.add_task(send_verification_email, user.email, plain_token)
134
+
135
+ logger.info("User registered", extra={"user_id": str(user.id), "email": user.email})
136
+ return MessageResponse(message="Account created. Check your email to verify.")
137
+
138
+ @router.post("/login", response_model=AuthResponse)
139
+ @limiter.limit("5/minute")
140
+ async def login(
141
+ request: Request,
142
+ background_tasks: BackgroundTasks,
143
+ form_data: OAuth2PasswordRequestForm = Depends(),
144
+ db: AsyncSession = Depends(get_db),
145
+ client_ip: str = Depends(get_client_ip)
146
+ ):
147
+ await brute_force_protector.check_and_record_failure(client_ip)
148
+
149
+ user = await User.get_by_email(db, form_data.username)
150
+ if not user or not user.hashed_password or not verify_password(form_data.password, user.hashed_password):
151
+ await brute_force_protector.check_and_record_failure(client_ip)
152
+ raise AuthenticationError("Invalid email or password")
153
+
154
+ if not user.is_active:
155
+ raise AccountInactiveError("Please verify your email before logging in")
156
+
157
+ brute_force_protector.record_success(client_ip)
158
+
159
+ access_token = create_access_token(user.id)
160
+ refresh_token = create_refresh_token(user.id)
161
+
162
+ response = JSONResponse(content=AuthResponse(
163
+ token=TokenResponse(access_token=access_token, token_type="bearer",
164
+ expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60),
165
+ user=UserResponse.model_validate(user)
166
+ ).model_dump())
167
+
168
+ set_refresh_cookie(response, refresh_token)
169
+ background_tasks.add_task(update_login_stats, user.id, client_ip)
170
+
171
+ logger.info("User logged in", extra={"user_id": str(user.id)})
172
+ return response
173
+
174
+ @router.post("/logout")
175
+ async def logout(
176
+ request: Request,
177
+ current_user: User = Depends(get_current_user)
178
+ ):
179
+ auth_header = request.headers.get("Authorization", "")
180
+ token = auth_header.replace("Bearer ", "")
181
+ if token:
182
+ try:
183
+ payload = verify_token(token, "access")
184
+ await blacklist_token(payload.jti, payload.exp)
185
+ except Exception:
186
+ pass # Ignore errors on logout
187
+
188
+ response = JSONResponse(content={"message": "Logged out successfully"})
189
+ clear_auth_cookies(response)
190
+ logger.info("User logged out", extra={"user_id": str(current_user.id)})
191
+ return response
192
+
193
+ @router.post("/refresh", response_model=TokenResponse)
194
+ async def refresh(request: Request, db: AsyncSession = Depends(get_db)):
195
+ raw_refresh = get_refresh_token_from_cookie(request)
196
+ payload = verify_token(raw_refresh, "refresh")
197
+
198
+ user = await get_or_404(db, User, payload.sub)
199
+ if not user.is_active:
200
+ raise AccountInactiveError()
201
+
202
+ new_access_token = create_access_token(user.id)
203
+ new_refresh_token = create_refresh_token(user.id)
204
+
205
+ await blacklist_token(payload.jti, payload.exp)
206
+
207
+ response = JSONResponse(content=TokenResponse(
208
+ access_token=new_access_token, token_type="bearer",
209
+ expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
210
+ ).model_dump())
211
+ set_refresh_cookie(response, new_refresh_token)
212
+ return response
213
+
214
+ @router.get("/me", response_model=UserResponse)
215
+ async def get_me(current_user: User = Depends(get_current_user)):
216
+ return UserResponse.model_validate(current_user)
backend/core/api_keys.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import secrets
3
+ from datetime import datetime, UTC
4
+ from sqlalchemy import select
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+ from backend.db.models import APIKey
7
+
8
+ def generate_api_key() -> tuple[str, str]:
9
+ """Returns (plain_key, hashed_key)"""
10
+ prefix = "ms_live"
11
+ random_part = secrets.token_hex(32)
12
+ plain_key = f"{prefix}_{random_part}"
13
+ hashed = hashlib.sha256(plain_key.encode()).hexdigest()
14
+ return plain_key, hashed
15
+
16
+ def hash_api_key(plain_key: str) -> str:
17
+ return hashlib.sha256(plain_key.encode()).hexdigest()
18
+
19
+ async def verify_api_key(plain_key: str, db: AsyncSession) -> APIKey | None:
20
+ hashed = hash_api_key(plain_key)
21
+ result = await db.execute(
22
+ select(APIKey).where(APIKey.key_hash == hashed, APIKey.is_active == True)
23
+ )
24
+ api_key = result.scalar_one_or_none()
25
+
26
+ if not api_key:
27
+ return None
28
+ if api_key.expires_at and api_key.expires_at < datetime.now(UTC).replace(tzinfo=None):
29
+ return None
30
+
31
+ return api_key
32
+
33
+ class APIKeyRateLimiter:
34
+ """In-memory rate limiter for API Keys."""
35
+ def __init__(self):
36
+ self._counters: dict[str, dict] = {}
37
+
38
+ async def check_and_increment(self, key_id: str, limit: int) -> bool:
39
+ current_hour = datetime.now(UTC).strftime("%Y-%m-%dT%H")
40
+
41
+ if key_id not in self._counters or self._counters[key_id]["hour"] != current_hour:
42
+ self._counters[key_id] = {"hour": current_hour, "count": 0}
43
+
44
+ if self._counters[key_id]["count"] >= limit:
45
+ return False
46
+
47
+ self._counters[key_id]["count"] += 1
48
+ return True
49
+
50
+ def get_remaining(self, key_id: str, limit: int) -> int:
51
+ current_hour = datetime.now(UTC).strftime("%Y-%m-%dT%H")
52
+ if key_id not in self._counters or self._counters[key_id]["hour"] != current_hour:
53
+ return limit
54
+ return max(0, limit - self._counters[key_id]["count"])
55
+
56
+ api_key_limiter = APIKeyRateLimiter()
backend/core/brute_force.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from datetime import datetime, timedelta, UTC
4
+ from backend.core.exceptions import AccountLockedError
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class BruteForceProtector:
9
+ """In-memory brute force protection for login attempts"""
10
+
11
+ THRESHOLDS = [
12
+ (3, 30), # 3 failures -> 30 second lockout
13
+ (5, 300), # 5 failures -> 5 minute lockout
14
+ (10, 3600), # 10 failures -> 1 hour lockout
15
+ ]
16
+
17
+ def __init__(self):
18
+ self._attempts: dict[str, list[datetime]] = {}
19
+ self._lockouts: dict[str, datetime] = {}
20
+
21
+ async def check_and_record_failure(self, ip: str):
22
+ now = datetime.now(UTC)
23
+
24
+ if ip in self._lockouts and self._lockouts[ip] > now:
25
+ remaining = int((self._lockouts[ip] - now).total_seconds())
26
+ raise AccountLockedError(f"Too many failed attempts. Try in {remaining}s")
27
+
28
+ if ip not in self._attempts:
29
+ self._attempts[ip] = []
30
+
31
+ # Clean old attempts
32
+ self._attempts[ip] = [t for t in self._attempts[ip] if (now - t).seconds < 3600]
33
+ self._attempts[ip].append(now)
34
+
35
+ count = len(self._attempts[ip])
36
+ for threshold, lockout_seconds in reversed(self.THRESHOLDS):
37
+ if count >= threshold:
38
+ self._lockouts[ip] = now + timedelta(seconds=lockout_seconds)
39
+ logger.warning(f"Login lockout: {ip}, attempts: {count}, lockout: {lockout_seconds}s")
40
+ raise AccountLockedError(f"Account locked for {lockout_seconds}s")
41
+
42
+ def record_success(self, ip: str):
43
+ self._attempts.pop(ip, None)
44
+ self._lockouts.pop(ip, None)
45
+
46
+ def is_locked(self, ip: str) -> bool:
47
+ return ip in self._lockouts and self._lockouts[ip] > datetime.now(UTC)
48
+
49
+ async def cleanup(self):
50
+ now = datetime.now(UTC)
51
+ expired = [ip for ip, until in self._lockouts.items() if until < now]
52
+ for ip in expired:
53
+ del self._lockouts[ip]
54
+ self._attempts.pop(ip, None)
55
+
56
+ brute_force_protector = BruteForceProtector()
backend/core/middleware.py CHANGED
@@ -1,26 +1,33 @@
1
- from fastapi import Request
2
- from starlette.middleware.base import BaseHTTPMiddleware
3
- from slowapi import Limiter, _rate_limit_exceeded_handler
4
- from slowapi.util import get_remote_address
5
 
6
- # Setup rate limiter
7
- limiter = Limiter(key_func=get_remote_address)
8
- rate_limit_handler = _rate_limit_exceeded_handler
 
 
 
 
 
 
9
 
10
- class RequestIDMiddleware(BaseHTTPMiddleware):
11
- async def dispatch(self, request: Request, call_next):
12
- # Stub implementation
13
- response = await call_next(request)
14
- return response
15
 
16
- class SecurityHeadersMiddleware(BaseHTTPMiddleware):
17
- async def dispatch(self, request: Request, call_next):
18
- # Stub implementation
19
- response = await call_next(request)
20
- return response
21
-
22
- class AccessLogMiddleware(BaseHTTPMiddleware):
23
- async def dispatch(self, request: Request, call_next):
24
- # Stub implementation
25
- response = await call_next(request)
26
- return response
 
 
1
+ from fastapi import Request, Response
2
+ from fastapi.responses import JSONResponse
3
+ from slowapi import Limiter
4
+ from slowapi.errors import RateLimitExceeded
5
 
6
+ def get_rate_limit_key(request: Request) -> str:
7
+ user_id = getattr(request.state, "user_id", None)
8
+ if user_id:
9
+ return f"user:{user_id}"
10
+
11
+ forwarded_for = request.headers.get("X-Forwarded-For")
12
+ if forwarded_for:
13
+ return f"ip:{forwarded_for.split(',')[0].strip()}"
14
+ return f"ip:{request.client.host}"
15
 
16
+ limiter = Limiter(
17
+ key_func=get_rate_limit_key,
18
+ default_limits=["1000/minute"],
19
+ storage_uri="memory://"
20
+ )
21
 
22
+ async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> Response:
23
+ retry_after = exc.retry_after if hasattr(exc, "retry_after") else 60
24
+ return JSONResponse(
25
+ status_code=429,
26
+ content={
27
+ "error_code": "RATE_LIMIT_EXCEEDED",
28
+ "message": f"Too many requests. Try again in {retry_after} seconds.",
29
+ "retry_after_seconds": retry_after,
30
+ "limit": str(exc.limit)
31
+ },
32
+ headers={"Retry-After": str(retry_after)}
33
+ )
backend/core/security.py CHANGED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import secrets
3
+ import asyncio
4
+ import re
5
+ from datetime import datetime, timedelta, UTC
6
+ from typing import Literal, Tuple, List, Dict
7
+ from jose import jwt, JWTError, ExpiredSignatureError
8
+ from passlib.context import CryptContext
9
+ from pydantic import BaseModel
10
+ from fastapi import Response, Request
11
+
12
+ from backend.core.config import settings
13
+ from backend.core.exceptions import (
14
+ ExpiredTokenError, InvalidTokenError, BlacklistedTokenError, AuthenticationError
15
+ )
16
+
17
+ # 1. Password Context
18
+ pwd_context = CryptContext(
19
+ schemes=["bcrypt"],
20
+ deprecated="auto",
21
+ bcrypt__rounds=12
22
+ )
23
+
24
+ def hash_password(plain: str) -> str:
25
+ return pwd_context.hash(plain)
26
+
27
+ def verify_password(plain: str, hashed: str) -> bool:
28
+ return pwd_context.verify(plain, hashed)
29
+
30
+ def validate_password_strength(password: str) -> List[str]:
31
+ """Returns list of failure reasons. Empty list = strong password."""
32
+ failures = []
33
+ if len(password) < 8:
34
+ failures.append("Must be at least 8 characters long")
35
+ if not re.search(r"[A-Z]", password):
36
+ failures.append("Must contain at least one uppercase letter")
37
+ if not re.search(r"[a-z]", password):
38
+ failures.append("Must contain at least one lowercase letter")
39
+ if not re.search(r"\d", password):
40
+ failures.append("Must contain at least one number")
41
+ if not re.search(r"[!@#$%^&*(),.?\":{}|<>]", password):
42
+ failures.append("Must contain at least one special character")
43
+ return failures
44
+
45
+ # 2. Token Payload Model
46
+ class TokenPayload(BaseModel):
47
+ sub: str # user_id
48
+ type: Literal["access", "refresh"]
49
+ jti: str # unique token ID
50
+ iat: datetime
51
+ exp: datetime
52
+
53
+ # 3 & 4. Token Creation
54
+ def create_token(user_id: str, token_type: Literal["access", "refresh"], extra_claims: dict = None) -> str:
55
+ claims = extra_claims or {}
56
+ now = datetime.now(UTC)
57
+
58
+ if token_type == "access":
59
+ expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
60
+ else:
61
+ expires_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
62
+
63
+ payload = {
64
+ "sub": str(user_id),
65
+ "type": token_type,
66
+ "jti": secrets.token_hex(16),
67
+ "iat": now,
68
+ "exp": now + expires_delta,
69
+ **claims
70
+ }
71
+ return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
72
+
73
+ def create_access_token(user_id: str, extra_claims: dict = None) -> str:
74
+ return create_token(user_id, "access", extra_claims)
75
+
76
+ def create_refresh_token(user_id: str) -> str:
77
+ return create_token(user_id, "refresh")
78
+
79
+ # 5. Token Verification
80
+ def verify_token(token: str, expected_type: Literal["access", "refresh"]) -> TokenPayload:
81
+ try:
82
+ payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
83
+ except ExpiredSignatureError:
84
+ raise ExpiredTokenError("Token has expired")
85
+ except JWTError:
86
+ raise InvalidTokenError("Token is invalid")
87
+
88
+ if payload.get("type") != expected_type:
89
+ raise InvalidTokenError(f"Expected {expected_type} token")
90
+
91
+ if is_token_blacklisted(payload.get("jti")):
92
+ raise BlacklistedTokenError()
93
+
94
+ return TokenPayload(**payload)
95
+
96
+ # 6. In-Memory Blacklist
97
+ _blacklist: Dict[str, datetime] = {}
98
+ _blacklist_lock = asyncio.Lock()
99
+
100
+ async def blacklist_token(jti: str, expires_at: datetime):
101
+ async with _blacklist_lock:
102
+ _blacklist[jti] = expires_at
103
+
104
+ def is_token_blacklisted(jti: str) -> bool:
105
+ return jti in _blacklist
106
+
107
+ async def cleanup_expired_blacklist() -> int:
108
+ """Call hourly from scheduler. Returns count of removed entries."""
109
+ now = datetime.now(UTC)
110
+ async with _blacklist_lock:
111
+ expired = [jti for jti, exp in _blacklist.items() if exp < now]
112
+ for jti in expired:
113
+ del _blacklist[jti]
114
+ return len(expired)
115
+
116
+ # 7, 8 & 9. Cookie Management
117
+ def set_refresh_cookie(response: Response, token: str):
118
+ response.set_cookie(
119
+ key="refresh_token",
120
+ value=token,
121
+ max_age=settings.REFRESH_TOKEN_EXPIRE_DAYS * 86400,
122
+ httponly=True,
123
+ secure=settings.is_production,
124
+ samesite="lax",
125
+ path="/api/v1/auth/refresh"
126
+ )
127
+
128
+ def clear_auth_cookies(response: Response):
129
+ response.delete_cookie("refresh_token", path="/api/v1/auth/refresh")
130
+
131
+ def get_refresh_token_from_cookie(request: Request) -> str:
132
+ token = request.cookies.get("refresh_token")
133
+ if not token:
134
+ raise AuthenticationError("No refresh token found")
135
+ return token
136
+
137
+ # 10 & 11. Security Utilities
138
+ def generate_verification_token() -> Tuple[str, str]:
139
+ plain = secrets.token_urlsafe(32)
140
+ hashed = hashlib.sha256(plain.encode()).hexdigest()
141
+ return plain, hashed
142
+
143
+ def verify_token_hash(plain: str, stored_hash: str) -> bool:
144
+ computed = hashlib.sha256(plain.encode()).hexdigest()
145
+ return secrets.compare_digest(computed, stored_hash)
backend/utils/validators.py CHANGED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import re
3
+ import uuid
4
+ import logging
5
+ import unicodedata
6
+ from pathlib import Path
7
+ from dataclasses import dataclass
8
+ from fastapi import UploadFile
9
+ import magic
10
+ from PIL import Image
11
+
12
+ from backend.core.config import settings
13
+ from backend.core.exceptions import (
14
+ FileTooLargeError, InvalidFileTypeError, InvalidFileError,
15
+ SecurityError, PromptInjectionError, ValidationError, PathTraversalError
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ @dataclass
21
+ class ImageMetadata:
22
+ filename: str
23
+ size_bytes: int
24
+ mime_type: str
25
+ width: int
26
+ height: int
27
+ mode: str
28
+ content: bytes
29
+
30
+ class ImageValidator:
31
+ ALLOWED_MIME_TYPES = {"image/jpeg", "image/png", "image/dicom"}
32
+ MAX_SIZE_BYTES = 10 * 1024 * 1024
33
+ MIN_DIMENSION = 64
34
+ MAX_DIMENSION = 4096
35
+
36
+ @staticmethod
37
+ async def validate(file: UploadFile) -> ImageMetadata:
38
+ content = await file.read(ImageValidator.MAX_SIZE_BYTES + 1)
39
+ if len(content) > ImageValidator.MAX_SIZE_BYTES:
40
+ raise FileTooLargeError(f"Max file size is {ImageValidator.MAX_SIZE_BYTES//1024//1024}MB")
41
+ await file.seek(0)
42
+
43
+ # Verify Mime via Magic Bytes
44
+ mime = magic.from_buffer(content[:2048], mime=True)
45
+ if mime not in ImageValidator.ALLOWED_MIME_TYPES:
46
+ raise InvalidFileTypeError(f"File type '{mime}' not supported. Allowed: {', '.join(ImageValidator.ALLOWED_MIME_TYPES)}")
47
+
48
+ try:
49
+ with Image.open(io.BytesIO(content)) as img:
50
+ width, height = img.size
51
+ mode = img.mode
52
+ except Exception:
53
+ raise InvalidFileError("File is corrupted or cannot be read as an image")
54
+
55
+ if width < ImageValidator.MIN_DIMENSION or height < ImageValidator.MIN_DIMENSION:
56
+ raise InvalidFileError(f"Image too small (min {ImageValidator.MIN_DIMENSION}px)")
57
+ if width > ImageValidator.MAX_DIMENSION or height > ImageValidator.MAX_DIMENSION:
58
+ raise InvalidFileError(f"Image too large (max {ImageValidator.MAX_DIMENSION}px)")
59
+
60
+ # EXIF script injection check
61
+ first_2kb = content[:2048].decode('utf-8', errors='ignore').lower()
62
+ if any(bad in first_2kb for bad in ["<script", "javascript:", "eval("]):
63
+ logger.error("Security alert: Script payload detected in image bytes")
64
+ raise SecurityError("Invalid file content detected")
65
+
66
+ return ImageMetadata(
67
+ filename=file.filename or "unknown.png", size_bytes=len(content),
68
+ mime_type=mime, width=width, height=height, mode=mode, content=content
69
+ )
70
+
71
+ def sanitize_symptoms_text(text: str) -> str:
72
+ if not text:
73
+ return ""
74
+ text = text.replace("\x00", "")
75
+ text = unicodedata.normalize("NFKC", text)
76
+ text = re.sub(r"<[^>]+>", "", text)
77
+
78
+ INJECTION_PATTERNS = [
79
+ r"ignore\s+(previous|all|above)\s+instructions",
80
+ r"you\s+are\s+now\s+(a|an)",
81
+ r"system\s*:",
82
+ r"assistant\s*:",
83
+ r"jailbreak",
84
+ r"DAN\s+mode",
85
+ r"pretend\s+you\s+are"
86
+ ]
87
+ for pattern in INJECTION_PATTERNS:
88
+ if re.search(pattern, text, re.IGNORECASE):
89
+ logger.warning(f"Prompt injection detected in input", extra={"pattern": pattern})
90
+ raise PromptInjectionError("Input contains disallowed content. Please describe symptoms naturally.")
91
+
92
+ MAX_LENGTH = 2000
93
+ if len(text) > MAX_LENGTH:
94
+ text = text[:MAX_LENGTH]
95
+
96
+ return " ".join(text.split()).strip()
97
+
98
+ def validate_patient_id(patient_id: str) -> str:
99
+ if not patient_id:
100
+ return ""
101
+ if not re.match(r"^[a-zA-Z0-9_-]{1,50}$", patient_id):
102
+ raise ValidationError("Patient ID must be 1-50 characters: letters, numbers, hyphens, underscores only")
103
+ return patient_id
104
+
105
+ def validate_chat_message(message: str) -> str:
106
+ if not message or not message.strip():
107
+ raise ValidationError("Message cannot be empty")
108
+ if len(message) > 500:
109
+ raise ValidationError("Message too long (max 500 characters)")
110
+ return sanitize_symptoms_text(message)
111
+
112
+ def safe_temp_path(filename: str) -> Path:
113
+ safe_name = re.sub(r"[^a-zA-Z0-9._-]", "_", Path(filename).name)
114
+ unique_name = f"{uuid.uuid4().hex}_{safe_name}"
115
+ temp_path = settings.TEMP_DIR / unique_name
116
+
117
+ resolved = temp_path.resolve()
118
+ temp_dir_resolved = settings.TEMP_DIR.resolve()
119
+ if not str(resolved).startswith(str(temp_dir_resolved)):
120
+ raise PathTraversalError(f"Invalid path detected: {filename}")
121
+
122
+ return temp_path
frontend/components/auth/GoogleLoginButton.jsx ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+ import { useState } from 'react';
3
+
4
+ export default function GoogleLoginButton() {
5
+ const [loading, setLoading] = useState(false);
6
+
7
+ const handleGoogleLogin = () => {
8
+ setLoading(true);
9
+ window.location.href = `${process.env.NEXT_PUBLIC_API_URL}/api/v1/auth/google/login`;
10
+ };
11
+
12
+ return (
13
+ <button
14
+ onClick={handleGoogleLogin}
15
+ disabled={loading}
16
+ className="flex items-center justify-center w-full px-4 py-2 space-x-2 bg-white border border-gray-300 rounded focus:outline-none focus:ring-2 focus:ring-offset-1 focus:ring-gray-200 hover:shadow disabled:opacity-50"
17
+ >
18
+ {loading ? (
19
+ <div className="w-5 h-5 border-2 border-gray-300 border-t-gray-600 rounded-full animate-spin" />
20
+ ) : (
21
+ <svg className="w-5 h-5" viewBox="0 0 24 24">
22
+ <path fill="#4285F4" d="M22.56 12.25c0-.78-.07-1.53-.2-2.25H12v4.26h5.92c-.26 1.37-1.04 2.53-2.21 3.31v2.77h3.57c2.08-1.92 3.28-4.74 3.28-8.09z" />
23
+ <path fill="#34A853" d="M12 23c2.97 0 5.46-.98 7.28-2.66l-3.57-2.77c-.98.66-2.23 1.06-3.71 1.06-2.86 0-5.29-1.93-6.16-4.53H2.18v2.84C3.99 20.53 7.7 23 12 23z" />
24
+ <path fill="#FBBC05" d="M5.84 14.09c-.22-.66-.35-1.36-.35-2.09s.13-1.43.35-2.09V7.07H2.18C1.43 8.55 1 10.22 1 12s.43 3.45 1.18 4.93l2.85-2.22.81-.62z" />
25
+ <path fill="#EA4335" d="M12 5.38c1.62 0 3.06.56 4.21 1.64l3.15-3.15C17.45 2.09 14.97 1 12 1 7.7 1 3.99 3.47 2.18 7.07l3.66 2.84c.87-2.6 3.3-4.53 6.16-4.53z" />
26
+ </svg>
27
+ )}
28
+ <span className="text-sm font-medium text-gray-700">Continue with Google</span>
29
+ </button>
30
+ );
31
+ }