Fix #42: multiple authentication backends (#47)

* Revamp authentication to allow multiple backends

* Make router generate a login route for each backend

* Apply black

* Remove unused imports

* Complete docstrings

* Update documentation

* WIP add cookie auth

* Complete cookie auth unit tests

* Add documentation for cookie auth

* Fix cookie backend default name

* Don't make cookie return a Response
This commit is contained in:
François Voron
2019-12-04 13:32:49 +01:00
committed by GitHub
parent 5e4c7996de
commit 49deb437a6
22 changed files with 591 additions and 341 deletions

View File

@ -1,2 +1,60 @@
from typing import Sequence
from fastapi import HTTPException
from starlette import status
from starlette.requests import Request
from fastapi_users.authentication.base import BaseAuthentication # noqa: F401
from fastapi_users.authentication.cookie import CookieAuthentication # noqa: F401
from fastapi_users.authentication.jwt import JWTAuthentication # noqa: F401
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUserDB
class Authenticator:
"""
Provides dependency callables to retrieve authenticated user.
It performs the authentication against a list of backends
defined by the end-developer. The first backend yielding a user wins.
If no backend yields a user, an HTTPException is raised.
:param backends: List of authentication backends.
:param user_db: Database adapter instance.
"""
backends: Sequence[BaseAuthentication]
user_db: BaseUserDatabase
def __init__(
self, backends: Sequence[BaseAuthentication], user_db: BaseUserDatabase
):
self.backends = backends
self.user_db = user_db
async def get_current_user(self, request: Request) -> BaseUserDB:
return await self._authenticate(request)
async def get_current_active_user(self, request: Request) -> BaseUserDB:
user = await self.get_current_user(request)
if not user.is_active:
raise self._get_credentials_exception()
return user
async def get_current_superuser(self, request: Request) -> BaseUserDB:
user = await self.get_current_active_user(request)
if not user.is_superuser:
raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN)
return user
async def _authenticate(self, request: Request) -> BaseUserDB:
for backend in self.backends:
user = await backend(request, self.user_db)
if user is not None:
return user
raise self._get_credentials_exception()
def _get_credentials_exception(
self, status_code: int = status.HTTP_401_UNAUTHORIZED
) -> HTTPException:
return HTTPException(status_code=status_code)

View File

@ -1,70 +1,30 @@
from typing import Callable
from typing import Any, Optional
from fastapi import HTTPException
from fastapi.security import OAuth2PasswordBearer
from starlette import status
from starlette.requests import Request
from starlette.responses import Response
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUserDB
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
class BaseAuthentication:
"""
Base adapter for generating and decoding authentication tokens.
Base authentication backend.
Provides dependency injectors to get current active/superuser user.
Every backend should derive from this class.
:param scheme: Optional authentication scheme for OpenAPI.
Defaults to `OAuth2PasswordBearer(tokenUrl="/users/login")`.
Override it if your login route lives somewhere else.
:param name: Name of the backend. It will be used to name the login route.
"""
scheme: OAuth2PasswordBearer
name: str
def __init__(self, scheme: OAuth2PasswordBearer = None):
if scheme is None:
self.scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
else:
self.scheme = scheme
def __init__(self, name: str = "base"):
self.name = name
async def get_login_response(self, user: BaseUserDB, response: Response):
async def __call__(
self, request: Request, user_db: BaseUserDatabase
) -> Optional[BaseUserDB]:
raise NotImplementedError()
def get_current_user(self, user_db: BaseUserDatabase):
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
raise NotImplementedError()
def get_current_active_user(self, user_db: BaseUserDatabase):
raise NotImplementedError()
def get_current_superuser(self, user_db: BaseUserDatabase):
raise NotImplementedError()
def _get_authentication_method(
self, user_db: BaseUserDatabase
) -> Callable[..., BaseUserDB]:
raise NotImplementedError()
def _get_current_user_base(self, user: BaseUserDB) -> BaseUserDB:
if user is None:
raise self._get_credentials_exception()
return user
def _get_current_active_user_base(self, user: BaseUserDB) -> BaseUserDB:
user = self._get_current_user_base(user)
if not user.is_active:
raise self._get_credentials_exception()
return user
def _get_current_superuser_base(self, user: BaseUserDB) -> BaseUserDB:
user = self._get_current_active_user_base(user)
if not user.is_superuser:
raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN)
return user
def _get_credentials_exception(
self, status_code: int = status.HTTP_401_UNAUTHORIZED
) -> HTTPException:
return HTTPException(status_code=status_code)

View File

@ -0,0 +1,53 @@
from typing import Any, Optional
from fastapi.security import APIKeyCookie
from starlette.requests import Request
from starlette.responses import Response
from fastapi_users.authentication.jwt import JWTAuthentication
from fastapi_users.models import BaseUserDB
class CookieAuthentication(JWTAuthentication):
"""
Authentication backend using a cookie.
Internally, uses a JWT token to store the data.
:param secret: Secret used to encode the cookie.
:param lifetime_seconds: Lifetime duration of the cookie in seconds.
:param cookie_name: Name of the cookie.
:param name: Name of the backend. It will be used to name the login route.
"""
lifetime_seconds: int
cookie_name: str
def __init__(
self,
secret: str,
lifetime_seconds: int,
cookie_name: str = "fastapiusersauth",
name: str = "cookie",
):
super().__init__(secret, lifetime_seconds, name=name)
self.lifetime_seconds = lifetime_seconds
self.cookie_name = cookie_name
self.api_key_cookie = APIKeyCookie(name=self.cookie_name, auto_error=False)
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
token = await self._generate_token(user)
response.set_cookie(
self.cookie_name,
token,
max_age=self.lifetime_seconds,
secure=True,
httponly=True,
)
# We shouldn't return directly the response
# so that FastAPI can terminate it properly
return None
async def _retrieve_token(self, request: Request) -> Optional[str]:
return await self.api_key_cookie.__call__(request)

View File

@ -1,5 +1,8 @@
from typing import Any, Optional
import jwt
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from starlette.requests import Request
from starlette.responses import Response
from fastapi_users.authentication.base import BaseAuthentication
@ -10,62 +13,58 @@ from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
class JWTAuthentication(BaseAuthentication):
"""
Authentication using a JWT.
Authentication backend using a JWT in a Bearer header.
:param secret: Secret used to encode the JWT.
:param lifetime_seconds: Lifetime duration of the JWT in seconds.
:param tokenUrl: Path where to get a token.
:param name: Name of the backend. It will be used to name the login route.
"""
token_audience: str = "fastapi-users:auth"
secret: str
lifetime_seconds: int
def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
secret: str,
lifetime_seconds: int,
tokenUrl: str = "/users/login",
name: str = "jwt",
):
super().__init__(name)
self.secret = secret
self.lifetime_seconds = lifetime_seconds
self.scheme = OAuth2PasswordBearer(tokenUrl, auto_error=False)
async def get_login_response(self, user: BaseUserDB, response: Response):
data = {"user_id": user.id, "aud": self.token_audience}
token = generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM)
async def __call__(
self, request: Request, user_db: BaseUserDatabase,
) -> Optional[BaseUserDB]:
token = await self._retrieve_token(request)
if token is None:
return None
try:
data = jwt.decode(
token,
self.secret,
audience=self.token_audience,
algorithms=[JWT_ALGORITHM],
)
user_id = data.get("user_id")
if user_id is None:
return None
except jwt.PyJWTError:
return None
return await user_db.get(user_id)
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
token = await self._generate_token(user)
return {"token": token}
def get_current_user(self, user_db: BaseUserDatabase):
async def _get_current_user(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_user_base(user)
async def _retrieve_token(self, request: Request) -> Optional[str]:
return await self.scheme.__call__(request)
return _get_current_user
def get_current_active_user(self, user_db: BaseUserDatabase):
async def _get_current_active_user(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_active_user_base(user)
return _get_current_active_user
def get_current_superuser(self, user_db: BaseUserDatabase):
async def _get_current_superuser(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_superuser_base(user)
return _get_current_superuser
def _get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(self.scheme)):
try:
data = jwt.decode(
token,
self.secret,
audience=self.token_audience,
algorithms=[JWT_ALGORITHM],
)
user_id = data.get("user_id")
if user_id is None:
return None
except jwt.PyJWTError:
return None
return await user_db.get(user_id)
return authentication_method
async def _generate_token(self, user: BaseUserDB) -> str:
data = {"user_id": user.id, "aud": self.token_audience}
return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM)

View File

@ -1,8 +1,8 @@
from typing import Callable, Type
from typing import Callable, Sequence, Type
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.authentication import Authenticator, BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUser, BaseUserDB
from fastapi_users.models import BaseUser
from fastapi_users.router import Event, UserRouter, get_user_router
@ -11,7 +11,7 @@ class FastAPIUsers:
Main object that ties together the component for users authentication.
:param db: Database adapter instance.
:param auth: Authentication logic instance.
:param auth_backends: List of authentication backends.
:param user_model: Pydantic model of a user.
:param reset_password_token_secret: Secret to encode reset password token.
:param reset_password_token_lifetime_seconds: Lifetime of reset password token.
@ -21,36 +21,30 @@ class FastAPIUsers:
"""
db: BaseUserDatabase
auth: BaseAuthentication
authenticator: Authenticator
router: UserRouter
get_current_user: Callable[..., BaseUserDB]
def __init__(
self,
db: BaseUserDatabase,
auth: BaseAuthentication,
auth_backends: Sequence[BaseAuthentication],
user_model: Type[BaseUser],
reset_password_token_secret: str,
reset_password_token_lifetime_seconds: int = 3600,
):
self.db = db
self.auth = auth
self.authenticator = Authenticator(auth_backends, db)
self.router = get_user_router(
self.db,
user_model,
self.auth,
self.authenticator,
reset_password_token_secret,
reset_password_token_lifetime_seconds,
)
get_current_user = self.auth.get_current_user(self.db)
self.get_current_user = get_current_user # type: ignore
get_current_active_user = self.auth.get_current_active_user(self.db)
self.get_current_active_user = get_current_active_user # type: ignore
get_current_superuser = self.auth.get_current_superuser(self.db)
self.get_current_superuser = get_current_superuser # type: ignore
self.get_current_user = self.authenticator.get_current_user
self.get_current_active_user = self.authenticator.get_current_active_user
self.get_current_superuser = self.authenticator.get_current_superuser
def on_after_register(self) -> Callable:
"""Add an event handler on successful registration."""

View File

@ -10,7 +10,7 @@ from pydantic import EmailStr
from starlette import status
from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.authentication import Authenticator, BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUser, Models
from fastapi_users.password import get_password_hash
@ -46,10 +46,28 @@ class UserRouter(APIRouter):
handler(*args, **kwargs)
def _add_login_route(
router: UserRouter, user_db: BaseUserDatabase, auth_backend: BaseAuthentication
):
@router.post(f"/login/{auth_backend.name}")
async def login(
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
):
user = await user_db.authenticate(credentials)
if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
)
return await auth_backend.get_login_response(user, response)
def get_user_router(
user_db: BaseUserDatabase,
user_model: typing.Type[BaseUser],
auth: BaseAuthentication,
authenticator: Authenticator,
reset_password_token_secret: str,
reset_password_token_lifetime_seconds: int = 3600,
) -> UserRouter:
@ -59,8 +77,8 @@ def get_user_router(
reset_password_token_audience = "fastapi-users:reset"
get_current_active_user = auth.get_current_active_user(user_db)
get_current_superuser = auth.get_current_superuser(user_db)
get_current_active_user = authenticator.get_current_active_user
get_current_superuser = authenticator.get_current_superuser
async def _get_or_404(id: str) -> models.UserDB: # type: ignore
user = await user_db.get(id)
@ -79,6 +97,9 @@ def get_user_router(
setattr(user, field, update_dict[field])
return await user_db.update(user)
for auth_backend in authenticator.backends:
_add_login_route(router, user_db, auth_backend)
@router.post(
"/register", response_model=models.User, status_code=status.HTTP_201_CREATED
)
@ -101,20 +122,6 @@ def get_user_router(
return created_user
@router.post("/login")
async def login(
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
):
user = await user_db.authenticate(credentials)
if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
)
return await auth.get_login_response(user, response)
@router.post("/forgot-password", status_code=status.HTTP_202_ACCEPTED)
async def forgot_password(email: EmailStr = Body(..., embed=True)):
user = await user_db.get_by_email(email)