mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-11-03 05:27:06 +08:00
* Revamp authentication to allow multiple backends * Make router generate a login route for each backend * Apply black * Remove unused imports * Complete docstrings * Update documentation * WIP add cookie auth * Complete cookie auth unit tests * Add documentation for cookie auth * Fix cookie backend default name * Don't make cookie return a Response
This commit is contained in:
@ -31,8 +31,9 @@ Add quickly a registration and authentication system to your [FastAPI](https://f
|
||||
* [X] Customizable database backend
|
||||
* [X] SQLAlchemy async backend included thanks to [encode/databases](https://www.encode.io/databases/)
|
||||
* [X] MongoDB async backend included thanks to [mongodb/motor](https://github.com/mongodb/motor)
|
||||
* [X] Customizable authentication backend
|
||||
* [X] Multiple customizable authentication backends
|
||||
* [X] JWT authentication backend included
|
||||
* [X] Cookie authentication backend included
|
||||
|
||||
## Development
|
||||
|
||||
|
||||
50
docs/configuration/authentication/cookie.md
Normal file
50
docs/configuration/authentication/cookie.md
Normal file
@ -0,0 +1,50 @@
|
||||
# Cookie
|
||||
|
||||
Cookies are an easy way to store stateful information into the user browser. Thus, it is more useful for browser-based navigation (e.g. a front-end app making API requests) rather than pure API interaction.
|
||||
|
||||
## Configuration
|
||||
|
||||
```py
|
||||
from fastapi_users.authentication import CookieAuthentication
|
||||
|
||||
SECRET = "SECRET"
|
||||
|
||||
auth_backends = []
|
||||
|
||||
cookie_authentication = CookieAuthentication(secret=SECRET, lifetime_seconds=3600))
|
||||
|
||||
auth_backends.append(cookie_authentication)
|
||||
```
|
||||
|
||||
As you can see, instantiation is quite simple. You just have to define a constant `SECRET` which is used to encode the token and the lifetime of the cookie (in seconds).
|
||||
|
||||
You can optionally define the `cookie_name`. **Defaults to `fastapiusersauth`**.
|
||||
|
||||
You can also optionally define the `name` which will be used to generate its [`/login` route](../../usage/routes.md#post-loginname). **Defaults to `cookie`**.
|
||||
|
||||
```py
|
||||
cookie_authentication = CookieAuthentication(
|
||||
secret=SECRET,
|
||||
lifetime_seconds=3600,
|
||||
name="my-cookie",
|
||||
)
|
||||
```
|
||||
|
||||
!!! tip
|
||||
The value of the cookie is actually a JWT. This authentication backend shares most of its logic with the [JWT](./jwt.md) one.
|
||||
|
||||
## Login
|
||||
|
||||
This method will return a response with a valid `set-cookie` header upon successful login:
|
||||
|
||||
!!! success "`200 OK`"
|
||||
|
||||
> Check documentation about [login route](../../usage/routes.md#post-loginname).
|
||||
|
||||
## Authentication
|
||||
|
||||
This method expects that you provide a valid cookie in the headers.
|
||||
|
||||
## Next steps
|
||||
|
||||
We will now configure the main **FastAPI Users** object that will expose the [API router](../router.md).
|
||||
@ -2,6 +2,15 @@
|
||||
|
||||
**FastAPI Users** allows you to plug in several authentication methods.
|
||||
|
||||
## How it works?
|
||||
|
||||
You can have **several** authentication methods, e.g. a cookie authentication for browser-based queries and a JWT token authentication for pure API queries.
|
||||
|
||||
When checking authentication, each method is run one after the other. The first method yielding a user wins. If no method yields a user, an `HTTPException` is raised.
|
||||
|
||||
Each defined method will generate a [`/login/{name}`](../../usage/routes.md#post-loginname) route where `name` is defined on the authentication method object.
|
||||
|
||||
## Provided methods
|
||||
|
||||
* [JWT authentication](jwt.md)
|
||||
* [Cookie authentication](cookie.md)
|
||||
|
||||
@ -9,11 +9,38 @@ from fastapi_users.authentication import JWTAuthentication
|
||||
|
||||
SECRET = "SECRET"
|
||||
|
||||
auth = JWTAuthentication(secret=SECRET, lifetime_seconds=3600)
|
||||
auth_backends = []
|
||||
|
||||
jwt_authentication = JWTAuthentication(secret=SECRET, lifetime_seconds=3600))
|
||||
|
||||
auth_backends.append(jwt_authentication)
|
||||
```
|
||||
|
||||
As you can see, instantiation is quite simple. You just have to define a constant `SECRET` which is used to encode the token and the lifetime of token (in seconds).
|
||||
|
||||
You can also optionally define the `name` which will be used to generate its [`/login` route](../../usage/routes.md#post-loginname). **Defaults to `jwt`**.
|
||||
|
||||
```py
|
||||
jwt_authentication = JWTAuthentication(
|
||||
secret=SECRET,
|
||||
lifetime_seconds=3600,
|
||||
name="my-jwt",
|
||||
)
|
||||
```
|
||||
|
||||
## Login
|
||||
|
||||
This method will return a JWT token upon successful login:
|
||||
|
||||
!!! success "`200 OK`"
|
||||
```json
|
||||
{
|
||||
"token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoiOTIyMWZmYzktNjQwZi00MzcyLTg2ZDMtY2U2NDJjYmE1NjAzIiwiYXVkIjoiZmFzdGFwaS11c2VyczphdXRoIiwiZXhwIjoxNTcxNTA0MTkzfQ.M10bjOe45I5Ncu_uXvOmVV8QxnL-nZfcH96U90JaocI"
|
||||
}
|
||||
```
|
||||
|
||||
> Check documentation about [login route](../../usage/routes.md#post-loginname).
|
||||
|
||||
## Authentication
|
||||
|
||||
This method expects that you provide a `Bearer` authentication with a valid JWT.
|
||||
|
||||
@ -7,7 +7,7 @@ We're almost there! The last step is to configure the `FastAPIUsers` object that
|
||||
Configure `FastAPIUsers` object with all the elements we defined before. More precisely:
|
||||
|
||||
* `db`: Database adapter instance.
|
||||
* `auth`: Authentication logic instance.
|
||||
* `auth_backends`: List of authentication backends. See [Authentication](./authentication/index.md).
|
||||
* `user_model`: Pydantic model of a user.
|
||||
* `reset_password_token_secret`: Secret to encode reset password token.
|
||||
* `reset_password_token_lifetime_seconds`: Lifetime of reset password token in seconds. Default to one hour.
|
||||
@ -17,7 +17,7 @@ from fastapi_users import FastAPIUsers
|
||||
|
||||
fastapi_users = FastAPIUsers(
|
||||
user_db,
|
||||
auth,
|
||||
auth_backends,
|
||||
User,
|
||||
SECRET,
|
||||
)
|
||||
|
||||
@ -20,10 +20,12 @@ class User(BaseUser):
|
||||
pass
|
||||
|
||||
|
||||
auth = JWTAuthentication(secret=SECRET, lifetime_seconds=3600)
|
||||
auth_backends = [
|
||||
JWTAuthentication(secret=SECRET, lifetime_seconds=3600),
|
||||
]
|
||||
|
||||
app = FastAPI()
|
||||
fastapi_users = FastAPIUsers(user_db, auth, User, SECRET)
|
||||
fastapi_users = FastAPIUsers(user_db, auth_backends, User, SECRET)
|
||||
app.include_router(fastapi_users.router, prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
|
||||
@ -33,10 +33,12 @@ class User(BaseUser):
|
||||
pass
|
||||
|
||||
|
||||
auth = JWTAuthentication(secret=SECRET, lifetime_seconds=3600)
|
||||
auth_backends = [
|
||||
JWTAuthentication(secret=SECRET, lifetime_seconds=3600),
|
||||
]
|
||||
|
||||
app = FastAPI()
|
||||
fastapi_users = FastAPIUsers(user_db, auth, User, SECRET)
|
||||
fastapi_users = FastAPIUsers(user_db, auth_backends, User, SECRET)
|
||||
app.include_router(fastapi_users.router, prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
|
||||
@ -37,22 +37,15 @@ Register a new user. Will call the `on_after_register` [event handlers](../confi
|
||||
}
|
||||
```
|
||||
|
||||
### `POST /login`
|
||||
### `POST /login/{name}`
|
||||
|
||||
Login a user.
|
||||
Login a user against the method named `name`. Check the corresponding [authentication method](../configuration/authentication/index.md) to view the success response.
|
||||
|
||||
!!! abstract "Payload (`application/x-www-form-urlencoded`)"
|
||||
```
|
||||
username=king.arthur@camelot.bt&password=guinevere
|
||||
```
|
||||
|
||||
!!! success "`200 OK`"
|
||||
```json
|
||||
{
|
||||
"token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoiOTIyMWZmYzktNjQwZi00MzcyLTg2ZDMtY2U2NDJjYmE1NjAzIiwiYXVkIjoiZmFzdGFwaS11c2VyczphdXRoIiwiZXhwIjoxNTcxNTA0MTkzfQ.M10bjOe45I5Ncu_uXvOmVV8QxnL-nZfcH96U90JaocI"
|
||||
}
|
||||
```
|
||||
|
||||
!!! fail "`422 Validation Error`"
|
||||
|
||||
!!! fail "`400 Bad Request`"
|
||||
|
||||
@ -1,2 +1,60 @@
|
||||
from typing import Sequence
|
||||
|
||||
from fastapi import HTTPException
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
|
||||
from fastapi_users.authentication.base import BaseAuthentication # noqa: F401
|
||||
from fastapi_users.authentication.cookie import CookieAuthentication # noqa: F401
|
||||
from fastapi_users.authentication.jwt import JWTAuthentication # noqa: F401
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class Authenticator:
|
||||
"""
|
||||
Provides dependency callables to retrieve authenticated user.
|
||||
|
||||
It performs the authentication against a list of backends
|
||||
defined by the end-developer. The first backend yielding a user wins.
|
||||
If no backend yields a user, an HTTPException is raised.
|
||||
|
||||
:param backends: List of authentication backends.
|
||||
:param user_db: Database adapter instance.
|
||||
"""
|
||||
|
||||
backends: Sequence[BaseAuthentication]
|
||||
user_db: BaseUserDatabase
|
||||
|
||||
def __init__(
|
||||
self, backends: Sequence[BaseAuthentication], user_db: BaseUserDatabase
|
||||
):
|
||||
self.backends = backends
|
||||
self.user_db = user_db
|
||||
|
||||
async def get_current_user(self, request: Request) -> BaseUserDB:
|
||||
return await self._authenticate(request)
|
||||
|
||||
async def get_current_active_user(self, request: Request) -> BaseUserDB:
|
||||
user = await self.get_current_user(request)
|
||||
if not user.is_active:
|
||||
raise self._get_credentials_exception()
|
||||
return user
|
||||
|
||||
async def get_current_superuser(self, request: Request) -> BaseUserDB:
|
||||
user = await self.get_current_active_user(request)
|
||||
if not user.is_superuser:
|
||||
raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN)
|
||||
return user
|
||||
|
||||
async def _authenticate(self, request: Request) -> BaseUserDB:
|
||||
for backend in self.backends:
|
||||
user = await backend(request, self.user_db)
|
||||
if user is not None:
|
||||
return user
|
||||
raise self._get_credentials_exception()
|
||||
|
||||
def _get_credentials_exception(
|
||||
self, status_code: int = status.HTTP_401_UNAUTHORIZED
|
||||
) -> HTTPException:
|
||||
return HTTPException(status_code=status_code)
|
||||
|
||||
@ -1,70 +1,30 @@
|
||||
from typing import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
class BaseAuthentication:
|
||||
"""
|
||||
Base adapter for generating and decoding authentication tokens.
|
||||
Base authentication backend.
|
||||
|
||||
Provides dependency injectors to get current active/superuser user.
|
||||
Every backend should derive from this class.
|
||||
|
||||
:param scheme: Optional authentication scheme for OpenAPI.
|
||||
Defaults to `OAuth2PasswordBearer(tokenUrl="/users/login")`.
|
||||
Override it if your login route lives somewhere else.
|
||||
:param name: Name of the backend. It will be used to name the login route.
|
||||
"""
|
||||
|
||||
scheme: OAuth2PasswordBearer
|
||||
name: str
|
||||
|
||||
def __init__(self, scheme: OAuth2PasswordBearer = None):
|
||||
if scheme is None:
|
||||
self.scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
||||
else:
|
||||
self.scheme = scheme
|
||||
def __init__(self, name: str = "base"):
|
||||
self.name = name
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
async def __call__(
|
||||
self, request: Request, user_db: BaseUserDatabase
|
||||
) -> Optional[BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_user(self, user_db: BaseUserDatabase):
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_active_user(self, user_db: BaseUserDatabase):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_superuser(self, user_db: BaseUserDatabase):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_authentication_method(
|
||||
self, user_db: BaseUserDatabase
|
||||
) -> Callable[..., BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_current_user_base(self, user: BaseUserDB) -> BaseUserDB:
|
||||
if user is None:
|
||||
raise self._get_credentials_exception()
|
||||
return user
|
||||
|
||||
def _get_current_active_user_base(self, user: BaseUserDB) -> BaseUserDB:
|
||||
user = self._get_current_user_base(user)
|
||||
if not user.is_active:
|
||||
raise self._get_credentials_exception()
|
||||
return user
|
||||
|
||||
def _get_current_superuser_base(self, user: BaseUserDB) -> BaseUserDB:
|
||||
user = self._get_current_active_user_base(user)
|
||||
if not user.is_superuser:
|
||||
raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN)
|
||||
return user
|
||||
|
||||
def _get_credentials_exception(
|
||||
self, status_code: int = status.HTTP_401_UNAUTHORIZED
|
||||
) -> HTTPException:
|
||||
return HTTPException(status_code=status_code)
|
||||
|
||||
53
fastapi_users/authentication/cookie.py
Normal file
53
fastapi_users/authentication/cookie.py
Normal file
@ -0,0 +1,53 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi.security import APIKeyCookie
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication.jwt import JWTAuthentication
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class CookieAuthentication(JWTAuthentication):
|
||||
"""
|
||||
Authentication backend using a cookie.
|
||||
|
||||
Internally, uses a JWT token to store the data.
|
||||
|
||||
:param secret: Secret used to encode the cookie.
|
||||
:param lifetime_seconds: Lifetime duration of the cookie in seconds.
|
||||
:param cookie_name: Name of the cookie.
|
||||
:param name: Name of the backend. It will be used to name the login route.
|
||||
"""
|
||||
|
||||
lifetime_seconds: int
|
||||
cookie_name: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secret: str,
|
||||
lifetime_seconds: int,
|
||||
cookie_name: str = "fastapiusersauth",
|
||||
name: str = "cookie",
|
||||
):
|
||||
super().__init__(secret, lifetime_seconds, name=name)
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
self.cookie_name = cookie_name
|
||||
self.api_key_cookie = APIKeyCookie(name=self.cookie_name, auto_error=False)
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||
token = await self._generate_token(user)
|
||||
response.set_cookie(
|
||||
self.cookie_name,
|
||||
token,
|
||||
max_age=self.lifetime_seconds,
|
||||
secure=True,
|
||||
httponly=True,
|
||||
)
|
||||
|
||||
# We shouldn't return directly the response
|
||||
# so that FastAPI can terminate it properly
|
||||
return None
|
||||
|
||||
async def _retrieve_token(self, request: Request) -> Optional[str]:
|
||||
return await self.api_key_cookie.__call__(request)
|
||||
@ -1,5 +1,8 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication.base import BaseAuthentication
|
||||
@ -10,62 +13,58 @@ from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
|
||||
class JWTAuthentication(BaseAuthentication):
|
||||
"""
|
||||
Authentication using a JWT.
|
||||
Authentication backend using a JWT in a Bearer header.
|
||||
|
||||
:param secret: Secret used to encode the JWT.
|
||||
:param lifetime_seconds: Lifetime duration of the JWT in seconds.
|
||||
:param tokenUrl: Path where to get a token.
|
||||
:param name: Name of the backend. It will be used to name the login route.
|
||||
"""
|
||||
|
||||
token_audience: str = "fastapi-users:auth"
|
||||
secret: str
|
||||
lifetime_seconds: int
|
||||
|
||||
def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
secret: str,
|
||||
lifetime_seconds: int,
|
||||
tokenUrl: str = "/users/login",
|
||||
name: str = "jwt",
|
||||
):
|
||||
super().__init__(name)
|
||||
self.secret = secret
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
self.scheme = OAuth2PasswordBearer(tokenUrl, auto_error=False)
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
data = {"user_id": user.id, "aud": self.token_audience}
|
||||
token = generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM)
|
||||
async def __call__(
|
||||
self, request: Request, user_db: BaseUserDatabase,
|
||||
) -> Optional[BaseUserDB]:
|
||||
token = await self._retrieve_token(request)
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = jwt.decode(
|
||||
token,
|
||||
self.secret,
|
||||
audience=self.token_audience,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
)
|
||||
user_id = data.get("user_id")
|
||||
if user_id is None:
|
||||
return None
|
||||
except jwt.PyJWTError:
|
||||
return None
|
||||
return await user_db.get(user_id)
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||
token = await self._generate_token(user)
|
||||
return {"token": token}
|
||||
|
||||
def get_current_user(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_user(token: str = Depends(self.scheme)):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_user_base(user)
|
||||
async def _retrieve_token(self, request: Request) -> Optional[str]:
|
||||
return await self.scheme.__call__(request)
|
||||
|
||||
return _get_current_user
|
||||
|
||||
def get_current_active_user(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_active_user(token: str = Depends(self.scheme)):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_active_user_base(user)
|
||||
|
||||
return _get_current_active_user
|
||||
|
||||
def get_current_superuser(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_superuser(token: str = Depends(self.scheme)):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_superuser_base(user)
|
||||
|
||||
return _get_current_superuser
|
||||
|
||||
def _get_authentication_method(self, user_db: BaseUserDatabase):
|
||||
async def authentication_method(token: str = Depends(self.scheme)):
|
||||
try:
|
||||
data = jwt.decode(
|
||||
token,
|
||||
self.secret,
|
||||
audience=self.token_audience,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
)
|
||||
user_id = data.get("user_id")
|
||||
if user_id is None:
|
||||
return None
|
||||
except jwt.PyJWTError:
|
||||
return None
|
||||
return await user_db.get(user_id)
|
||||
|
||||
return authentication_method
|
||||
async def _generate_token(self, user: BaseUserDB) -> str:
|
||||
data = {"user_id": user.id, "aud": self.token_audience}
|
||||
return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from typing import Callable, Type
|
||||
from typing import Callable, Sequence, Type
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.authentication import Authenticator, BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUser, BaseUserDB
|
||||
from fastapi_users.models import BaseUser
|
||||
from fastapi_users.router import Event, UserRouter, get_user_router
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ class FastAPIUsers:
|
||||
Main object that ties together the component for users authentication.
|
||||
|
||||
:param db: Database adapter instance.
|
||||
:param auth: Authentication logic instance.
|
||||
:param auth_backends: List of authentication backends.
|
||||
:param user_model: Pydantic model of a user.
|
||||
:param reset_password_token_secret: Secret to encode reset password token.
|
||||
:param reset_password_token_lifetime_seconds: Lifetime of reset password token.
|
||||
@ -21,36 +21,30 @@ class FastAPIUsers:
|
||||
"""
|
||||
|
||||
db: BaseUserDatabase
|
||||
auth: BaseAuthentication
|
||||
authenticator: Authenticator
|
||||
router: UserRouter
|
||||
get_current_user: Callable[..., BaseUserDB]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: BaseUserDatabase,
|
||||
auth: BaseAuthentication,
|
||||
auth_backends: Sequence[BaseAuthentication],
|
||||
user_model: Type[BaseUser],
|
||||
reset_password_token_secret: str,
|
||||
reset_password_token_lifetime_seconds: int = 3600,
|
||||
):
|
||||
self.db = db
|
||||
self.auth = auth
|
||||
self.authenticator = Authenticator(auth_backends, db)
|
||||
self.router = get_user_router(
|
||||
self.db,
|
||||
user_model,
|
||||
self.auth,
|
||||
self.authenticator,
|
||||
reset_password_token_secret,
|
||||
reset_password_token_lifetime_seconds,
|
||||
)
|
||||
|
||||
get_current_user = self.auth.get_current_user(self.db)
|
||||
self.get_current_user = get_current_user # type: ignore
|
||||
|
||||
get_current_active_user = self.auth.get_current_active_user(self.db)
|
||||
self.get_current_active_user = get_current_active_user # type: ignore
|
||||
|
||||
get_current_superuser = self.auth.get_current_superuser(self.db)
|
||||
self.get_current_superuser = get_current_superuser # type: ignore
|
||||
self.get_current_user = self.authenticator.get_current_user
|
||||
self.get_current_active_user = self.authenticator.get_current_active_user
|
||||
self.get_current_superuser = self.authenticator.get_current_superuser
|
||||
|
||||
def on_after_register(self) -> Callable:
|
||||
"""Add an event handler on successful registration."""
|
||||
|
||||
@ -10,7 +10,7 @@ from pydantic import EmailStr
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.authentication import Authenticator, BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUser, Models
|
||||
from fastapi_users.password import get_password_hash
|
||||
@ -46,10 +46,28 @@ class UserRouter(APIRouter):
|
||||
handler(*args, **kwargs)
|
||||
|
||||
|
||||
def _add_login_route(
|
||||
router: UserRouter, user_db: BaseUserDatabase, auth_backend: BaseAuthentication
|
||||
):
|
||||
@router.post(f"/login/{auth_backend.name}")
|
||||
async def login(
|
||||
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
|
||||
):
|
||||
user = await user_db.authenticate(credentials)
|
||||
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
|
||||
return await auth_backend.get_login_response(user, response)
|
||||
|
||||
|
||||
def get_user_router(
|
||||
user_db: BaseUserDatabase,
|
||||
user_model: typing.Type[BaseUser],
|
||||
auth: BaseAuthentication,
|
||||
authenticator: Authenticator,
|
||||
reset_password_token_secret: str,
|
||||
reset_password_token_lifetime_seconds: int = 3600,
|
||||
) -> UserRouter:
|
||||
@ -59,8 +77,8 @@ def get_user_router(
|
||||
|
||||
reset_password_token_audience = "fastapi-users:reset"
|
||||
|
||||
get_current_active_user = auth.get_current_active_user(user_db)
|
||||
get_current_superuser = auth.get_current_superuser(user_db)
|
||||
get_current_active_user = authenticator.get_current_active_user
|
||||
get_current_superuser = authenticator.get_current_superuser
|
||||
|
||||
async def _get_or_404(id: str) -> models.UserDB: # type: ignore
|
||||
user = await user_db.get(id)
|
||||
@ -79,6 +97,9 @@ def get_user_router(
|
||||
setattr(user, field, update_dict[field])
|
||||
return await user_db.update(user)
|
||||
|
||||
for auth_backend in authenticator.backends:
|
||||
_add_login_route(router, user_db, auth_backend)
|
||||
|
||||
@router.post(
|
||||
"/register", response_model=models.User, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
@ -101,20 +122,6 @@ def get_user_router(
|
||||
|
||||
return created_user
|
||||
|
||||
@router.post("/login")
|
||||
async def login(
|
||||
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
|
||||
):
|
||||
user = await user_db.authenticate(credentials)
|
||||
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
|
||||
return await auth.get_login_response(user, response)
|
||||
|
||||
@router.post("/forgot-password", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def forgot_password(email: EmailStr = Body(..., embed=True)):
|
||||
user = await user_db.get_by_email(email)
|
||||
|
||||
@ -34,8 +34,9 @@ nav:
|
||||
- configuration/databases/mongodb.md
|
||||
- configuration/databases/tortoise.md
|
||||
- Authentication:
|
||||
- configuration/authentication/index.md
|
||||
- Introduction: configuration/authentication/index.md
|
||||
- configuration/authentication/jwt.md
|
||||
- configuration/authentication/cookie.md
|
||||
- configuration/router.md
|
||||
- configuration/full_example.md
|
||||
- Usage:
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Mapping, Optional, Tuple
|
||||
|
||||
import http.cookies
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.authentication import Authenticator, BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUserDB
|
||||
from fastapi_users.password import get_password_hash
|
||||
@ -81,70 +83,76 @@ def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase:
|
||||
return MockUserDatabase()
|
||||
|
||||
|
||||
class MockAuthentication(BaseAuthentication):
|
||||
def __init__(self, name: str = "mock"):
|
||||
super().__init__(name)
|
||||
self.scheme = OAuth2PasswordBearer("/users/login", auto_error=False)
|
||||
|
||||
async def __call__(self, request: Request, user_db: BaseUserDatabase):
|
||||
token = await self.scheme.__call__(request)
|
||||
if token is not None:
|
||||
return await user_db.get(token)
|
||||
return None
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
return {"token": user.id}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_authentication():
|
||||
class MockAuthentication(BaseAuthentication):
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
return {"token": user.id}
|
||||
|
||||
def get_current_user(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_user(token: str = Depends(self.oauth2_scheme)):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_user_base(user)
|
||||
|
||||
return _get_current_user
|
||||
|
||||
def get_current_active_user(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_active_user(
|
||||
token: str = Depends(self.oauth2_scheme),
|
||||
):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_active_user_base(user)
|
||||
|
||||
return _get_current_active_user
|
||||
|
||||
def get_current_superuser(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_superuser_base(user)
|
||||
|
||||
return _get_current_superuser
|
||||
|
||||
def _get_authentication_method(self, user_db: BaseUserDatabase):
|
||||
async def authentication_method(token: str = Depends(self.oauth2_scheme)):
|
||||
return await user_db.get(token)
|
||||
|
||||
return authentication_method
|
||||
|
||||
return MockAuthentication()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def request_builder():
|
||||
def _request_builder(
|
||||
headers: Mapping[str, Any] = None, cookies: Mapping[str, str] = None
|
||||
) -> Request:
|
||||
encoded_headers: List[Tuple[bytes, bytes]] = []
|
||||
|
||||
if headers is not None:
|
||||
encoded_headers += [
|
||||
(key.lower().encode("latin-1"), headers[key].encode("latin-1"))
|
||||
for key in headers
|
||||
]
|
||||
|
||||
if cookies is not None:
|
||||
for key in cookies:
|
||||
cookie = http.cookies.SimpleCookie() # type: http.cookies.BaseCookie
|
||||
cookie[key] = cookies[key]
|
||||
cookie_val = cookie.output(header="").strip()
|
||||
encoded_headers.append((b"cookie", cookie_val.encode("latin-1")))
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"headers": encoded_headers,
|
||||
}
|
||||
return Request(scope)
|
||||
|
||||
return _request_builder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_test_auth_client(mock_user_db):
|
||||
def _get_test_auth_client(authentication):
|
||||
def _get_test_auth_client(backends: List[BaseAuthentication]) -> TestClient:
|
||||
app = FastAPI()
|
||||
authenticator = Authenticator(backends, mock_user_db)
|
||||
|
||||
@app.get("/test-current-user")
|
||||
def test_current_user(
|
||||
user: BaseUserDB = Depends(authentication.get_current_user(mock_user_db)),
|
||||
user: BaseUserDB = Depends(authenticator.get_current_user),
|
||||
):
|
||||
return user
|
||||
|
||||
@app.get("/test-current-active-user")
|
||||
def test_current_active_user(
|
||||
user: BaseUserDB = Depends(
|
||||
authentication.get_current_active_user(mock_user_db)
|
||||
),
|
||||
user: BaseUserDB = Depends(authenticator.get_current_active_user),
|
||||
):
|
||||
return user
|
||||
|
||||
@app.get("/test-current-superuser")
|
||||
def test_current_superuser(
|
||||
user: BaseUserDB = Depends(
|
||||
authentication.get_current_superuser(mock_user_db)
|
||||
),
|
||||
user: BaseUserDB = Depends(authenticator.get_current_superuser),
|
||||
):
|
||||
return user
|
||||
|
||||
|
||||
45
tests/test_authentication.py
Normal file
45
tests/test_authentication.py
Normal file
@ -0,0 +1,45 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def auth_backend_none():
|
||||
class BackendNone(BaseAuthentication):
|
||||
async def __call__(
|
||||
self, request: Request, user_db: BaseUserDatabase
|
||||
) -> Optional[BaseUserDB]:
|
||||
return None
|
||||
|
||||
return BackendNone()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def auth_backend_user(user):
|
||||
class BackendUser(BaseAuthentication):
|
||||
async def __call__(
|
||||
self, request: Request, user_db: BaseUserDatabase
|
||||
) -> Optional[BaseUserDB]:
|
||||
return user
|
||||
|
||||
return BackendUser()
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
def test_authenticator(get_test_auth_client, auth_backend_none, auth_backend_user):
|
||||
client = get_test_auth_client([auth_backend_none, auth_backend_user])
|
||||
response = client.get("/test-current-user")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
def test_authenticator_none(get_test_auth_client, auth_backend_none):
|
||||
client = get_test_auth_client([auth_backend_none, auth_backend_none])
|
||||
response = client.get("/test-current-user")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@ -1,30 +1,27 @@
|
||||
import pytest
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.fixture
|
||||
def base_authentication():
|
||||
return BaseAuthentication()
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
@pytest.mark.parametrize(
|
||||
"constructor_kwargs", [{}, {"scheme": OAuth2PasswordBearer(tokenUrl="/foo")}]
|
||||
)
|
||||
async def test_not_implemented_methods(constructor_kwargs, user, mock_user_db):
|
||||
response = Response()
|
||||
base_authentication = BaseAuthentication(**constructor_kwargs)
|
||||
class TestAuthenticate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_implemented(
|
||||
self, base_authentication, mock_user_db, request_builder
|
||||
):
|
||||
request = request_builder({})
|
||||
with pytest.raises(NotImplementedError):
|
||||
await base_authentication(request, mock_user_db)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await base_authentication.get_login_response(user, response)
|
||||
|
||||
@pytest.mark.authentication
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_login_response(base_authentication, user):
|
||||
with pytest.raises(NotImplementedError):
|
||||
await base_authentication.get_current_user(mock_user_db)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await base_authentication.get_current_active_user(mock_user_db)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await base_authentication.get_current_superuser(mock_user_db)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await base_authentication._get_authentication_method(mock_user_db)
|
||||
await base_authentication.get_login_response(user, Response())
|
||||
|
||||
105
tests/test_authentication_cookie.py
Normal file
105
tests/test_authentication_cookie.py
Normal file
@ -0,0 +1,105 @@
|
||||
import re
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication.cookie import CookieAuthentication
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
|
||||
SECRET = "SECRET"
|
||||
LIFETIME = 3600
|
||||
COOKIE_NAME = "COOKIE_NAME"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cookie_authentication():
|
||||
return CookieAuthentication(SECRET, LIFETIME, COOKIE_NAME)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token():
|
||||
def _token(user=None, lifetime=LIFETIME):
|
||||
data = {"aud": "fastapi-users:auth"}
|
||||
if user is not None:
|
||||
data["user_id"] = user.id
|
||||
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
|
||||
|
||||
return _token
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
def test_default_name(cookie_authentication):
|
||||
assert cookie_authentication.name == "cookie"
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
class TestAuthenticate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_token(
|
||||
self, cookie_authentication, mock_user_db, request_builder
|
||||
):
|
||||
request = request_builder()
|
||||
authenticated_user = await cookie_authentication(request, mock_user_db)
|
||||
assert authenticated_user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_token(
|
||||
self, cookie_authentication, mock_user_db, request_builder
|
||||
):
|
||||
cookies = {}
|
||||
cookies[COOKIE_NAME] = "foo"
|
||||
request = request_builder(cookies=cookies)
|
||||
authenticated_user = await cookie_authentication(request, mock_user_db)
|
||||
assert authenticated_user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_token_missing_user_payload(
|
||||
self, cookie_authentication, mock_user_db, request_builder, token
|
||||
):
|
||||
cookies = {}
|
||||
cookies[COOKIE_NAME] = token()
|
||||
request = request_builder(cookies=cookies)
|
||||
authenticated_user = await cookie_authentication(request, mock_user_db)
|
||||
assert authenticated_user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_token(
|
||||
self, cookie_authentication, mock_user_db, request_builder, token, user
|
||||
):
|
||||
cookies = {}
|
||||
cookies[COOKIE_NAME] = token(user)
|
||||
request = request_builder(cookies=cookies)
|
||||
authenticated_user = await cookie_authentication(request, mock_user_db)
|
||||
assert authenticated_user.id == user.id
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_login_response(cookie_authentication, user):
|
||||
response = Response()
|
||||
login_response = await cookie_authentication.get_login_response(user, response)
|
||||
|
||||
# We shouldn't return directly the response
|
||||
# so that FastAPI can terminate it properly
|
||||
assert login_response is None
|
||||
|
||||
cookies = [
|
||||
header for header in response.raw_headers if header[0] == b"set-cookie"
|
||||
]
|
||||
assert len(cookies) == 1
|
||||
|
||||
cookie = cookies[0][1].decode("latin-1")
|
||||
|
||||
assert f"Max-Age={LIFETIME}" in cookie
|
||||
|
||||
cookie_name_value = re.match(r"^(\w+)=([^;]+);", cookie)
|
||||
|
||||
cookie_name = cookie_name_value[1]
|
||||
assert cookie_name == COOKIE_NAME
|
||||
|
||||
cookie_value = cookie_name_value[2]
|
||||
decoded = jwt.decode(
|
||||
cookie_value, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
|
||||
)
|
||||
assert decoded["user_id"] == user.id
|
||||
@ -1,6 +1,5 @@
|
||||
import jwt
|
||||
import pytest
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication.jwt import JWTAuthentication
|
||||
@ -8,11 +7,12 @@ from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
|
||||
SECRET = "SECRET"
|
||||
LIFETIME = 3600
|
||||
TOKEN_URL = "/login"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_authentication():
|
||||
return JWTAuthentication(SECRET, LIFETIME)
|
||||
return JWTAuthentication(SECRET, LIFETIME, TOKEN_URL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -26,11 +26,47 @@ def token():
|
||||
return _token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_auth_client(get_test_auth_client, jwt_authentication):
|
||||
return get_test_auth_client(jwt_authentication)
|
||||
@pytest.mark.authentication
|
||||
def test_default_name(jwt_authentication):
|
||||
assert jwt_authentication.name == "jwt"
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
class TestAuthenticate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_token(
|
||||
self, jwt_authentication, mock_user_db, request_builder
|
||||
):
|
||||
request = request_builder(headers={})
|
||||
authenticated_user = await jwt_authentication(request, mock_user_db)
|
||||
assert authenticated_user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_token(
|
||||
self, jwt_authentication, mock_user_db, request_builder
|
||||
):
|
||||
request = request_builder(headers={"Authorization": "Bearer foo"})
|
||||
authenticated_user = await jwt_authentication(request, mock_user_db)
|
||||
assert authenticated_user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_token_missing_user_payload(
|
||||
self, jwt_authentication, mock_user_db, request_builder, token
|
||||
):
|
||||
request = request_builder(headers={"Authorization": f"Bearer {token()}"})
|
||||
authenticated_user = await jwt_authentication(request, mock_user_db)
|
||||
assert authenticated_user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_token(
|
||||
self, jwt_authentication, mock_user_db, request_builder, token, user
|
||||
):
|
||||
request = request_builder(headers={"Authorization": f"Bearer {token(user)}"})
|
||||
authenticated_user = await jwt_authentication(request, mock_user_db)
|
||||
assert authenticated_user.id == user.id
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_login_response(jwt_authentication, user):
|
||||
login_response = await jwt_authentication.get_login_response(user, Response())
|
||||
@ -42,108 +78,3 @@ async def test_get_login_response(jwt_authentication, user):
|
||||
token, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
|
||||
)
|
||||
assert decoded["user_id"] == user.id
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
class TestGetCurrentUser:
|
||||
def test_missing_token(self, test_auth_client):
|
||||
response = test_auth_client.get("/test-current-user")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_invalid_token(self, test_auth_client):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-user", headers={"Authorization": "Bearer foo"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_missing_user_payload(self, test_auth_client, token):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-user", headers={"Authorization": f"Bearer {token()}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-user",
|
||||
headers={"Authorization": f"Bearer {token(inactive_user)}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == inactive_user.id
|
||||
|
||||
def test_valid_token(self, test_auth_client, token, user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-user", headers={"Authorization": f"Bearer {token(user)}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == user.id
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
class TestGetCurrentActiveUser:
|
||||
def test_missing_token(self, test_auth_client):
|
||||
response = test_auth_client.get("/test-current-active-user")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_invalid_token(self, test_auth_client):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-active-user", headers={"Authorization": "Bearer foo"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-active-user",
|
||||
headers={"Authorization": f"Bearer {token(inactive_user)}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token(self, test_auth_client, token, user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-active-user",
|
||||
headers={"Authorization": f"Bearer {token(user)}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == user.id
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
class TestGetCurrentSuperuser:
|
||||
def test_missing_token(self, test_auth_client):
|
||||
response = test_auth_client.get("/test-current-superuser")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_invalid_token(self, test_auth_client):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-superuser", headers={"Authorization": "Bearer foo"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-superuser",
|
||||
headers={"Authorization": f"Bearer {token(inactive_user)}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_regular_user(self, test_auth_client, token, user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-superuser",
|
||||
headers={"Authorization": f"Bearer {token(user)}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_valid_token_superuser(self, test_auth_client, token, superuser):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-superuser",
|
||||
headers={"Authorization": f"Bearer {token(superuser)}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == superuser.id
|
||||
|
||||
@ -21,7 +21,7 @@ def fastapi_users(request, mock_user_db, mock_authentication) -> FastAPIUsers:
|
||||
class User(BaseUser):
|
||||
pass
|
||||
|
||||
fastapi_users = FastAPIUsers(mock_user_db, mock_authentication, User, "SECRET")
|
||||
fastapi_users = FastAPIUsers(mock_user_db, [mock_authentication], User, "SECRET")
|
||||
|
||||
@fastapi_users.on_after_register()
|
||||
def on_after_register():
|
||||
|
||||
@ -7,9 +7,11 @@ from fastapi import FastAPI
|
||||
from starlette import status
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from fastapi_users.authentication import Authenticator
|
||||
from fastapi_users.models import BaseUser, BaseUserDB
|
||||
from fastapi_users.router import ErrorCode, Event, get_user_router
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
from tests.conftest import MockAuthentication
|
||||
|
||||
SECRET = "SECRET"
|
||||
LIFETIME = 3600
|
||||
@ -44,10 +46,13 @@ def test_app_client(mock_user_db, mock_authentication, event_handler) -> TestCli
|
||||
class User(BaseUser):
|
||||
pass
|
||||
|
||||
userRouter = get_user_router(
|
||||
mock_user_db, User, mock_authentication, SECRET, LIFETIME
|
||||
mock_authentication_bis = MockAuthentication(name="mock-bis")
|
||||
authenticator = Authenticator(
|
||||
[mock_authentication, mock_authentication_bis], mock_user_db
|
||||
)
|
||||
|
||||
userRouter = get_user_router(mock_user_db, User, authenticator, SECRET, LIFETIME)
|
||||
|
||||
userRouter.add_event_handler(Event.ON_AFTER_REGISTER, event_handler)
|
||||
userRouter.add_event_handler(Event.ON_AFTER_FORGOT_PASSWORD, event_handler)
|
||||
|
||||
@ -125,42 +130,45 @@ class TestRegister:
|
||||
|
||||
|
||||
@pytest.mark.router
|
||||
@pytest.mark.parametrize("path", ["/login/mock", "/login/mock-bis"])
|
||||
class TestLogin:
|
||||
def test_empty_body(self, test_app_client: TestClient):
|
||||
response = test_app_client.post("/login", data={})
|
||||
def test_empty_body(self, path, test_app_client: TestClient):
|
||||
response = test_app_client.post(path, data={})
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
def test_missing_username(self, test_app_client: TestClient):
|
||||
def test_missing_username(self, path, test_app_client: TestClient):
|
||||
data = {"password": "guinevere"}
|
||||
response = test_app_client.post("/login", data=data)
|
||||
response = test_app_client.post(path, data=data)
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
def test_missing_password(self, test_app_client: TestClient):
|
||||
def test_missing_password(self, path, test_app_client: TestClient):
|
||||
data = {"username": "king.arthur@camelot.bt"}
|
||||
response = test_app_client.post("/login", data=data)
|
||||
response = test_app_client.post(path, data=data)
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
def test_not_existing_user(self, test_app_client: TestClient):
|
||||
def test_not_existing_user(self, path, test_app_client: TestClient):
|
||||
data = {"username": "lancelot@camelot.bt", "password": "guinevere"}
|
||||
response = test_app_client.post("/login", data=data)
|
||||
response = test_app_client.post(path, data=data)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS
|
||||
|
||||
def test_wrong_password(self, test_app_client: TestClient):
|
||||
def test_wrong_password(self, path, test_app_client: TestClient):
|
||||
data = {"username": "king.arthur@camelot.bt", "password": "percival"}
|
||||
response = test_app_client.post("/login", data=data)
|
||||
response = test_app_client.post(path, data=data)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS
|
||||
|
||||
def test_valid_credentials(self, test_app_client: TestClient, user: BaseUserDB):
|
||||
def test_valid_credentials(
|
||||
self, path, test_app_client: TestClient, user: BaseUserDB
|
||||
):
|
||||
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
|
||||
response = test_app_client.post("/login", data=data)
|
||||
response = test_app_client.post(path, data=data)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == {"token": user.id}
|
||||
|
||||
def test_inactive_user(self, test_app_client: TestClient):
|
||||
def test_inactive_user(self, path, test_app_client: TestClient):
|
||||
data = {"username": "percival@camelot.bt", "password": "angharad"}
|
||||
response = test_app_client.post("/login", data=data)
|
||||
response = test_app_client.post(path, data=data)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS
|
||||
|
||||
|
||||
Reference in New Issue
Block a user