Implement variant of dep injections to get active/super user

This commit is contained in:
François Voron
2019-10-11 08:09:47 +02:00
parent ef796abb55
commit 76bb7bf6a5
8 changed files with 330 additions and 42 deletions

View File

@ -1,18 +1,57 @@
from typing import Callable
from fastapi import HTTPException
from starlette import status
from starlette.responses import Response
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUserDB
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
class BaseAuthentication:
"""Base adapter for generating and decoding authentication tokens."""
"""
Base adapter for generating and decoding authentication tokens.
Provides dependency injectors to get current active/superuser user.
"""
async def get_login_response(self, user: BaseUserDB, response: Response):
raise NotImplementedError()
def get_authentication_method(
def get_current_user(self, user_db: BaseUserDatabase):
raise NotImplementedError()
def get_current_active_user(self, user_db: BaseUserDatabase):
raise NotImplementedError()
def get_current_superuser(self, user_db: BaseUserDatabase):
raise NotImplementedError()
def _get_authentication_method(
self, user_db: BaseUserDatabase
) -> Callable[..., BaseUserDB]:
raise NotImplementedError()
def _get_current_user_base(self, user: BaseUserDB) -> BaseUserDB:
if user is None:
raise self._get_credentials_exception()
return user
def _get_current_active_user_base(self, user: BaseUserDB) -> BaseUserDB:
user = self._get_current_user_base(user)
if not user.is_active:
raise self._get_credentials_exception()
return user
def _get_current_superuser_base(self, user: BaseUserDB) -> BaseUserDB:
user = self._get_current_active_user_base(user)
if not user.is_superuser:
raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN)
return user
def _get_credentials_exception(
self, status_code: int = status.HTTP_401_UNAUTHORIZED
) -> HTTPException:
return HTTPException(status_code=status_code)

View File

@ -1,9 +1,8 @@
from datetime import datetime, timedelta
import jwt
from fastapi import Depends, HTTPException
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from starlette import status
from starlette.responses import Response
from fastapi_users.authentication.base import BaseAuthentication
@ -42,24 +41,36 @@ class JWTAuthentication(BaseAuthentication):
return {"token": token}
def get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED
)
def get_current_user(self, user_db: BaseUserDatabase):
async def _get_current_user(token: str = Depends(oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_user_base(user)
return _get_current_user
def get_current_active_user(self, user_db: BaseUserDatabase):
async def _get_current_active_user(token: str = Depends(oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_active_user_base(user)
return _get_current_active_user
def get_current_superuser(self, user_db: BaseUserDatabase):
async def _get_current_superuser(token: str = Depends(oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_superuser_base(user)
return _get_current_superuser
def _get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(oauth2_scheme)):
try:
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
user_id = data.get("user_id")
if user_id is None:
raise credentials_exception
return None
except jwt.PyJWTError:
raise credentials_exception
user = await user_db.get(user_id)
if user is None or not user.is_active:
raise credentials_exception
return user
return None
return await user_db.get(user_id)
return authentication_method

View File

@ -34,5 +34,11 @@ class FastAPIUsers:
self.auth = auth
self.router = get_user_router(self.db, user_model, self.auth)
get_current_user = self.auth.get_authentication_method(self.db)
get_current_user = self.auth.get_current_user(self.db)
self.get_current_user = get_current_user # type: ignore
get_current_active_user = self.auth.get_current_active_user(self.db)
self.get_current_active_user = get_current_active_user # type: ignore
get_current_superuser = self.auth.get_current_superuser(self.db)
self.get_current_superuser = get_current_superuser # type: ignore