mirror of
				https://github.com/fastapi-users/fastapi-users.git
				synced 2025-11-04 14:45:50 +08:00 
			
		
		
		
	* Added OAuth2 authorize route response_model * Changed OAuth2 "authorize" route backend param type. On OAuth2 router creation, an enum is created for the `authentication_backend` parameter. This allows us to generate OpenAPI docs correctly. * Reformatted with "make format" * Updated OAuth2 authorize url docs - /authorize now returns 422 instead of 400 for an invalid authentication backend
		
			
				
	
	
		
			303 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			303 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import Any, AsyncGenerator, Dict, cast
 | 
						|
 | 
						|
import httpx
 | 
						|
import pytest
 | 
						|
from fastapi import FastAPI, status
 | 
						|
from httpx_oauth.oauth2 import BaseOAuth2
 | 
						|
 | 
						|
from fastapi_users.authentication import Authenticator
 | 
						|
from fastapi_users.router.oauth import generate_state_token, get_oauth_router
 | 
						|
from tests.conftest import (
 | 
						|
    AsyncMethodMocker,
 | 
						|
    MockAuthentication,
 | 
						|
    UserDB,
 | 
						|
    UserManagerMock,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def get_test_app_client(
 | 
						|
    secret,
 | 
						|
    get_user_manager_oauth,
 | 
						|
    mock_authentication,
 | 
						|
    oauth_client,
 | 
						|
    get_test_client,
 | 
						|
):
 | 
						|
    async def _get_test_app_client(
 | 
						|
        redirect_url: str = None,
 | 
						|
    ) -> AsyncGenerator[httpx.AsyncClient, None]:
 | 
						|
        mock_authentication_bis = MockAuthentication(name="mock-bis")
 | 
						|
        authenticator = Authenticator(
 | 
						|
            [mock_authentication, mock_authentication_bis], get_user_manager_oauth
 | 
						|
        )
 | 
						|
 | 
						|
        oauth_router = get_oauth_router(
 | 
						|
            oauth_client,
 | 
						|
            get_user_manager_oauth,
 | 
						|
            authenticator,
 | 
						|
            secret,
 | 
						|
            redirect_url,
 | 
						|
        )
 | 
						|
 | 
						|
        app = FastAPI()
 | 
						|
        app.include_router(oauth_router)
 | 
						|
 | 
						|
        async for client in get_test_client(app):
 | 
						|
            yield client
 | 
						|
 | 
						|
    return _get_test_app_client
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
@pytest.mark.asyncio
 | 
						|
async def test_app_client(get_test_app_client):
 | 
						|
    async for client in get_test_app_client():
 | 
						|
        yield client
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
@pytest.mark.asyncio
 | 
						|
async def test_app_client_redirect_url(get_test_app_client):
 | 
						|
    async for client in get_test_app_client("http://www.tintagel.bt/callback"):
 | 
						|
        yield client
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.router
 | 
						|
@pytest.mark.oauth
 | 
						|
@pytest.mark.asyncio
 | 
						|
class TestAuthorize:
 | 
						|
    async def test_missing_authentication_backend(
 | 
						|
        self,
 | 
						|
        async_method_mocker: AsyncMethodMocker,
 | 
						|
        test_app_client: httpx.AsyncClient,
 | 
						|
        oauth_client: BaseOAuth2,
 | 
						|
    ):
 | 
						|
        async_method_mocker(
 | 
						|
            oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
 | 
						|
        )
 | 
						|
 | 
						|
        response = await test_app_client.get(
 | 
						|
            "/authorize",
 | 
						|
            params={"scopes": ["scope1", "scope2"]},
 | 
						|
        )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
 | 
						|
 | 
						|
    async def test_wrong_authentication_backend(
 | 
						|
        self,
 | 
						|
        async_method_mocker: AsyncMethodMocker,
 | 
						|
        test_app_client: httpx.AsyncClient,
 | 
						|
        oauth_client: BaseOAuth2,
 | 
						|
    ):
 | 
						|
        async_method_mocker(
 | 
						|
            oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
 | 
						|
        )
 | 
						|
 | 
						|
        response = await test_app_client.get(
 | 
						|
            "/authorize",
 | 
						|
            params={
 | 
						|
                "authentication_backend": "foo",
 | 
						|
                "scopes": ["scope1", "scope2"],
 | 
						|
            },
 | 
						|
        )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
 | 
						|
 | 
						|
    async def test_success(
 | 
						|
        self,
 | 
						|
        async_method_mocker: AsyncMethodMocker,
 | 
						|
        test_app_client: httpx.AsyncClient,
 | 
						|
        oauth_client: BaseOAuth2,
 | 
						|
    ):
 | 
						|
        get_authorization_url_mock = async_method_mocker(
 | 
						|
            oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
 | 
						|
        )
 | 
						|
 | 
						|
        response = await test_app_client.get(
 | 
						|
            "/authorize",
 | 
						|
            params={
 | 
						|
                "authentication_backend": "mock",
 | 
						|
                "scopes": ["scope1", "scope2"],
 | 
						|
            },
 | 
						|
        )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_200_OK
 | 
						|
        get_authorization_url_mock.assert_called_once()
 | 
						|
 | 
						|
        data = response.json()
 | 
						|
        assert "authorization_url" in data
 | 
						|
 | 
						|
    async def test_with_redirect_url(
 | 
						|
        self,
 | 
						|
        async_method_mocker: AsyncMethodMocker,
 | 
						|
        test_app_client_redirect_url: httpx.AsyncClient,
 | 
						|
        oauth_client: BaseOAuth2,
 | 
						|
    ):
 | 
						|
        get_authorization_url_mock = async_method_mocker(
 | 
						|
            oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
 | 
						|
        )
 | 
						|
 | 
						|
        response = await test_app_client_redirect_url.get(
 | 
						|
            "/authorize",
 | 
						|
            params={
 | 
						|
                "authentication_backend": "mock",
 | 
						|
                "scopes": ["scope1", "scope2"],
 | 
						|
            },
 | 
						|
        )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_200_OK
 | 
						|
        get_authorization_url_mock.assert_called_once()
 | 
						|
 | 
						|
        data = response.json()
 | 
						|
        assert "authorization_url" in data
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.router
 | 
						|
@pytest.mark.oauth
 | 
						|
@pytest.mark.asyncio
 | 
						|
@pytest.mark.parametrize(
 | 
						|
    "access_token",
 | 
						|
    [
 | 
						|
        ({"access_token": "TOKEN", "expires_at": 1579179542}),
 | 
						|
        ({"access_token": "TOKEN"}),
 | 
						|
    ],
 | 
						|
)
 | 
						|
class TestCallback:
 | 
						|
    async def test_invalid_state(
 | 
						|
        self,
 | 
						|
        async_method_mocker: AsyncMethodMocker,
 | 
						|
        test_app_client: httpx.AsyncClient,
 | 
						|
        oauth_client: BaseOAuth2,
 | 
						|
        user_oauth: UserDB,
 | 
						|
        access_token: str,
 | 
						|
    ):
 | 
						|
        async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
 | 
						|
        get_id_email_mock = async_method_mocker(
 | 
						|
            oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
 | 
						|
        )
 | 
						|
 | 
						|
        response = await test_app_client.get(
 | 
						|
            "/callback",
 | 
						|
            params={"code": "CODE", "state": "STATE"},
 | 
						|
        )
 | 
						|
        assert response.status_code == status.HTTP_400_BAD_REQUEST
 | 
						|
 | 
						|
        get_id_email_mock.assert_called_once_with("TOKEN")
 | 
						|
 | 
						|
    async def test_active_user(
 | 
						|
        self,
 | 
						|
        async_method_mocker: AsyncMethodMocker,
 | 
						|
        test_app_client: httpx.AsyncClient,
 | 
						|
        oauth_client: BaseOAuth2,
 | 
						|
        user_oauth: UserDB,
 | 
						|
        user_manager_oauth: UserManagerMock,
 | 
						|
        access_token: str,
 | 
						|
    ):
 | 
						|
        state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
 | 
						|
        async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
 | 
						|
        async_method_mocker(
 | 
						|
            oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
 | 
						|
        )
 | 
						|
        async_method_mocker(
 | 
						|
            user_manager_oauth, "oauth_callback", return_value=user_oauth
 | 
						|
        )
 | 
						|
 | 
						|
        response = await test_app_client.get(
 | 
						|
            "/callback",
 | 
						|
            params={"code": "CODE", "state": state_jwt},
 | 
						|
        )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_200_OK
 | 
						|
 | 
						|
        data = cast(Dict[str, Any], response.json())
 | 
						|
        assert data["token"] == str(user_oauth.id)
 | 
						|
 | 
						|
    async def test_inactive_user(
 | 
						|
        self,
 | 
						|
        async_method_mocker: AsyncMethodMocker,
 | 
						|
        test_app_client: httpx.AsyncClient,
 | 
						|
        oauth_client: BaseOAuth2,
 | 
						|
        inactive_user_oauth: UserDB,
 | 
						|
        user_manager_oauth: UserManagerMock,
 | 
						|
        access_token: str,
 | 
						|
    ):
 | 
						|
        state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
 | 
						|
        async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
 | 
						|
        async_method_mocker(
 | 
						|
            oauth_client,
 | 
						|
            "get_id_email",
 | 
						|
            return_value=("user_oauth1", inactive_user_oauth.email),
 | 
						|
        )
 | 
						|
        async_method_mocker(
 | 
						|
            user_manager_oauth, "oauth_callback", return_value=inactive_user_oauth
 | 
						|
        )
 | 
						|
 | 
						|
        response = await test_app_client.get(
 | 
						|
            "/callback",
 | 
						|
            params={"code": "CODE", "state": state_jwt},
 | 
						|
        )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_400_BAD_REQUEST
 | 
						|
 | 
						|
    async def test_redirect_url_router(
 | 
						|
        self,
 | 
						|
        async_method_mocker: AsyncMethodMocker,
 | 
						|
        test_app_client_redirect_url: httpx.AsyncClient,
 | 
						|
        oauth_client: BaseOAuth2,
 | 
						|
        user_oauth: UserDB,
 | 
						|
        user_manager_oauth: UserManagerMock,
 | 
						|
        access_token: str,
 | 
						|
    ):
 | 
						|
        state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
 | 
						|
        get_access_token_mock = async_method_mocker(
 | 
						|
            oauth_client, "get_access_token", return_value=access_token
 | 
						|
        )
 | 
						|
        async_method_mocker(
 | 
						|
            oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
 | 
						|
        )
 | 
						|
        async_method_mocker(
 | 
						|
            user_manager_oauth, "oauth_callback", return_value=user_oauth
 | 
						|
        )
 | 
						|
 | 
						|
        response = await test_app_client_redirect_url.get(
 | 
						|
            "/callback",
 | 
						|
            params={"code": "CODE", "state": state_jwt},
 | 
						|
        )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_200_OK
 | 
						|
 | 
						|
        get_access_token_mock.assert_called_once_with(
 | 
						|
            "CODE", "http://www.tintagel.bt/callback"
 | 
						|
        )
 | 
						|
 | 
						|
        data = cast(Dict[str, Any], response.json())
 | 
						|
        assert data["token"] == str(user_oauth.id)
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.asyncio
 | 
						|
async def test_oauth_authorize_namespace(
 | 
						|
    secret,
 | 
						|
    get_user_manager_oauth,
 | 
						|
    mock_authentication,
 | 
						|
    oauth_client,
 | 
						|
    get_test_client,
 | 
						|
    redirect_url: str = None,
 | 
						|
):
 | 
						|
 | 
						|
    mock_authentication_bis = MockAuthentication(name="mock-bis")
 | 
						|
    authenticator = Authenticator(
 | 
						|
        [mock_authentication, mock_authentication_bis], get_user_manager_oauth
 | 
						|
    )
 | 
						|
 | 
						|
    app = FastAPI()
 | 
						|
    app.include_router(
 | 
						|
        get_oauth_router(
 | 
						|
            oauth_client,
 | 
						|
            get_user_manager_oauth,
 | 
						|
            authenticator,
 | 
						|
            secret,
 | 
						|
            redirect_url,
 | 
						|
        )
 | 
						|
    )
 | 
						|
    assert app.url_path_for("oauth:authorize") == "/authorize"
 |