diff --git a/tests/conftest.py b/tests/conftest.py index 20fecee0..20b4e33a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import asyncio import dataclasses +import secrets import uuid from collections.abc import AsyncGenerator, Callable from typing import Any, Generic @@ -28,6 +29,7 @@ angharad_password_hash = password_helper.hash("angharad") viviane_password_hash = password_helper.hash("viviane") lancelot_password_hash = password_helper.hash("lancelot") excalibur_password_hash = password_helper.hash("excalibur") +JWT_SECRET = secrets.token_urlsafe(32) IDType = UUID4 @@ -79,8 +81,8 @@ class UserOAuth(User, schemas.BaseOAuthAccountMixin): class BaseTestUserManager( Generic[models.UP], UUIDIDMixin, BaseUserManager[models.UP, IDType] ): - reset_password_token_secret = "SECRET" - verification_token_secret = "SECRET" + reset_password_token_secret = JWT_SECRET + verification_token_secret = JWT_SECRET async def validate_password( self, password: str, user: schemas.UC | models.UP @@ -141,7 +143,7 @@ def async_method_mocker(mocker: MockerFixture) -> AsyncMethodMocker: return _async_method_mocker -@pytest.fixture(params=["SECRET", SecretStr("SECRET")]) +@pytest.fixture(params=[JWT_SECRET, SecretStr(JWT_SECRET)]) def secret(request) -> SecretType: return request.param diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index 39a5b7f1..5156c9cb 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -15,6 +15,7 @@ from fastapi_users.router.oauth import ( get_oauth_router, ) from tests.conftest import ( + JWT_SECRET, AsyncMethodMocker, User, UserManagerMock, @@ -173,7 +174,7 @@ class TestCallback: user_oauth: UserOAuthModel, access_token: str, ): - state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET") + state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, JWT_SECRET) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) get_id_email_mock = async_method_mocker( oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) @@ -198,7 +199,7 @@ class TestCallback: user_manager_oauth: UserManagerMock, access_token: str, ): - state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET") + state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, JWT_SECRET) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) async_method_mocker( oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) @@ -229,7 +230,7 @@ class TestCallback: user_manager_oauth: UserManagerMock, access_token: str, ): - state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET") + state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, JWT_SECRET) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) async_method_mocker( oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) @@ -260,7 +261,7 @@ class TestCallback: user_manager_oauth: UserManagerMock, access_token: str, ): - state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET") + state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, JWT_SECRET) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) async_method_mocker( oauth_client, @@ -289,7 +290,7 @@ class TestCallback: user_manager_oauth: UserManagerMock, access_token: str, ): - state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET") + state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, JWT_SECRET) get_access_token_mock = async_method_mocker( oauth_client, "get_access_token", return_value=access_token ) @@ -325,7 +326,7 @@ class TestCallback: user_manager_oauth: UserManagerMock, access_token: str, ): - state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET") + state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, JWT_SECRET) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) async_method_mocker( oauth_client, "get_id_email", return_value=("user_oauth1", None) @@ -354,7 +355,7 @@ class TestCallback: access_token: str, ): state_jwt = generate_state_token( - {"csrftoken": "CSRFTOKEN"}, "SECRET", lifetime_seconds=-1 + {"csrftoken": "CSRFTOKEN"}, JWT_SECRET, lifetime_seconds=-1 ) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) async_method_mocker( @@ -382,7 +383,9 @@ class TestCallback: user_manager_oauth: UserManagerMock, access_token: str, ): - state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "RANDOM") + state_jwt = generate_state_token( + {"csrftoken": "CSRFTOKEN"}, JWT_SECRET + "invalid" + ) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) async_method_mocker( oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) @@ -393,10 +396,9 @@ class TestCallback: params={"code": "CODE", "state": state_jwt}, ) + assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(dict[str, Any], response.json()) assert data["detail"] == ErrorCode.ACCESS_TOKEN_DECODE_ERROR - - assert response.status_code == status.HTTP_400_BAD_REQUEST assert user_manager_oauth.on_after_login.called is False @@ -540,7 +542,7 @@ class TestAssociateCallback: user_oauth: UserOAuthModel, access_token: str, ): - state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET") + state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, JWT_SECRET) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) get_id_email_mock = async_method_mocker( oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) @@ -567,7 +569,7 @@ class TestAssociateCallback: access_token: str, ): state_jwt = generate_state_token( - {"sub": str(user.id), "csrftoken": "CSRFTOKEN"}, "SECRET" + {"sub": str(user.id), "csrftoken": "CSRFTOKEN"}, JWT_SECRET ) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) get_id_email_mock = async_method_mocker( @@ -595,7 +597,7 @@ class TestAssociateCallback: access_token: str, ): state_jwt = generate_state_token( - {"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, "SECRET" + {"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, JWT_SECRET ) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) async_method_mocker( @@ -627,7 +629,7 @@ class TestAssociateCallback: access_token: str, ): state_jwt = generate_state_token( - {"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, "SECRET" + {"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, JWT_SECRET ) get_access_token_mock = async_method_mocker( oauth_client, "get_access_token", return_value=access_token @@ -665,7 +667,7 @@ class TestAssociateCallback: access_token: str, ): state_jwt = generate_state_token( - {"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, "SECRET" + {"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, JWT_SECRET ) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) async_method_mocker( @@ -697,7 +699,7 @@ class TestAssociateCallback: ): state_jwt = generate_state_token( {"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, - "SECRET", + JWT_SECRET, lifetime_seconds=-1, ) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) @@ -729,7 +731,8 @@ class TestAssociateCallback: access_token: str, ): state_jwt = generate_state_token( - {"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, "RANDOM" + {"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, + JWT_SECRET + "invalid", ) async_method_mocker(oauth_client, "get_access_token", return_value=access_token) async_method_mocker( @@ -745,11 +748,10 @@ class TestAssociateCallback: headers={"Authorization": f"Bearer {user_oauth.id}"}, ) + assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(dict[str, Any], response.json()) assert data["detail"] == ErrorCode.ACCESS_TOKEN_DECODE_ERROR - assert response.status_code == status.HTTP_400_BAD_REQUEST - @pytest.mark.asyncio @pytest.mark.oauth