diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index 322b7944..6e40f890 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -64,21 +64,12 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]): self.get_user_manager = get_user_manager self.current_user = self.authenticator.current_user - def get_register_router( - self, - after_register: Optional[Callable[[models.UD, Request], None]] = None, - ) -> APIRouter: - """ - Return a router with a register route. - - :param after_register: Optional function called - after a successful registration. - """ + def get_register_router(self) -> APIRouter: + """Return a router with a register route.""" return get_register_router( self.get_user_manager, self._user_model, self._user_create_model, - after_register, ) def get_verify_router( diff --git a/fastapi_users/manager.py b/fastapi_users/manager.py index 0cde1105..5ef0caea 100644 --- a/fastapi_users/manager.py +++ b/fastapi_users/manager.py @@ -1,8 +1,8 @@ from typing import Any, Callable, Dict, Generic, Optional, Type, Union -from pydantic.types import UUID4 - +from fastapi import Request from fastapi.security import OAuth2PasswordRequestForm +from pydantic.types import UUID4 from fastapi_users import models, password from fastapi_users.db import BaseUserDatabase @@ -67,7 +67,9 @@ class BaseUserManager(Generic[models.UC, models.UD]): return user - async def create(self, user: models.UC, safe: bool = False) -> models.UD: + async def create( + self, user: models.UC, safe: bool = False, request: Optional[Request] = None + ) -> models.UD: await self.validate_password(user.password, user) existing_user = await self.user_db.get_by_email(user.email) @@ -79,7 +81,12 @@ class BaseUserManager(Generic[models.UC, models.UD]): user.create_update_dict() if safe else user.create_update_dict_superuser() ) db_user = self.user_db_model(**user_dict, hashed_password=hashed_password) - return await self.user_db.create(db_user) + + created_user = await self.user_db.create(db_user) + + await self.on_after_register(created_user, request) + + return created_user async def verify(self, user: models.UD) -> models.UD: if user.is_verified: @@ -105,6 +112,11 @@ class BaseUserManager(Generic[models.UC, models.UD]): ) -> None: return # pragma: no cover + async def on_after_register( + self, user: models.UD, request: Optional[Request] = None + ) -> None: + return # pragma: no cover + async def authenticate( self, credentials: OAuth2PasswordRequestForm ) -> Optional[models.UD]: diff --git a/fastapi_users/router/register.py b/fastapi_users/router/register.py index 92b94e1c..452262be 100644 --- a/fastapi_users/router/register.py +++ b/fastapi_users/router/register.py @@ -1,22 +1,21 @@ -from typing import Callable, Optional, Type +from typing import Type from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi_users import models from fastapi_users.manager import ( + BaseUserManager, InvalidPasswordException, UserAlreadyExists, - BaseUserManager, UserManagerDependency, ) -from fastapi_users.router.common import ErrorCode, run_handler +from fastapi_users.router.common import ErrorCode def get_register_router( get_user_manager: UserManagerDependency[models.UC, models.UD], user_model: Type[models.U], user_create_model: Type[models.UC], - after_register: Optional[Callable[[models.UD, Request], None]] = None, ) -> APIRouter: """Generate a router with the register route.""" router = APIRouter() @@ -30,7 +29,7 @@ def get_register_router( user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), ): try: - created_user = await user_manager.create(user, safe=True) + created_user = await user_manager.create(user, safe=True, request=request) except UserAlreadyExists: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -45,9 +44,6 @@ def get_register_router( }, ) - if after_register: - await run_handler(after_register, created_user, request) - return created_user return router diff --git a/fastapi_users/router/reset.py b/fastapi_users/router/reset.py index a9a26358..a3bdc3de 100644 --- a/fastapi_users/router/reset.py +++ b/fastapi_users/router/reset.py @@ -7,8 +7,8 @@ from pydantic import UUID4, EmailStr from fastapi_users import models from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.manager import ( - InvalidPasswordException, BaseUserManager, + InvalidPasswordException, UserManagerDependency, UserNotExists, ) diff --git a/fastapi_users/router/users.py b/fastapi_users/router/users.py index e368ff81..c09e626b 100644 --- a/fastapi_users/router/users.py +++ b/fastapi_users/router/users.py @@ -6,9 +6,9 @@ from pydantic import UUID4 from fastapi_users import models from fastapi_users.authentication import Authenticator from fastapi_users.manager import ( + BaseUserManager, InvalidPasswordException, UserAlreadyExists, - BaseUserManager, UserManagerDependency, UserNotExists, ) diff --git a/fastapi_users/router/verify.py b/fastapi_users/router/verify.py index 01b09653..9d75419a 100644 --- a/fastapi_users/router/verify.py +++ b/fastapi_users/router/verify.py @@ -7,8 +7,8 @@ from pydantic import UUID4, EmailStr from fastapi_users import models from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.manager import ( - UserAlreadyVerified, BaseUserManager, + UserAlreadyVerified, UserManagerDependency, UserNotExists, ) diff --git a/tests/conftest.py b/tests/conftest.py index 68280bbc..a44e14af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,22 +1,26 @@ import asyncio from typing import AsyncGenerator, List, Optional, Union +from unittest.mock import MagicMock import httpx import pytest from asgi_lifespan import LifespanManager -from fastapi import Depends, FastAPI, Response +from fastapi import Depends, FastAPI, Request, Response from fastapi.security import OAuth2PasswordBearer from httpx_oauth.oauth2 import OAuth2 from pydantic import UUID4, SecretStr +from pytest_mock import MockerFixture from starlette.applications import ASGIApp from fastapi_users import models from fastapi_users.authentication import Authenticator, BaseAuthentication from fastapi_users.db import BaseUserDatabase from fastapi_users.jwt import SecretType -from fastapi_users.manager import InvalidPasswordException -from fastapi_users.manager import BaseUserManager -from fastapi_users.manager import UserNotExists +from fastapi_users.manager import ( + BaseUserManager, + InvalidPasswordException, + UserNotExists, +) from fastapi_users.models import BaseOAuthAccount, BaseOAuthAccountMixin from fastapi_users.password import get_password_hash @@ -60,6 +64,15 @@ class UserManager(BaseUserManager[UserCreate, UserDB]): reason="Password should be at least 3 characters" ) + async def on_after_register( + self, user: UserDB, request: Optional[Request] = None + ) -> None: + return + + +class UserManagerMock(UserManager): + on_after_register: MagicMock + @pytest.fixture(scope="session") def event_loop(): @@ -351,8 +364,10 @@ def get_mock_user_db_oauth(mock_user_db_oauth): @pytest.fixture -def user_manager(mock_user_db): - return UserManager(UserDB, mock_user_db) +def user_manager(mocker: MockerFixture, mock_user_db): + user_manager = UserManager(UserDB, mock_user_db) + mocker.spy(user_manager, "on_after_register") + return user_manager @pytest.fixture diff --git a/tests/test_manager.py b/tests/test_manager.py index f8759a4e..059652c8 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -5,7 +5,7 @@ from fastapi.security import OAuth2PasswordRequestForm from pytest_mock import MockerFixture from fastapi_users.manager import UserAlreadyExists, UserAlreadyVerified -from tests.conftest import UserCreate, UserDB, UserManager +from tests.conftest import UserCreate, UserDB, UserManagerMock @pytest.fixture @@ -23,19 +23,24 @@ class TestCreateUser: @pytest.mark.parametrize( "email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"] ) - async def test_existing_user(self, email: str, user_manager: UserManager): + async def test_existing_user(self, email: str, user_manager: UserManagerMock): user = UserCreate(email=email, password="guinevere") with pytest.raises(UserAlreadyExists): await user_manager.create(user) + assert user_manager.on_after_register.called is False @pytest.mark.parametrize("email", ["lancelot@camelot.bt", "Lancelot@camelot.bt"]) - async def test_regular_user(self, email: str, user_manager: UserManager): + async def test_regular_user(self, email: str, user_manager: UserManagerMock): user = UserCreate(email=email, password="guinevere") created_user = await user_manager.create(user) assert type(created_user) == UserDB + assert user_manager.on_after_register.called is True + @pytest.mark.parametrize("safe,result", [(True, False), (False, True)]) - async def test_superuser(self, user_manager: UserManager, safe: bool, result: bool): + async def test_superuser( + self, user_manager: UserManagerMock, safe: bool, result: bool + ): user = UserCreate( email="lancelot@camelot.b", password="guinevere", is_superuser=True ) @@ -43,8 +48,12 @@ class TestCreateUser: assert type(created_user) == UserDB assert created_user.is_superuser is result + assert user_manager.on_after_register.called is True + @pytest.mark.parametrize("safe,result", [(True, True), (False, False)]) - async def test_is_active(self, user_manager: UserManager, safe: bool, result: bool): + async def test_is_active( + self, user_manager: UserManagerMock, safe: bool, result: bool + ): user = UserCreate( email="lancelot@camelot.b", password="guinevere", is_active=False ) @@ -52,16 +61,18 @@ class TestCreateUser: assert type(created_user) == UserDB assert created_user.is_active is result + assert user_manager.on_after_register.called is True + @pytest.mark.asyncio class TestVerifyUser: async def test_already_verified_user( - self, user_manager: UserManager, verified_user: UserDB + self, user_manager: UserManagerMock, verified_user: UserDB ): with pytest.raises(UserAlreadyVerified): await user_manager.verify(verified_user) - async def test_non_verified_user(self, user_manager: UserManager, user: UserDB): + async def test_non_verified_user(self, user_manager: UserManagerMock, user: UserDB): user = await user_manager.verify(user) assert user.is_verified @@ -74,7 +85,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManager, + user_manager: UserManagerMock, ): form = create_oauth2_password_request_form("lancelot@camelot.bt", "guinevere") user = await user_manager.authenticate(form) @@ -86,7 +97,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManager, + user_manager: UserManagerMock, ): form = create_oauth2_password_request_form("king.arthur@camelot.bt", "percival") user = await user_manager.authenticate(form) @@ -98,7 +109,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManager, + user_manager: UserManagerMock, ): form = create_oauth2_password_request_form( "king.arthur@camelot.bt", "guinevere" @@ -114,7 +125,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManager, + user_manager: UserManagerMock, ): verify_and_update_password_patch = mocker.patch( "fastapi_users.password.verify_and_update_password" diff --git a/tests/test_router_register.py b/tests/test_router_register.py index 13a6d38e..a7f7c909 100644 --- a/tests/test_router_register.py +++ b/tests/test_router_register.py @@ -1,38 +1,22 @@ from typing import Any, AsyncGenerator, Dict, cast -from unittest.mock import MagicMock -import asynctest import httpx import pytest -from fastapi import FastAPI, Request, status +from fastapi import FastAPI, status from fastapi_users.router import ErrorCode, get_register_router from tests.conftest import User, UserCreate -def after_register_sync(): - return MagicMock(return_value=None) - - -def after_register_async(): - return asynctest.CoroutineMock(return_value=None) - - -@pytest.fixture(params=[after_register_sync, after_register_async]) -def after_register(request): - return request.param() - - @pytest.fixture @pytest.mark.asyncio async def test_app_client( - get_user_manager, after_register, get_test_client + get_user_manager, get_test_client ) -> AsyncGenerator[httpx.AsyncClient, None]: register_router = get_register_router( get_user_manager, User, UserCreate, - after_register, ) app = FastAPI() @@ -45,38 +29,26 @@ async def test_app_client( @pytest.mark.router @pytest.mark.asyncio class TestRegister: - async def test_empty_body(self, test_app_client: httpx.AsyncClient, after_register): + async def test_empty_body(self, test_app_client: httpx.AsyncClient): response = await test_app_client.post("/register", json={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - assert after_register.called is False - async def test_missing_email( - self, test_app_client: httpx.AsyncClient, after_register - ): + async def test_missing_email(self, test_app_client: httpx.AsyncClient): json = {"password": "guinevere"} response = await test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - assert after_register.called is False - async def test_missing_password( - self, test_app_client: httpx.AsyncClient, after_register - ): + async def test_missing_password(self, test_app_client: httpx.AsyncClient): json = {"email": "king.arthur@camelot.bt"} response = await test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - assert after_register.called is False - async def test_wrong_email( - self, test_app_client: httpx.AsyncClient, after_register - ): + async def test_wrong_email(self, test_app_client: httpx.AsyncClient): json = {"email": "king.arthur", "password": "guinevere"} response = await test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - assert after_register.called is False - async def test_invalid_password( - self, test_app_client: httpx.AsyncClient, after_register - ): + async def test_invalid_password(self, test_app_client: httpx.AsyncClient): json = {"email": "king.arthur@camelot.bt", "password": "g"} response = await test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -85,44 +57,29 @@ class TestRegister: "code": ErrorCode.REGISTER_INVALID_PASSWORD, "reason": "Password should be at least 3 characters", } - assert after_register.called is False @pytest.mark.parametrize( "email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"] ) - async def test_existing_user( - self, email, test_app_client: httpx.AsyncClient, after_register - ): + async def test_existing_user(self, email, test_app_client: httpx.AsyncClient): json = {"email": email, "password": "guinevere"} response = await test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(Dict[str, Any], response.json()) assert data["detail"] == ErrorCode.REGISTER_USER_ALREADY_EXISTS - assert after_register.called is False @pytest.mark.parametrize("email", ["lancelot@camelot.bt", "Lancelot@camelot.bt"]) - async def test_valid_body( - self, email, test_app_client: httpx.AsyncClient, after_register - ): + async def test_valid_body(self, email, test_app_client: httpx.AsyncClient): json = {"email": email, "password": "guinevere"} response = await test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_201_CREATED - assert after_register.called is True data = cast(Dict[str, Any], response.json()) assert "hashed_password" not in data assert "password" not in data assert data["id"] is not None - actual_user = after_register.call_args[0][0] - assert str(actual_user.id) == data["id"] - assert str(actual_user.email) == email - request = after_register.call_args[0][1] - assert isinstance(request, Request) - - async def test_valid_body_is_superuser( - self, test_app_client: httpx.AsyncClient, after_register - ): + async def test_valid_body_is_superuser(self, test_app_client: httpx.AsyncClient): json = { "email": "lancelot@camelot.bt", "password": "guinevere", @@ -130,14 +87,11 @@ class TestRegister: } response = await test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_201_CREATED - assert after_register.called is True data = cast(Dict[str, Any], response.json()) assert data["is_superuser"] is False - async def test_valid_body_is_active( - self, test_app_client: httpx.AsyncClient, after_register - ): + async def test_valid_body_is_active(self, test_app_client: httpx.AsyncClient): json = { "email": "lancelot@camelot.bt", "password": "guinevere", @@ -145,7 +99,6 @@ class TestRegister: } response = await test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_201_CREATED - assert after_register.called is True data = cast(Dict[str, Any], response.json()) assert data["is_active"] is True