mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-14 18:58:10 +08:00
Revamp Transport so they always build a full Response object (#1049)
* Revamp Transport so they always build a full Response object * Fix linting * Add private methods to set cookies on CookieTransport * Change on_after_login login_return parameter to response
This commit is contained in:
@ -187,7 +187,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self,
|
||||
user: User,
|
||||
request: Optional[Request] = None,
|
||||
login_return: Optional[Any] = None,
|
||||
response: Optional[Response] = None,
|
||||
):
|
||||
print(f"User {user.id} logged in.")
|
||||
```
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Any, Generic
|
||||
from typing import Generic
|
||||
|
||||
from fastapi import Response
|
||||
from fastapi import Response, status
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication.strategy import (
|
||||
@ -40,27 +40,22 @@ class AuthenticationBackend(Generic[models.UP, models.ID]):
|
||||
self.get_strategy = get_strategy
|
||||
|
||||
async def login(
|
||||
self,
|
||||
strategy: Strategy[models.UP, models.ID],
|
||||
user: models.UP,
|
||||
response: Response,
|
||||
) -> Any:
|
||||
self, strategy: Strategy[models.UP, models.ID], user: models.UP
|
||||
) -> Response:
|
||||
token = await strategy.write_token(user)
|
||||
return await self.transport.get_login_response(token, response)
|
||||
return await self.transport.get_login_response(token)
|
||||
|
||||
async def logout(
|
||||
self,
|
||||
strategy: Strategy[models.UP, models.ID],
|
||||
user: models.UP,
|
||||
token: str,
|
||||
response: Response,
|
||||
) -> Any:
|
||||
self, strategy: Strategy[models.UP, models.ID], user: models.UP, token: str
|
||||
) -> Response:
|
||||
try:
|
||||
await strategy.destroy_token(token, user)
|
||||
except StrategyDestroyNotSupportedError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await self.transport.get_logout_response(response)
|
||||
response = await self.transport.get_logout_response()
|
||||
except TransportLogoutNotSupportedError:
|
||||
return None
|
||||
response = Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
return response
|
||||
|
@ -1,5 +1,4 @@
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Protocol # pragma: no cover
|
||||
@ -19,10 +18,10 @@ class TransportLogoutNotSupportedError(Exception):
|
||||
class Transport(Protocol):
|
||||
scheme: SecurityBase
|
||||
|
||||
async def get_login_response(self, token: str, response: Response) -> Any:
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
... # pragma: no cover
|
||||
|
||||
async def get_logout_response(self, response: Response) -> Any:
|
||||
async def get_logout_response(self) -> Response:
|
||||
... # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
|
@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -22,10 +21,11 @@ class BearerTransport(Transport):
|
||||
def __init__(self, tokenUrl: str):
|
||||
self.scheme = OAuth2PasswordBearer(tokenUrl, auto_error=False)
|
||||
|
||||
async def get_login_response(self, token: str, response: Response) -> Any:
|
||||
return BearerResponse(access_token=token, token_type="bearer")
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
bearer_response = BearerResponse(access_token=token, token_type="bearer")
|
||||
return JSONResponse(bearer_response.dict())
|
||||
|
||||
async def get_logout_response(self, response: Response) -> Any:
|
||||
async def get_logout_response(self) -> Response:
|
||||
raise TransportLogoutNotSupportedError()
|
||||
|
||||
@staticmethod
|
||||
|
@ -1,5 +1,5 @@
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal # pragma: no cover
|
||||
@ -35,7 +35,15 @@ class CookieTransport(Transport):
|
||||
self.cookie_samesite = cookie_samesite
|
||||
self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False)
|
||||
|
||||
async def get_login_response(self, token: str, response: Response) -> Any:
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
response = Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
return self._set_login_cookie(response, token)
|
||||
|
||||
async def get_logout_response(self) -> Response:
|
||||
response = Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
return self._set_logout_cookie(response)
|
||||
|
||||
def _set_login_cookie(self, response: Response, token: str) -> Response:
|
||||
response.set_cookie(
|
||||
self.cookie_name,
|
||||
token,
|
||||
@ -46,12 +54,9 @@ class CookieTransport(Transport):
|
||||
httponly=self.cookie_httponly,
|
||||
samesite=self.cookie_samesite,
|
||||
)
|
||||
return response
|
||||
|
||||
# We shouldn't return directly the response
|
||||
# so that FastAPI can terminate it properly
|
||||
return None
|
||||
|
||||
async def get_logout_response(self, response: Response) -> Any:
|
||||
def _set_logout_cookie(self, response: Response) -> Response:
|
||||
response.set_cookie(
|
||||
self.cookie_name,
|
||||
"",
|
||||
@ -62,11 +67,12 @@ class CookieTransport(Transport):
|
||||
httponly=self.cookie_httponly,
|
||||
samesite=self.cookie_samesite,
|
||||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def get_openapi_login_responses_success() -> OpenAPIResponseType:
|
||||
return {status.HTTP_200_OK: {"model": None}}
|
||||
return {status.HTTP_204_NO_CONTENT: {"model": None}}
|
||||
|
||||
@staticmethod
|
||||
def get_openapi_logout_responses_success() -> OpenAPIResponseType:
|
||||
return {status.HTTP_200_OK: {"model": None}}
|
||||
return {status.HTTP_204_NO_CONTENT: {"model": None}}
|
||||
|
@ -2,7 +2,7 @@ import uuid
|
||||
from typing import Any, Dict, Generic, Optional, Union
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
from fastapi import Request, Response
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from fastapi_users import exceptions, models, schemas
|
||||
@ -589,7 +589,7 @@ class BaseUserManager(Generic[models.UP, models.ID]):
|
||||
self,
|
||||
user: models.UP,
|
||||
request: Optional[Request] = None,
|
||||
login_return: Optional[Any] = None,
|
||||
response: Optional[Response] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Perform logic after user login.
|
||||
@ -598,8 +598,8 @@ class BaseUserManager(Generic[models.UP, models.ID]):
|
||||
|
||||
:param user: The user that is logging in
|
||||
:param request: Optional FastAPI request
|
||||
:param login_return: Optional return of the login
|
||||
triggered the operation, defaults to None.
|
||||
:param response: Optional response built by the transport.
|
||||
Defaults to None
|
||||
"""
|
||||
return # pragma: no cover
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from fastapi_users import models
|
||||
@ -50,7 +50,6 @@ def get_auth_router(
|
||||
)
|
||||
async def login(
|
||||
request: Request,
|
||||
response: Response,
|
||||
credentials: OAuth2PasswordRequestForm = Depends(),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
@ -67,9 +66,9 @@ def get_auth_router(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.LOGIN_USER_NOT_VERIFIED,
|
||||
)
|
||||
login_return = await backend.login(strategy, user, response)
|
||||
await user_manager.on_after_login(user, request, login_return)
|
||||
return login_return
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
return response
|
||||
|
||||
logout_responses: OpenAPIResponseType = {
|
||||
**{
|
||||
@ -84,11 +83,10 @@ def get_auth_router(
|
||||
"/logout", name=f"auth:{backend.name}.logout", responses=logout_responses
|
||||
)
|
||||
async def logout(
|
||||
response: Response,
|
||||
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
):
|
||||
user, token = user_token
|
||||
return await backend.logout(strategy, user, token, response)
|
||||
return await backend.logout(strategy, user, token)
|
||||
|
||||
return router
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
@ -100,7 +100,6 @@ def get_oauth_router(
|
||||
)
|
||||
async def callback(
|
||||
request: Request,
|
||||
response: Response,
|
||||
access_token_state: Tuple[OAuth2Token, str] = Depends(
|
||||
oauth2_authorize_callback
|
||||
),
|
||||
@ -148,9 +147,9 @@ def get_oauth_router(
|
||||
)
|
||||
|
||||
# Authenticate
|
||||
login_return = await backend.login(strategy, user, response)
|
||||
await user_manager.on_after_login(user, request, login_return)
|
||||
return login_return
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
return response
|
||||
|
||||
return router
|
||||
|
||||
|
@ -517,8 +517,8 @@ class MockTransport(BearerTransport):
|
||||
def __init__(self, tokenUrl: str):
|
||||
super().__init__(tokenUrl)
|
||||
|
||||
async def get_logout_response(self, response: Response) -> Any:
|
||||
return None
|
||||
async def get_logout_response(self) -> Any:
|
||||
return Response()
|
||||
|
||||
@staticmethod
|
||||
def get_openapi_logout_responses_success() -> OpenAPIResponseType:
|
||||
|
@ -57,5 +57,5 @@ def backend(
|
||||
@pytest.mark.authentication
|
||||
async def test_logout(backend: AuthenticationBackend, user: UserModel):
|
||||
strategy = cast(Strategy, backend.get_strategy())
|
||||
result = await backend.logout(strategy, user, "TOKEN", Response())
|
||||
assert result is None
|
||||
result = await backend.logout(strategy, user, "TOKEN")
|
||||
assert isinstance(result, Response)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
from fastapi import Response, status
|
||||
from fastapi import status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_users.authentication.transport import (
|
||||
BearerTransport,
|
||||
@ -16,21 +17,17 @@ def bearer_transport() -> BearerTransport:
|
||||
@pytest.mark.authentication
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_login_response(bearer_transport: BearerTransport):
|
||||
response = Response()
|
||||
login_response = await bearer_transport.get_login_response("TOKEN", response)
|
||||
response = await bearer_transport.get_login_response("TOKEN")
|
||||
|
||||
assert isinstance(login_response, BearerResponse)
|
||||
|
||||
assert login_response.access_token == "TOKEN"
|
||||
assert login_response.token_type == "bearer"
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.body == b'{"access_token":"TOKEN","token_type":"bearer"}'
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logout_response(bearer_transport: BearerTransport):
|
||||
response = Response()
|
||||
with pytest.raises(TransportLogoutNotSupportedError):
|
||||
await bearer_transport.get_logout_response(response)
|
||||
await bearer_transport.get_logout_response()
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
|
@ -38,10 +38,10 @@ async def test_get_login_response(cookie_transport: CookieTransport):
|
||||
secure = cookie_transport.cookie_secure
|
||||
httponly = cookie_transport.cookie_httponly
|
||||
|
||||
response = Response()
|
||||
login_response = await cookie_transport.get_login_response("TOKEN", response)
|
||||
response = await cookie_transport.get_login_response("TOKEN")
|
||||
|
||||
assert login_response is None
|
||||
assert isinstance(response, Response)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
cookies = [header for header in response.raw_headers if header[0] == b"set-cookie"]
|
||||
assert len(cookies) == 1
|
||||
@ -79,10 +79,10 @@ async def test_get_login_response(cookie_transport: CookieTransport):
|
||||
@pytest.mark.authentication
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logout_response(cookie_transport: CookieTransport):
|
||||
response = Response()
|
||||
logout_response = await cookie_transport.get_logout_response(response)
|
||||
response = await cookie_transport.get_logout_response()
|
||||
|
||||
assert logout_response is None
|
||||
assert isinstance(response, Response)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
cookies = [header for header in response.raw_headers if header[0] == b"set-cookie"]
|
||||
assert len(cookies) == 1
|
||||
@ -96,7 +96,7 @@ async def test_get_logout_response(cookie_transport: CookieTransport):
|
||||
@pytest.mark.openapi
|
||||
def test_get_openapi_login_responses_success(cookie_transport: CookieTransport):
|
||||
assert cookie_transport.get_openapi_login_responses_success() == {
|
||||
status.HTTP_200_OK: {"model": None}
|
||||
status.HTTP_204_NO_CONTENT: {"model": None}
|
||||
}
|
||||
|
||||
|
||||
@ -104,5 +104,5 @@ def test_get_openapi_login_responses_success(cookie_transport: CookieTransport):
|
||||
@pytest.mark.openapi
|
||||
def test_get_openapi_logout_responses_success(cookie_transport: CookieTransport):
|
||||
assert cookie_transport.get_openapi_logout_responses_success() == {
|
||||
status.HTTP_200_OK: {"model": None}
|
||||
status.HTTP_204_NO_CONTENT: {"model": None}
|
||||
}
|
||||
|
Reference in New Issue
Block a user