diff --git a/README.md b/README.md index d50bfdb8..775f1ffe 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,9 @@ Add quickly a registration and authentication system to your [FastAPI](https://f * [X] Customizable database backend * [X] SQLAlchemy async backend included thanks to [encode/databases](https://www.encode.io/databases/) * [X] MongoDB async backend included thanks to [mongodb/motor](https://github.com/mongodb/motor) -* [X] Customizable authentication backend +* [X] Multiple customizable authentication backends * [X] JWT authentication backend included + * [X] Cookie authentication backend included ## Development diff --git a/docs/configuration/authentication/cookie.md b/docs/configuration/authentication/cookie.md new file mode 100644 index 00000000..908b390f --- /dev/null +++ b/docs/configuration/authentication/cookie.md @@ -0,0 +1,50 @@ +# Cookie + +Cookies are an easy way to store stateful information into the user browser. Thus, it is more useful for browser-based navigation (e.g. a front-end app making API requests) rather than pure API interaction. + +## Configuration + +```py +from fastapi_users.authentication import CookieAuthentication + +SECRET = "SECRET" + +auth_backends = [] + +cookie_authentication = CookieAuthentication(secret=SECRET, lifetime_seconds=3600)) + +auth_backends.append(cookie_authentication) +``` + +As you can see, instantiation is quite simple. You just have to define a constant `SECRET` which is used to encode the token and the lifetime of the cookie (in seconds). + +You can optionally define the `cookie_name`. **Defaults to `fastapiusersauth`**. + +You can also optionally define the `name` which will be used to generate its [`/login` route](../../usage/routes.md#post-loginname). **Defaults to `cookie`**. + +```py +cookie_authentication = CookieAuthentication( + secret=SECRET, + lifetime_seconds=3600, + name="my-cookie", +) +``` + +!!! tip + The value of the cookie is actually a JWT. This authentication backend shares most of its logic with the [JWT](./jwt.md) one. + +## Login + +This method will return a response with a valid `set-cookie` header upon successful login: + +!!! success "`200 OK`" + +> Check documentation about [login route](../../usage/routes.md#post-loginname). + +## Authentication + +This method expects that you provide a valid cookie in the headers. + +## Next steps + +We will now configure the main **FastAPI Users** object that will expose the [API router](../router.md). diff --git a/docs/configuration/authentication/index.md b/docs/configuration/authentication/index.md index 48af84f0..0359c933 100644 --- a/docs/configuration/authentication/index.md +++ b/docs/configuration/authentication/index.md @@ -2,6 +2,15 @@ **FastAPI Users** allows you to plug in several authentication methods. +## How it works? + +You can have **several** authentication methods, e.g. a cookie authentication for browser-based queries and a JWT token authentication for pure API queries. + +When checking authentication, each method is run one after the other. The first method yielding a user wins. If no method yields a user, an `HTTPException` is raised. + +Each defined method will generate a [`/login/{name}`](../../usage/routes.md#post-loginname) route where `name` is defined on the authentication method object. + ## Provided methods * [JWT authentication](jwt.md) +* [Cookie authentication](cookie.md) diff --git a/docs/configuration/authentication/jwt.md b/docs/configuration/authentication/jwt.md index 937fc2f1..39831e5a 100644 --- a/docs/configuration/authentication/jwt.md +++ b/docs/configuration/authentication/jwt.md @@ -9,11 +9,38 @@ from fastapi_users.authentication import JWTAuthentication SECRET = "SECRET" -auth = JWTAuthentication(secret=SECRET, lifetime_seconds=3600) +auth_backends = [] + +jwt_authentication = JWTAuthentication(secret=SECRET, lifetime_seconds=3600)) + +auth_backends.append(jwt_authentication) ``` As you can see, instantiation is quite simple. You just have to define a constant `SECRET` which is used to encode the token and the lifetime of token (in seconds). +You can also optionally define the `name` which will be used to generate its [`/login` route](../../usage/routes.md#post-loginname). **Defaults to `jwt`**. + +```py +jwt_authentication = JWTAuthentication( + secret=SECRET, + lifetime_seconds=3600, + name="my-jwt", +) +``` + +## Login + +This method will return a JWT token upon successful login: + +!!! success "`200 OK`" + ```json + { + "token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoiOTIyMWZmYzktNjQwZi00MzcyLTg2ZDMtY2U2NDJjYmE1NjAzIiwiYXVkIjoiZmFzdGFwaS11c2VyczphdXRoIiwiZXhwIjoxNTcxNTA0MTkzfQ.M10bjOe45I5Ncu_uXvOmVV8QxnL-nZfcH96U90JaocI" + } + ``` + +> Check documentation about [login route](../../usage/routes.md#post-loginname). + ## Authentication This method expects that you provide a `Bearer` authentication with a valid JWT. diff --git a/docs/configuration/router.md b/docs/configuration/router.md index a139eb90..c9b7cb60 100644 --- a/docs/configuration/router.md +++ b/docs/configuration/router.md @@ -7,7 +7,7 @@ We're almost there! The last step is to configure the `FastAPIUsers` object that Configure `FastAPIUsers` object with all the elements we defined before. More precisely: * `db`: Database adapter instance. -* `auth`: Authentication logic instance. +* `auth_backends`: List of authentication backends. See [Authentication](./authentication/index.md). * `user_model`: Pydantic model of a user. * `reset_password_token_secret`: Secret to encode reset password token. * `reset_password_token_lifetime_seconds`: Lifetime of reset password token in seconds. Default to one hour. @@ -17,7 +17,7 @@ from fastapi_users import FastAPIUsers fastapi_users = FastAPIUsers( user_db, - auth, + auth_backends, User, SECRET, ) diff --git a/docs/src/full_mongodb.py b/docs/src/full_mongodb.py index ca46e4bf..a9c7a127 100644 --- a/docs/src/full_mongodb.py +++ b/docs/src/full_mongodb.py @@ -20,10 +20,12 @@ class User(BaseUser): pass -auth = JWTAuthentication(secret=SECRET, lifetime_seconds=3600) +auth_backends = [ + JWTAuthentication(secret=SECRET, lifetime_seconds=3600), +] app = FastAPI() -fastapi_users = FastAPIUsers(user_db, auth, User, SECRET) +fastapi_users = FastAPIUsers(user_db, auth_backends, User, SECRET) app.include_router(fastapi_users.router, prefix="/users", tags=["users"]) diff --git a/docs/src/full_sqlalchemy.py b/docs/src/full_sqlalchemy.py index 1515ed1b..f8091d4e 100644 --- a/docs/src/full_sqlalchemy.py +++ b/docs/src/full_sqlalchemy.py @@ -33,10 +33,12 @@ class User(BaseUser): pass -auth = JWTAuthentication(secret=SECRET, lifetime_seconds=3600) +auth_backends = [ + JWTAuthentication(secret=SECRET, lifetime_seconds=3600), +] app = FastAPI() -fastapi_users = FastAPIUsers(user_db, auth, User, SECRET) +fastapi_users = FastAPIUsers(user_db, auth_backends, User, SECRET) app.include_router(fastapi_users.router, prefix="/users", tags=["users"]) diff --git a/docs/usage/routes.md b/docs/usage/routes.md index 896097df..8c4ffa5a 100644 --- a/docs/usage/routes.md +++ b/docs/usage/routes.md @@ -37,22 +37,15 @@ Register a new user. Will call the `on_after_register` [event handlers](../confi } ``` -### `POST /login` +### `POST /login/{name}` -Login a user. +Login a user against the method named `name`. Check the corresponding [authentication method](../configuration/authentication/index.md) to view the success response. !!! abstract "Payload (`application/x-www-form-urlencoded`)" ``` username=king.arthur@camelot.bt&password=guinevere ``` -!!! success "`200 OK`" - ```json - { - "token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoiOTIyMWZmYzktNjQwZi00MzcyLTg2ZDMtY2U2NDJjYmE1NjAzIiwiYXVkIjoiZmFzdGFwaS11c2VyczphdXRoIiwiZXhwIjoxNTcxNTA0MTkzfQ.M10bjOe45I5Ncu_uXvOmVV8QxnL-nZfcH96U90JaocI" - } - ``` - !!! fail "`422 Validation Error`" !!! fail "`400 Bad Request`" diff --git a/fastapi_users/authentication/__init__.py b/fastapi_users/authentication/__init__.py index 87ac17dc..70e28034 100644 --- a/fastapi_users/authentication/__init__.py +++ b/fastapi_users/authentication/__init__.py @@ -1,2 +1,60 @@ +from typing import Sequence + +from fastapi import HTTPException +from starlette import status +from starlette.requests import Request + from fastapi_users.authentication.base import BaseAuthentication # noqa: F401 +from fastapi_users.authentication.cookie import CookieAuthentication # noqa: F401 from fastapi_users.authentication.jwt import JWTAuthentication # noqa: F401 +from fastapi_users.db import BaseUserDatabase +from fastapi_users.models import BaseUserDB + + +class Authenticator: + """ + Provides dependency callables to retrieve authenticated user. + + It performs the authentication against a list of backends + defined by the end-developer. The first backend yielding a user wins. + If no backend yields a user, an HTTPException is raised. + + :param backends: List of authentication backends. + :param user_db: Database adapter instance. + """ + + backends: Sequence[BaseAuthentication] + user_db: BaseUserDatabase + + def __init__( + self, backends: Sequence[BaseAuthentication], user_db: BaseUserDatabase + ): + self.backends = backends + self.user_db = user_db + + async def get_current_user(self, request: Request) -> BaseUserDB: + return await self._authenticate(request) + + async def get_current_active_user(self, request: Request) -> BaseUserDB: + user = await self.get_current_user(request) + if not user.is_active: + raise self._get_credentials_exception() + return user + + async def get_current_superuser(self, request: Request) -> BaseUserDB: + user = await self.get_current_active_user(request) + if not user.is_superuser: + raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN) + return user + + async def _authenticate(self, request: Request) -> BaseUserDB: + for backend in self.backends: + user = await backend(request, self.user_db) + if user is not None: + return user + raise self._get_credentials_exception() + + def _get_credentials_exception( + self, status_code: int = status.HTTP_401_UNAUTHORIZED + ) -> HTTPException: + return HTTPException(status_code=status_code) diff --git a/fastapi_users/authentication/base.py b/fastapi_users/authentication/base.py index 41616680..799b579d 100644 --- a/fastapi_users/authentication/base.py +++ b/fastapi_users/authentication/base.py @@ -1,70 +1,30 @@ -from typing import Callable +from typing import Any, Optional -from fastapi import HTTPException -from fastapi.security import OAuth2PasswordBearer -from starlette import status +from starlette.requests import Request from starlette.responses import Response from fastapi_users.db import BaseUserDatabase from fastapi_users.models import BaseUserDB -credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - class BaseAuthentication: """ - Base adapter for generating and decoding authentication tokens. + Base authentication backend. - Provides dependency injectors to get current active/superuser user. + Every backend should derive from this class. - :param scheme: Optional authentication scheme for OpenAPI. - Defaults to `OAuth2PasswordBearer(tokenUrl="/users/login")`. - Override it if your login route lives somewhere else. + :param name: Name of the backend. It will be used to name the login route. """ - scheme: OAuth2PasswordBearer + name: str - def __init__(self, scheme: OAuth2PasswordBearer = None): - if scheme is None: - self.scheme = OAuth2PasswordBearer(tokenUrl="/users/login") - else: - self.scheme = scheme + def __init__(self, name: str = "base"): + self.name = name - async def get_login_response(self, user: BaseUserDB, response: Response): + async def __call__( + self, request: Request, user_db: BaseUserDatabase + ) -> Optional[BaseUserDB]: raise NotImplementedError() - def get_current_user(self, user_db: BaseUserDatabase): + async def get_login_response(self, user: BaseUserDB, response: Response) -> Any: raise NotImplementedError() - - def get_current_active_user(self, user_db: BaseUserDatabase): - raise NotImplementedError() - - def get_current_superuser(self, user_db: BaseUserDatabase): - raise NotImplementedError() - - def _get_authentication_method( - self, user_db: BaseUserDatabase - ) -> Callable[..., BaseUserDB]: - raise NotImplementedError() - - def _get_current_user_base(self, user: BaseUserDB) -> BaseUserDB: - if user is None: - raise self._get_credentials_exception() - return user - - def _get_current_active_user_base(self, user: BaseUserDB) -> BaseUserDB: - user = self._get_current_user_base(user) - if not user.is_active: - raise self._get_credentials_exception() - return user - - def _get_current_superuser_base(self, user: BaseUserDB) -> BaseUserDB: - user = self._get_current_active_user_base(user) - if not user.is_superuser: - raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN) - return user - - def _get_credentials_exception( - self, status_code: int = status.HTTP_401_UNAUTHORIZED - ) -> HTTPException: - return HTTPException(status_code=status_code) diff --git a/fastapi_users/authentication/cookie.py b/fastapi_users/authentication/cookie.py new file mode 100644 index 00000000..c027965f --- /dev/null +++ b/fastapi_users/authentication/cookie.py @@ -0,0 +1,53 @@ +from typing import Any, Optional + +from fastapi.security import APIKeyCookie +from starlette.requests import Request +from starlette.responses import Response + +from fastapi_users.authentication.jwt import JWTAuthentication +from fastapi_users.models import BaseUserDB + + +class CookieAuthentication(JWTAuthentication): + """ + Authentication backend using a cookie. + + Internally, uses a JWT token to store the data. + + :param secret: Secret used to encode the cookie. + :param lifetime_seconds: Lifetime duration of the cookie in seconds. + :param cookie_name: Name of the cookie. + :param name: Name of the backend. It will be used to name the login route. + """ + + lifetime_seconds: int + cookie_name: str + + def __init__( + self, + secret: str, + lifetime_seconds: int, + cookie_name: str = "fastapiusersauth", + name: str = "cookie", + ): + super().__init__(secret, lifetime_seconds, name=name) + self.lifetime_seconds = lifetime_seconds + self.cookie_name = cookie_name + self.api_key_cookie = APIKeyCookie(name=self.cookie_name, auto_error=False) + + async def get_login_response(self, user: BaseUserDB, response: Response) -> Any: + token = await self._generate_token(user) + response.set_cookie( + self.cookie_name, + token, + max_age=self.lifetime_seconds, + secure=True, + httponly=True, + ) + + # We shouldn't return directly the response + # so that FastAPI can terminate it properly + return None + + async def _retrieve_token(self, request: Request) -> Optional[str]: + return await self.api_key_cookie.__call__(request) diff --git a/fastapi_users/authentication/jwt.py b/fastapi_users/authentication/jwt.py index ce51618b..7d81d938 100644 --- a/fastapi_users/authentication/jwt.py +++ b/fastapi_users/authentication/jwt.py @@ -1,5 +1,8 @@ +from typing import Any, Optional + import jwt -from fastapi import Depends +from fastapi.security import OAuth2PasswordBearer +from starlette.requests import Request from starlette.responses import Response from fastapi_users.authentication.base import BaseAuthentication @@ -10,62 +13,58 @@ from fastapi_users.utils import JWT_ALGORITHM, generate_jwt class JWTAuthentication(BaseAuthentication): """ - Authentication using a JWT. + Authentication backend using a JWT in a Bearer header. :param secret: Secret used to encode the JWT. :param lifetime_seconds: Lifetime duration of the JWT in seconds. + :param tokenUrl: Path where to get a token. + :param name: Name of the backend. It will be used to name the login route. """ token_audience: str = "fastapi-users:auth" secret: str lifetime_seconds: int - def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + secret: str, + lifetime_seconds: int, + tokenUrl: str = "/users/login", + name: str = "jwt", + ): + super().__init__(name) self.secret = secret self.lifetime_seconds = lifetime_seconds + self.scheme = OAuth2PasswordBearer(tokenUrl, auto_error=False) - async def get_login_response(self, user: BaseUserDB, response: Response): - data = {"user_id": user.id, "aud": self.token_audience} - token = generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM) + async def __call__( + self, request: Request, user_db: BaseUserDatabase, + ) -> Optional[BaseUserDB]: + token = await self._retrieve_token(request) + if token is None: + return None + try: + data = jwt.decode( + token, + self.secret, + audience=self.token_audience, + algorithms=[JWT_ALGORITHM], + ) + user_id = data.get("user_id") + if user_id is None: + return None + except jwt.PyJWTError: + return None + return await user_db.get(user_id) + + async def get_login_response(self, user: BaseUserDB, response: Response) -> Any: + token = await self._generate_token(user) return {"token": token} - def get_current_user(self, user_db: BaseUserDatabase): - async def _get_current_user(token: str = Depends(self.scheme)): - user = await self._get_authentication_method(user_db)(token) - return self._get_current_user_base(user) + async def _retrieve_token(self, request: Request) -> Optional[str]: + return await self.scheme.__call__(request) - return _get_current_user - - def get_current_active_user(self, user_db: BaseUserDatabase): - async def _get_current_active_user(token: str = Depends(self.scheme)): - user = await self._get_authentication_method(user_db)(token) - return self._get_current_active_user_base(user) - - return _get_current_active_user - - def get_current_superuser(self, user_db: BaseUserDatabase): - async def _get_current_superuser(token: str = Depends(self.scheme)): - user = await self._get_authentication_method(user_db)(token) - return self._get_current_superuser_base(user) - - return _get_current_superuser - - def _get_authentication_method(self, user_db: BaseUserDatabase): - async def authentication_method(token: str = Depends(self.scheme)): - try: - data = jwt.decode( - token, - self.secret, - audience=self.token_audience, - algorithms=[JWT_ALGORITHM], - ) - user_id = data.get("user_id") - if user_id is None: - return None - except jwt.PyJWTError: - return None - return await user_db.get(user_id) - - return authentication_method + async def _generate_token(self, user: BaseUserDB) -> str: + data = {"user_id": user.id, "aud": self.token_audience} + return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM) diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index 4b8fbb94..ba754136 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -1,8 +1,8 @@ -from typing import Callable, Type +from typing import Callable, Sequence, Type -from fastapi_users.authentication import BaseAuthentication +from fastapi_users.authentication import Authenticator, BaseAuthentication from fastapi_users.db import BaseUserDatabase -from fastapi_users.models import BaseUser, BaseUserDB +from fastapi_users.models import BaseUser from fastapi_users.router import Event, UserRouter, get_user_router @@ -11,7 +11,7 @@ class FastAPIUsers: Main object that ties together the component for users authentication. :param db: Database adapter instance. - :param auth: Authentication logic instance. + :param auth_backends: List of authentication backends. :param user_model: Pydantic model of a user. :param reset_password_token_secret: Secret to encode reset password token. :param reset_password_token_lifetime_seconds: Lifetime of reset password token. @@ -21,36 +21,30 @@ class FastAPIUsers: """ db: BaseUserDatabase - auth: BaseAuthentication + authenticator: Authenticator router: UserRouter - get_current_user: Callable[..., BaseUserDB] def __init__( self, db: BaseUserDatabase, - auth: BaseAuthentication, + auth_backends: Sequence[BaseAuthentication], user_model: Type[BaseUser], reset_password_token_secret: str, reset_password_token_lifetime_seconds: int = 3600, ): self.db = db - self.auth = auth + self.authenticator = Authenticator(auth_backends, db) self.router = get_user_router( self.db, user_model, - self.auth, + self.authenticator, reset_password_token_secret, reset_password_token_lifetime_seconds, ) - get_current_user = self.auth.get_current_user(self.db) - self.get_current_user = get_current_user # type: ignore - - get_current_active_user = self.auth.get_current_active_user(self.db) - self.get_current_active_user = get_current_active_user # type: ignore - - get_current_superuser = self.auth.get_current_superuser(self.db) - self.get_current_superuser = get_current_superuser # type: ignore + self.get_current_user = self.authenticator.get_current_user + self.get_current_active_user = self.authenticator.get_current_active_user + self.get_current_superuser = self.authenticator.get_current_superuser def on_after_register(self) -> Callable: """Add an event handler on successful registration.""" diff --git a/fastapi_users/router.py b/fastapi_users/router.py index 496dcfca..29825e99 100644 --- a/fastapi_users/router.py +++ b/fastapi_users/router.py @@ -10,7 +10,7 @@ from pydantic import EmailStr from starlette import status from starlette.responses import Response -from fastapi_users.authentication import BaseAuthentication +from fastapi_users.authentication import Authenticator, BaseAuthentication from fastapi_users.db import BaseUserDatabase from fastapi_users.models import BaseUser, Models from fastapi_users.password import get_password_hash @@ -46,10 +46,28 @@ class UserRouter(APIRouter): handler(*args, **kwargs) +def _add_login_route( + router: UserRouter, user_db: BaseUserDatabase, auth_backend: BaseAuthentication +): + @router.post(f"/login/{auth_backend.name}") + async def login( + response: Response, credentials: OAuth2PasswordRequestForm = Depends() + ): + user = await user_db.authenticate(credentials) + + if user is None or not user.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.LOGIN_BAD_CREDENTIALS, + ) + + return await auth_backend.get_login_response(user, response) + + def get_user_router( user_db: BaseUserDatabase, user_model: typing.Type[BaseUser], - auth: BaseAuthentication, + authenticator: Authenticator, reset_password_token_secret: str, reset_password_token_lifetime_seconds: int = 3600, ) -> UserRouter: @@ -59,8 +77,8 @@ def get_user_router( reset_password_token_audience = "fastapi-users:reset" - get_current_active_user = auth.get_current_active_user(user_db) - get_current_superuser = auth.get_current_superuser(user_db) + get_current_active_user = authenticator.get_current_active_user + get_current_superuser = authenticator.get_current_superuser async def _get_or_404(id: str) -> models.UserDB: # type: ignore user = await user_db.get(id) @@ -79,6 +97,9 @@ def get_user_router( setattr(user, field, update_dict[field]) return await user_db.update(user) + for auth_backend in authenticator.backends: + _add_login_route(router, user_db, auth_backend) + @router.post( "/register", response_model=models.User, status_code=status.HTTP_201_CREATED ) @@ -101,20 +122,6 @@ def get_user_router( return created_user - @router.post("/login") - async def login( - response: Response, credentials: OAuth2PasswordRequestForm = Depends() - ): - user = await user_db.authenticate(credentials) - - if user is None or not user.is_active: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ErrorCode.LOGIN_BAD_CREDENTIALS, - ) - - return await auth.get_login_response(user, response) - @router.post("/forgot-password", status_code=status.HTTP_202_ACCEPTED) async def forgot_password(email: EmailStr = Body(..., embed=True)): user = await user_db.get_by_email(email) diff --git a/mkdocs.yml b/mkdocs.yml index defcc42b..0bf030c7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -34,8 +34,9 @@ nav: - configuration/databases/mongodb.md - configuration/databases/tortoise.md - Authentication: - - configuration/authentication/index.md + - Introduction: configuration/authentication/index.md - configuration/authentication/jwt.md + - configuration/authentication/cookie.md - configuration/router.md - configuration/full_example.md - Usage: diff --git a/tests/conftest.py b/tests/conftest.py index 9f9a6191..2341e081 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,14 @@ -from typing import List, Optional +from typing import Any, List, Mapping, Optional, Tuple +import http.cookies import pytest from fastapi import Depends, FastAPI from fastapi.security import OAuth2PasswordBearer +from starlette.requests import Request from starlette.responses import Response from starlette.testclient import TestClient -from fastapi_users.authentication import BaseAuthentication +from fastapi_users.authentication import Authenticator, BaseAuthentication from fastapi_users.db import BaseUserDatabase from fastapi_users.models import BaseUserDB from fastapi_users.password import get_password_hash @@ -81,70 +83,76 @@ def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase: return MockUserDatabase() +class MockAuthentication(BaseAuthentication): + def __init__(self, name: str = "mock"): + super().__init__(name) + self.scheme = OAuth2PasswordBearer("/users/login", auto_error=False) + + async def __call__(self, request: Request, user_db: BaseUserDatabase): + token = await self.scheme.__call__(request) + if token is not None: + return await user_db.get(token) + return None + + async def get_login_response(self, user: BaseUserDB, response: Response): + return {"token": user.id} + + @pytest.fixture def mock_authentication(): - class MockAuthentication(BaseAuthentication): - oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") - - async def get_login_response(self, user: BaseUserDB, response: Response): - return {"token": user.id} - - def get_current_user(self, user_db: BaseUserDatabase): - async def _get_current_user(token: str = Depends(self.oauth2_scheme)): - user = await self._get_authentication_method(user_db)(token) - return self._get_current_user_base(user) - - return _get_current_user - - def get_current_active_user(self, user_db: BaseUserDatabase): - async def _get_current_active_user( - token: str = Depends(self.oauth2_scheme), - ): - user = await self._get_authentication_method(user_db)(token) - return self._get_current_active_user_base(user) - - return _get_current_active_user - - def get_current_superuser(self, user_db: BaseUserDatabase): - async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)): - user = await self._get_authentication_method(user_db)(token) - return self._get_current_superuser_base(user) - - return _get_current_superuser - - def _get_authentication_method(self, user_db: BaseUserDatabase): - async def authentication_method(token: str = Depends(self.oauth2_scheme)): - return await user_db.get(token) - - return authentication_method - return MockAuthentication() +@pytest.fixture +def request_builder(): + def _request_builder( + headers: Mapping[str, Any] = None, cookies: Mapping[str, str] = None + ) -> Request: + encoded_headers: List[Tuple[bytes, bytes]] = [] + + if headers is not None: + encoded_headers += [ + (key.lower().encode("latin-1"), headers[key].encode("latin-1")) + for key in headers + ] + + if cookies is not None: + for key in cookies: + cookie = http.cookies.SimpleCookie() # type: http.cookies.BaseCookie + cookie[key] = cookies[key] + cookie_val = cookie.output(header="").strip() + encoded_headers.append((b"cookie", cookie_val.encode("latin-1"))) + + scope = { + "type": "http", + "headers": encoded_headers, + } + return Request(scope) + + return _request_builder + + @pytest.fixture def get_test_auth_client(mock_user_db): - def _get_test_auth_client(authentication): + def _get_test_auth_client(backends: List[BaseAuthentication]) -> TestClient: app = FastAPI() + authenticator = Authenticator(backends, mock_user_db) @app.get("/test-current-user") def test_current_user( - user: BaseUserDB = Depends(authentication.get_current_user(mock_user_db)), + user: BaseUserDB = Depends(authenticator.get_current_user), ): return user @app.get("/test-current-active-user") def test_current_active_user( - user: BaseUserDB = Depends( - authentication.get_current_active_user(mock_user_db) - ), + user: BaseUserDB = Depends(authenticator.get_current_active_user), ): return user @app.get("/test-current-superuser") def test_current_superuser( - user: BaseUserDB = Depends( - authentication.get_current_superuser(mock_user_db) - ), + user: BaseUserDB = Depends(authenticator.get_current_superuser), ): return user diff --git a/tests/test_authentication.py b/tests/test_authentication.py new file mode 100644 index 00000000..60894259 --- /dev/null +++ b/tests/test_authentication.py @@ -0,0 +1,45 @@ +from typing import Optional + +import pytest +from starlette import status +from starlette.requests import Request + +from fastapi_users.authentication import BaseAuthentication +from fastapi_users.db import BaseUserDatabase +from fastapi_users.models import BaseUserDB + + +@pytest.fixture() +def auth_backend_none(): + class BackendNone(BaseAuthentication): + async def __call__( + self, request: Request, user_db: BaseUserDatabase + ) -> Optional[BaseUserDB]: + return None + + return BackendNone() + + +@pytest.fixture() +def auth_backend_user(user): + class BackendUser(BaseAuthentication): + async def __call__( + self, request: Request, user_db: BaseUserDatabase + ) -> Optional[BaseUserDB]: + return user + + return BackendUser() + + +@pytest.mark.authentication +def test_authenticator(get_test_auth_client, auth_backend_none, auth_backend_user): + client = get_test_auth_client([auth_backend_none, auth_backend_user]) + response = client.get("/test-current-user") + assert response.status_code == status.HTTP_200_OK + + +@pytest.mark.authentication +def test_authenticator_none(get_test_auth_client, auth_backend_none): + client = get_test_auth_client([auth_backend_none, auth_backend_none]) + response = client.get("/test-current-user") + assert response.status_code == status.HTTP_401_UNAUTHORIZED diff --git a/tests/test_authentication_base.py b/tests/test_authentication_base.py index 794e0757..72ae626d 100644 --- a/tests/test_authentication_base.py +++ b/tests/test_authentication_base.py @@ -1,30 +1,27 @@ import pytest -from fastapi.security import OAuth2PasswordBearer from starlette.responses import Response from fastapi_users.authentication import BaseAuthentication -@pytest.mark.asyncio +@pytest.fixture +def base_authentication(): + return BaseAuthentication() + + @pytest.mark.authentication -@pytest.mark.parametrize( - "constructor_kwargs", [{}, {"scheme": OAuth2PasswordBearer(tokenUrl="/foo")}] -) -async def test_not_implemented_methods(constructor_kwargs, user, mock_user_db): - response = Response() - base_authentication = BaseAuthentication(**constructor_kwargs) +class TestAuthenticate: + @pytest.mark.asyncio + async def test_not_implemented( + self, base_authentication, mock_user_db, request_builder + ): + request = request_builder({}) + with pytest.raises(NotImplementedError): + await base_authentication(request, mock_user_db) - with pytest.raises(NotImplementedError): - await base_authentication.get_login_response(user, response) +@pytest.mark.authentication +@pytest.mark.asyncio +async def test_get_login_response(base_authentication, user): with pytest.raises(NotImplementedError): - await base_authentication.get_current_user(mock_user_db) - - with pytest.raises(NotImplementedError): - await base_authentication.get_current_active_user(mock_user_db) - - with pytest.raises(NotImplementedError): - await base_authentication.get_current_superuser(mock_user_db) - - with pytest.raises(NotImplementedError): - await base_authentication._get_authentication_method(mock_user_db) + await base_authentication.get_login_response(user, Response()) diff --git a/tests/test_authentication_cookie.py b/tests/test_authentication_cookie.py new file mode 100644 index 00000000..f91a99b9 --- /dev/null +++ b/tests/test_authentication_cookie.py @@ -0,0 +1,105 @@ +import re + +import jwt +import pytest +from starlette.responses import Response + +from fastapi_users.authentication.cookie import CookieAuthentication +from fastapi_users.utils import JWT_ALGORITHM, generate_jwt + +SECRET = "SECRET" +LIFETIME = 3600 +COOKIE_NAME = "COOKIE_NAME" + + +@pytest.fixture +def cookie_authentication(): + return CookieAuthentication(SECRET, LIFETIME, COOKIE_NAME) + + +@pytest.fixture +def token(): + def _token(user=None, lifetime=LIFETIME): + data = {"aud": "fastapi-users:auth"} + if user is not None: + data["user_id"] = user.id + return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) + + return _token + + +@pytest.mark.authentication +def test_default_name(cookie_authentication): + assert cookie_authentication.name == "cookie" + + +@pytest.mark.authentication +class TestAuthenticate: + @pytest.mark.asyncio + async def test_missing_token( + self, cookie_authentication, mock_user_db, request_builder + ): + request = request_builder() + authenticated_user = await cookie_authentication(request, mock_user_db) + assert authenticated_user is None + + @pytest.mark.asyncio + async def test_invalid_token( + self, cookie_authentication, mock_user_db, request_builder + ): + cookies = {} + cookies[COOKIE_NAME] = "foo" + request = request_builder(cookies=cookies) + authenticated_user = await cookie_authentication(request, mock_user_db) + assert authenticated_user is None + + @pytest.mark.asyncio + async def test_valid_token_missing_user_payload( + self, cookie_authentication, mock_user_db, request_builder, token + ): + cookies = {} + cookies[COOKIE_NAME] = token() + request = request_builder(cookies=cookies) + authenticated_user = await cookie_authentication(request, mock_user_db) + assert authenticated_user is None + + @pytest.mark.asyncio + async def test_valid_token( + self, cookie_authentication, mock_user_db, request_builder, token, user + ): + cookies = {} + cookies[COOKIE_NAME] = token(user) + request = request_builder(cookies=cookies) + authenticated_user = await cookie_authentication(request, mock_user_db) + assert authenticated_user.id == user.id + + +@pytest.mark.authentication +@pytest.mark.asyncio +async def test_get_login_response(cookie_authentication, user): + response = Response() + login_response = await cookie_authentication.get_login_response(user, response) + + # We shouldn't return directly the response + # so that FastAPI can terminate it properly + assert login_response is None + + cookies = [ + header for header in response.raw_headers if header[0] == b"set-cookie" + ] + assert len(cookies) == 1 + + cookie = cookies[0][1].decode("latin-1") + + assert f"Max-Age={LIFETIME}" in cookie + + cookie_name_value = re.match(r"^(\w+)=([^;]+);", cookie) + + cookie_name = cookie_name_value[1] + assert cookie_name == COOKIE_NAME + + cookie_value = cookie_name_value[2] + decoded = jwt.decode( + cookie_value, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM] + ) + assert decoded["user_id"] == user.id diff --git a/tests/test_authentication_jwt.py b/tests/test_authentication_jwt.py index 391ad522..cff8751d 100644 --- a/tests/test_authentication_jwt.py +++ b/tests/test_authentication_jwt.py @@ -1,6 +1,5 @@ import jwt import pytest -from starlette import status from starlette.responses import Response from fastapi_users.authentication.jwt import JWTAuthentication @@ -8,11 +7,12 @@ from fastapi_users.utils import JWT_ALGORITHM, generate_jwt SECRET = "SECRET" LIFETIME = 3600 +TOKEN_URL = "/login" @pytest.fixture def jwt_authentication(): - return JWTAuthentication(SECRET, LIFETIME) + return JWTAuthentication(SECRET, LIFETIME, TOKEN_URL) @pytest.fixture @@ -26,11 +26,47 @@ def token(): return _token -@pytest.fixture -def test_auth_client(get_test_auth_client, jwt_authentication): - return get_test_auth_client(jwt_authentication) +@pytest.mark.authentication +def test_default_name(jwt_authentication): + assert jwt_authentication.name == "jwt" +@pytest.mark.authentication +class TestAuthenticate: + @pytest.mark.asyncio + async def test_missing_token( + self, jwt_authentication, mock_user_db, request_builder + ): + request = request_builder(headers={}) + authenticated_user = await jwt_authentication(request, mock_user_db) + assert authenticated_user is None + + @pytest.mark.asyncio + async def test_invalid_token( + self, jwt_authentication, mock_user_db, request_builder + ): + request = request_builder(headers={"Authorization": "Bearer foo"}) + authenticated_user = await jwt_authentication(request, mock_user_db) + assert authenticated_user is None + + @pytest.mark.asyncio + async def test_valid_token_missing_user_payload( + self, jwt_authentication, mock_user_db, request_builder, token + ): + request = request_builder(headers={"Authorization": f"Bearer {token()}"}) + authenticated_user = await jwt_authentication(request, mock_user_db) + assert authenticated_user is None + + @pytest.mark.asyncio + async def test_valid_token( + self, jwt_authentication, mock_user_db, request_builder, token, user + ): + request = request_builder(headers={"Authorization": f"Bearer {token(user)}"}) + authenticated_user = await jwt_authentication(request, mock_user_db) + assert authenticated_user.id == user.id + + +@pytest.mark.authentication @pytest.mark.asyncio async def test_get_login_response(jwt_authentication, user): login_response = await jwt_authentication.get_login_response(user, Response()) @@ -42,108 +78,3 @@ async def test_get_login_response(jwt_authentication, user): token, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM] ) assert decoded["user_id"] == user.id - - -@pytest.mark.authentication -class TestGetCurrentUser: - def test_missing_token(self, test_auth_client): - response = test_auth_client.get("/test-current-user") - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_invalid_token(self, test_auth_client): - response = test_auth_client.get( - "/test-current-user", headers={"Authorization": "Bearer foo"} - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_valid_token_missing_user_payload(self, test_auth_client, token): - response = test_auth_client.get( - "/test-current-user", headers={"Authorization": f"Bearer {token()}"} - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user): - response = test_auth_client.get( - "/test-current-user", - headers={"Authorization": f"Bearer {token(inactive_user)}"}, - ) - assert response.status_code == status.HTTP_200_OK - - response_json = response.json() - assert response_json["id"] == inactive_user.id - - def test_valid_token(self, test_auth_client, token, user): - response = test_auth_client.get( - "/test-current-user", headers={"Authorization": f"Bearer {token(user)}"} - ) - assert response.status_code == status.HTTP_200_OK - - response_json = response.json() - assert response_json["id"] == user.id - - -@pytest.mark.authentication -class TestGetCurrentActiveUser: - def test_missing_token(self, test_auth_client): - response = test_auth_client.get("/test-current-active-user") - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_invalid_token(self, test_auth_client): - response = test_auth_client.get( - "/test-current-active-user", headers={"Authorization": "Bearer foo"} - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user): - response = test_auth_client.get( - "/test-current-active-user", - headers={"Authorization": f"Bearer {token(inactive_user)}"}, - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_valid_token(self, test_auth_client, token, user): - response = test_auth_client.get( - "/test-current-active-user", - headers={"Authorization": f"Bearer {token(user)}"}, - ) - assert response.status_code == status.HTTP_200_OK - - response_json = response.json() - assert response_json["id"] == user.id - - -@pytest.mark.authentication -class TestGetCurrentSuperuser: - def test_missing_token(self, test_auth_client): - response = test_auth_client.get("/test-current-superuser") - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_invalid_token(self, test_auth_client): - response = test_auth_client.get( - "/test-current-superuser", headers={"Authorization": "Bearer foo"} - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user): - response = test_auth_client.get( - "/test-current-superuser", - headers={"Authorization": f"Bearer {token(inactive_user)}"}, - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_valid_token_regular_user(self, test_auth_client, token, user): - response = test_auth_client.get( - "/test-current-superuser", - headers={"Authorization": f"Bearer {token(user)}"}, - ) - assert response.status_code == status.HTTP_403_FORBIDDEN - - def test_valid_token_superuser(self, test_auth_client, token, superuser): - response = test_auth_client.get( - "/test-current-superuser", - headers={"Authorization": f"Bearer {token(superuser)}"}, - ) - assert response.status_code == status.HTTP_200_OK - - response_json = response.json() - assert response_json["id"] == superuser.id diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index 0e3e4b85..0fb33e53 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -21,7 +21,7 @@ def fastapi_users(request, mock_user_db, mock_authentication) -> FastAPIUsers: class User(BaseUser): pass - fastapi_users = FastAPIUsers(mock_user_db, mock_authentication, User, "SECRET") + fastapi_users = FastAPIUsers(mock_user_db, [mock_authentication], User, "SECRET") @fastapi_users.on_after_register() def on_after_register(): diff --git a/tests/test_router.py b/tests/test_router.py index f2a038f7..5b895886 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -7,9 +7,11 @@ from fastapi import FastAPI from starlette import status from starlette.testclient import TestClient +from fastapi_users.authentication import Authenticator from fastapi_users.models import BaseUser, BaseUserDB from fastapi_users.router import ErrorCode, Event, get_user_router from fastapi_users.utils import JWT_ALGORITHM, generate_jwt +from tests.conftest import MockAuthentication SECRET = "SECRET" LIFETIME = 3600 @@ -44,10 +46,13 @@ def test_app_client(mock_user_db, mock_authentication, event_handler) -> TestCli class User(BaseUser): pass - userRouter = get_user_router( - mock_user_db, User, mock_authentication, SECRET, LIFETIME + mock_authentication_bis = MockAuthentication(name="mock-bis") + authenticator = Authenticator( + [mock_authentication, mock_authentication_bis], mock_user_db ) + userRouter = get_user_router(mock_user_db, User, authenticator, SECRET, LIFETIME) + userRouter.add_event_handler(Event.ON_AFTER_REGISTER, event_handler) userRouter.add_event_handler(Event.ON_AFTER_FORGOT_PASSWORD, event_handler) @@ -125,42 +130,45 @@ class TestRegister: @pytest.mark.router +@pytest.mark.parametrize("path", ["/login/mock", "/login/mock-bis"]) class TestLogin: - def test_empty_body(self, test_app_client: TestClient): - response = test_app_client.post("/login", data={}) + def test_empty_body(self, path, test_app_client: TestClient): + response = test_app_client.post(path, data={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - def test_missing_username(self, test_app_client: TestClient): + def test_missing_username(self, path, test_app_client: TestClient): data = {"password": "guinevere"} - response = test_app_client.post("/login", data=data) + response = test_app_client.post(path, data=data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - def test_missing_password(self, test_app_client: TestClient): + def test_missing_password(self, path, test_app_client: TestClient): data = {"username": "king.arthur@camelot.bt"} - response = test_app_client.post("/login", data=data) + response = test_app_client.post(path, data=data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - def test_not_existing_user(self, test_app_client: TestClient): + def test_not_existing_user(self, path, test_app_client: TestClient): data = {"username": "lancelot@camelot.bt", "password": "guinevere"} - response = test_app_client.post("/login", data=data) + response = test_app_client.post(path, data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS - def test_wrong_password(self, test_app_client: TestClient): + def test_wrong_password(self, path, test_app_client: TestClient): data = {"username": "king.arthur@camelot.bt", "password": "percival"} - response = test_app_client.post("/login", data=data) + response = test_app_client.post(path, data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS - def test_valid_credentials(self, test_app_client: TestClient, user: BaseUserDB): + def test_valid_credentials( + self, path, test_app_client: TestClient, user: BaseUserDB + ): data = {"username": "king.arthur@camelot.bt", "password": "guinevere"} - response = test_app_client.post("/login", data=data) + response = test_app_client.post(path, data=data) assert response.status_code == status.HTTP_200_OK assert response.json() == {"token": user.id} - def test_inactive_user(self, test_app_client: TestClient): + def test_inactive_user(self, path, test_app_client: TestClient): data = {"username": "percival@camelot.bt", "password": "angharad"} - response = test_app_client.post("/login", data=data) + response = test_app_client.post(path, data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS