diff --git a/docs/configuration/user-manager.md b/docs/configuration/user-manager.md index acc5faef..1b3fe69b 100644 --- a/docs/configuration/user-manager.md +++ b/docs/configuration/user-manager.md @@ -164,6 +164,33 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): print(f"User {user.id} has been updated with {update_dict}.") ``` +#### `on_after_login` + +Perform logic after a successful user login. + +It may be useful for custom logic or processes triggered by new logins, for example a daily login reward or for analytics. + +**Arguments** + +* `user` (`User`): the updated user. +* `request` (`Optional[Request]`): optional FastAPI request object that triggered the operation. Defaults to None. + +**Example** + +```py +from fastapi_users import BaseUserManager, UUIDIDMixin + + +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): + # ... + async def on_after_login( + self, + user: User, + request: Optional[Request] = None, + ): + print(f"User {user.id} logged in.") +``` + #### `on_after_request_verify` Perform logic after successful verification request. diff --git a/fastapi_users/manager.py b/fastapi_users/manager.py index dc78277c..58a3ee5c 100644 --- a/fastapi_users/manager.py +++ b/fastapi_users/manager.py @@ -571,6 +571,20 @@ class BaseUserManager(Generic[models.UP, models.ID]): """ return # pragma: no cover + async def on_after_login( + self, user: models.UP, request: Optional[Request] = None + ) -> None: + """ + Perform logic after user login. + + *You should overload this method to add your own logic.* + + :param user: The user that is logging in + :param request: Optional FastAPI request that + triggered the operation, defaults to None. + """ + return # pragma: no cover + async def on_before_delete( self, user: models.UP, request: Optional[Request] = None ) -> None: diff --git a/fastapi_users/router/auth.py b/fastapi_users/router/auth.py index cda0fa3e..10ce7652 100644 --- a/fastapi_users/router/auth.py +++ b/fastapi_users/router/auth.py @@ -1,6 +1,6 @@ from typing import Tuple -from fastapi import APIRouter, Depends, HTTPException, Response, status +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import models @@ -49,6 +49,7 @@ def get_auth_router( responses=login_responses, ) async def login( + request: Request, response: Response, credentials: OAuth2PasswordRequestForm = Depends(), user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), @@ -66,7 +67,9 @@ def get_auth_router( status_code=status.HTTP_400_BAD_REQUEST, detail=ErrorCode.LOGIN_USER_NOT_VERIFIED, ) - return await backend.login(strategy, user, response) + login_return = await backend.login(strategy, user, response) + await user_manager.on_after_login(user, request) + return login_return logout_responses: OpenAPIResponseType = { **{ diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index 16e96d00..2da179d2 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -140,7 +140,9 @@ def get_oauth_router( ) # Authenticate - return await backend.login(strategy, user, response) + login_return = await backend.login(strategy, user, response) + await user_manager.on_after_login(user, request) + return login_return return router diff --git a/tests/conftest.py b/tests/conftest.py index 9f0c3497..e0fcabb5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -122,6 +122,7 @@ class UserManagerMock(BaseTestUserManager[models.UP]): on_after_update: MagicMock on_before_delete: MagicMock on_after_delete: MagicMock + on_after_login: MagicMock _update: MagicMock @@ -479,6 +480,7 @@ def make_user_manager(mocker: MockerFixture): mocker.spy(user_manager, "on_after_update") mocker.spy(user_manager, "on_before_delete") mocker.spy(user_manager, "on_after_delete") + mocker.spy(user_manager, "on_after_login") mocker.spy(user_manager, "_update") return user_manager diff --git a/tests/test_router_auth.py b/tests/test_router_auth.py index 1bbc8026..7d2a042d 100644 --- a/tests/test_router_auth.py +++ b/tests/test_router_auth.py @@ -61,35 +61,42 @@ class TestLogin: self, path, test_app_client: Tuple[httpx.AsyncClient, bool], + user_manager, ): client, _ = test_app_client response = await client.post(path, data={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert user_manager.on_after_login.called is False async def test_missing_username( self, path, test_app_client: Tuple[httpx.AsyncClient, bool], + user_manager, ): client, _ = test_app_client data = {"password": "guinevere"} response = await client.post(path, data=data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert user_manager.on_after_login.called is False async def test_missing_password( self, path, test_app_client: Tuple[httpx.AsyncClient, bool], + user_manager, ): client, _ = test_app_client data = {"username": "king.arthur@camelot.bt"} response = await client.post(path, data=data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert user_manager.on_after_login.called is False async def test_not_existing_user( self, path, test_app_client: Tuple[httpx.AsyncClient, bool], + user_manager, ): client, _ = test_app_client data = {"username": "lancelot@camelot.bt", "password": "guinevere"} @@ -97,11 +104,13 @@ class TestLogin: assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(Dict[str, Any], response.json()) assert data["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS + assert user_manager.on_after_login.called is False async def test_wrong_password( self, path, test_app_client: Tuple[httpx.AsyncClient, bool], + user_manager, ): client, _ = test_app_client data = {"username": "king.arthur@camelot.bt", "password": "percival"} @@ -109,6 +118,7 @@ class TestLogin: assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(Dict[str, Any], response.json()) assert data["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS + assert user_manager.on_after_login.called is False @pytest.mark.parametrize( "email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"] @@ -118,6 +128,7 @@ class TestLogin: path, email, test_app_client: Tuple[httpx.AsyncClient, bool], + user_manager, user: UserModel, ): client, requires_verification = test_app_client @@ -127,12 +138,14 @@ class TestLogin: assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(Dict[str, Any], response.json()) assert data["detail"] == ErrorCode.LOGIN_USER_NOT_VERIFIED + assert user_manager.on_after_login.called is False else: assert response.status_code == status.HTTP_200_OK assert response.json() == { "access_token": str(user.id), "token_type": "bearer", } + assert user_manager.on_after_login.called is True @pytest.mark.parametrize("email", ["lake.lady@camelot.bt", "Lake.Lady@camelot.bt"]) async def test_valid_credentials_verified( @@ -140,6 +153,7 @@ class TestLogin: path, email, test_app_client: Tuple[httpx.AsyncClient, bool], + user_manager, verified_user: UserModel, ): client, _ = test_app_client @@ -150,11 +164,13 @@ class TestLogin: "access_token": str(verified_user.id), "token_type": "bearer", } + assert user_manager.on_after_login.called is True async def test_inactive_user( self, path, test_app_client: Tuple[httpx.AsyncClient, bool], + user_manager, ): client, _ = test_app_client data = {"username": "percival@camelot.bt", "password": "angharad"} @@ -162,6 +178,7 @@ class TestLogin: assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(Dict[str, Any], response.json()) assert data["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS + assert user_manager.on_after_login.called is False @pytest.mark.router diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index c13c164d..11373442 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -188,6 +188,8 @@ class TestCallback: data = cast(Dict[str, Any], response.json()) assert data["detail"] == ErrorCode.OAUTH_USER_ALREADY_EXISTS + assert user_manager_oauth.on_after_login.called is False + async def test_active_user( self, async_method_mocker: AsyncMethodMocker, @@ -216,6 +218,8 @@ class TestCallback: data = cast(Dict[str, Any], response.json()) assert data["access_token"] == str(user_oauth.id) + assert user_manager_oauth.on_after_login.called is True + async def test_inactive_user( self, async_method_mocker: AsyncMethodMocker, @@ -242,6 +246,7 @@ class TestCallback: ) assert response.status_code == status.HTTP_400_BAD_REQUEST + assert user_manager_oauth.on_after_login.called is False async def test_redirect_url_router( self, @@ -276,6 +281,7 @@ class TestCallback: data = cast(Dict[str, Any], response.json()) assert data["access_token"] == str(user_oauth.id) + assert user_manager_oauth.on_after_login.called is True @pytest.mark.router