diff --git a/docs/usage/routes.md b/docs/usage/routes.md index f30a390a..6f4c8f51 100644 --- a/docs/usage/routes.md +++ b/docs/usage/routes.md @@ -229,9 +229,7 @@ Return the authorization URL for the OAuth service where you should redirect you ``` !!! fail "`422 Validation Error`" - -!!! fail "`400 Bad Request`" - Unknown authentication backend. + Invalid parameters - e.g. unknown authentication backend. ### `GET /callback` diff --git a/fastapi_users/models.py b/fastapi_users/models.py index 6529a3f5..b9c8330d 100644 --- a/fastapi_users/models.py +++ b/fastapi_users/models.py @@ -79,3 +79,7 @@ class BaseOAuthAccountMixin(BaseModel): """Adds OAuth accounts list to a User model.""" oauth_accounts: List[BaseOAuthAccount] = [] + + +class OAuth2AuthorizeResponse(BaseModel): + authorization_url: str diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index 1bbdbf05..1553635d 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -1,3 +1,4 @@ +import enum from typing import Dict, List import jwt @@ -43,27 +44,28 @@ def get_oauth_router( route_name=callback_route_name, ) - @router.get("/authorize", name="oauth:authorize") + AuthenticationBackendName: enum.EnumMeta = enum.Enum( + "AuthenticationBackendName", + {backend.name: backend.name for backend in authenticator.backends}, + ) + + @router.get( + "/authorize", + name="oauth:authorize", + response_model=models.OAuth2AuthorizeResponse, + ) async def authorize( request: Request, - authentication_backend: str, + authentication_backend: AuthenticationBackendName, scopes: List[str] = Query(None), ): - # Check that authentication_backend exists - backend_exists = any( - backend.name == authentication_backend for backend in authenticator.backends - ) - - if not backend_exists: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - if redirect_url is not None: authorize_redirect_url = redirect_url else: authorize_redirect_url = request.url_for(callback_route_name) state_data = { - "authentication_backend": authentication_backend, + "authentication_backend": str(authentication_backend), } state = generate_state_token(state_data, state_secret) authorization_url = await oauth_client.get_authorization_url( diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index 53b4e2b1..d4c87450 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -101,7 +101,7 @@ class TestAuthorize: }, ) - assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY async def test_success( self,