mirror of
				https://github.com/fastapi-users/fastapi-users.git
				synced 2025-11-04 14:45:50 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			338 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			338 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from unittest.mock import MagicMock
 | 
						|
 | 
						|
import asynctest
 | 
						|
import pytest
 | 
						|
from fastapi import FastAPI
 | 
						|
from starlette import status
 | 
						|
from starlette.requests import Request
 | 
						|
from starlette.testclient import TestClient
 | 
						|
 | 
						|
from fastapi_users.authentication import Authenticator
 | 
						|
from fastapi_users.router.common import ErrorCode, Event
 | 
						|
from fastapi_users.router.oauth import generate_state_token, get_oauth_router
 | 
						|
from tests.conftest import MockAuthentication, UserDB
 | 
						|
 | 
						|
 | 
						|
SECRET = "SECRET"
 | 
						|
 | 
						|
 | 
						|
def event_handler_sync():
 | 
						|
    return MagicMock(return_value=None)
 | 
						|
 | 
						|
 | 
						|
def event_handler_async():
 | 
						|
    return asynctest.CoroutineMock(return_value=None)
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture(params=[event_handler_sync, event_handler_async])
 | 
						|
def event_handler(request):
 | 
						|
    return request.param()
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture()
 | 
						|
def get_test_app_client(
 | 
						|
    mock_user_db_oauth, mock_authentication, oauth_client, event_handler
 | 
						|
):
 | 
						|
    def _get_test_app_client(redirect_url: str = None) -> TestClient:
 | 
						|
        mock_authentication_bis = MockAuthentication(name="mock-bis")
 | 
						|
        authenticator = Authenticator(
 | 
						|
            [mock_authentication, mock_authentication_bis], mock_user_db_oauth
 | 
						|
        )
 | 
						|
 | 
						|
        oauth_router = get_oauth_router(
 | 
						|
            oauth_client,
 | 
						|
            mock_user_db_oauth,
 | 
						|
            UserDB,
 | 
						|
            authenticator,
 | 
						|
            SECRET,
 | 
						|
            redirect_url,
 | 
						|
        )
 | 
						|
 | 
						|
        oauth_router.add_event_handler(Event.ON_AFTER_REGISTER, event_handler)
 | 
						|
 | 
						|
        app = FastAPI()
 | 
						|
        app.include_router(oauth_router)
 | 
						|
 | 
						|
        return TestClient(app)
 | 
						|
 | 
						|
    return _get_test_app_client
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture()
 | 
						|
def test_app_client(get_test_app_client):
 | 
						|
    return get_test_app_client()
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture()
 | 
						|
def test_app_client_redirect_url(get_test_app_client):
 | 
						|
    return get_test_app_client("http://www.tintagel.bt/callback")
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.router
 | 
						|
@pytest.mark.oauth
 | 
						|
class TestAuthorize:
 | 
						|
    def test_missing_authentication_backend(
 | 
						|
        self, test_app_client: TestClient, oauth_client
 | 
						|
    ):
 | 
						|
        with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
 | 
						|
            mock.return_value = "AUTHORIZATION_URL"
 | 
						|
            response = test_app_client.get(
 | 
						|
                "/authorize", params={"scopes": ["scope1", "scope2"]},
 | 
						|
            )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
 | 
						|
 | 
						|
    def test_wrong_authentication_backend(
 | 
						|
        self, test_app_client: TestClient, oauth_client
 | 
						|
    ):
 | 
						|
        with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
 | 
						|
            mock.return_value = "AUTHORIZATION_URL"
 | 
						|
            response = test_app_client.get(
 | 
						|
                "/authorize",
 | 
						|
                params={
 | 
						|
                    "authentication_backend": "foo",
 | 
						|
                    "scopes": ["scope1", "scope2"],
 | 
						|
                },
 | 
						|
            )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_400_BAD_REQUEST
 | 
						|
 | 
						|
    def test_success(self, test_app_client: TestClient, oauth_client):
 | 
						|
        with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
 | 
						|
            mock.return_value = "AUTHORIZATION_URL"
 | 
						|
            response = test_app_client.get(
 | 
						|
                "/authorize",
 | 
						|
                params={
 | 
						|
                    "authentication_backend": "mock",
 | 
						|
                    "scopes": ["scope1", "scope2"],
 | 
						|
                },
 | 
						|
            )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_200_OK
 | 
						|
        mock.assert_awaited_once()
 | 
						|
 | 
						|
        data = response.json()
 | 
						|
        assert "authorization_url" in data
 | 
						|
 | 
						|
    def test_with_redirect_url(
 | 
						|
        self, test_app_client_redirect_url: TestClient, oauth_client
 | 
						|
    ):
 | 
						|
        with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
 | 
						|
            mock.return_value = "AUTHORIZATION_URL"
 | 
						|
            response = test_app_client_redirect_url.get(
 | 
						|
                "/authorize",
 | 
						|
                params={
 | 
						|
                    "authentication_backend": "mock",
 | 
						|
                    "scopes": ["scope1", "scope2"],
 | 
						|
                },
 | 
						|
            )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_200_OK
 | 
						|
        mock.assert_awaited_once()
 | 
						|
 | 
						|
        data = response.json()
 | 
						|
        assert "authorization_url" in data
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.router
 | 
						|
@pytest.mark.oauth
 | 
						|
class TestCallback:
 | 
						|
    def test_invalid_state(
 | 
						|
        self, test_app_client: TestClient, oauth_client, user_oauth, event_handler
 | 
						|
    ):
 | 
						|
        with asynctest.patch.object(
 | 
						|
            oauth_client, "get_access_token"
 | 
						|
        ) as get_access_token_mock:
 | 
						|
            get_access_token_mock.return_value = {
 | 
						|
                "access_token": "TOKEN",
 | 
						|
                "expires_at": 1579179542,
 | 
						|
            }
 | 
						|
            with asynctest.patch.object(
 | 
						|
                oauth_client, "get_id_email"
 | 
						|
            ) as get_id_email_mock:
 | 
						|
                get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
 | 
						|
                response = test_app_client.get(
 | 
						|
                    "/callback", params={"code": "CODE", "state": "STATE"},
 | 
						|
                )
 | 
						|
 | 
						|
        get_id_email_mock.assert_awaited_once_with("TOKEN")
 | 
						|
        assert response.status_code == status.HTTP_400_BAD_REQUEST
 | 
						|
 | 
						|
        assert event_handler.called is False
 | 
						|
 | 
						|
    def test_existing_user_with_oauth(
 | 
						|
        self,
 | 
						|
        mock_user_db_oauth,
 | 
						|
        test_app_client: TestClient,
 | 
						|
        oauth_client,
 | 
						|
        user_oauth,
 | 
						|
        event_handler,
 | 
						|
    ):
 | 
						|
        state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
 | 
						|
        with asynctest.patch.object(
 | 
						|
            oauth_client, "get_access_token"
 | 
						|
        ) as get_access_token_mock:
 | 
						|
            get_access_token_mock.return_value = {
 | 
						|
                "access_token": "TOKEN",
 | 
						|
                "expires_at": 1579179542,
 | 
						|
            }
 | 
						|
            with asynctest.patch.object(
 | 
						|
                oauth_client, "get_id_email"
 | 
						|
            ) as get_id_email_mock:
 | 
						|
                with asynctest.patch.object(
 | 
						|
                    mock_user_db_oauth, "update"
 | 
						|
                ) as user_update_mock:
 | 
						|
                    get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
 | 
						|
                    response = test_app_client.get(
 | 
						|
                        "/callback", params={"code": "CODE", "state": state_jwt},
 | 
						|
                    )
 | 
						|
 | 
						|
        get_id_email_mock.assert_awaited_once_with("TOKEN")
 | 
						|
        user_update_mock.assert_awaited_once()
 | 
						|
        data = response.json()
 | 
						|
 | 
						|
        assert data["token"] == user_oauth.id
 | 
						|
 | 
						|
        assert event_handler.called is False
 | 
						|
 | 
						|
    def test_existing_user_without_oauth(
 | 
						|
        self,
 | 
						|
        mock_user_db_oauth,
 | 
						|
        test_app_client: TestClient,
 | 
						|
        oauth_client,
 | 
						|
        superuser_oauth,
 | 
						|
        event_handler,
 | 
						|
    ):
 | 
						|
        state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
 | 
						|
        with asynctest.patch.object(
 | 
						|
            oauth_client, "get_access_token"
 | 
						|
        ) as get_access_token_mock:
 | 
						|
            get_access_token_mock.return_value = {
 | 
						|
                "access_token": "TOKEN",
 | 
						|
                "expires_at": 1579179542,
 | 
						|
            }
 | 
						|
            with asynctest.patch.object(
 | 
						|
                oauth_client, "get_id_email"
 | 
						|
            ) as get_id_email_mock:
 | 
						|
                with asynctest.patch.object(
 | 
						|
                    mock_user_db_oauth, "update"
 | 
						|
                ) as user_update_mock:
 | 
						|
                    get_id_email_mock.return_value = (
 | 
						|
                        "superuser_oauth1",
 | 
						|
                        superuser_oauth.email,
 | 
						|
                    )
 | 
						|
                    response = test_app_client.get(
 | 
						|
                        "/callback", params={"code": "CODE", "state": state_jwt},
 | 
						|
                    )
 | 
						|
 | 
						|
        get_id_email_mock.assert_awaited_once_with("TOKEN")
 | 
						|
        user_update_mock.assert_awaited_once()
 | 
						|
        data = response.json()
 | 
						|
 | 
						|
        assert data["token"] == superuser_oauth.id
 | 
						|
 | 
						|
        assert event_handler.called is False
 | 
						|
 | 
						|
    def test_unknown_user(
 | 
						|
        self,
 | 
						|
        mock_user_db_oauth,
 | 
						|
        test_app_client: TestClient,
 | 
						|
        oauth_client,
 | 
						|
        event_handler,
 | 
						|
    ):
 | 
						|
        state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
 | 
						|
        with asynctest.patch.object(
 | 
						|
            oauth_client, "get_access_token"
 | 
						|
        ) as get_access_token_mock:
 | 
						|
            get_access_token_mock.return_value = {
 | 
						|
                "access_token": "TOKEN",
 | 
						|
                "expires_at": 1579179542,
 | 
						|
            }
 | 
						|
            with asynctest.patch.object(
 | 
						|
                oauth_client, "get_id_email"
 | 
						|
            ) as get_id_email_mock:
 | 
						|
                with asynctest.patch.object(
 | 
						|
                    mock_user_db_oauth, "create"
 | 
						|
                ) as user_create_mock:
 | 
						|
                    get_id_email_mock.return_value = (
 | 
						|
                        "unknown_user_oauth1",
 | 
						|
                        "galahad@camelot.bt",
 | 
						|
                    )
 | 
						|
                    response = test_app_client.get(
 | 
						|
                        "/callback", params={"code": "CODE", "state": state_jwt},
 | 
						|
                    )
 | 
						|
 | 
						|
        get_id_email_mock.assert_awaited_once_with("TOKEN")
 | 
						|
        user_create_mock.assert_awaited_once()
 | 
						|
        data = response.json()
 | 
						|
 | 
						|
        assert "token" in data
 | 
						|
 | 
						|
        assert event_handler.called is True
 | 
						|
        actual_user = event_handler.call_args[0][0]
 | 
						|
        assert actual_user.id == data["token"]
 | 
						|
        request = event_handler.call_args[0][1]
 | 
						|
        assert isinstance(request, Request)
 | 
						|
 | 
						|
    def test_inactive_user(
 | 
						|
        self,
 | 
						|
        mock_user_db_oauth,
 | 
						|
        test_app_client: TestClient,
 | 
						|
        oauth_client,
 | 
						|
        inactive_user_oauth,
 | 
						|
        event_handler,
 | 
						|
    ):
 | 
						|
        state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
 | 
						|
        with asynctest.patch.object(
 | 
						|
            oauth_client, "get_access_token"
 | 
						|
        ) as get_access_token_mock:
 | 
						|
            get_access_token_mock.return_value = {
 | 
						|
                "access_token": "TOKEN",
 | 
						|
                "expires_at": 1579179542,
 | 
						|
            }
 | 
						|
            with asynctest.patch.object(
 | 
						|
                oauth_client, "get_id_email"
 | 
						|
            ) as get_id_email_mock:
 | 
						|
                get_id_email_mock.return_value = (
 | 
						|
                    "inactive_user_oauth1",
 | 
						|
                    inactive_user_oauth.email,
 | 
						|
                )
 | 
						|
                response = test_app_client.get(
 | 
						|
                    "/callback", params={"code": "CODE", "state": state_jwt},
 | 
						|
                )
 | 
						|
 | 
						|
        assert response.status_code == status.HTTP_400_BAD_REQUEST
 | 
						|
        assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS
 | 
						|
 | 
						|
        assert event_handler.called is False
 | 
						|
 | 
						|
    def test_redirect_url_router(
 | 
						|
        self,
 | 
						|
        mock_user_db_oauth,
 | 
						|
        test_app_client_redirect_url: TestClient,
 | 
						|
        oauth_client,
 | 
						|
        user_oauth,
 | 
						|
    ):
 | 
						|
        state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
 | 
						|
        with asynctest.patch.object(
 | 
						|
            oauth_client, "get_access_token"
 | 
						|
        ) as get_access_token_mock:
 | 
						|
            get_access_token_mock.return_value = {
 | 
						|
                "access_token": "TOKEN",
 | 
						|
                "expires_at": 1579179542,
 | 
						|
            }
 | 
						|
            with asynctest.patch.object(
 | 
						|
                oauth_client, "get_id_email"
 | 
						|
            ) as get_id_email_mock:
 | 
						|
                get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
 | 
						|
                response = test_app_client_redirect_url.get(
 | 
						|
                    "/callback", params={"code": "CODE", "state": state_jwt},
 | 
						|
                )
 | 
						|
 | 
						|
        get_access_token_mock.assert_awaited_once_with(
 | 
						|
            "CODE", "http://www.tintagel.bt/callback"
 | 
						|
        )
 | 
						|
        data = response.json()
 | 
						|
 | 
						|
        assert data["token"] == user_oauth.id
 |