diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 40fcf13e..a26d1237 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,5 +24,5 @@ jobs: env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} run: | - pipenv run pytest --cov=./ + pipenv run pytest --cov=fastapi_users/ pipenv run codecov diff --git a/Makefile b/Makefile index 1677490c..f0da842d 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ format: $(PIPENV_RUN) black . test: - $(PIPENV_RUN) pytest + $(PIPENV_RUN) pytest --cov=fastapi_users/ docs-serve: $(PIPENV_RUN) mkdocs serve diff --git a/Pipfile b/Pipfile index b021d012..a49f2bab 100644 --- a/Pipfile +++ b/Pipfile @@ -18,6 +18,7 @@ mypy = "*" codecov = "*" pytest-cov = "*" pytest-mock = "*" +asynctest = "*" [packages] fastapi = "*" diff --git a/Pipfile.lock b/Pipfile.lock index e547ee47..6f087605 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "a45946502cd269a192b56dc49c22a1e2d2c95002de21c6f56da8599d4082014c" + "sha256": "4d1dac037a2a0eae31b75c787118ecef1fc552f381d7f6a112382de8003b2008" }, "pipfile-spec": 6, "requires": { @@ -184,6 +184,14 @@ ], "version": "==1.4.3" }, + "asynctest": { + "hashes": [ + "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676", + "sha256:c27862842d15d83e6a34eb0b2866c323880eb3a75e4485b079ea11748fd77fac" + ], + "index": "pypi", + "version": "==0.13.0" + }, "atomicwrites": { "hashes": [ "sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4", diff --git a/fastapi_users/db/base.py b/fastapi_users/db/base.py index 42204c8b..b45048e2 100644 --- a/fastapi_users/db/base.py +++ b/fastapi_users/db/base.py @@ -2,8 +2,8 @@ from typing import List, Optional from fastapi.security import OAuth2PasswordRequestForm +from fastapi_users import password from fastapi_users.models import BaseUserDB -from fastapi_users.password import get_password_hash, verify_and_update_password class BaseUserDatabase: @@ -41,12 +41,12 @@ class BaseUserDatabase: # Always run the hasher to mitigate timing attack # Inspired from Django: https://code.djangoproject.com/ticket/20760 - get_password_hash(credentials.password) + password.get_password_hash(credentials.password) if user is None: return None else: - verified, updated_password_hash = verify_and_update_password( + verified, updated_password_hash = password.verify_and_update_password( credentials.password, user.hashed_password ) if not verified: diff --git a/fastapi_users/router.py b/fastapi_users/router.py index 515472c0..b7b47f72 100644 --- a/fastapi_users/router.py +++ b/fastapi_users/router.py @@ -1,4 +1,4 @@ -import inspect +import asyncio from typing import Any, Callable, Type import jwt @@ -28,7 +28,7 @@ def get_user_router( models = Models(user_model) reset_password_token_audience = "fastapi-users:reset" - is_on_after_forgot_password_async = inspect.iscoroutinefunction( + is_on_after_forgot_password_async = asyncio.iscoroutinefunction( on_after_forgot_password ) @@ -87,9 +87,8 @@ def get_user_router( if user is None or not user.is_active: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - updated_user = BaseUserDB(**user.dict()) - updated_user.hashed_password = get_password_hash(password) - await user_db.update(updated_user) + user.hashed_password = get_password_hash(password) + await user_db.update(user) except jwt.PyJWTError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) diff --git a/tests/conftest.py b/tests/conftest.py index 9ff2a917..9e641546 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,116 +1,147 @@ from typing import Optional import pytest -from fastapi import Depends +from fastapi import Depends, FastAPI from fastapi.security import OAuth2PasswordBearer from starlette.responses import Response +from starlette.testclient import TestClient from fastapi_users.authentication import BaseAuthentication from fastapi_users.db import BaseUserDatabase from fastapi_users.models import BaseUserDB from fastapi_users.password import get_password_hash -active_user_data = BaseUserDB( - id="aaa", - email="king.arthur@camelot.bt", - hashed_password=get_password_hash("guinevere"), -) - -inactive_user_data = BaseUserDB( - id="bbb", - email="percival@camelot.bt", - hashed_password=get_password_hash("angharad"), - is_active=False, -) - -superuser_data = BaseUserDB( - id="ccc", - email="merlin@camelot.bt", - hashed_password=get_password_hash("viviane"), - is_superuser=True, -) +guinevere_password_hash = get_password_hash("guinevere") +angharad_password_hash = get_password_hash("angharad") +viviane_password_hash = get_password_hash("viviane") @pytest.fixture def user() -> BaseUserDB: - return active_user_data + return BaseUserDB( + id="aaa", + email="king.arthur@camelot.bt", + hashed_password=guinevere_password_hash, + ) @pytest.fixture def inactive_user() -> BaseUserDB: - return inactive_user_data + return BaseUserDB( + id="bbb", + email="percival@camelot.bt", + hashed_password=angharad_password_hash, + is_active=False, + ) @pytest.fixture def superuser() -> BaseUserDB: - return superuser_data - - -class MockUserDatabase(BaseUserDatabase): - async def get(self, id: str) -> Optional[BaseUserDB]: - if id == active_user_data.id: - return active_user_data - elif id == inactive_user_data.id: - return inactive_user_data - elif id == superuser_data.id: - return superuser_data - return None - - async def get_by_email(self, email: str) -> Optional[BaseUserDB]: - if email == active_user_data.email: - return active_user_data - elif email == inactive_user_data.email: - return inactive_user_data - elif email == superuser_data.email: - return superuser_data - return None - - async def create(self, user: BaseUserDB) -> BaseUserDB: - return user - - async def update(self, user: BaseUserDB) -> BaseUserDB: - return user + return BaseUserDB( + id="ccc", + email="merlin@camelot.bt", + hashed_password=viviane_password_hash, + is_superuser=True, + ) @pytest.fixture -def mock_user_db() -> MockUserDatabase: +def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase: + class MockUserDatabase(BaseUserDatabase): + async def get(self, id: str) -> Optional[BaseUserDB]: + if id == user.id: + return user + elif id == inactive_user.id: + return inactive_user + elif id == superuser.id: + return superuser + return None + + async def get_by_email(self, email: str) -> Optional[BaseUserDB]: + if email == user.email: + return user + elif email == inactive_user.email: + return inactive_user + elif email == superuser.email: + return superuser + return None + + async def create(self, user: BaseUserDB) -> BaseUserDB: + return user + + async def update(self, user: BaseUserDB) -> BaseUserDB: + return user + return MockUserDatabase() -class MockAuthentication(BaseAuthentication): - oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") +@pytest.fixture +def mock_authentication(): + class MockAuthentication(BaseAuthentication): + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") - async def get_login_response(self, user: BaseUserDB, response: Response): - return {"token": user.id} + async def get_login_response(self, user: BaseUserDB, response: Response): + return {"token": user.id} - def get_current_user(self, user_db: BaseUserDatabase): - async def _get_current_user(token: str = Depends(self.oauth2_scheme)): - user = await self._get_authentication_method(user_db)(token) - return self._get_current_user_base(user) + def get_current_user(self, user_db: BaseUserDatabase): + async def _get_current_user(token: str = Depends(self.oauth2_scheme)): + user = await self._get_authentication_method(user_db)(token) + return self._get_current_user_base(user) - return _get_current_user + return _get_current_user - def get_current_active_user(self, user_db: BaseUserDatabase): - async def _get_current_active_user(token: str = Depends(self.oauth2_scheme)): - user = await self._get_authentication_method(user_db)(token) - return self._get_current_active_user_base(user) + def get_current_active_user(self, user_db: BaseUserDatabase): + async def _get_current_active_user( + token: str = Depends(self.oauth2_scheme) + ): + user = await self._get_authentication_method(user_db)(token) + return self._get_current_active_user_base(user) - return _get_current_active_user + return _get_current_active_user - def get_current_superuser(self, user_db: BaseUserDatabase): - async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)): - user = await self._get_authentication_method(user_db)(token) - return self._get_current_superuser_base(user) + def get_current_superuser(self, user_db: BaseUserDatabase): + async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)): + user = await self._get_authentication_method(user_db)(token) + return self._get_current_superuser_base(user) - return _get_current_superuser + return _get_current_superuser - def _get_authentication_method(self, user_db: BaseUserDatabase): - async def authentication_method(token: str = Depends(self.oauth2_scheme)): - return await user_db.get(token) + def _get_authentication_method(self, user_db: BaseUserDatabase): + async def authentication_method(token: str = Depends(self.oauth2_scheme)): + return await user_db.get(token) - return authentication_method + return authentication_method + + return MockAuthentication() @pytest.fixture -def mock_authentication() -> MockAuthentication: - return MockAuthentication() +def get_test_auth_client(mock_user_db): + def _get_test_auth_client(authentication): + app = FastAPI() + + @app.get("/test-current-user") + def test_current_user( + user: BaseUserDB = Depends(authentication.get_current_user(mock_user_db)) + ): + return user + + @app.get("/test-current-active-user") + def test_current_active_user( + user: BaseUserDB = Depends( + authentication.get_current_active_user(mock_user_db) + ) + ): + return user + + @app.get("/test-current-superuser") + def test_current_superuser( + user: BaseUserDB = Depends( + authentication.get_current_superuser(mock_user_db) + ) + ): + return user + + return TestClient(app) + + return _get_test_auth_client diff --git a/tests/test_authentication_base.py b/tests/test_authentication_base.py index 8f4cf792..692add1c 100644 --- a/tests/test_authentication_base.py +++ b/tests/test_authentication_base.py @@ -1,130 +1,25 @@ import pytest -from fastapi import Depends, FastAPI -from starlette import status -from starlette.testclient import TestClient +from starlette.responses import Response -from fastapi_users.models import BaseUserDB +from fastapi_users.authentication import BaseAuthentication -@pytest.fixture -def test_auth_client(mock_authentication, mock_user_db): - app = FastAPI() +@pytest.mark.asyncio +async def test_not_implemented_methods(user, mock_user_db): + response = Response() + base_authentication = BaseAuthentication() - @app.get("/test-current-user") - def test_current_user( - user: BaseUserDB = Depends(mock_authentication.get_current_user(mock_user_db)) - ): - return user + with pytest.raises(NotImplementedError): + await base_authentication.get_login_response(user, response) - @app.get("/test-current-active-user") - def test_current_active_user( - user: BaseUserDB = Depends( - mock_authentication.get_current_active_user(mock_user_db) - ) - ): - return user + with pytest.raises(NotImplementedError): + await base_authentication.get_current_user(mock_user_db) - @app.get("/test-current-superuser") - def test_current_superuser( - user: BaseUserDB = Depends( - mock_authentication.get_current_superuser(mock_user_db) - ) - ): - return user + with pytest.raises(NotImplementedError): + await base_authentication.get_current_active_user(mock_user_db) - return TestClient(app) + with pytest.raises(NotImplementedError): + await base_authentication.get_current_superuser(mock_user_db) - -class TestGetCurrentUser: - def test_missing_token(self, test_auth_client): - response = test_auth_client.get("/test-current-user") - print(response.json()) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_invalid_token(self, test_auth_client): - response = test_auth_client.get( - "/test-current-user", headers={"Authorization": "Bearer foo"} - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_valid_token_inactive_user(self, test_auth_client, inactive_user): - response = test_auth_client.get( - "/test-current-user", - headers={"Authorization": f"Bearer {inactive_user.id}"}, - ) - assert response.status_code == status.HTTP_200_OK - - response_json = response.json() - assert response_json["id"] == inactive_user.id - - def test_valid_token(self, test_auth_client, user): - response = test_auth_client.get( - "/test-current-user", headers={"Authorization": f"Bearer {user.id}"} - ) - assert response.status_code == status.HTTP_200_OK - - response_json = response.json() - assert response_json["id"] == user.id - - -class TestGetCurrentActiveUser: - def test_missing_token(self, test_auth_client): - response = test_auth_client.get("/test-current-active-user") - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_invalid_token(self, test_auth_client): - response = test_auth_client.get( - "/test-current-active-user", headers={"Authorization": "Bearer foo"} - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_valid_token_inactive_user(self, test_auth_client, inactive_user): - response = test_auth_client.get( - "/test-current-active-user", - headers={"Authorization": f"Bearer {inactive_user.id}"}, - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_valid_token(self, test_auth_client, user): - response = test_auth_client.get( - "/test-current-active-user", headers={"Authorization": f"Bearer {user.id}"} - ) - assert response.status_code == status.HTTP_200_OK - - response_json = response.json() - assert response_json["id"] == user.id - - -class TestGetCurrentSuperuser: - def test_missing_token(self, test_auth_client): - response = test_auth_client.get("/test-current-superuser") - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_invalid_token(self, test_auth_client): - response = test_auth_client.get( - "/test-current-superuser", headers={"Authorization": "Bearer foo"} - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_valid_token_inactive_user(self, test_auth_client, inactive_user): - response = test_auth_client.get( - "/test-current-superuser", - headers={"Authorization": f"Bearer {inactive_user.id}"}, - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - def test_valid_token_regular_user(self, test_auth_client, user): - response = test_auth_client.get( - "/test-current-superuser", headers={"Authorization": f"Bearer {user.id}"} - ) - assert response.status_code == status.HTTP_403_FORBIDDEN - - def test_valid_token_superuser(self, test_auth_client, superuser): - response = test_auth_client.get( - "/test-current-superuser", - headers={"Authorization": f"Bearer {superuser.id}"}, - ) - assert response.status_code == status.HTTP_200_OK - - response_json = response.json() - assert response_json["id"] == superuser.id + with pytest.raises(NotImplementedError): + await base_authentication._get_authentication_method(mock_user_db) diff --git a/tests/test_authentication_jwt.py b/tests/test_authentication_jwt.py index 6a5511f2..bee80de9 100644 --- a/tests/test_authentication_jwt.py +++ b/tests/test_authentication_jwt.py @@ -1,12 +1,9 @@ import jwt import pytest -from fastapi import Depends, FastAPI from starlette import status from starlette.responses import Response -from starlette.testclient import TestClient from fastapi_users.authentication.jwt import JWTAuthentication -from fastapi_users.models import BaseUserDB from fastapi_users.utils import JWT_ALGORITHM, generate_jwt SECRET = "SECRET" @@ -20,24 +17,18 @@ def jwt_authentication(): @pytest.fixture def token(): - def _token(user, lifetime=LIFETIME): - data = {"user_id": user.id, "aud": "fastapi-users:auth"} + def _token(user=None, lifetime=LIFETIME): + data = {"aud": "fastapi-users:auth"} + if user is not None: + data["user_id"] = user.id return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return _token @pytest.fixture -def test_auth_client(jwt_authentication, mock_user_db): - app = FastAPI() - - @app.get("/test-auth") - def test_auth( - user: BaseUserDB = Depends(jwt_authentication.get_current_user(mock_user_db)) - ): - return user - - return TestClient(app) +def test_auth_client(get_test_auth_client, jwt_authentication): + return get_test_auth_client(jwt_authentication) @pytest.mark.asyncio @@ -55,18 +46,26 @@ async def test_get_login_response(jwt_authentication, user): class TestGetCurrentUser: def test_missing_token(self, test_auth_client): - response = test_auth_client.get("/test-auth") + response = test_auth_client.get("/test-current-user") + print(response.json()) assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_invalid_token(self, test_auth_client): response = test_auth_client.get( - "/test-auth", headers={"Authorization": "Bearer foo"} + "/test-current-user", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token_missing_user_payload(self, test_auth_client, token): + response = test_auth_client.get( + "/test-current-user", headers={"Authorization": f"Bearer {token()}"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user): response = test_auth_client.get( - "/test-auth", headers={"Authorization": f"Bearer {token(inactive_user)}"} + "/test-current-user", + headers={"Authorization": f"Bearer {token(inactive_user)}"}, ) assert response.status_code == status.HTTP_200_OK @@ -75,9 +74,74 @@ class TestGetCurrentUser: def test_valid_token(self, test_auth_client, token, user): response = test_auth_client.get( - "/test-auth", headers={"Authorization": f"Bearer {token(user)}"} + "/test-current-user", headers={"Authorization": f"Bearer {token(user)}"} ) assert response.status_code == status.HTTP_200_OK response_json = response.json() assert response_json["id"] == user.id + + +class TestGetCurrentActiveUser: + def test_missing_token(self, test_auth_client): + response = test_auth_client.get("/test-current-active-user") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_invalid_token(self, test_auth_client): + response = test_auth_client.get( + "/test-current-active-user", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user): + response = test_auth_client.get( + "/test-current-active-user", + headers={"Authorization": f"Bearer {token(inactive_user)}"}, + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token(self, test_auth_client, token, user): + response = test_auth_client.get( + "/test-current-active-user", + headers={"Authorization": f"Bearer {token(user)}"}, + ) + assert response.status_code == status.HTTP_200_OK + + response_json = response.json() + assert response_json["id"] == user.id + + +class TestGetCurrentSuperuser: + def test_missing_token(self, test_auth_client): + response = test_auth_client.get("/test-current-superuser") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_invalid_token(self, test_auth_client): + response = test_auth_client.get( + "/test-current-superuser", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user): + response = test_auth_client.get( + "/test-current-superuser", + headers={"Authorization": f"Bearer {token(inactive_user)}"}, + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token_regular_user(self, test_auth_client, token, user): + response = test_auth_client.get( + "/test-current-superuser", + headers={"Authorization": f"Bearer {token(user)}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_valid_token_superuser(self, test_auth_client, token, superuser): + response = test_auth_client.get( + "/test-current-superuser", + headers={"Authorization": f"Bearer {token(superuser)}"}, + ) + assert response.status_code == status.HTTP_200_OK + + response_json = response.json() + assert response_json["id"] == superuser.id diff --git a/tests/test_db_base.py b/tests/test_db_base.py index 70dfd823..fee574a8 100644 --- a/tests/test_db_base.py +++ b/tests/test_db_base.py @@ -1,6 +1,8 @@ import pytest from fastapi.security import OAuth2PasswordRequestForm +from fastapi_users.db import BaseUserDatabase + @pytest.fixture def create_oauth2_password_request_form(): @@ -10,6 +12,26 @@ def create_oauth2_password_request_form(): return _create_oauth2_password_request_form +@pytest.mark.asyncio +async def test_not_implemented_methods(user): + base_user_db = BaseUserDatabase() + + with pytest.raises(NotImplementedError): + await base_user_db.list() + + with pytest.raises(NotImplementedError): + await base_user_db.get("aaa") + + with pytest.raises(NotImplementedError): + await base_user_db.get_by_email("lancelot@camelot.bt") + + with pytest.raises(NotImplementedError): + await base_user_db.create(user) + + with pytest.raises(NotImplementedError): + await base_user_db.update(user) + + class TestAuthenticate: @pytest.mark.asyncio async def test_unknown_user( @@ -37,3 +59,21 @@ class TestAuthenticate: user = await mock_user_db.authenticate(form) assert user is not None assert user.email == "king.arthur@camelot.bt" + + @pytest.mark.asyncio + async def test_upgrade_password_hash( + self, mocker, create_oauth2_password_request_form, mock_user_db + ): + verify_and_update_password_patch = mocker.patch( + "fastapi_users.password.verify_and_update_password" + ) + verify_and_update_password_patch.return_value = (True, "updated_hash") + mocker.spy(mock_user_db, "update") + + form = create_oauth2_password_request_form( + "king.arthur@camelot.bt", "guinevere" + ) + user = await mock_user_db.authenticate(form) + assert user is not None + assert user.email == "king.arthur@camelot.bt" + assert mock_user_db.update.called is True diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index 1e06d827..3716b85f 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -9,21 +9,22 @@ from fastapi_users.models import BaseUser, BaseUserDB SECRET = "SECRET" -@pytest.fixture -def fastapi_users(mock_user_db, mock_authentication) -> FastAPIUsers: +def sync_on_after_forgot_password(): + return None + + +async def async_on_after_forgot_password(): + return None + + +@pytest.fixture(params=[sync_on_after_forgot_password, async_on_after_forgot_password]) +def test_app_client(request, mock_user_db, mock_authentication) -> TestClient: class User(BaseUser): pass - def on_after_forgot_password(user, token): - pass - - return FastAPIUsers( - mock_user_db, mock_authentication, User, on_after_forgot_password, SECRET + fastapi_users = FastAPIUsers( + mock_user_db, mock_authentication, User, request.param, SECRET ) - - -@pytest.fixture -def test_app_client(fastapi_users: FastAPIUsers) -> TestClient: app = FastAPI() app.include_router(fastapi_users.router) diff --git a/tests/test_router.py b/tests/test_router.py index 4cf42605..624dfdf6 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -1,6 +1,6 @@ -import asyncio from unittest.mock import MagicMock +import asynctest import jwt import pytest from fastapi import FastAPI @@ -17,57 +17,48 @@ LIFETIME = 3600 @pytest.fixture def forgot_password_token(): - def _forgot_password_token(user_id, lifetime=LIFETIME): - data = {"user_id": user_id, "aud": "fastapi-users:reset"} + def _forgot_password_token(user_id=None, lifetime=LIFETIME): + data = {"aud": "fastapi-users:reset"} + if user_id is not None: + data["user_id"] = user_id return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return _forgot_password_token -@pytest.fixture() def on_after_forgot_password_sync(): - on_after_forgot_password_mock = MagicMock(return_value=None) - return on_after_forgot_password_mock + return MagicMock(return_value=None) + + +def on_after_forgot_password_async(): + return asynctest.CoroutineMock(return_value=None) + + +@pytest.fixture(params=[on_after_forgot_password_sync, on_after_forgot_password_async]) +def on_after_forgot_password(request): + return request.param() @pytest.fixture() -def on_after_forgot_password_async(): - on_after_forgot_password_mock = MagicMock(return_value=asyncio.Future()) - on_after_forgot_password_mock.return_value.set_result(None) - return on_after_forgot_password_mock +def test_app_client( + mock_user_db, mock_authentication, on_after_forgot_password +) -> TestClient: + class User(BaseUser): + pass + userRouter = get_user_router( + mock_user_db, + User, + mock_authentication, + on_after_forgot_password, + SECRET, + LIFETIME, + ) -@pytest.fixture -def get_test_app_client(mock_user_db, mock_authentication): - def _get_test_app_client(on_after_forgot_password) -> TestClient: - class User(BaseUser): - pass + app = FastAPI() + app.include_router(userRouter) - userRouter = get_user_router( - mock_user_db, - User, - mock_authentication, - on_after_forgot_password, - SECRET, - LIFETIME, - ) - - app = FastAPI() - app.include_router(userRouter) - - return TestClient(app) - - return _get_test_app_client - - -@pytest.fixture -def test_app_client(get_test_app_client, on_after_forgot_password_sync): - return get_test_app_client(on_after_forgot_password_sync) - - -@pytest.fixture -def test_app_client_async(get_test_app_client, on_after_forgot_password_async): - return get_test_app_client(on_after_forgot_password_async) + return TestClient(app) class TestRegister: @@ -134,59 +125,36 @@ class TestLogin: class TestForgotPassword: - def test_empty_body( - self, test_app_client: TestClient, on_after_forgot_password_sync - ): + def test_empty_body(self, test_app_client: TestClient, on_after_forgot_password): response = test_app_client.post("/forgot-password", json={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - assert on_after_forgot_password_sync.called is False + assert on_after_forgot_password.called is False def test_not_existing_user( - self, test_app_client: TestClient, on_after_forgot_password_sync + self, test_app_client: TestClient, on_after_forgot_password ): json = {"email": "lancelot@camelot.bt"} response = test_app_client.post("/forgot-password", json=json) assert response.status_code == status.HTTP_202_ACCEPTED - assert on_after_forgot_password_sync.called is False + assert on_after_forgot_password.called is False - def test_inactive_user( - self, test_app_client: TestClient, on_after_forgot_password_sync - ): + def test_inactive_user(self, test_app_client: TestClient, on_after_forgot_password): json = {"email": "percival@camelot.bt"} response = test_app_client.post("/forgot-password", json=json) assert response.status_code == status.HTTP_202_ACCEPTED - assert on_after_forgot_password_sync.called is False + assert on_after_forgot_password.called is False - def test_existing_user_sync_hook( - self, test_app_client: TestClient, on_after_forgot_password_sync, user + def test_existing_user( + self, test_app_client: TestClient, on_after_forgot_password, user ): json = {"email": "king.arthur@camelot.bt"} response = test_app_client.post("/forgot-password", json=json) assert response.status_code == status.HTTP_202_ACCEPTED - assert on_after_forgot_password_sync.called is True + assert on_after_forgot_password.called is True - actual_user = on_after_forgot_password_sync.call_args[0][0] + actual_user = on_after_forgot_password.call_args[0][0] assert actual_user.id == user.id - actual_token = on_after_forgot_password_sync.call_args[0][1] - decoded_token = jwt.decode( - actual_token, - SECRET, - audience="fastapi-users:reset", - algorithms=[JWT_ALGORITHM], - ) - assert decoded_token["user_id"] == user.id - - def test_existing_user_async_hook( - self, test_app_client_async: TestClient, on_after_forgot_password_async, user - ): - json = {"email": "king.arthur@camelot.bt"} - response = test_app_client_async.post("/forgot-password", json=json) - assert response.status_code == status.HTTP_202_ACCEPTED - assert on_after_forgot_password_async.called is True - - actual_user = on_after_forgot_password_async.call_args[0][0] - assert actual_user.id == user.id - actual_token = on_after_forgot_password_async.call_args[0][1] + actual_token = on_after_forgot_password.call_args[0][1] decoded_token = jwt.decode( actual_token, SECRET, @@ -217,6 +185,21 @@ class TestResetPassword: print(response.json(), response.status_code) assert response.status_code == status.HTTP_400_BAD_REQUEST + def test_valid_token_missing_user_id_payload( + self, + mocker, + mock_user_db, + test_app_client: TestClient, + forgot_password_token, + inactive_user: BaseUserDB, + ): + mocker.spy(mock_user_db, "update") + + json = {"token": forgot_password_token(), "password": "holygrail"} + response = test_app_client.post("/reset-password", json=json) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert mock_user_db.update.called is False + def test_inactive_user( self, mocker, @@ -244,6 +227,7 @@ class TestResetPassword: user: BaseUserDB, ): mocker.spy(mock_user_db, "update") + current_hashed_passord = user.hashed_password json = {"token": forgot_password_token(user.id), "password": "holygrail"} response = test_app_client.post("/reset-password", json=json) @@ -251,4 +235,4 @@ class TestResetPassword: assert mock_user_db.update.called is True updated_user = mock_user_db.update.call_args[0][0] - assert updated_user.hashed_password != user.hashed_password + assert updated_user.hashed_password != current_hashed_passord