diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index a7a362cf..284277c1 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -126,7 +126,6 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]): oauth_client: BaseOAuth2, state_secret: SecretType, redirect_url: str = None, - after_register: Optional[Callable[[models.UD, Request], None]] = None, ) -> APIRouter: """ Return an OAuth router for a given OAuth client. @@ -135,17 +134,13 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]): :param state_secret: Secret used to encode the state JWT. :param redirect_url: Optional arbitrary redirect URL for the OAuth2 flow. If not given, the URL to the callback endpoint will be generated. - :param after_register: Optional function called - after a successful registration. """ return get_oauth_router( oauth_client, self.get_user_manager, - self._user_db_model, self.authenticator, state_secret, redirect_url, - after_register, ) def get_users_router( diff --git a/fastapi_users/manager.py b/fastapi_users/manager.py index 6c9dd26c..ca63adeb 100644 --- a/fastapi_users/manager.py +++ b/fastapi_users/manager.py @@ -8,7 +8,7 @@ from pydantic import UUID4 from fastapi_users import models, password from fastapi_users.db import BaseUserDatabase from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt -from fastapi_users.password import get_password_hash +from fastapi_users.password import generate_password, get_password_hash RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset" @@ -104,6 +104,42 @@ class BaseUserManager(Generic[models.UC, models.UD]): return created_user + async def oauth_callback( + self, oauth_account: models.BaseOAuthAccount, request: Optional[Request] = None + ) -> models.UD: + try: + user = await self.get_by_oauth_account( + oauth_account.oauth_name, oauth_account.account_id + ) + except UserNotExists: + try: + # Link account + user = await self.get_by_email(oauth_account.account_email) + user.oauth_accounts.append(oauth_account) # type: ignore + await self.user_db.update(user) + except UserNotExists: + # Create account + password = generate_password() + user = self.user_db_model( + email=oauth_account.account_email, + hashed_password=get_password_hash(password), + oauth_accounts=[oauth_account], + ) + await self.user_db.create(user) + await self.on_after_register(user, request) + else: + # Update oauth + updated_oauth_accounts = [] + for existing_oauth_account in user.oauth_accounts: # type: ignore + if existing_oauth_account.account_id == oauth_account.account_id: + updated_oauth_accounts.append(oauth_account) + else: + updated_oauth_accounts.append(existing_oauth_account) + user.oauth_accounts = updated_oauth_accounts # type: ignore + await self.user_db.update(user) + + return user + async def forgot_password( self, user: models.UD, request: Optional[Request] = None ) -> None: diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index ac60b460..2c29dcdf 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Optional, Type +from typing import Dict, List import jwt from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status @@ -8,9 +8,8 @@ from httpx_oauth.oauth2 import BaseOAuth2 from fastapi_users import models from fastapi_users.authentication import Authenticator from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt -from fastapi_users.manager import BaseUserManager, UserManagerDependency, UserNotExists -from fastapi_users.password import generate_password, get_password_hash -from fastapi_users.router.common import ErrorCode, run_handler +from fastapi_users.manager import BaseUserManager, UserManagerDependency +from fastapi_users.router.common import ErrorCode STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state" @@ -25,11 +24,9 @@ def generate_state_token( def get_oauth_router( oauth_client: BaseOAuth2, get_user_manager: UserManagerDependency[models.UC, models.UD], - user_db_model: Type[models.UD], authenticator: Authenticator, state_secret: SecretType, redirect_url: str = None, - after_register: Optional[Callable[[models.UD, Request], None]] = None, ) -> APIRouter: """Generate a router with the OAuth routes.""" router = APIRouter() @@ -104,37 +101,7 @@ def get_oauth_router( account_email=account_email, ) - try: - user = await user_manager.get_by_oauth_account( - oauth_client.name, account_id - ) - except UserNotExists: - try: - # Link account - user = await user_manager.get_by_email(account_email) - user.oauth_accounts.append(new_oauth_account) # type: ignore - await user_manager.user_db.update(user) - except UserNotExists: - # Create account - password = generate_password() - user = user_db_model( - email=account_email, - hashed_password=get_password_hash(password), - oauth_accounts=[new_oauth_account], - ) - await user_manager.user_db.create(user) - if after_register: - await run_handler(after_register, user, request) - else: - # Update oauth - updated_oauth_accounts = [] - for oauth_account in user.oauth_accounts: # type: ignore - if oauth_account.account_id == account_id: - updated_oauth_accounts.append(new_oauth_account) - else: - updated_oauth_accounts.append(oauth_account) - user.oauth_accounts = updated_oauth_accounts # type: ignore - await user_manager.user_db.update(user) + user = await user_manager.oauth_callback(new_oauth_account, request) if not user.is_active: raise HTTPException( diff --git a/tests/conftest.py b/tests/conftest.py index f945dac3..2b6be053 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import asyncio -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, Callable, List, Optional, Union from unittest.mock import MagicMock import httpx @@ -65,16 +65,6 @@ class UserManager(BaseUserManager[UserCreate, UserDB]): reason="Password should be at least 3 characters" ) - def mock_method(self, name: str): - mock = MagicMock() - - future: asyncio.Future = asyncio.Future() - future.set_result(None) - mock.return_value = future - mock.side_effect = None - - setattr(self, name, mock) - class UserManagerMock(UserManager): get_by_email: MagicMock @@ -94,6 +84,30 @@ def event_loop(): yield loop +AsyncMethodMocker = Callable[..., MagicMock] + + +@pytest.fixture +def async_method_mocker(mocker: MockerFixture) -> AsyncMethodMocker: + def _async_method_mocker( + object: Any, + method: str, + return_value: Any = None, + ) -> MagicMock: + mock: MagicMock = mocker.MagicMock() + + future: asyncio.Future = asyncio.Future() + future.set_result(return_value) + mock.return_value = future + mock.side_effect = None + + setattr(object, method, mock) + + return mock + + return _async_method_mocker + + @pytest.fixture(params=["SECRET", SecretStr("SECRET")]) def secret(request) -> SecretType: return request.param @@ -361,33 +375,30 @@ def mock_user_db_oauth( @pytest.fixture -def get_mock_user_db(mock_user_db): - def _get_mock_user_db(): - yield mock_user_db +def make_user_manager(mocker: MockerFixture): + def _make_user_manager(user_db_model, mock_user_db): + user_manager = UserManager(user_db_model, mock_user_db) + mocker.spy(user_manager, "get_by_email") + mocker.spy(user_manager, "forgot_password") + mocker.spy(user_manager, "reset_password") + mocker.spy(user_manager, "on_after_register") + mocker.spy(user_manager, "on_after_update") + mocker.spy(user_manager, "on_after_forgot_password") + mocker.spy(user_manager, "on_after_reset_password") + mocker.spy(user_manager, "_update") + return user_manager - return _get_mock_user_db + return _make_user_manager @pytest.fixture -def get_mock_user_db_oauth(mock_user_db_oauth): - def _get_mock_user_db_oauth(): - yield mock_user_db_oauth - - return _get_mock_user_db_oauth +def user_manager(make_user_manager, mock_user_db): + return make_user_manager(UserDB, mock_user_db) @pytest.fixture -def user_manager(mocker: MockerFixture, mock_user_db): - user_manager = UserManager(UserDB, mock_user_db) - mocker.spy(user_manager, "get_by_email") - mocker.spy(user_manager, "forgot_password") - mocker.spy(user_manager, "reset_password") - mocker.spy(user_manager, "on_after_register") - mocker.spy(user_manager, "on_after_update") - mocker.spy(user_manager, "on_after_forgot_password") - mocker.spy(user_manager, "on_after_reset_password") - mocker.spy(user_manager, "_update") - return user_manager +def user_manager_oauth(make_user_manager, mock_user_db_oauth): + return make_user_manager(UserDBOAuth, mock_user_db_oauth) @pytest.fixture @@ -399,9 +410,9 @@ def get_user_manager(user_manager): @pytest.fixture -def get_user_manager_oauth(get_mock_user_db_oauth): - def _get_user_manager_oauth(user_db=Depends(get_mock_user_db_oauth)): - return UserManager(UserDBOAuth, user_db) +def get_user_manager_oauth(user_manager_oauth): + def _get_user_manager_oauth(): + return user_manager_oauth return _get_user_manager_oauth diff --git a/tests/test_manager.py b/tests/test_manager.py index 80ad6d97..8de35fc7 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,9 +1,10 @@ -from typing import Callable +from typing import Callable, cast import pytest from fastapi.security import OAuth2PasswordRequestForm from pytest_mock import MockerFixture +from fastapi_users import models from fastapi_users.jwt import decode_jwt, generate_jwt from fastapi_users.manager import ( InvalidPasswordException, @@ -13,7 +14,7 @@ from fastapi_users.manager import ( UserInactive, UserNotExists, ) -from tests.conftest import UserCreate, UserDB, UserManagerMock, UserUpdate +from tests.conftest import UserCreate, UserDB, UserDBOAuth, UserManagerMock, UserUpdate @pytest.fixture @@ -85,6 +86,61 @@ class TestCreateUser: assert user_manager.on_after_register.called is True +@pytest.mark.asyncio +class TestOAuthCallback: + async def test_existing_user_with_oauth( + self, user_manager_oauth: UserManagerMock, user_oauth: UserDBOAuth + ): + oauth_account = models.BaseOAuthAccount( + **user_oauth.oauth_accounts[0].dict(exclude={"id", "access_token"}), + access_token="UPDATED_TOKEN" + ) + user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account)) + + assert user.id == user_oauth.id + assert len(user.oauth_accounts) == 2 + assert user.oauth_accounts[0].oauth_name == "service1" + assert user.oauth_accounts[0].access_token == "UPDATED_TOKEN" + assert user.oauth_accounts[1].access_token == "TOKEN" + assert user.oauth_accounts[1].oauth_name == "service2" + + assert user_manager_oauth.on_after_register.called is False + + async def test_existing_user_without_oauth( + self, user_manager_oauth: UserManagerMock, superuser_oauth: UserDBOAuth + ): + oauth_account = models.BaseOAuthAccount( + oauth_name="service1", + access_token="TOKEN", + expires_at=1579000751, + account_id="superuser_oauth1", + account_email=superuser_oauth.email, + ) + user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account)) + + assert user.id == superuser_oauth.id + assert len(user.oauth_accounts) == 1 + assert user.oauth_accounts[0].id == oauth_account.id + + assert user_manager_oauth.on_after_register.called is False + + async def test_new_user(self, user_manager_oauth: UserManagerMock): + oauth_account = models.BaseOAuthAccount( + oauth_name="service1", + access_token="TOKEN", + expires_at=1579000751, + account_id="new_user_oauth1", + account_email="galahad@camelot.bt", + ) + user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account)) + + assert user.email == "galahad@camelot.bt" + assert len(user.oauth_accounts) == 1 + assert user.oauth_accounts[0].id == oauth_account.id + + assert user_manager_oauth.on_after_register.called is True + + @pytest.mark.asyncio class TestUpdateUser: async def test_safe_update(self, user: UserDB, user_manager: UserManagerMock): diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index 047c62e2..6b811a50 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -1,28 +1,18 @@ from typing import Any, AsyncGenerator, Dict, cast -from unittest.mock import MagicMock -import asynctest import httpx import pytest -from fastapi import FastAPI, Request, status +from fastapi import FastAPI, status +from httpx_oauth.oauth2 import BaseOAuth2 from fastapi_users.authentication import Authenticator -from fastapi_users.router.common import ErrorCode from fastapi_users.router.oauth import generate_state_token, get_oauth_router -from tests.conftest import MockAuthentication, UserDB - - -def after_register_sync(): - return MagicMock(return_value=None) - - -def after_register_async(): - return asynctest.CoroutineMock(return_value=None) - - -@pytest.fixture(params=[after_register_sync, after_register_async]) -def after_register(request): - return request.param() +from tests.conftest import ( + AsyncMethodMocker, + MockAuthentication, + UserDB, + UserManagerMock, +) @pytest.fixture @@ -31,7 +21,6 @@ def get_test_app_client( get_user_manager_oauth, mock_authentication, oauth_client, - after_register, get_test_client, ): async def _get_test_app_client( @@ -45,11 +34,9 @@ def get_test_app_client( oauth_router = get_oauth_router( oauth_client, get_user_manager_oauth, - UserDB, authenticator, secret, redirect_url, - after_register, ) app = FastAPI() @@ -80,64 +67,86 @@ async def test_app_client_redirect_url(get_test_app_client): @pytest.mark.asyncio class TestAuthorize: async def test_missing_authentication_backend( - self, test_app_client: httpx.AsyncClient, oauth_client + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, ): - with asynctest.patch.object(oauth_client, "get_authorization_url") as mock: - mock.return_value = "AUTHORIZATION_URL" - response = await test_app_client.get( - "/authorize", - params={"scopes": ["scope1", "scope2"]}, - ) + async_method_mocker( + oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL" + ) + + response = await test_app_client.get( + "/authorize", + params={"scopes": ["scope1", "scope2"]}, + ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY async def test_wrong_authentication_backend( - self, test_app_client: httpx.AsyncClient, oauth_client + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, ): - with asynctest.patch.object(oauth_client, "get_authorization_url") as mock: - mock.return_value = "AUTHORIZATION_URL" - response = await test_app_client.get( - "/authorize", - params={ - "authentication_backend": "foo", - "scopes": ["scope1", "scope2"], - }, - ) + async_method_mocker( + oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL" + ) + + response = await test_app_client.get( + "/authorize", + params={ + "authentication_backend": "foo", + "scopes": ["scope1", "scope2"], + }, + ) assert response.status_code == status.HTTP_400_BAD_REQUEST - async def test_success(self, test_app_client: httpx.AsyncClient, oauth_client): - with asynctest.patch.object(oauth_client, "get_authorization_url") as mock: - mock.return_value = "AUTHORIZATION_URL" - response = await test_app_client.get( - "/authorize", - params={ - "authentication_backend": "mock", - "scopes": ["scope1", "scope2"], - }, - ) + async def test_success( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, + ): + get_authorization_url_mock = async_method_mocker( + oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL" + ) + + response = await test_app_client.get( + "/authorize", + params={ + "authentication_backend": "mock", + "scopes": ["scope1", "scope2"], + }, + ) assert response.status_code == status.HTTP_200_OK - mock.assert_awaited_once() + get_authorization_url_mock.assert_called_once() data = response.json() assert "authorization_url" in data async def test_with_redirect_url( - self, test_app_client_redirect_url: httpx.AsyncClient, oauth_client + self, + async_method_mocker: AsyncMethodMocker, + test_app_client_redirect_url: httpx.AsyncClient, + oauth_client: BaseOAuth2, ): - with asynctest.patch.object(oauth_client, "get_authorization_url") as mock: - mock.return_value = "AUTHORIZATION_URL" - response = await test_app_client_redirect_url.get( - "/authorize", - params={ - "authentication_backend": "mock", - "scopes": ["scope1", "scope2"], - }, - ) + get_authorization_url_mock = async_method_mocker( + oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL" + ) + + response = await test_app_client_redirect_url.get( + "/authorize", + params={ + "authentication_backend": "mock", + "scopes": ["scope1", "scope2"], + }, + ) assert response.status_code == status.HTTP_200_OK - mock.assert_awaited_once() + get_authorization_url_mock.assert_called_once() data = response.json() assert "authorization_url" in data @@ -156,196 +165,110 @@ class TestAuthorize: class TestCallback: async def test_invalid_state( self, + async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, - access_token, - oauth_client, - user_oauth, - after_register, + oauth_client: BaseOAuth2, + user_oauth: UserDB, + access_token: str, ): - with asynctest.patch.object( - oauth_client, "get_access_token" - ) as get_access_token_mock: - get_access_token_mock.return_value = access_token - with asynctest.patch.object( - oauth_client, "get_id_email" - ) as get_id_email_mock: - get_id_email_mock.return_value = ("user_oauth1", user_oauth.email) - response = await test_app_client.get( - "/callback", - params={"code": "CODE", "state": "STATE"}, - ) + async_method_mocker(oauth_client, "get_access_token", return_value=access_token) + get_id_email_mock = async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) - get_id_email_mock.assert_awaited_once_with("TOKEN") + response = await test_app_client.get( + "/callback", + params={"code": "CODE", "state": "STATE"}, + ) assert response.status_code == status.HTTP_400_BAD_REQUEST - assert after_register.called is False + get_id_email_mock.assert_called_once_with("TOKEN") - async def test_existing_user_with_oauth( + async def test_active_user( self, - mock_user_db_oauth, + async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, - access_token, - oauth_client, - user_oauth, - after_register, + oauth_client: BaseOAuth2, + user_oauth: UserDB, + user_manager_oauth: UserManagerMock, + access_token: str, ): state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET") - with asynctest.patch.object( - oauth_client, "get_access_token" - ) as get_access_token_mock: - get_access_token_mock.return_value = access_token - with asynctest.patch.object( - oauth_client, "get_id_email" - ) as get_id_email_mock: - with asynctest.patch.object( - mock_user_db_oauth, "update" - ) as user_update_mock: - get_id_email_mock.return_value = ("user_oauth1", user_oauth.email) - response = await test_app_client.get( - "/callback", - params={"code": "CODE", "state": state_jwt}, - ) + async_method_mocker(oauth_client, "get_access_token", return_value=access_token) + async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) + async_method_mocker( + user_manager_oauth, "oauth_callback", return_value=user_oauth + ) + + response = await test_app_client.get( + "/callback", + params={"code": "CODE", "state": state_jwt}, + ) + + assert response.status_code == status.HTTP_200_OK - get_id_email_mock.assert_awaited_once_with("TOKEN") - user_update_mock.assert_awaited_once() data = cast(Dict[str, Any], response.json()) - assert data["token"] == str(user_oauth.id) - assert after_register.called is False - - async def test_existing_user_without_oauth( - self, - mock_user_db_oauth, - test_app_client: httpx.AsyncClient, - access_token, - oauth_client, - superuser_oauth, - after_register, - ): - state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET") - with asynctest.patch.object( - oauth_client, "get_access_token" - ) as get_access_token_mock: - get_access_token_mock.return_value = access_token - with asynctest.patch.object( - oauth_client, "get_id_email" - ) as get_id_email_mock: - with asynctest.patch.object( - mock_user_db_oauth, "update" - ) as user_update_mock: - get_id_email_mock.return_value = ( - "superuser_oauth1", - superuser_oauth.email, - ) - response = await test_app_client.get( - "/callback", - params={"code": "CODE", "state": state_jwt}, - ) - - get_id_email_mock.assert_awaited_once_with("TOKEN") - user_update_mock.assert_awaited_once() - data = cast(Dict[str, Any], response.json()) - - assert data["token"] == str(superuser_oauth.id) - - assert after_register.called is False - - async def test_unknown_user( - self, - mock_user_db_oauth, - test_app_client: httpx.AsyncClient, - access_token, - oauth_client, - after_register, - ): - state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET") - with asynctest.patch.object( - oauth_client, "get_access_token" - ) as get_access_token_mock: - get_access_token_mock.return_value = access_token - with asynctest.patch.object( - oauth_client, "get_id_email" - ) as get_id_email_mock: - with asynctest.patch.object( - mock_user_db_oauth, "create" - ) as user_create_mock: - get_id_email_mock.return_value = ( - "unknown_user_oauth1", - "galahad@camelot.bt", - ) - response = await test_app_client.get( - "/callback", - params={"code": "CODE", "state": state_jwt}, - ) - - get_id_email_mock.assert_awaited_once_with("TOKEN") - user_create_mock.assert_awaited_once() - data = cast(Dict[str, Any], response.json()) - - assert "token" in data - - assert after_register.called is True - actual_user = after_register.call_args[0][0] - assert str(actual_user.id) == data["token"] - request = after_register.call_args[0][1] - assert isinstance(request, Request) - async def test_inactive_user( self, + async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, - access_token, - oauth_client, - inactive_user_oauth, - after_register, + oauth_client: BaseOAuth2, + inactive_user_oauth: UserDB, + user_manager_oauth: UserManagerMock, + access_token: str, ): state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET") - with asynctest.patch.object( - oauth_client, "get_access_token" - ) as get_access_token_mock: - get_access_token_mock.return_value = access_token - with asynctest.patch.object( - oauth_client, "get_id_email" - ) as get_id_email_mock: - get_id_email_mock.return_value = ( - "inactive_user_oauth1", - inactive_user_oauth.email, - ) - response = await test_app_client.get( - "/callback", - params={"code": "CODE", "state": state_jwt}, - ) + async_method_mocker(oauth_client, "get_access_token", return_value=access_token) + async_method_mocker( + oauth_client, + "get_id_email", + return_value=("user_oauth1", inactive_user_oauth.email), + ) + async_method_mocker( + user_manager_oauth, "oauth_callback", return_value=inactive_user_oauth + ) + + response = await test_app_client.get( + "/callback", + params={"code": "CODE", "state": state_jwt}, + ) assert response.status_code == status.HTTP_400_BAD_REQUEST - data = cast(Dict[str, Any], response.json()) - assert data["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS - - assert after_register.called is False async def test_redirect_url_router( self, + async_method_mocker: AsyncMethodMocker, test_app_client_redirect_url: httpx.AsyncClient, - access_token, - oauth_client, - user_oauth, + oauth_client: BaseOAuth2, + user_oauth: UserDB, + user_manager_oauth: UserManagerMock, + access_token: str, ): state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET") - with asynctest.patch.object( - oauth_client, "get_access_token" - ) as get_access_token_mock: - get_access_token_mock.return_value = access_token - with asynctest.patch.object( - oauth_client, "get_id_email" - ) as get_id_email_mock: - get_id_email_mock.return_value = ("user_oauth1", user_oauth.email) - response = await test_app_client_redirect_url.get( - "/callback", - params={"code": "CODE", "state": state_jwt}, - ) + get_access_token_mock = async_method_mocker( + oauth_client, "get_access_token", return_value=access_token + ) + async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) + async_method_mocker( + user_manager_oauth, "oauth_callback", return_value=user_oauth + ) - get_access_token_mock.assert_awaited_once_with( + response = await test_app_client_redirect_url.get( + "/callback", + params={"code": "CODE", "state": state_jwt}, + ) + + assert response.status_code == status.HTTP_200_OK + + get_access_token_mock.assert_called_once_with( "CODE", "http://www.tintagel.bt/callback" ) - data = cast(Dict[str, Any], response.json()) + data = cast(Dict[str, Any], response.json()) assert data["token"] == str(user_oauth.id) diff --git a/tests/test_router_reset.py b/tests/test_router_reset.py index 1005a815..e761cb7a 100644 --- a/tests/test_router_reset.py +++ b/tests/test_router_reset.py @@ -11,7 +11,7 @@ from fastapi_users.manager import ( UserNotExists, ) from fastapi_users.router import ErrorCode, get_reset_password_router -from tests.conftest import UserManagerMock +from tests.conftest import AsyncMethodMocker, UserManagerMock @pytest.fixture @@ -57,10 +57,11 @@ class TestForgotPassword: async def test_existing_user( self, + async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock, ): - user_manager.mock_method("forgot_password") + async_method_mocker(user_manager, "forgot_password", return_value=None) json = {"email": "king.arthur@camelot.bt"} response = await test_app_client.post("/forgot-password", json=json) assert response.status_code == status.HTTP_202_ACCEPTED @@ -139,10 +140,11 @@ class TestResetPassword: async def test_valid_user_password( self, + async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock, ): - user_manager.mock_method("reset_password") + async_method_mocker(user_manager, "reset_password", return_value=None) json = {"token": "foo", "password": "guinevere"} response = await test_app_client.post("/reset-password", json=json) assert response.status_code == status.HTTP_200_OK