diff --git a/docs/configuration/router.md b/docs/configuration/router.md index 465c2ff2..a139eb90 100644 --- a/docs/configuration/router.md +++ b/docs/configuration/router.md @@ -34,6 +34,22 @@ app.include_router(fastapi_users.router, prefix="/users", tags=["users"]) In order to be as unopinionated as possible, we expose decorators that allow you to plug your own logic after some actions. You can have several handlers per event. +### After register + +This event handler is called after a successful registration. It is called with **one argument**: the **user** that has just registered. + +Typically, you'll want to **send a welcome e-mail** or add it to your marketing analytics pipeline. + +You can define it as an `async` or standard method. + +Example: + +```py +@fastapi_users.on_after_register() +def on_after_register(user: User): + print(f"User {user.id} has registered.") +``` + ### After forgot password This event handler is called after a successful forgot password request. It is called with **two arguments**: the **user** which has requested to reset their password and a ready-to-use **JWT token** that will be accepted by the reset password route. @@ -46,8 +62,8 @@ Example: ```py @fastapi_users.on_after_forgot_password() -def on_after_forgot_password(user, token): - print(f'User {user.id} has forgot their password. Reset token: {token}') +def on_after_forgot_password(user: User, token: str): + print(f"User {user.id} has forgot their password. Reset token: {token}") ``` ## Next steps diff --git a/docs/src/full_sqlalchemy.py b/docs/src/full_sqlalchemy.py index 14fac916..1515ed1b 100644 --- a/docs/src/full_sqlalchemy.py +++ b/docs/src/full_sqlalchemy.py @@ -40,8 +40,13 @@ fastapi_users = FastAPIUsers(user_db, auth, User, SECRET) app.include_router(fastapi_users.router, prefix="/users", tags=["users"]) +@fastapi_users.on_after_register() +def on_after_register(user: User): + print(f"User {user.id} has registered.") + + @fastapi_users.on_after_forgot_password() -def on_after_forgot_password(user, token): +def on_after_forgot_password(user: User, token: str): print(f"User {user.id} has forgot their password. Reset token: {token}") diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index 5a11e072..34a94f3e 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -52,6 +52,10 @@ class FastAPIUsers: get_current_superuser = self.auth.get_current_superuser(self.db) self.get_current_superuser = get_current_superuser # type: ignore + def on_after_register(self) -> Callable: + """Add an event handler on successful registration.""" + return self._on_event(Events.ON_AFTER_REGISTER) + def on_after_forgot_password(self) -> Callable: """Add an event handler on successful forgot password request.""" return self._on_event(Events.ON_AFTER_FORGOT_PASSWORD) diff --git a/fastapi_users/router.py b/fastapi_users/router.py index 6175ede7..db6ec268 100644 --- a/fastapi_users/router.py +++ b/fastapi_users/router.py @@ -18,7 +18,8 @@ from fastapi_users.utils import JWT_ALGORITHM, generate_jwt class Events(Enum): - ON_AFTER_FORGOT_PASSWORD = 1 + ON_AFTER_REGISTER = 1 + ON_AFTER_FORGOT_PASSWORD = 2 class UserRouter(APIRouter): @@ -68,6 +69,9 @@ def get_user_router( **user.create_update_dict(), hashed_password=hashed_password ) created_user = await user_db.create(db_user) + + await router.run_handlers(Events.ON_AFTER_REGISTER, created_user) + return created_user @router.post("/login") diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index a4476050..1cfac3ee 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -5,29 +5,37 @@ from starlette.testclient import TestClient from fastapi_users import FastAPIUsers from fastapi_users.models import BaseUser, BaseUserDB - -SECRET = "SECRET" +from fastapi_users.router import Events -def sync_on_after_forgot_password(): +def sync_event_handler(): return None -async def async_on_after_forgot_password(): +async def async_event_handler(): return None -@pytest.fixture(params=[sync_on_after_forgot_password, async_on_after_forgot_password]) -def test_app_client(request, mock_user_db, mock_authentication) -> TestClient: +@pytest.fixture(params=[sync_event_handler, async_event_handler]) +def fastapi_users(request, mock_user_db, mock_authentication) -> FastAPIUsers: class User(BaseUser): pass - fastapi_users = FastAPIUsers(mock_user_db, mock_authentication, User, SECRET) + fastapi_users = FastAPIUsers(mock_user_db, mock_authentication, User, "SECRET") + + @fastapi_users.on_after_register() + def on_after_register(): + return request.param() @fastapi_users.on_after_forgot_password() def on_after_forgot_password(): return request.param() + return fastapi_users + + +@pytest.fixture() +def test_app_client(fastapi_users) -> TestClient: app = FastAPI() app.include_router(fastapi_users.router) @@ -46,6 +54,13 @@ def test_app_client(request, mock_user_db, mock_authentication) -> TestClient: return TestClient(app) +class TestFastAPIUsers: + def test_event_handlers(self, fastapi_users): + event_handlers = fastapi_users.router.event_handlers + assert len(event_handlers[Events.ON_AFTER_REGISTER]) == 1 + assert len(event_handlers[Events.ON_AFTER_FORGOT_PASSWORD]) == 1 + + class TestRouter: def test_routes_exist(self, test_app_client: TestClient): response = test_app_client.post("/register") diff --git a/tests/test_router.py b/tests/test_router.py index a83f0720..4edf82c6 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -26,23 +26,21 @@ def forgot_password_token(): return _forgot_password_token -def on_after_forgot_password_sync(): +def event_handler_sync(): return MagicMock(return_value=None) -def on_after_forgot_password_async(): +def event_handler_async(): return asynctest.CoroutineMock(return_value=None) -@pytest.fixture(params=[on_after_forgot_password_sync, on_after_forgot_password_async]) -def on_after_forgot_password(request): +@pytest.fixture(params=[event_handler_sync, event_handler_async]) +def event_handler(request): return request.param() @pytest.fixture() -def test_app_client( - mock_user_db, mock_authentication, on_after_forgot_password -) -> TestClient: +def test_app_client(mock_user_db, mock_authentication, event_handler) -> TestClient: class User(BaseUser): pass @@ -50,9 +48,8 @@ def test_app_client( mock_user_db, User, mock_authentication, SECRET, LIFETIME ) - userRouter.add_event_handler( - Events.ON_AFTER_FORGOT_PASSWORD, on_after_forgot_password - ) + userRouter.add_event_handler(Events.ON_AFTER_REGISTER, event_handler) + userRouter.add_event_handler(Events.ON_AFTER_FORGOT_PASSWORD, event_handler) app = FastAPI() app.include_router(userRouter) @@ -61,36 +58,44 @@ def test_app_client( class TestRegister: - def test_empty_body(self, test_app_client: TestClient): + def test_empty_body(self, test_app_client: TestClient, event_handler): response = test_app_client.post("/register", json={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert event_handler.called is False - def test_missing_password(self, test_app_client: TestClient): + def test_missing_password(self, test_app_client: TestClient, event_handler): json = {"email": "king.arthur@camelot.bt"} response = test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert event_handler.called is False - def test_wrong_email(self, test_app_client: TestClient): + def test_wrong_email(self, test_app_client: TestClient, event_handler): json = {"email": "king.arthur", "password": "guinevere"} response = test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert event_handler.called is False - def test_existing_user(self, test_app_client: TestClient): + def test_existing_user(self, test_app_client: TestClient, event_handler): json = {"email": "king.arthur@camelot.bt", "password": "guinevere"} response = test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_400_BAD_REQUEST + assert event_handler.called is False - def test_valid_body(self, test_app_client: TestClient): + def test_valid_body(self, test_app_client: TestClient, event_handler): json = {"email": "lancelot@camelot.bt", "password": "guinevere"} response = test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_201_CREATED + assert event_handler.called is True response_json = response.json() assert "hashed_password" not in response_json assert "password" not in response_json assert response_json["id"] is not None - def test_valid_body_is_superuser(self, test_app_client: TestClient): + actual_user = event_handler.call_args[0][0] + assert actual_user.id == response_json["id"] + + def test_valid_body_is_superuser(self, test_app_client: TestClient, event_handler): json = { "email": "lancelot@camelot.bt", "password": "guinevere", @@ -98,11 +103,12 @@ class TestRegister: } response = test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_201_CREATED + assert event_handler.called is True response_json = response.json() assert response_json["is_superuser"] is False - def test_valid_body_is_active(self, test_app_client: TestClient): + def test_valid_body_is_active(self, test_app_client: TestClient, event_handler): json = { "email": "lancelot@camelot.bt", "password": "guinevere", @@ -110,6 +116,7 @@ class TestRegister: } response = test_app_client.post("/register", json=json) assert response.status_code == status.HTTP_201_CREATED + assert event_handler.called is True response_json = response.json() assert response_json["is_active"] is True @@ -153,36 +160,32 @@ class TestLogin: class TestForgotPassword: - def test_empty_body(self, test_app_client: TestClient, on_after_forgot_password): + def test_empty_body(self, test_app_client: TestClient, event_handler): response = test_app_client.post("/forgot-password", json={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - assert on_after_forgot_password.called is False + assert event_handler.called is False - def test_not_existing_user( - self, test_app_client: TestClient, on_after_forgot_password - ): + def test_not_existing_user(self, test_app_client: TestClient, event_handler): json = {"email": "lancelot@camelot.bt"} response = test_app_client.post("/forgot-password", json=json) assert response.status_code == status.HTTP_202_ACCEPTED - assert on_after_forgot_password.called is False + assert event_handler.called is False - def test_inactive_user(self, test_app_client: TestClient, on_after_forgot_password): + def test_inactive_user(self, test_app_client: TestClient, event_handler): json = {"email": "percival@camelot.bt"} response = test_app_client.post("/forgot-password", json=json) assert response.status_code == status.HTTP_202_ACCEPTED - assert on_after_forgot_password.called is False + assert event_handler.called is False - def test_existing_user( - self, test_app_client: TestClient, on_after_forgot_password, user - ): + def test_existing_user(self, test_app_client: TestClient, event_handler, user): json = {"email": "king.arthur@camelot.bt"} response = test_app_client.post("/forgot-password", json=json) assert response.status_code == status.HTTP_202_ACCEPTED - assert on_after_forgot_password.called is True + assert event_handler.called is True - actual_user = on_after_forgot_password.call_args[0][0] + actual_user = event_handler.call_args[0][0] assert actual_user.id == user.id - actual_token = on_after_forgot_password.call_args[0][1] + actual_token = event_handler.call_args[0][1] decoded_token = jwt.decode( actual_token, SECRET,