Spaces:
Sleeping
Sleeping
| from datetime import datetime, timedelta | |
| from typing import Union, Any, List | |
| from jose import jwt, JWTError | |
| from fastapi import Depends, HTTPException, status, Request, Query | |
| from fastapi.security import OAuth2PasswordBearer | |
| from sqlalchemy.orm import Session | |
| from app.core.config import settings | |
| from app.core.database import get_db | |
| from app.models import models | |
| from app.crud import crud | |
| oauth2_scheme = OAuth2PasswordBearer( | |
| tokenUrl=f"{settings.API_V1_STR}/auth/login" | |
| ) | |
| def create_access_token(subject: Union[str, Any], role: str, expires_delta: timedelta = None) -> str: | |
| if expires_delta: | |
| expire = datetime.utcnow() + expires_delta | |
| else: | |
| expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) | |
| to_encode = { | |
| "exp": expire, | |
| "sub": str(subject), | |
| "role": role, | |
| "type": "access" | |
| } | |
| encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) | |
| return encoded_jwt | |
| def create_refresh_token(subject: Union[str, Any], role: str, expires_delta: timedelta = None) -> str: | |
| if expires_delta: | |
| expire = datetime.utcnow() + expires_delta | |
| else: | |
| expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) | |
| to_encode = { | |
| "exp": expire, | |
| "sub": str(subject), | |
| "role": role, | |
| "type": "refresh" | |
| } | |
| encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) | |
| return encoded_jwt | |
| def get_current_user( | |
| db: Session = Depends(get_db), | |
| token: str = Depends(oauth2_scheme) | |
| ) -> models.User: | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| payload = jwt.decode( | |
| token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] | |
| ) | |
| email: str = payload.get("sub") | |
| token_type: str = payload.get("type") | |
| if email is None or token_type != "access": | |
| raise credentials_exception | |
| except JWTError: | |
| raise credentials_exception | |
| user = crud.get_user_by_email(db, email=email) | |
| if user is None: | |
| raise credentials_exception | |
| if not user.is_active: | |
| raise HTTPException(status_code=400, detail="Inactive user") | |
| return user | |
| class RoleChecker: | |
| def __init__(self, allowed_roles: List[str]): | |
| self.allowed_roles = allowed_roles | |
| def __call__( | |
| self, | |
| current_user: models.User = Depends(get_current_user) | |
| ) -> models.User: | |
| if current_user.role.name not in self.allowed_roles: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail=f"User role '{current_user.role.name}' does not have permission to access this resource. Allowed: {self.allowed_roles}" | |
| ) | |
| return current_user | |
| def get_current_user_from_token(token: str, db: Session) -> models.User: | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| payload = jwt.decode( | |
| token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] | |
| ) | |
| email: str = payload.get("sub") | |
| token_type: str = payload.get("type") | |
| if email is None or token_type != "access": | |
| raise credentials_exception | |
| except JWTError: | |
| raise credentials_exception | |
| user = crud.get_user_by_email(db, email=email) | |
| if user is None: | |
| raise credentials_exception | |
| if not user.is_active: | |
| raise HTTPException(status_code=400, detail="Inactive user") | |
| return user | |
| def get_current_user_sse( | |
| request: Request, | |
| token: str = Query(None), | |
| db: Session = Depends(get_db) | |
| ) -> models.User: | |
| # 1. Try to get token from query parameter | |
| if token: | |
| return get_current_user_from_token(token, db) | |
| # 2. Try to get token from Authorization header | |
| auth_header = request.headers.get("Authorization") | |
| if auth_header and auth_header.startswith("Bearer "): | |
| actual_token = auth_header.split(" ")[1] | |
| return get_current_user_from_token(actual_token, db) | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Not authenticated", | |
| ) | |
| class RoleCheckerSSE: | |
| def __init__(self, allowed_roles: List[str]): | |
| self.allowed_roles = allowed_roles | |
| def __call__( | |
| self, | |
| current_user: models.User = Depends(get_current_user_sse) | |
| ) -> models.User: | |
| if current_user.role.name not in self.allowed_roles: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail=f"User role '{current_user.role.name}' does not have permission to access this resource. Allowed: {self.allowed_roles}" | |
| ) | |
| return current_user | |