Fix #42: multiple authentication backends (#47)

* Revamp authentication to allow multiple backends

* Make router generate a login route for each backend

* Apply black

* Remove unused imports

* Complete docstrings

* Update documentation

* WIP add cookie auth

* Complete cookie auth unit tests

* Add documentation for cookie auth

* Fix cookie backend default name

* Don't make cookie return a Response
This commit is contained in:
François Voron
2019-12-04 13:32:49 +01:00
committed by GitHub
parent 5e4c7996de
commit 49deb437a6
22 changed files with 591 additions and 341 deletions

View File

@ -31,8 +31,9 @@ Add quickly a registration and authentication system to your [FastAPI](https://f
* [X] Customizable database backend
* [X] SQLAlchemy async backend included thanks to [encode/databases](https://www.encode.io/databases/)
* [X] MongoDB async backend included thanks to [mongodb/motor](https://github.com/mongodb/motor)
* [X] Customizable authentication backend
* [X] Multiple customizable authentication backends
* [X] JWT authentication backend included
* [X] Cookie authentication backend included
## Development

View File

@ -0,0 +1,50 @@
# Cookie
Cookies are an easy way to store stateful information into the user browser. Thus, it is more useful for browser-based navigation (e.g. a front-end app making API requests) rather than pure API interaction.
## Configuration
```py
from fastapi_users.authentication import CookieAuthentication
SECRET = "SECRET"
auth_backends = []
cookie_authentication = CookieAuthentication(secret=SECRET, lifetime_seconds=3600))
auth_backends.append(cookie_authentication)
```
As you can see, instantiation is quite simple. You just have to define a constant `SECRET` which is used to encode the token and the lifetime of the cookie (in seconds).
You can optionally define the `cookie_name`. **Defaults to `fastapiusersauth`**.
You can also optionally define the `name` which will be used to generate its [`/login` route](../../usage/routes.md#post-loginname). **Defaults to `cookie`**.
```py
cookie_authentication = CookieAuthentication(
secret=SECRET,
lifetime_seconds=3600,
name="my-cookie",
)
```
!!! tip
The value of the cookie is actually a JWT. This authentication backend shares most of its logic with the [JWT](./jwt.md) one.
## Login
This method will return a response with a valid `set-cookie` header upon successful login:
!!! success "`200 OK`"
> Check documentation about [login route](../../usage/routes.md#post-loginname).
## Authentication
This method expects that you provide a valid cookie in the headers.
## Next steps
We will now configure the main **FastAPI Users** object that will expose the [API router](../router.md).

View File

@ -2,6 +2,15 @@
**FastAPI Users** allows you to plug in several authentication methods.
## How it works?
You can have **several** authentication methods, e.g. a cookie authentication for browser-based queries and a JWT token authentication for pure API queries.
When checking authentication, each method is run one after the other. The first method yielding a user wins. If no method yields a user, an `HTTPException` is raised.
Each defined method will generate a [`/login/{name}`](../../usage/routes.md#post-loginname) route where `name` is defined on the authentication method object.
## Provided methods
* [JWT authentication](jwt.md)
* [Cookie authentication](cookie.md)

View File

@ -9,11 +9,38 @@ from fastapi_users.authentication import JWTAuthentication
SECRET = "SECRET"
auth = JWTAuthentication(secret=SECRET, lifetime_seconds=3600)
auth_backends = []
jwt_authentication = JWTAuthentication(secret=SECRET, lifetime_seconds=3600))
auth_backends.append(jwt_authentication)
```
As you can see, instantiation is quite simple. You just have to define a constant `SECRET` which is used to encode the token and the lifetime of token (in seconds).
You can also optionally define the `name` which will be used to generate its [`/login` route](../../usage/routes.md#post-loginname). **Defaults to `jwt`**.
```py
jwt_authentication = JWTAuthentication(
secret=SECRET,
lifetime_seconds=3600,
name="my-jwt",
)
```
## Login
This method will return a JWT token upon successful login:
!!! success "`200 OK`"
```json
{
"token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoiOTIyMWZmYzktNjQwZi00MzcyLTg2ZDMtY2U2NDJjYmE1NjAzIiwiYXVkIjoiZmFzdGFwaS11c2VyczphdXRoIiwiZXhwIjoxNTcxNTA0MTkzfQ.M10bjOe45I5Ncu_uXvOmVV8QxnL-nZfcH96U90JaocI"
}
```
> Check documentation about [login route](../../usage/routes.md#post-loginname).
## Authentication
This method expects that you provide a `Bearer` authentication with a valid JWT.

View File

@ -7,7 +7,7 @@ We're almost there! The last step is to configure the `FastAPIUsers` object that
Configure `FastAPIUsers` object with all the elements we defined before. More precisely:
* `db`: Database adapter instance.
* `auth`: Authentication logic instance.
* `auth_backends`: List of authentication backends. See [Authentication](./authentication/index.md).
* `user_model`: Pydantic model of a user.
* `reset_password_token_secret`: Secret to encode reset password token.
* `reset_password_token_lifetime_seconds`: Lifetime of reset password token in seconds. Default to one hour.
@ -17,7 +17,7 @@ from fastapi_users import FastAPIUsers
fastapi_users = FastAPIUsers(
user_db,
auth,
auth_backends,
User,
SECRET,
)

View File

@ -20,10 +20,12 @@ class User(BaseUser):
pass
auth = JWTAuthentication(secret=SECRET, lifetime_seconds=3600)
auth_backends = [
JWTAuthentication(secret=SECRET, lifetime_seconds=3600),
]
app = FastAPI()
fastapi_users = FastAPIUsers(user_db, auth, User, SECRET)
fastapi_users = FastAPIUsers(user_db, auth_backends, User, SECRET)
app.include_router(fastapi_users.router, prefix="/users", tags=["users"])

View File

@ -33,10 +33,12 @@ class User(BaseUser):
pass
auth = JWTAuthentication(secret=SECRET, lifetime_seconds=3600)
auth_backends = [
JWTAuthentication(secret=SECRET, lifetime_seconds=3600),
]
app = FastAPI()
fastapi_users = FastAPIUsers(user_db, auth, User, SECRET)
fastapi_users = FastAPIUsers(user_db, auth_backends, User, SECRET)
app.include_router(fastapi_users.router, prefix="/users", tags=["users"])

View File

@ -37,22 +37,15 @@ Register a new user. Will call the `on_after_register` [event handlers](../confi
}
```
### `POST /login`
### `POST /login/{name}`
Login a user.
Login a user against the method named `name`. Check the corresponding [authentication method](../configuration/authentication/index.md) to view the success response.
!!! abstract "Payload (`application/x-www-form-urlencoded`)"
```
username=king.arthur@camelot.bt&password=guinevere
```
!!! success "`200 OK`"
```json
{
"token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoiOTIyMWZmYzktNjQwZi00MzcyLTg2ZDMtY2U2NDJjYmE1NjAzIiwiYXVkIjoiZmFzdGFwaS11c2VyczphdXRoIiwiZXhwIjoxNTcxNTA0MTkzfQ.M10bjOe45I5Ncu_uXvOmVV8QxnL-nZfcH96U90JaocI"
}
```
!!! fail "`422 Validation Error`"
!!! fail "`400 Bad Request`"

View File

@ -1,2 +1,60 @@
from typing import Sequence
from fastapi import HTTPException
from starlette import status
from starlette.requests import Request
from fastapi_users.authentication.base import BaseAuthentication # noqa: F401
from fastapi_users.authentication.cookie import CookieAuthentication # noqa: F401
from fastapi_users.authentication.jwt import JWTAuthentication # noqa: F401
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUserDB
class Authenticator:
"""
Provides dependency callables to retrieve authenticated user.
It performs the authentication against a list of backends
defined by the end-developer. The first backend yielding a user wins.
If no backend yields a user, an HTTPException is raised.
:param backends: List of authentication backends.
:param user_db: Database adapter instance.
"""
backends: Sequence[BaseAuthentication]
user_db: BaseUserDatabase
def __init__(
self, backends: Sequence[BaseAuthentication], user_db: BaseUserDatabase
):
self.backends = backends
self.user_db = user_db
async def get_current_user(self, request: Request) -> BaseUserDB:
return await self._authenticate(request)
async def get_current_active_user(self, request: Request) -> BaseUserDB:
user = await self.get_current_user(request)
if not user.is_active:
raise self._get_credentials_exception()
return user
async def get_current_superuser(self, request: Request) -> BaseUserDB:
user = await self.get_current_active_user(request)
if not user.is_superuser:
raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN)
return user
async def _authenticate(self, request: Request) -> BaseUserDB:
for backend in self.backends:
user = await backend(request, self.user_db)
if user is not None:
return user
raise self._get_credentials_exception()
def _get_credentials_exception(
self, status_code: int = status.HTTP_401_UNAUTHORIZED
) -> HTTPException:
return HTTPException(status_code=status_code)

View File

@ -1,70 +1,30 @@
from typing import Callable
from typing import Any, Optional
from fastapi import HTTPException
from fastapi.security import OAuth2PasswordBearer
from starlette import status
from starlette.requests import Request
from starlette.responses import Response
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUserDB
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
class BaseAuthentication:
"""
Base adapter for generating and decoding authentication tokens.
Base authentication backend.
Provides dependency injectors to get current active/superuser user.
Every backend should derive from this class.
:param scheme: Optional authentication scheme for OpenAPI.
Defaults to `OAuth2PasswordBearer(tokenUrl="/users/login")`.
Override it if your login route lives somewhere else.
:param name: Name of the backend. It will be used to name the login route.
"""
scheme: OAuth2PasswordBearer
name: str
def __init__(self, scheme: OAuth2PasswordBearer = None):
if scheme is None:
self.scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
else:
self.scheme = scheme
def __init__(self, name: str = "base"):
self.name = name
async def get_login_response(self, user: BaseUserDB, response: Response):
async def __call__(
self, request: Request, user_db: BaseUserDatabase
) -> Optional[BaseUserDB]:
raise NotImplementedError()
def get_current_user(self, user_db: BaseUserDatabase):
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
raise NotImplementedError()
def get_current_active_user(self, user_db: BaseUserDatabase):
raise NotImplementedError()
def get_current_superuser(self, user_db: BaseUserDatabase):
raise NotImplementedError()
def _get_authentication_method(
self, user_db: BaseUserDatabase
) -> Callable[..., BaseUserDB]:
raise NotImplementedError()
def _get_current_user_base(self, user: BaseUserDB) -> BaseUserDB:
if user is None:
raise self._get_credentials_exception()
return user
def _get_current_active_user_base(self, user: BaseUserDB) -> BaseUserDB:
user = self._get_current_user_base(user)
if not user.is_active:
raise self._get_credentials_exception()
return user
def _get_current_superuser_base(self, user: BaseUserDB) -> BaseUserDB:
user = self._get_current_active_user_base(user)
if not user.is_superuser:
raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN)
return user
def _get_credentials_exception(
self, status_code: int = status.HTTP_401_UNAUTHORIZED
) -> HTTPException:
return HTTPException(status_code=status_code)

View File

@ -0,0 +1,53 @@
from typing import Any, Optional
from fastapi.security import APIKeyCookie
from starlette.requests import Request
from starlette.responses import Response
from fastapi_users.authentication.jwt import JWTAuthentication
from fastapi_users.models import BaseUserDB
class CookieAuthentication(JWTAuthentication):
"""
Authentication backend using a cookie.
Internally, uses a JWT token to store the data.
:param secret: Secret used to encode the cookie.
:param lifetime_seconds: Lifetime duration of the cookie in seconds.
:param cookie_name: Name of the cookie.
:param name: Name of the backend. It will be used to name the login route.
"""
lifetime_seconds: int
cookie_name: str
def __init__(
self,
secret: str,
lifetime_seconds: int,
cookie_name: str = "fastapiusersauth",
name: str = "cookie",
):
super().__init__(secret, lifetime_seconds, name=name)
self.lifetime_seconds = lifetime_seconds
self.cookie_name = cookie_name
self.api_key_cookie = APIKeyCookie(name=self.cookie_name, auto_error=False)
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
token = await self._generate_token(user)
response.set_cookie(
self.cookie_name,
token,
max_age=self.lifetime_seconds,
secure=True,
httponly=True,
)
# We shouldn't return directly the response
# so that FastAPI can terminate it properly
return None
async def _retrieve_token(self, request: Request) -> Optional[str]:
return await self.api_key_cookie.__call__(request)

View File

@ -1,5 +1,8 @@
from typing import Any, Optional
import jwt
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from starlette.requests import Request
from starlette.responses import Response
from fastapi_users.authentication.base import BaseAuthentication
@ -10,62 +13,58 @@ from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
class JWTAuthentication(BaseAuthentication):
"""
Authentication using a JWT.
Authentication backend using a JWT in a Bearer header.
:param secret: Secret used to encode the JWT.
:param lifetime_seconds: Lifetime duration of the JWT in seconds.
:param tokenUrl: Path where to get a token.
:param name: Name of the backend. It will be used to name the login route.
"""
token_audience: str = "fastapi-users:auth"
secret: str
lifetime_seconds: int
def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
secret: str,
lifetime_seconds: int,
tokenUrl: str = "/users/login",
name: str = "jwt",
):
super().__init__(name)
self.secret = secret
self.lifetime_seconds = lifetime_seconds
self.scheme = OAuth2PasswordBearer(tokenUrl, auto_error=False)
async def get_login_response(self, user: BaseUserDB, response: Response):
data = {"user_id": user.id, "aud": self.token_audience}
token = generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM)
async def __call__(
self, request: Request, user_db: BaseUserDatabase,
) -> Optional[BaseUserDB]:
token = await self._retrieve_token(request)
if token is None:
return None
try:
data = jwt.decode(
token,
self.secret,
audience=self.token_audience,
algorithms=[JWT_ALGORITHM],
)
user_id = data.get("user_id")
if user_id is None:
return None
except jwt.PyJWTError:
return None
return await user_db.get(user_id)
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
token = await self._generate_token(user)
return {"token": token}
def get_current_user(self, user_db: BaseUserDatabase):
async def _get_current_user(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_user_base(user)
async def _retrieve_token(self, request: Request) -> Optional[str]:
return await self.scheme.__call__(request)
return _get_current_user
def get_current_active_user(self, user_db: BaseUserDatabase):
async def _get_current_active_user(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_active_user_base(user)
return _get_current_active_user
def get_current_superuser(self, user_db: BaseUserDatabase):
async def _get_current_superuser(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_superuser_base(user)
return _get_current_superuser
def _get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(self.scheme)):
try:
data = jwt.decode(
token,
self.secret,
audience=self.token_audience,
algorithms=[JWT_ALGORITHM],
)
user_id = data.get("user_id")
if user_id is None:
return None
except jwt.PyJWTError:
return None
return await user_db.get(user_id)
return authentication_method
async def _generate_token(self, user: BaseUserDB) -> str:
data = {"user_id": user.id, "aud": self.token_audience}
return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM)

View File

@ -1,8 +1,8 @@
from typing import Callable, Type
from typing import Callable, Sequence, Type
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.authentication import Authenticator, BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUser, BaseUserDB
from fastapi_users.models import BaseUser
from fastapi_users.router import Event, UserRouter, get_user_router
@ -11,7 +11,7 @@ class FastAPIUsers:
Main object that ties together the component for users authentication.
:param db: Database adapter instance.
:param auth: Authentication logic instance.
:param auth_backends: List of authentication backends.
:param user_model: Pydantic model of a user.
:param reset_password_token_secret: Secret to encode reset password token.
:param reset_password_token_lifetime_seconds: Lifetime of reset password token.
@ -21,36 +21,30 @@ class FastAPIUsers:
"""
db: BaseUserDatabase
auth: BaseAuthentication
authenticator: Authenticator
router: UserRouter
get_current_user: Callable[..., BaseUserDB]
def __init__(
self,
db: BaseUserDatabase,
auth: BaseAuthentication,
auth_backends: Sequence[BaseAuthentication],
user_model: Type[BaseUser],
reset_password_token_secret: str,
reset_password_token_lifetime_seconds: int = 3600,
):
self.db = db
self.auth = auth
self.authenticator = Authenticator(auth_backends, db)
self.router = get_user_router(
self.db,
user_model,
self.auth,
self.authenticator,
reset_password_token_secret,
reset_password_token_lifetime_seconds,
)
get_current_user = self.auth.get_current_user(self.db)
self.get_current_user = get_current_user # type: ignore
get_current_active_user = self.auth.get_current_active_user(self.db)
self.get_current_active_user = get_current_active_user # type: ignore
get_current_superuser = self.auth.get_current_superuser(self.db)
self.get_current_superuser = get_current_superuser # type: ignore
self.get_current_user = self.authenticator.get_current_user
self.get_current_active_user = self.authenticator.get_current_active_user
self.get_current_superuser = self.authenticator.get_current_superuser
def on_after_register(self) -> Callable:
"""Add an event handler on successful registration."""

View File

@ -10,7 +10,7 @@ from pydantic import EmailStr
from starlette import status
from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.authentication import Authenticator, BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUser, Models
from fastapi_users.password import get_password_hash
@ -46,10 +46,28 @@ class UserRouter(APIRouter):
handler(*args, **kwargs)
def _add_login_route(
router: UserRouter, user_db: BaseUserDatabase, auth_backend: BaseAuthentication
):
@router.post(f"/login/{auth_backend.name}")
async def login(
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
):
user = await user_db.authenticate(credentials)
if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
)
return await auth_backend.get_login_response(user, response)
def get_user_router(
user_db: BaseUserDatabase,
user_model: typing.Type[BaseUser],
auth: BaseAuthentication,
authenticator: Authenticator,
reset_password_token_secret: str,
reset_password_token_lifetime_seconds: int = 3600,
) -> UserRouter:
@ -59,8 +77,8 @@ def get_user_router(
reset_password_token_audience = "fastapi-users:reset"
get_current_active_user = auth.get_current_active_user(user_db)
get_current_superuser = auth.get_current_superuser(user_db)
get_current_active_user = authenticator.get_current_active_user
get_current_superuser = authenticator.get_current_superuser
async def _get_or_404(id: str) -> models.UserDB: # type: ignore
user = await user_db.get(id)
@ -79,6 +97,9 @@ def get_user_router(
setattr(user, field, update_dict[field])
return await user_db.update(user)
for auth_backend in authenticator.backends:
_add_login_route(router, user_db, auth_backend)
@router.post(
"/register", response_model=models.User, status_code=status.HTTP_201_CREATED
)
@ -101,20 +122,6 @@ def get_user_router(
return created_user
@router.post("/login")
async def login(
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
):
user = await user_db.authenticate(credentials)
if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
)
return await auth.get_login_response(user, response)
@router.post("/forgot-password", status_code=status.HTTP_202_ACCEPTED)
async def forgot_password(email: EmailStr = Body(..., embed=True)):
user = await user_db.get_by_email(email)

View File

@ -34,8 +34,9 @@ nav:
- configuration/databases/mongodb.md
- configuration/databases/tortoise.md
- Authentication:
- configuration/authentication/index.md
- Introduction: configuration/authentication/index.md
- configuration/authentication/jwt.md
- configuration/authentication/cookie.md
- configuration/router.md
- configuration/full_example.md
- Usage:

View File

@ -1,12 +1,14 @@
from typing import List, Optional
from typing import Any, List, Mapping, Optional, Tuple
import http.cookies
import pytest
from fastapi import Depends, FastAPI
from fastapi.security import OAuth2PasswordBearer
from starlette.requests import Request
from starlette.responses import Response
from starlette.testclient import TestClient
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.authentication import Authenticator, BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash
@ -81,70 +83,76 @@ def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase:
return MockUserDatabase()
class MockAuthentication(BaseAuthentication):
def __init__(self, name: str = "mock"):
super().__init__(name)
self.scheme = OAuth2PasswordBearer("/users/login", auto_error=False)
async def __call__(self, request: Request, user_db: BaseUserDatabase):
token = await self.scheme.__call__(request)
if token is not None:
return await user_db.get(token)
return None
async def get_login_response(self, user: BaseUserDB, response: Response):
return {"token": user.id}
@pytest.fixture
def mock_authentication():
class MockAuthentication(BaseAuthentication):
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
async def get_login_response(self, user: BaseUserDB, response: Response):
return {"token": user.id}
def get_current_user(self, user_db: BaseUserDatabase):
async def _get_current_user(token: str = Depends(self.oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_user_base(user)
return _get_current_user
def get_current_active_user(self, user_db: BaseUserDatabase):
async def _get_current_active_user(
token: str = Depends(self.oauth2_scheme),
):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_active_user_base(user)
return _get_current_active_user
def get_current_superuser(self, user_db: BaseUserDatabase):
async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_superuser_base(user)
return _get_current_superuser
def _get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(self.oauth2_scheme)):
return await user_db.get(token)
return authentication_method
return MockAuthentication()
@pytest.fixture
def request_builder():
def _request_builder(
headers: Mapping[str, Any] = None, cookies: Mapping[str, str] = None
) -> Request:
encoded_headers: List[Tuple[bytes, bytes]] = []
if headers is not None:
encoded_headers += [
(key.lower().encode("latin-1"), headers[key].encode("latin-1"))
for key in headers
]
if cookies is not None:
for key in cookies:
cookie = http.cookies.SimpleCookie() # type: http.cookies.BaseCookie
cookie[key] = cookies[key]
cookie_val = cookie.output(header="").strip()
encoded_headers.append((b"cookie", cookie_val.encode("latin-1")))
scope = {
"type": "http",
"headers": encoded_headers,
}
return Request(scope)
return _request_builder
@pytest.fixture
def get_test_auth_client(mock_user_db):
def _get_test_auth_client(authentication):
def _get_test_auth_client(backends: List[BaseAuthentication]) -> TestClient:
app = FastAPI()
authenticator = Authenticator(backends, mock_user_db)
@app.get("/test-current-user")
def test_current_user(
user: BaseUserDB = Depends(authentication.get_current_user(mock_user_db)),
user: BaseUserDB = Depends(authenticator.get_current_user),
):
return user
@app.get("/test-current-active-user")
def test_current_active_user(
user: BaseUserDB = Depends(
authentication.get_current_active_user(mock_user_db)
),
user: BaseUserDB = Depends(authenticator.get_current_active_user),
):
return user
@app.get("/test-current-superuser")
def test_current_superuser(
user: BaseUserDB = Depends(
authentication.get_current_superuser(mock_user_db)
),
user: BaseUserDB = Depends(authenticator.get_current_superuser),
):
return user

View File

@ -0,0 +1,45 @@
from typing import Optional
import pytest
from starlette import status
from starlette.requests import Request
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUserDB
@pytest.fixture()
def auth_backend_none():
class BackendNone(BaseAuthentication):
async def __call__(
self, request: Request, user_db: BaseUserDatabase
) -> Optional[BaseUserDB]:
return None
return BackendNone()
@pytest.fixture()
def auth_backend_user(user):
class BackendUser(BaseAuthentication):
async def __call__(
self, request: Request, user_db: BaseUserDatabase
) -> Optional[BaseUserDB]:
return user
return BackendUser()
@pytest.mark.authentication
def test_authenticator(get_test_auth_client, auth_backend_none, auth_backend_user):
client = get_test_auth_client([auth_backend_none, auth_backend_user])
response = client.get("/test-current-user")
assert response.status_code == status.HTTP_200_OK
@pytest.mark.authentication
def test_authenticator_none(get_test_auth_client, auth_backend_none):
client = get_test_auth_client([auth_backend_none, auth_backend_none])
response = client.get("/test-current-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@ -1,30 +1,27 @@
import pytest
from fastapi.security import OAuth2PasswordBearer
from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
@pytest.mark.asyncio
@pytest.fixture
def base_authentication():
return BaseAuthentication()
@pytest.mark.authentication
@pytest.mark.parametrize(
"constructor_kwargs", [{}, {"scheme": OAuth2PasswordBearer(tokenUrl="/foo")}]
)
async def test_not_implemented_methods(constructor_kwargs, user, mock_user_db):
response = Response()
base_authentication = BaseAuthentication(**constructor_kwargs)
class TestAuthenticate:
@pytest.mark.asyncio
async def test_not_implemented(
self, base_authentication, mock_user_db, request_builder
):
request = request_builder({})
with pytest.raises(NotImplementedError):
await base_authentication(request, mock_user_db)
with pytest.raises(NotImplementedError):
await base_authentication.get_login_response(user, response)
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_login_response(base_authentication, user):
with pytest.raises(NotImplementedError):
await base_authentication.get_current_user(mock_user_db)
with pytest.raises(NotImplementedError):
await base_authentication.get_current_active_user(mock_user_db)
with pytest.raises(NotImplementedError):
await base_authentication.get_current_superuser(mock_user_db)
with pytest.raises(NotImplementedError):
await base_authentication._get_authentication_method(mock_user_db)
await base_authentication.get_login_response(user, Response())

View File

@ -0,0 +1,105 @@
import re
import jwt
import pytest
from starlette.responses import Response
from fastapi_users.authentication.cookie import CookieAuthentication
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
SECRET = "SECRET"
LIFETIME = 3600
COOKIE_NAME = "COOKIE_NAME"
@pytest.fixture
def cookie_authentication():
return CookieAuthentication(SECRET, LIFETIME, COOKIE_NAME)
@pytest.fixture
def token():
def _token(user=None, lifetime=LIFETIME):
data = {"aud": "fastapi-users:auth"}
if user is not None:
data["user_id"] = user.id
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
return _token
@pytest.mark.authentication
def test_default_name(cookie_authentication):
assert cookie_authentication.name == "cookie"
@pytest.mark.authentication
class TestAuthenticate:
@pytest.mark.asyncio
async def test_missing_token(
self, cookie_authentication, mock_user_db, request_builder
):
request = request_builder()
authenticated_user = await cookie_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_invalid_token(
self, cookie_authentication, mock_user_db, request_builder
):
cookies = {}
cookies[COOKIE_NAME] = "foo"
request = request_builder(cookies=cookies)
authenticated_user = await cookie_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token_missing_user_payload(
self, cookie_authentication, mock_user_db, request_builder, token
):
cookies = {}
cookies[COOKIE_NAME] = token()
request = request_builder(cookies=cookies)
authenticated_user = await cookie_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token(
self, cookie_authentication, mock_user_db, request_builder, token, user
):
cookies = {}
cookies[COOKIE_NAME] = token(user)
request = request_builder(cookies=cookies)
authenticated_user = await cookie_authentication(request, mock_user_db)
assert authenticated_user.id == user.id
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_login_response(cookie_authentication, user):
response = Response()
login_response = await cookie_authentication.get_login_response(user, response)
# We shouldn't return directly the response
# so that FastAPI can terminate it properly
assert login_response is None
cookies = [
header for header in response.raw_headers if header[0] == b"set-cookie"
]
assert len(cookies) == 1
cookie = cookies[0][1].decode("latin-1")
assert f"Max-Age={LIFETIME}" in cookie
cookie_name_value = re.match(r"^(\w+)=([^;]+);", cookie)
cookie_name = cookie_name_value[1]
assert cookie_name == COOKIE_NAME
cookie_value = cookie_name_value[2]
decoded = jwt.decode(
cookie_value, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
)
assert decoded["user_id"] == user.id

View File

@ -1,6 +1,5 @@
import jwt
import pytest
from starlette import status
from starlette.responses import Response
from fastapi_users.authentication.jwt import JWTAuthentication
@ -8,11 +7,12 @@ from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
SECRET = "SECRET"
LIFETIME = 3600
TOKEN_URL = "/login"
@pytest.fixture
def jwt_authentication():
return JWTAuthentication(SECRET, LIFETIME)
return JWTAuthentication(SECRET, LIFETIME, TOKEN_URL)
@pytest.fixture
@ -26,11 +26,47 @@ def token():
return _token
@pytest.fixture
def test_auth_client(get_test_auth_client, jwt_authentication):
return get_test_auth_client(jwt_authentication)
@pytest.mark.authentication
def test_default_name(jwt_authentication):
assert jwt_authentication.name == "jwt"
@pytest.mark.authentication
class TestAuthenticate:
@pytest.mark.asyncio
async def test_missing_token(
self, jwt_authentication, mock_user_db, request_builder
):
request = request_builder(headers={})
authenticated_user = await jwt_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_invalid_token(
self, jwt_authentication, mock_user_db, request_builder
):
request = request_builder(headers={"Authorization": "Bearer foo"})
authenticated_user = await jwt_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token_missing_user_payload(
self, jwt_authentication, mock_user_db, request_builder, token
):
request = request_builder(headers={"Authorization": f"Bearer {token()}"})
authenticated_user = await jwt_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token(
self, jwt_authentication, mock_user_db, request_builder, token, user
):
request = request_builder(headers={"Authorization": f"Bearer {token(user)}"})
authenticated_user = await jwt_authentication(request, mock_user_db)
assert authenticated_user.id == user.id
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_login_response(jwt_authentication, user):
login_response = await jwt_authentication.get_login_response(user, Response())
@ -42,108 +78,3 @@ async def test_get_login_response(jwt_authentication, user):
token, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
)
assert decoded["user_id"] == user.id
@pytest.mark.authentication
class TestGetCurrentUser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_missing_user_payload(self, test_auth_client, token):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": f"Bearer {token()}"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
response = test_auth_client.get(
"/test-current-user",
headers={"Authorization": f"Bearer {token(inactive_user)}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == inactive_user.id
def test_valid_token(self, test_auth_client, token, user):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": f"Bearer {token(user)}"}
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == user.id
@pytest.mark.authentication
class TestGetCurrentActiveUser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-active-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-active-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
response = test_auth_client.get(
"/test-current-active-user",
headers={"Authorization": f"Bearer {token(inactive_user)}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token(self, test_auth_client, token, user):
response = test_auth_client.get(
"/test-current-active-user",
headers={"Authorization": f"Bearer {token(user)}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == user.id
@pytest.mark.authentication
class TestGetCurrentSuperuser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-superuser")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-superuser", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {token(inactive_user)}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_regular_user(self, test_auth_client, token, user):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {token(user)}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_valid_token_superuser(self, test_auth_client, token, superuser):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {token(superuser)}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == superuser.id

View File

@ -21,7 +21,7 @@ def fastapi_users(request, mock_user_db, mock_authentication) -> FastAPIUsers:
class User(BaseUser):
pass
fastapi_users = FastAPIUsers(mock_user_db, mock_authentication, User, "SECRET")
fastapi_users = FastAPIUsers(mock_user_db, [mock_authentication], User, "SECRET")
@fastapi_users.on_after_register()
def on_after_register():

View File

@ -7,9 +7,11 @@ from fastapi import FastAPI
from starlette import status
from starlette.testclient import TestClient
from fastapi_users.authentication import Authenticator
from fastapi_users.models import BaseUser, BaseUserDB
from fastapi_users.router import ErrorCode, Event, get_user_router
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
from tests.conftest import MockAuthentication
SECRET = "SECRET"
LIFETIME = 3600
@ -44,10 +46,13 @@ def test_app_client(mock_user_db, mock_authentication, event_handler) -> TestCli
class User(BaseUser):
pass
userRouter = get_user_router(
mock_user_db, User, mock_authentication, SECRET, LIFETIME
mock_authentication_bis = MockAuthentication(name="mock-bis")
authenticator = Authenticator(
[mock_authentication, mock_authentication_bis], mock_user_db
)
userRouter = get_user_router(mock_user_db, User, authenticator, SECRET, LIFETIME)
userRouter.add_event_handler(Event.ON_AFTER_REGISTER, event_handler)
userRouter.add_event_handler(Event.ON_AFTER_FORGOT_PASSWORD, event_handler)
@ -125,42 +130,45 @@ class TestRegister:
@pytest.mark.router
@pytest.mark.parametrize("path", ["/login/mock", "/login/mock-bis"])
class TestLogin:
def test_empty_body(self, test_app_client: TestClient):
response = test_app_client.post("/login", data={})
def test_empty_body(self, path, test_app_client: TestClient):
response = test_app_client.post(path, data={})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_missing_username(self, test_app_client: TestClient):
def test_missing_username(self, path, test_app_client: TestClient):
data = {"password": "guinevere"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_missing_password(self, test_app_client: TestClient):
def test_missing_password(self, path, test_app_client: TestClient):
data = {"username": "king.arthur@camelot.bt"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_not_existing_user(self, test_app_client: TestClient):
def test_not_existing_user(self, path, test_app_client: TestClient):
data = {"username": "lancelot@camelot.bt", "password": "guinevere"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS
def test_wrong_password(self, test_app_client: TestClient):
def test_wrong_password(self, path, test_app_client: TestClient):
data = {"username": "king.arthur@camelot.bt", "password": "percival"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS
def test_valid_credentials(self, test_app_client: TestClient, user: BaseUserDB):
def test_valid_credentials(
self, path, test_app_client: TestClient, user: BaseUserDB
):
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"token": user.id}
def test_inactive_user(self, test_app_client: TestClient):
def test_inactive_user(self, path, test_app_client: TestClient):
data = {"username": "percival@camelot.bt", "password": "angharad"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS