diff --git a/docs/configuration/user-manager.md b/docs/configuration/user-manager.md index 2410cd02..be5ed195 100644 --- a/docs/configuration/user-manager.md +++ b/docs/configuration/user-manager.md @@ -187,7 +187,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): self, user: User, request: Optional[Request] = None, - login_return: Optional[Any] = None, + response: Optional[Response] = None, ): print(f"User {user.id} logged in.") ``` diff --git a/fastapi_users/authentication/backend.py b/fastapi_users/authentication/backend.py index b02ca98d..4dab8674 100644 --- a/fastapi_users/authentication/backend.py +++ b/fastapi_users/authentication/backend.py @@ -1,6 +1,6 @@ -from typing import Any, Generic +from typing import Generic -from fastapi import Response +from fastapi import Response, status from fastapi_users import models from fastapi_users.authentication.strategy import ( @@ -40,27 +40,22 @@ class AuthenticationBackend(Generic[models.UP, models.ID]): self.get_strategy = get_strategy async def login( - self, - strategy: Strategy[models.UP, models.ID], - user: models.UP, - response: Response, - ) -> Any: + self, strategy: Strategy[models.UP, models.ID], user: models.UP + ) -> Response: token = await strategy.write_token(user) - return await self.transport.get_login_response(token, response) + return await self.transport.get_login_response(token) async def logout( - self, - strategy: Strategy[models.UP, models.ID], - user: models.UP, - token: str, - response: Response, - ) -> Any: + self, strategy: Strategy[models.UP, models.ID], user: models.UP, token: str + ) -> Response: try: await strategy.destroy_token(token, user) except StrategyDestroyNotSupportedError: pass try: - await self.transport.get_logout_response(response) + response = await self.transport.get_logout_response() except TransportLogoutNotSupportedError: - return None + response = Response(status_code=status.HTTP_204_NO_CONTENT) + + return response diff --git a/fastapi_users/authentication/transport/base.py b/fastapi_users/authentication/transport/base.py index d54c3a5a..64073e7e 100644 --- a/fastapi_users/authentication/transport/base.py +++ b/fastapi_users/authentication/transport/base.py @@ -1,5 +1,4 @@ import sys -from typing import Any if sys.version_info < (3, 8): from typing_extensions import Protocol # pragma: no cover @@ -19,10 +18,10 @@ class TransportLogoutNotSupportedError(Exception): class Transport(Protocol): scheme: SecurityBase - async def get_login_response(self, token: str, response: Response) -> Any: + async def get_login_response(self, token: str) -> Response: ... # pragma: no cover - async def get_logout_response(self, response: Response) -> Any: + async def get_logout_response(self) -> Response: ... # pragma: no cover @staticmethod diff --git a/fastapi_users/authentication/transport/bearer.py b/fastapi_users/authentication/transport/bearer.py index 924fe9f8..d060720b 100644 --- a/fastapi_users/authentication/transport/bearer.py +++ b/fastapi_users/authentication/transport/bearer.py @@ -1,6 +1,5 @@ -from typing import Any - from fastapi import Response, status +from fastapi.responses import JSONResponse from fastapi.security import OAuth2PasswordBearer from pydantic import BaseModel @@ -22,10 +21,11 @@ class BearerTransport(Transport): def __init__(self, tokenUrl: str): self.scheme = OAuth2PasswordBearer(tokenUrl, auto_error=False) - async def get_login_response(self, token: str, response: Response) -> Any: - return BearerResponse(access_token=token, token_type="bearer") + async def get_login_response(self, token: str) -> Response: + bearer_response = BearerResponse(access_token=token, token_type="bearer") + return JSONResponse(bearer_response.dict()) - async def get_logout_response(self, response: Response) -> Any: + async def get_logout_response(self) -> Response: raise TransportLogoutNotSupportedError() @staticmethod diff --git a/fastapi_users/authentication/transport/cookie.py b/fastapi_users/authentication/transport/cookie.py index edc07aa9..68c6dd90 100644 --- a/fastapi_users/authentication/transport/cookie.py +++ b/fastapi_users/authentication/transport/cookie.py @@ -1,5 +1,5 @@ import sys -from typing import Any, Optional +from typing import Optional if sys.version_info < (3, 8): from typing_extensions import Literal # pragma: no cover @@ -35,7 +35,15 @@ class CookieTransport(Transport): self.cookie_samesite = cookie_samesite self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False) - async def get_login_response(self, token: str, response: Response) -> Any: + async def get_login_response(self, token: str) -> Response: + response = Response(status_code=status.HTTP_204_NO_CONTENT) + return self._set_login_cookie(response, token) + + async def get_logout_response(self) -> Response: + response = Response(status_code=status.HTTP_204_NO_CONTENT) + return self._set_logout_cookie(response) + + def _set_login_cookie(self, response: Response, token: str) -> Response: response.set_cookie( self.cookie_name, token, @@ -46,12 +54,9 @@ class CookieTransport(Transport): httponly=self.cookie_httponly, samesite=self.cookie_samesite, ) + return response - # We shouldn't return directly the response - # so that FastAPI can terminate it properly - return None - - async def get_logout_response(self, response: Response) -> Any: + def _set_logout_cookie(self, response: Response) -> Response: response.set_cookie( self.cookie_name, "", @@ -62,11 +67,12 @@ class CookieTransport(Transport): httponly=self.cookie_httponly, samesite=self.cookie_samesite, ) + return response @staticmethod def get_openapi_login_responses_success() -> OpenAPIResponseType: - return {status.HTTP_200_OK: {"model": None}} + return {status.HTTP_204_NO_CONTENT: {"model": None}} @staticmethod def get_openapi_logout_responses_success() -> OpenAPIResponseType: - return {status.HTTP_200_OK: {"model": None}} + return {status.HTTP_204_NO_CONTENT: {"model": None}} diff --git a/fastapi_users/manager.py b/fastapi_users/manager.py index 3878c13d..60eb6c32 100644 --- a/fastapi_users/manager.py +++ b/fastapi_users/manager.py @@ -2,7 +2,7 @@ import uuid from typing import Any, Dict, Generic, Optional, Union import jwt -from fastapi import Request +from fastapi import Request, Response from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import exceptions, models, schemas @@ -589,7 +589,7 @@ class BaseUserManager(Generic[models.UP, models.ID]): self, user: models.UP, request: Optional[Request] = None, - login_return: Optional[Any] = None, + response: Optional[Response] = None, ) -> None: """ Perform logic after user login. @@ -598,8 +598,8 @@ class BaseUserManager(Generic[models.UP, models.ID]): :param user: The user that is logging in :param request: Optional FastAPI request - :param login_return: Optional return of the login - triggered the operation, defaults to None. + :param response: Optional response built by the transport. + Defaults to None """ return # pragma: no cover diff --git a/fastapi_users/router/auth.py b/fastapi_users/router/auth.py index b9587ab0..c61770f0 100644 --- a/fastapi_users/router/auth.py +++ b/fastapi_users/router/auth.py @@ -1,6 +1,6 @@ from typing import Tuple -from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import models @@ -50,7 +50,6 @@ def get_auth_router( ) async def login( request: Request, - response: Response, credentials: OAuth2PasswordRequestForm = Depends(), user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), @@ -67,9 +66,9 @@ def get_auth_router( status_code=status.HTTP_400_BAD_REQUEST, detail=ErrorCode.LOGIN_USER_NOT_VERIFIED, ) - login_return = await backend.login(strategy, user, response) - await user_manager.on_after_login(user, request, login_return) - return login_return + response = await backend.login(strategy, user) + await user_manager.on_after_login(user, request, response) + return response logout_responses: OpenAPIResponseType = { **{ @@ -84,11 +83,10 @@ def get_auth_router( "/logout", name=f"auth:{backend.name}.logout", responses=logout_responses ) async def logout( - response: Response, user_token: Tuple[models.UP, str] = Depends(get_current_user_token), strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), ): user, token = user_token - return await backend.logout(strategy, user, token, response) + return await backend.logout(strategy, user, token) return router diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index e6c090db..cf43c9c4 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional, Tuple, Type import jwt -from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token from pydantic import BaseModel @@ -100,7 +100,6 @@ def get_oauth_router( ) async def callback( request: Request, - response: Response, access_token_state: Tuple[OAuth2Token, str] = Depends( oauth2_authorize_callback ), @@ -148,9 +147,9 @@ def get_oauth_router( ) # Authenticate - login_return = await backend.login(strategy, user, response) - await user_manager.on_after_login(user, request, login_return) - return login_return + response = await backend.login(strategy, user) + await user_manager.on_after_login(user, request, response) + return response return router diff --git a/tests/conftest.py b/tests/conftest.py index 6fb14340..b3bcd4da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -517,8 +517,8 @@ class MockTransport(BearerTransport): def __init__(self, tokenUrl: str): super().__init__(tokenUrl) - async def get_logout_response(self, response: Response) -> Any: - return None + async def get_logout_response(self) -> Any: + return Response() @staticmethod def get_openapi_logout_responses_success() -> OpenAPIResponseType: diff --git a/tests/test_authentication_backend.py b/tests/test_authentication_backend.py index b7d99795..a4cbc050 100644 --- a/tests/test_authentication_backend.py +++ b/tests/test_authentication_backend.py @@ -57,5 +57,5 @@ def backend( @pytest.mark.authentication async def test_logout(backend: AuthenticationBackend, user: UserModel): strategy = cast(Strategy, backend.get_strategy()) - result = await backend.logout(strategy, user, "TOKEN", Response()) - assert result is None + result = await backend.logout(strategy, user, "TOKEN") + assert isinstance(result, Response) diff --git a/tests/test_authentication_transport_bearer.py b/tests/test_authentication_transport_bearer.py index cadf0457..9739be4d 100644 --- a/tests/test_authentication_transport_bearer.py +++ b/tests/test_authentication_transport_bearer.py @@ -1,5 +1,6 @@ import pytest -from fastapi import Response, status +from fastapi import status +from fastapi.responses import JSONResponse from fastapi_users.authentication.transport import ( BearerTransport, @@ -16,21 +17,17 @@ def bearer_transport() -> BearerTransport: @pytest.mark.authentication @pytest.mark.asyncio async def test_get_login_response(bearer_transport: BearerTransport): - response = Response() - login_response = await bearer_transport.get_login_response("TOKEN", response) + response = await bearer_transport.get_login_response("TOKEN") - assert isinstance(login_response, BearerResponse) - - assert login_response.access_token == "TOKEN" - assert login_response.token_type == "bearer" + assert isinstance(response, JSONResponse) + assert response.body == b'{"access_token":"TOKEN","token_type":"bearer"}' @pytest.mark.authentication @pytest.mark.asyncio async def test_get_logout_response(bearer_transport: BearerTransport): - response = Response() with pytest.raises(TransportLogoutNotSupportedError): - await bearer_transport.get_logout_response(response) + await bearer_transport.get_logout_response() @pytest.mark.authentication diff --git a/tests/test_authentication_transport_cookie.py b/tests/test_authentication_transport_cookie.py index ca99f9ca..fba43008 100644 --- a/tests/test_authentication_transport_cookie.py +++ b/tests/test_authentication_transport_cookie.py @@ -38,10 +38,10 @@ async def test_get_login_response(cookie_transport: CookieTransport): secure = cookie_transport.cookie_secure httponly = cookie_transport.cookie_httponly - response = Response() - login_response = await cookie_transport.get_login_response("TOKEN", response) + response = await cookie_transport.get_login_response("TOKEN") - assert login_response is None + assert isinstance(response, Response) + assert response.status_code == status.HTTP_204_NO_CONTENT cookies = [header for header in response.raw_headers if header[0] == b"set-cookie"] assert len(cookies) == 1 @@ -79,10 +79,10 @@ async def test_get_login_response(cookie_transport: CookieTransport): @pytest.mark.authentication @pytest.mark.asyncio async def test_get_logout_response(cookie_transport: CookieTransport): - response = Response() - logout_response = await cookie_transport.get_logout_response(response) + response = await cookie_transport.get_logout_response() - assert logout_response is None + assert isinstance(response, Response) + assert response.status_code == status.HTTP_204_NO_CONTENT cookies = [header for header in response.raw_headers if header[0] == b"set-cookie"] assert len(cookies) == 1 @@ -96,7 +96,7 @@ async def test_get_logout_response(cookie_transport: CookieTransport): @pytest.mark.openapi def test_get_openapi_login_responses_success(cookie_transport: CookieTransport): assert cookie_transport.get_openapi_login_responses_success() == { - status.HTTP_200_OK: {"model": None} + status.HTTP_204_NO_CONTENT: {"model": None} } @@ -104,5 +104,5 @@ def test_get_openapi_login_responses_success(cookie_transport: CookieTransport): @pytest.mark.openapi def test_get_openapi_logout_responses_success(cookie_transport: CookieTransport): assert cookie_transport.get_openapi_logout_responses_success() == { - status.HTTP_200_OK: {"model": None} + status.HTTP_204_NO_CONTENT: {"model": None} }