From 2e8f1f2eb2d8bdaf86ab82aab78ba01eeff76b31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maty=C3=A1=C5=A1=20Richter?= Date: Wed, 29 Dec 2021 13:25:15 +0100 Subject: [PATCH] Fixed #823 (#824) * Added a failing test for the multi-oauth-router issue * Fixed the #823 regression. Using a regex for the backend name validation instead of an enum. * Fixed formatting errors * Moved the `AuthenticationBackendName` enum to `Authenticator` This prevents an issue with OpenAPI schema generation caused by two endpoints accepting a parameter with a duplicate name. --- fastapi_users/authentication/__init__.py | 6 ++++++ fastapi_users/router/oauth.py | 8 +------- tests/test_openapi.py | 19 +++++++++++++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/fastapi_users/authentication/__init__.py b/fastapi_users/authentication/__init__.py index 51b938e4..aaa948ff 100644 --- a/fastapi_users/authentication/__init__.py +++ b/fastapi_users/authentication/__init__.py @@ -1,3 +1,4 @@ +import enum import re from inspect import Parameter, Signature from typing import Callable, Optional, Sequence @@ -42,6 +43,7 @@ class Authenticator: """ backends: Sequence[BaseAuthentication] + backends_enum: enum.Enum def __init__( self, @@ -49,6 +51,10 @@ class Authenticator: get_user_manager: UserManagerDependency[models.UC, models.UD], ): self.backends = backends + self.backends_enum = enum.Enum( # type: ignore + "AuthenticationBackendName", + {backend.name: backend.name for backend in backends}, # type: ignore + ) self.get_user_manager = get_user_manager def current_user( diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index e15d297c..ccd3f5d9 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -1,4 +1,3 @@ -import enum from typing import Dict, List import jwt @@ -44,11 +43,6 @@ def get_oauth_router( route_name=callback_route_name, ) - AuthenticationBackendName: enum.EnumMeta = enum.Enum( - "AuthenticationBackendName", - {backend.name: backend.name for backend in authenticator.backends}, - ) - @router.get( "/authorize", name="oauth:authorize", @@ -56,7 +50,7 @@ def get_oauth_router( ) async def authorize( request: Request, - authentication_backend: AuthenticationBackendName, + authentication_backend: authenticator.backends_enum, # type: ignore scopes: List[str] = Query(None), ): if redirect_url is not None: diff --git a/tests/test_openapi.py b/tests/test_openapi.py index 921d4627..a7649903 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -1,6 +1,7 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient +from httpx_oauth.clients.facebook import FacebookOAuth2 from httpx_oauth.clients.google import GoogleOAuth2 import fastapi_users.authentication @@ -134,3 +135,21 @@ class TestOAuth2: def test_google_callback_status_codes(self, get_openapi_dict): route = get_openapi_dict["paths"]["/callback"]["get"] assert list(route["responses"].keys()) == ["200", "400", "422"] + + def test_two_oauth_routers(self): + a = FastAPI() + a.include_router( + users.get_oauth_router( + GoogleOAuth2(client_id="1234", client_secret="4321"), + state_secret="secret", + ), + prefix="/google", + ) + a.include_router( + users.get_oauth_router( + FacebookOAuth2(client_id="1234", client_secret="4321"), + state_secret="secret", + ), + prefix="/facebook", + ) + assert TestClient(a).get("/openapi.json").status_code == 200