Fix #701: factorize JWT handling and support secrets as SecretStr

This commit is contained in:
François Voron
2021-09-09 11:51:55 +02:00
parent c7f1e448a2
commit 7ae2042500
21 changed files with 175 additions and 158 deletions

View File

@ -18,7 +18,7 @@ auth_backends.append(cookie_authentication)
As you can see, instantiation is quite simple. It accepts the following arguments:
* `secret` (`str`): A constant secret which is used to encode the cookie. **Use a strong passphrase and keep it secure.**
* `secret` (`Union[str, pydantic.SecretStr]`): A constant secret which is used to encode the cookie. **Use a strong passphrase and keep it secure.**
* `lifetime_seconds` (`int`): The lifetime of the cookie in seconds.
* `cookie_name` (`fastapiusersauth`): Name of the cookie.
* `cookie_path` (`/`): Cookie path.

View File

@ -18,7 +18,7 @@ auth_backends.append(jwt_authentication)
As you can see, instantiation is quite simple. It accepts the following arguments:
* `secret` (`str`): A constant secret which is used to encode the token. **Use a strong passphrase and keep it secure.**
* `secret` (`Union[str, pydantic.SecretStr]`): A constant secret which is used to encode the token. **Use a strong passphrase and keep it secure.**
* `lifetime_seconds` (`int`): The lifetime of the token in seconds.
* `tokenUrl` (`Optional[str]`): The exact path of your login endpoint. It'll allow the interactive documentation to automatically discover it and get a working *Authorize* button. In most cases, you'll probably need a **relative** path, not absolute. You can read more details about this in the [FastAPI documentation](https://fastapi.tiangolo.com/tutorial/security/first-steps/#fastapis-oauth2passwordbearer). Defaults to `auth/jwt/login`.
* `name` (`Optional[str]`): Name of the backend. It's useful in the case you wish to have several backends of the same class. Each backend should have a unique name. Defaults to `jwt`.

View File

@ -29,8 +29,8 @@ app.include_router(
Parameters:
* `reset_password_token_secret`: Secret to encode reset password token.
* `reset_password_token_lifetime_seconds`: Lifetime of reset password token. **Defaults to 3600**.
* `reset_password_token_secret` (`Union[str, pydantic.SecretStr]`): Secret to encode reset password token.
* `reset_password_token_lifetime_seconds` (`int`): Lifetime of reset password token. **Defaults to 3600**.
* `after_forgot_password`: Optional function called after a successful forgot password request. See below.
## After forgot password

View File

@ -30,8 +30,8 @@ app.include_router(
Parameters:
* `verification_token_secret`: Secret to encode verify token.
* `verification_token_lifetime_seconds`: Lifetime of verify token. **Defaults to 3600**.
* `verification_token_secret` (`Union[str, pydantic.SecretStr]`): Secret to encode verify token.
* `verification_token_lifetime_seconds` (`int`): Lifetime of verify token. **Defaults to 3600**.
* `after_verification_request`: Optional function called after a successful verify request. See below.
* `after_verification`: Optional function called after a successful verification. See below.

View File

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any, List, Optional
import jwt
from fastapi import Response
@ -7,8 +7,8 @@ from pydantic import UUID4
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db.base import BaseUserDatabase
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.models import BaseUserDB
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
class CookieAuthentication(BaseAuthentication[str]):
@ -25,11 +25,12 @@ class CookieAuthentication(BaseAuthentication[str]):
:param cookie_secure: Whether to only send the cookie to the server via SSL request.
:param cookie_httponly: Whether to prevent access to the cookie via JavaScript.
:param name: Name of the backend. It will be used to name the login route.
:param token_audience: List of valid audiences for the JWT.
"""
scheme: APIKeyCookie
token_audience: str = "fastapi-users:auth"
secret: str
token_audience: List[str]
secret: SecretType
lifetime_seconds: Optional[int]
cookie_name: str
cookie_path: str
@ -40,7 +41,7 @@ class CookieAuthentication(BaseAuthentication[str]):
def __init__(
self,
secret: str,
secret: SecretType,
lifetime_seconds: Optional[int] = None,
cookie_name: str = "fastapiusersauth",
cookie_path: str = "/",
@ -49,6 +50,7 @@ class CookieAuthentication(BaseAuthentication[str]):
cookie_httponly: bool = True,
cookie_samesite: str = "lax",
name: str = "cookie",
token_audience: List[str] = ["fastapi-users:auth"],
):
super().__init__(name, logout=True)
self.secret = secret
@ -59,6 +61,7 @@ class CookieAuthentication(BaseAuthentication[str]):
self.cookie_secure = cookie_secure
self.cookie_httponly = cookie_httponly
self.cookie_samesite = cookie_samesite
self.token_audience = token_audience
self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False)
async def __call__(
@ -70,12 +73,7 @@ class CookieAuthentication(BaseAuthentication[str]):
return None
try:
data = jwt.decode(
credentials,
self.secret,
audience=self.token_audience,
algorithms=[JWT_ALGORITHM],
)
data = decode_jwt(credentials, self.secret, self.token_audience)
user_id = data.get("user_id")
if user_id is None:
return None
@ -112,4 +110,4 @@ class CookieAuthentication(BaseAuthentication[str]):
async def _generate_token(self, user: BaseUserDB) -> str:
data = {"user_id": str(user.id), "aud": self.token_audience}
return generate_jwt(data, self.secret, self.lifetime_seconds, JWT_ALGORITHM)
return generate_jwt(data, self.secret, self.lifetime_seconds)

View File

@ -7,8 +7,8 @@ from pydantic import UUID4
from fastapi_users.authentication.base import BaseAuthentication
from fastapi_users.db.base import BaseUserDatabase
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.models import BaseUserDB
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
class JWTAuthentication(BaseAuthentication[str]):
@ -19,11 +19,12 @@ class JWTAuthentication(BaseAuthentication[str]):
:param lifetime_seconds: Lifetime duration of the JWT in seconds.
:param tokenUrl: Path where to get a token.
:param name: Name of the backend. It will be used to name the login route.
:param token_audience: List of valid audiences for the JWT.
"""
scheme: OAuth2PasswordBearer
token_audience: List[str] = ["fastapi-users:auth"]
secret: str
token_audience: List[str]
secret: SecretType
lifetime_seconds: int
def __init__(
@ -49,12 +50,7 @@ class JWTAuthentication(BaseAuthentication[str]):
return None
try:
data = jwt.decode(
credentials,
self.secret,
audience=self.token_audience,
algorithms=[JWT_ALGORITHM],
)
data = decode_jwt(credentials, self.secret, self.token_audience)
user_id = data.get("user_id")
if user_id is None:
return None
@ -73,4 +69,4 @@ class JWTAuthentication(BaseAuthentication[str]):
async def _generate_token(self, user: BaseUserDB) -> str:
data = {"user_id": str(user.id), "aud": self.token_audience}
return generate_jwt(data, self.secret, self.lifetime_seconds, JWT_ALGORITHM)
return generate_jwt(data, self.secret, self.lifetime_seconds)

View File

@ -5,6 +5,7 @@ from fastapi import APIRouter, Request
from fastapi_users import models
from fastapi_users.authentication import Authenticator, BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.jwt import SecretType
from fastapi_users.router import (
get_auth_router,
get_register_router,
@ -132,7 +133,7 @@ class FastAPIUsers:
def get_verify_router(
self,
verification_token_secret: str,
verification_token_secret: SecretType,
verification_token_lifetime_seconds: int = 3600,
after_verification_request: Optional[
Callable[[models.UD, str, Request], None]
@ -161,7 +162,7 @@ class FastAPIUsers:
def get_reset_password_router(
self,
reset_password_token_secret: str,
reset_password_token_secret: SecretType,
reset_password_token_lifetime_seconds: int = 3600,
after_forgot_password: Optional[
Callable[[models.UD, str, Request], None]
@ -207,7 +208,7 @@ class FastAPIUsers:
def get_oauth_router(
self,
oauth_client: BaseOAuth2,
state_secret: str,
state_secret: SecretType,
redirect_url: str = None,
after_register: Optional[Callable[[models.UD, Request], None]] = None,
) -> APIRouter:

41
fastapi_users/jwt.py Normal file
View File

@ -0,0 +1,41 @@
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Union
import jwt
from pydantic import SecretStr
SecretType = Union[str, SecretStr]
JWT_ALGORITHM = "HS256"
def _get_secret_value(secret: SecretType) -> str:
if isinstance(secret, SecretStr):
return secret.get_secret_value()
return secret
def generate_jwt(
data: dict,
secret: SecretType,
lifetime_seconds: Optional[int] = None,
algorithm: str = JWT_ALGORITHM,
) -> str:
payload = data.copy()
if lifetime_seconds:
expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds)
payload["exp"] = expire
return jwt.encode(payload, _get_secret_value(secret), algorithm=algorithm)
def decode_jwt(
encoded_jwt: str,
secret: SecretType,
audience: List[str],
algorithms: List[str] = [JWT_ALGORITHM],
) -> Dict[str, Any]:
return jwt.decode(
encoded_jwt,
_get_secret_value(secret),
audience=audience,
algorithms=algorithms,
)

View File

@ -8,27 +8,18 @@ from httpx_oauth.oauth2 import BaseOAuth2
from fastapi_users import models
from fastapi_users.authentication import Authenticator
from fastapi_users.db import BaseUserDatabase
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.password import generate_password, get_password_hash
from fastapi_users.router.common import ErrorCode, run_handler
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
def generate_state_token(
data: Dict[str, str], secret: str, lifetime_seconds: int = 3600
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
) -> str:
data["aud"] = STATE_TOKEN_AUDIENCE
return generate_jwt(data, secret, lifetime_seconds, JWT_ALGORITHM)
def decode_state_token(token: str, secret: str) -> Dict[str, str]:
return jwt.decode(
token,
secret,
audience=STATE_TOKEN_AUDIENCE,
algorithms=[JWT_ALGORITHM],
)
return generate_jwt(data, secret, lifetime_seconds)
def get_oauth_router(
@ -36,7 +27,7 @@ def get_oauth_router(
user_db: BaseUserDatabase[models.BaseUserDB],
user_db_model: Type[models.BaseUserDB],
authenticator: Authenticator,
state_secret: str,
state_secret: SecretType,
redirect_url: str = None,
after_register: Optional[Callable[[models.UD, Request], None]] = None,
) -> APIRouter:
@ -99,7 +90,7 @@ def get_oauth_router(
)
try:
state_data = decode_state_token(state, state_secret)
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
except jwt.DecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)

View File

@ -6,17 +6,17 @@ from pydantic import UUID4, EmailStr
from fastapi_users import models
from fastapi_users.db import BaseUserDatabase
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.password import get_password_hash
from fastapi_users.router.common import ErrorCode, run_handler
from fastapi_users.user import InvalidPasswordException, ValidatePasswordProtocol
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset"
def get_reset_password_router(
user_db: BaseUserDatabase[models.BaseUserDB],
reset_password_token_secret: str,
reset_password_token_secret: SecretType,
reset_password_token_lifetime_seconds: int = 3600,
after_forgot_password: Optional[Callable[[models.UD, str, Request], None]] = None,
after_reset_password: Optional[Callable[[models.UD, Request], None]] = None,
@ -48,11 +48,8 @@ def get_reset_password_router(
request: Request, token: str = Body(...), password: str = Body(...)
):
try:
data = jwt.decode(
token,
reset_password_token_secret,
audience=RESET_PASSWORD_TOKEN_AUDIENCE,
algorithms=[JWT_ALGORITHM],
data = decode_jwt(
token, reset_password_token_secret, [RESET_PASSWORD_TOKEN_AUDIENCE]
)
user_id = data.get("user_id")
if user_id is None:

View File

@ -5,6 +5,7 @@ from fastapi import APIRouter, Body, HTTPException, Request, status
from pydantic import UUID4, EmailStr
from fastapi_users import models
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.router.common import ErrorCode, run_handler
from fastapi_users.user import (
GetUserProtocol,
@ -12,7 +13,6 @@ from fastapi_users.user import (
UserNotExists,
VerifyUserProtocol,
)
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
VERIFY_USER_TOKEN_AUDIENCE = "fastapi-users:verify"
@ -21,7 +21,7 @@ def get_verify_router(
verify_user: VerifyUserProtocol,
get_user: GetUserProtocol,
user_model: Type[models.BaseUser],
verification_token_secret: str,
verification_token_secret: SecretType,
verification_token_lifetime_seconds: int = 3600,
after_verification_request: Optional[
Callable[[models.UD, str, Request], None]
@ -58,11 +58,8 @@ def get_verify_router(
@router.post("/verify", response_model=user_model)
async def verify(request: Request, token: str = Body(..., embed=True)):
try:
data = jwt.decode(
token,
verification_token_secret,
audience=VERIFY_USER_TOKEN_AUDIENCE,
algorithms=[JWT_ALGORITHM],
data = decode_jwt(
token, verification_token_secret, [VERIFY_USER_TOKEN_AUDIENCE]
)
except jwt.exceptions.ExpiredSignatureError:
raise HTTPException(

View File

@ -1,19 +0,0 @@
from datetime import datetime, timedelta
from typing import Optional
import jwt
JWT_ALGORITHM = "HS256"
def generate_jwt(
data: dict,
secret: str,
lifetime_seconds: Optional[int] = None,
algorithm: str = JWT_ALGORITHM,
) -> str:
payload = data.copy()
if lifetime_seconds:
expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds)
payload["exp"] = expire
return jwt.encode(payload, secret, algorithm=algorithm)

View File

@ -8,12 +8,13 @@ from asgi_lifespan import LifespanManager
from fastapi import Depends, FastAPI, Response
from fastapi.security import OAuth2PasswordBearer
from httpx_oauth.oauth2 import OAuth2
from pydantic import UUID4
from pydantic import UUID4, SecretStr
from starlette.applications import ASGIApp
from fastapi_users import models
from fastapi_users.authentication import Authenticator, BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.jwt import SecretType
from fastapi_users.models import BaseOAuthAccount, BaseOAuthAccountMixin, BaseUserDB
from fastapi_users.password import get_password_hash
from fastapi_users.user import InvalidPasswordException, ValidatePasswordProtocol
@ -56,6 +57,11 @@ def event_loop():
yield loop
@pytest.fixture(params=["SECRET", SecretStr("SECRET")])
def secret(request) -> SecretType:
return request.param
@pytest.fixture
def user() -> UserDB:
return UserDB(

View File

@ -1,90 +1,101 @@
import re
import jwt
import pytest
from fastapi import Response
from fastapi_users.authentication.cookie import CookieAuthentication
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
SECRET = "SECRET"
LIFETIME = 3600
COOKIE_NAME = "COOKIE_NAME"
cookie_authentication = CookieAuthentication(SECRET, LIFETIME, COOKIE_NAME)
cookie_authentication_path = CookieAuthentication(
SECRET, LIFETIME, COOKIE_NAME, cookie_path="/arthur"
)
cookie_authentication_domain = CookieAuthentication(
SECRET, LIFETIME, COOKIE_NAME, cookie_domain="camelot.bt"
)
cookie_authentication_secure = CookieAuthentication(
SECRET, LIFETIME, COOKIE_NAME, cookie_secure=False
)
cookie_authentication_httponly = CookieAuthentication(
SECRET, LIFETIME, COOKIE_NAME, cookie_httponly=False
@pytest.fixture(
params=[
("/", None, True, True),
("/arthur", None, True, True),
("/", "camelot.bt", True, True),
("/", None, False, True),
("/", None, True, False),
]
)
def cookie_authentication(secret: SecretType, request):
path, domain, secure, httponly = request.param
return CookieAuthentication(
secret,
lifetime_seconds=LIFETIME,
cookie_name=COOKIE_NAME,
cookie_path=path,
cookie_domain=domain,
cookie_secure=secure,
cookie_httponly=httponly,
)
@pytest.fixture
def token():
def token(secret):
def _token(user_id=None, lifetime=LIFETIME):
data = {"aud": "fastapi-users:auth"}
if user_id is not None:
data["user_id"] = str(user_id)
return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
return generate_jwt(data, secret, lifetime)
return _token
@pytest.mark.authentication
def test_default_name():
def test_default_name(cookie_authentication: CookieAuthentication):
assert cookie_authentication.name == "cookie"
@pytest.mark.authentication
class TestAuthenticate:
@pytest.mark.asyncio
async def test_missing_token(self, mock_user_db):
async def test_missing_token(
self, mock_user_db, cookie_authentication: CookieAuthentication
):
authenticated_user = await cookie_authentication(None, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_invalid_token(self, mock_user_db):
async def test_invalid_token(
self, mock_user_db, cookie_authentication: CookieAuthentication
):
authenticated_user = await cookie_authentication("foo", mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token_missing_user_payload(self, mock_user_db, token):
async def test_valid_token_missing_user_payload(
self, mock_user_db, token, cookie_authentication: CookieAuthentication
):
authenticated_user = await cookie_authentication(token(), mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token_invalid_uuid(self, mock_user_db, token):
async def test_valid_token_invalid_uuid(
self, mock_user_db, token, cookie_authentication: CookieAuthentication
):
authenticated_user = await cookie_authentication(token("foo"), mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token(self, mock_user_db, token, user):
async def test_valid_token(
self, mock_user_db, token, user, cookie_authentication: CookieAuthentication
):
authenticated_user = await cookie_authentication(token(user.id), mock_user_db)
assert authenticated_user is not None
assert authenticated_user.id == user.id
@pytest.mark.authentication
@pytest.mark.asyncio
@pytest.mark.parametrize(
"cookie_authentication,path,domain,secure,httponly",
[
(cookie_authentication, "/", None, True, True),
(cookie_authentication_path, "/arthur", None, True, True),
(cookie_authentication_domain, "/", "camelot.bt", True, True),
(cookie_authentication_secure, "/", None, False, True),
(cookie_authentication_httponly, "/", None, True, False),
],
)
async def test_get_login_response(
user, cookie_authentication, path, domain, secure, httponly
):
async def test_get_login_response(user, cookie_authentication: CookieAuthentication):
secret = cookie_authentication.secret
path = cookie_authentication.cookie_path
domain = cookie_authentication.cookie_domain
secure = cookie_authentication.cookie_secure
httponly = cookie_authentication.cookie_httponly
response = Response()
login_response = await cookie_authentication.get_login_response(user, response)
@ -116,20 +127,19 @@ async def test_get_login_response(
assert "HttpOnly" not in cookie
cookie_name_value = re.match(r"^(\w+)=([^;]+);", cookie)
assert cookie_name_value is not None
cookie_name = cookie_name_value[1]
assert cookie_name == COOKIE_NAME
cookie_value = cookie_name_value[2]
decoded = jwt.decode(
cookie_value, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
)
decoded = decode_jwt(cookie_value, secret, audience=["fastapi-users:auth"])
assert decoded["user_id"] == str(user.id)
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_logout_response(user):
async def test_get_logout_response(user, cookie_authentication: CookieAuthentication):
response = Response()
logout_response = await cookie_authentication.get_logout_response(user, response)

View File

@ -1,27 +1,25 @@
import jwt
import pytest
from fastapi import Response
from fastapi_users.authentication.jwt import JWTAuthentication
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
from fastapi_users.jwt import decode_jwt, generate_jwt
SECRET = "SECRET"
LIFETIME = 3600
TOKEN_URL = "/login"
@pytest.fixture
def jwt_authentication():
return JWTAuthentication(SECRET, LIFETIME, TOKEN_URL)
def jwt_authentication(secret):
return JWTAuthentication(secret, LIFETIME, TOKEN_URL)
@pytest.fixture
def token():
def token(secret):
def _token(user_id=None, lifetime=LIFETIME):
data = {"aud": "fastapi-users:auth"}
if user_id is not None:
data["user_id"] = str(user_id)
return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
return generate_jwt(data, secret, lifetime)
return _token
@ -72,8 +70,8 @@ async def test_get_login_response(jwt_authentication, user):
assert login_response["token_type"] == "bearer"
token = login_response["access_token"]
decoded = jwt.decode(
token, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
decoded = decode_jwt(
token, jwt_authentication.secret, audience=["fastapi-users:auth"]
)
assert decoded["user_id"] == str(user.id)

View File

@ -11,7 +11,12 @@ from tests.conftest import User, UserCreate, UserDB, UserUpdate
@pytest.fixture
@pytest.mark.asyncio
async def test_app_client(
mock_user_db, mock_authentication, oauth_client, get_test_client, validate_password
secret,
mock_user_db,
mock_authentication,
oauth_client,
get_test_client,
validate_password,
) -> AsyncGenerator[httpx.AsyncClient, None]:
fastapi_users = FastAPIUsers(
mock_user_db,
@ -25,11 +30,11 @@ async def test_app_client(
app = FastAPI()
app.include_router(fastapi_users.get_register_router())
app.include_router(fastapi_users.get_reset_password_router("SECRET"))
app.include_router(fastapi_users.get_reset_password_router(secret))
app.include_router(fastapi_users.get_auth_router(mock_authentication))
app.include_router(fastapi_users.get_oauth_router(oauth_client, "SECRET"))
app.include_router(fastapi_users.get_oauth_router(oauth_client, secret))
app.include_router(fastapi_users.get_users_router(), prefix="/users")
app.include_router(fastapi_users.get_verify_router("SECRET"))
app.include_router(fastapi_users.get_verify_router(secret))
@app.delete("/users/me")
def custom_users_route():

View File

@ -11,8 +11,6 @@ from fastapi_users.router.common import ErrorCode
from fastapi_users.router.oauth import generate_state_token, get_oauth_router
from tests.conftest import MockAuthentication, UserDB
SECRET = "SECRET"
def after_register_sync():
return MagicMock(return_value=None)
@ -29,6 +27,7 @@ def after_register(request):
@pytest.fixture
def get_test_app_client(
secret,
mock_user_db_oauth,
mock_authentication,
oauth_client,
@ -48,7 +47,7 @@ def get_test_app_client(
mock_user_db_oauth,
UserDB,
authenticator,
SECRET,
secret,
redirect_url,
after_register,
)

View File

@ -10,9 +10,6 @@ from fastapi_users.router import ErrorCode, get_register_router
from fastapi_users.user import get_create_user
from tests.conftest import User, UserCreate, UserDB
SECRET = "SECRET"
LIFETIME = 3600
def after_register_sync():
return MagicMock(return_value=None)

View File

@ -3,25 +3,23 @@ from unittest.mock import MagicMock
import asynctest
import httpx
import jwt
import pytest
from fastapi import FastAPI, Request, status
from fastapi_users.jwt import decode_jwt, generate_jwt
from fastapi_users.router import ErrorCode, get_reset_password_router
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
from tests.conftest import UserDB
SECRET = "SECRET"
LIFETIME = 3600
@pytest.fixture
def forgot_password_token():
def forgot_password_token(secret):
def _forgot_password_token(user_id=None, lifetime=LIFETIME):
data = {"aud": "fastapi-users:reset"}
if user_id is not None:
data["user_id"] = str(user_id)
return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
return generate_jwt(data, secret, lifetime)
return _forgot_password_token
@ -55,6 +53,7 @@ def after_reset_password(request):
@pytest.fixture
@pytest.mark.asyncio
async def test_app_client(
secret,
mock_user_db,
after_forgot_password,
after_reset_password,
@ -63,7 +62,7 @@ async def test_app_client(
) -> AsyncGenerator[httpx.AsyncClient, None]:
reset_router = get_reset_password_router(
mock_user_db,
SECRET,
secret,
LIFETIME,
after_forgot_password,
after_reset_password,
@ -107,7 +106,12 @@ class TestForgotPassword:
"email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"]
)
async def test_existing_user(
self, email, test_app_client: httpx.AsyncClient, after_forgot_password, user
self,
secret,
email,
test_app_client: httpx.AsyncClient,
after_forgot_password,
user,
):
json = {"email": email}
response = await test_app_client.post("/forgot-password", json=json)
@ -117,11 +121,11 @@ class TestForgotPassword:
actual_user = after_forgot_password.call_args[0][0]
assert actual_user.id == user.id
actual_token = after_forgot_password.call_args[0][1]
decoded_token = jwt.decode(
decoded_token = decode_jwt(
actual_token,
SECRET,
audience="fastapi-users:reset",
algorithms=[JWT_ALGORITHM],
secret,
audience=["fastapi-users:reset"],
)
assert decoded_token["user_id"] == str(user.id)
request = after_forgot_password.call_args[0][2]

View File

@ -10,9 +10,6 @@ from fastapi_users.authentication import Authenticator
from fastapi_users.router import ErrorCode, get_users_router
from tests.conftest import MockAuthentication, User, UserDB, UserUpdate
SECRET = "SECRET"
LIFETIME = 3600
def after_update_sync():
return MagicMock(return_value=None)

View File

@ -6,26 +6,24 @@ import httpx
import pytest
from fastapi import FastAPI, status
from fastapi_users.jwt import generate_jwt
from fastapi_users.router import ErrorCode, get_verify_router
from fastapi_users.user import get_get_user, get_verify_user
from fastapi_users.utils import generate_jwt
from tests.conftest import User, UserDB
SECRET = "SECRET"
LIFETIME = 3600
VERIFY_USER_TOKEN_AUDIENCE = "fastapi-users:verify"
JWT_ALGORITHM = "HS256"
@pytest.fixture
def verify_token():
def verify_token(secret):
def _verify_token(user_id=None, email=None, lifetime=LIFETIME):
data = {"aud": VERIFY_USER_TOKEN_AUDIENCE}
if user_id is not None:
data["user_id"] = str(user_id)
if email is not None:
data["email"] = email
return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
return generate_jwt(data, secret, lifetime)
return _verify_token
@ -61,6 +59,7 @@ def after_verification_request(request):
@pytest.fixture
@pytest.mark.asyncio
async def test_app_client(
secret,
mock_user_db,
after_verification_request,
after_verification,
@ -72,7 +71,7 @@ async def test_app_client(
verify_user,
get_user,
User,
SECRET,
secret,
LIFETIME,
after_verification_request,
after_verification,