Revamp Transport so they always build a full Response object (#1049)

* Revamp Transport so they always build a full Response object

* Fix linting

* Add private methods to set cookies on CookieTransport

* Change on_after_login login_return parameter to response
This commit is contained in:
François Voron
2023-04-27 09:32:49 +02:00
committed by GitHub
parent 9a2515f56c
commit 8fd097cbc8
12 changed files with 65 additions and 71 deletions

View File

@ -187,7 +187,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self,
user: User,
request: Optional[Request] = None,
login_return: Optional[Any] = None,
response: Optional[Response] = None,
):
print(f"User {user.id} logged in.")
```

View File

@ -1,6 +1,6 @@
from typing import Any, Generic
from typing import Generic
from fastapi import Response
from fastapi import Response, status
from fastapi_users import models
from fastapi_users.authentication.strategy import (
@ -40,27 +40,22 @@ class AuthenticationBackend(Generic[models.UP, models.ID]):
self.get_strategy = get_strategy
async def login(
self,
strategy: Strategy[models.UP, models.ID],
user: models.UP,
response: Response,
) -> Any:
self, strategy: Strategy[models.UP, models.ID], user: models.UP
) -> Response:
token = await strategy.write_token(user)
return await self.transport.get_login_response(token, response)
return await self.transport.get_login_response(token)
async def logout(
self,
strategy: Strategy[models.UP, models.ID],
user: models.UP,
token: str,
response: Response,
) -> Any:
self, strategy: Strategy[models.UP, models.ID], user: models.UP, token: str
) -> Response:
try:
await strategy.destroy_token(token, user)
except StrategyDestroyNotSupportedError:
pass
try:
await self.transport.get_logout_response(response)
response = await self.transport.get_logout_response()
except TransportLogoutNotSupportedError:
return None
response = Response(status_code=status.HTTP_204_NO_CONTENT)
return response

View File

@ -1,5 +1,4 @@
import sys
from typing import Any
if sys.version_info < (3, 8):
from typing_extensions import Protocol # pragma: no cover
@ -19,10 +18,10 @@ class TransportLogoutNotSupportedError(Exception):
class Transport(Protocol):
scheme: SecurityBase
async def get_login_response(self, token: str, response: Response) -> Any:
async def get_login_response(self, token: str) -> Response:
... # pragma: no cover
async def get_logout_response(self, response: Response) -> Any:
async def get_logout_response(self) -> Response:
... # pragma: no cover
@staticmethod

View File

@ -1,6 +1,5 @@
from typing import Any
from fastapi import Response, status
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
@ -22,10 +21,11 @@ class BearerTransport(Transport):
def __init__(self, tokenUrl: str):
self.scheme = OAuth2PasswordBearer(tokenUrl, auto_error=False)
async def get_login_response(self, token: str, response: Response) -> Any:
return BearerResponse(access_token=token, token_type="bearer")
async def get_login_response(self, token: str) -> Response:
bearer_response = BearerResponse(access_token=token, token_type="bearer")
return JSONResponse(bearer_response.dict())
async def get_logout_response(self, response: Response) -> Any:
async def get_logout_response(self) -> Response:
raise TransportLogoutNotSupportedError()
@staticmethod

View File

@ -1,5 +1,5 @@
import sys
from typing import Any, Optional
from typing import Optional
if sys.version_info < (3, 8):
from typing_extensions import Literal # pragma: no cover
@ -35,7 +35,15 @@ class CookieTransport(Transport):
self.cookie_samesite = cookie_samesite
self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False)
async def get_login_response(self, token: str, response: Response) -> Any:
async def get_login_response(self, token: str) -> Response:
response = Response(status_code=status.HTTP_204_NO_CONTENT)
return self._set_login_cookie(response, token)
async def get_logout_response(self) -> Response:
response = Response(status_code=status.HTTP_204_NO_CONTENT)
return self._set_logout_cookie(response)
def _set_login_cookie(self, response: Response, token: str) -> Response:
response.set_cookie(
self.cookie_name,
token,
@ -46,12 +54,9 @@ class CookieTransport(Transport):
httponly=self.cookie_httponly,
samesite=self.cookie_samesite,
)
return response
# We shouldn't return directly the response
# so that FastAPI can terminate it properly
return None
async def get_logout_response(self, response: Response) -> Any:
def _set_logout_cookie(self, response: Response) -> Response:
response.set_cookie(
self.cookie_name,
"",
@ -62,11 +67,12 @@ class CookieTransport(Transport):
httponly=self.cookie_httponly,
samesite=self.cookie_samesite,
)
return response
@staticmethod
def get_openapi_login_responses_success() -> OpenAPIResponseType:
return {status.HTTP_200_OK: {"model": None}}
return {status.HTTP_204_NO_CONTENT: {"model": None}}
@staticmethod
def get_openapi_logout_responses_success() -> OpenAPIResponseType:
return {status.HTTP_200_OK: {"model": None}}
return {status.HTTP_204_NO_CONTENT: {"model": None}}

View File

@ -2,7 +2,7 @@ import uuid
from typing import Any, Dict, Generic, Optional, Union
import jwt
from fastapi import Request
from fastapi import Request, Response
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import exceptions, models, schemas
@ -589,7 +589,7 @@ class BaseUserManager(Generic[models.UP, models.ID]):
self,
user: models.UP,
request: Optional[Request] = None,
login_return: Optional[Any] = None,
response: Optional[Response] = None,
) -> None:
"""
Perform logic after user login.
@ -598,8 +598,8 @@ class BaseUserManager(Generic[models.UP, models.ID]):
:param user: The user that is logging in
:param request: Optional FastAPI request
:param login_return: Optional return of the login
triggered the operation, defaults to None.
:param response: Optional response built by the transport.
Defaults to None
"""
return # pragma: no cover

View File

@ -1,6 +1,6 @@
from typing import Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import models
@ -50,7 +50,6 @@ def get_auth_router(
)
async def login(
request: Request,
response: Response,
credentials: OAuth2PasswordRequestForm = Depends(),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
@ -67,9 +66,9 @@ def get_auth_router(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_USER_NOT_VERIFIED,
)
login_return = await backend.login(strategy, user, response)
await user_manager.on_after_login(user, request, login_return)
return login_return
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
return response
logout_responses: OpenAPIResponseType = {
**{
@ -84,11 +83,10 @@ def get_auth_router(
"/logout", name=f"auth:{backend.name}.logout", responses=logout_responses
)
async def logout(
response: Response,
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
):
user, token = user_token
return await backend.logout(strategy, user, token, response)
return await backend.logout(strategy, user, token)
return router

View File

@ -1,7 +1,7 @@
from typing import Dict, List, Optional, Tuple, Type
import jwt
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token
from pydantic import BaseModel
@ -100,7 +100,6 @@ def get_oauth_router(
)
async def callback(
request: Request,
response: Response,
access_token_state: Tuple[OAuth2Token, str] = Depends(
oauth2_authorize_callback
),
@ -148,9 +147,9 @@ def get_oauth_router(
)
# Authenticate
login_return = await backend.login(strategy, user, response)
await user_manager.on_after_login(user, request, login_return)
return login_return
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
return response
return router

View File

@ -517,8 +517,8 @@ class MockTransport(BearerTransport):
def __init__(self, tokenUrl: str):
super().__init__(tokenUrl)
async def get_logout_response(self, response: Response) -> Any:
return None
async def get_logout_response(self) -> Any:
return Response()
@staticmethod
def get_openapi_logout_responses_success() -> OpenAPIResponseType:

View File

@ -57,5 +57,5 @@ def backend(
@pytest.mark.authentication
async def test_logout(backend: AuthenticationBackend, user: UserModel):
strategy = cast(Strategy, backend.get_strategy())
result = await backend.logout(strategy, user, "TOKEN", Response())
assert result is None
result = await backend.logout(strategy, user, "TOKEN")
assert isinstance(result, Response)

View File

@ -1,5 +1,6 @@
import pytest
from fastapi import Response, status
from fastapi import status
from fastapi.responses import JSONResponse
from fastapi_users.authentication.transport import (
BearerTransport,
@ -16,21 +17,17 @@ def bearer_transport() -> BearerTransport:
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_login_response(bearer_transport: BearerTransport):
response = Response()
login_response = await bearer_transport.get_login_response("TOKEN", response)
response = await bearer_transport.get_login_response("TOKEN")
assert isinstance(login_response, BearerResponse)
assert login_response.access_token == "TOKEN"
assert login_response.token_type == "bearer"
assert isinstance(response, JSONResponse)
assert response.body == b'{"access_token":"TOKEN","token_type":"bearer"}'
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_logout_response(bearer_transport: BearerTransport):
response = Response()
with pytest.raises(TransportLogoutNotSupportedError):
await bearer_transport.get_logout_response(response)
await bearer_transport.get_logout_response()
@pytest.mark.authentication

View File

@ -38,10 +38,10 @@ async def test_get_login_response(cookie_transport: CookieTransport):
secure = cookie_transport.cookie_secure
httponly = cookie_transport.cookie_httponly
response = Response()
login_response = await cookie_transport.get_login_response("TOKEN", response)
response = await cookie_transport.get_login_response("TOKEN")
assert login_response is None
assert isinstance(response, Response)
assert response.status_code == status.HTTP_204_NO_CONTENT
cookies = [header for header in response.raw_headers if header[0] == b"set-cookie"]
assert len(cookies) == 1
@ -79,10 +79,10 @@ async def test_get_login_response(cookie_transport: CookieTransport):
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_logout_response(cookie_transport: CookieTransport):
response = Response()
logout_response = await cookie_transport.get_logout_response(response)
response = await cookie_transport.get_logout_response()
assert logout_response is None
assert isinstance(response, Response)
assert response.status_code == status.HTTP_204_NO_CONTENT
cookies = [header for header in response.raw_headers if header[0] == b"set-cookie"]
assert len(cookies) == 1
@ -96,7 +96,7 @@ async def test_get_logout_response(cookie_transport: CookieTransport):
@pytest.mark.openapi
def test_get_openapi_login_responses_success(cookie_transport: CookieTransport):
assert cookie_transport.get_openapi_login_responses_success() == {
status.HTTP_200_OK: {"model": None}
status.HTTP_204_NO_CONTENT: {"model": None}
}
@ -104,5 +104,5 @@ def test_get_openapi_login_responses_success(cookie_transport: CookieTransport):
@pytest.mark.openapi
def test_get_openapi_logout_responses_success(cookie_transport: CookieTransport):
assert cookie_transport.get_openapi_logout_responses_success() == {
status.HTTP_200_OK: {"model": None}
status.HTTP_204_NO_CONTENT: {"model": None}
}