From f4338ca3dfd39d1e0326ac95fb47469ff3198850 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 20 Jun 2022 16:55:58 +0200 Subject: [PATCH] Handle UserAlreadyExists error in oauth callback route --- fastapi_users/router/common.py | 1 + fastapi_users/router/oauth.py | 27 +++++++++++++++++---------- tests/test_router_oauth.py | 31 ++++++++++++++++++++++++++++++- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/fastapi_users/router/common.py b/fastapi_users/router/common.py index 567055f7..944068fc 100644 --- a/fastapi_users/router/common.py +++ b/fastapi_users/router/common.py @@ -16,6 +16,7 @@ class ErrorCodeReasonModel(BaseModel): class ErrorCode(str, Enum): REGISTER_INVALID_PASSWORD = "REGISTER_INVALID_PASSWORD" REGISTER_USER_ALREADY_EXISTS = "REGISTER_USER_ALREADY_EXISTS" + OAUTH_USER_ALREADY_EXISTS = "OAUTH_USER_ALREADY_EXISTS" LOGIN_BAD_CREDENTIALS = "LOGIN_BAD_CREDENTIALS" LOGIN_USER_NOT_VERIFIED = "LOGIN_USER_NOT_VERIFIED" RESET_PASSWORD_BAD_TOKEN = "RESET_PASSWORD_BAD_TOKEN" diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index 2a54882f..16e96d00 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from fastapi_users import models, schemas from fastapi_users.authentication import AuthenticationBackend, Authenticator, Strategy +from fastapi_users.exceptions import UserAlreadyExists from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.manager import BaseUserManager, UserManagerDependency from fastapi_users.router.common import ErrorCode, ErrorModel @@ -115,16 +116,22 @@ def get_oauth_router( except jwt.DecodeError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - user = await user_manager.oauth_callback( - oauth_client.name, - token["access_token"], - account_id, - account_email, - token.get("expires_at"), - token.get("refresh_token"), - request, - associate_by_email=associate_by_email, - ) + try: + user = await user_manager.oauth_callback( + oauth_client.name, + token["access_token"], + account_id, + account_email, + token.get("expires_at"), + token.get("refresh_token"), + request, + associate_by_email=associate_by_email, + ) + except UserAlreadyExists: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS, + ) if not user.is_active: raise HTTPException( diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index e18bce27..c13c164d 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -5,7 +5,9 @@ import pytest from fastapi import FastAPI, status from httpx_oauth.oauth2 import BaseOAuth2, OAuth2 +from fastapi_users import exceptions from fastapi_users.authentication import AuthenticationBackend, Authenticator +from fastapi_users.router.common import ErrorCode from fastapi_users.router.oauth import ( generate_state_token, get_oauth_associate_router, @@ -16,7 +18,6 @@ from tests.conftest import ( User, UserManagerMock, UserModel, - UserOAuth, UserOAuthModel, ) @@ -159,6 +160,34 @@ class TestCallback: get_id_email_mock.assert_called_once_with("TOKEN") + async def test_already_exists_error( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, + user_oauth: UserOAuthModel, + user_manager_oauth: UserManagerMock, + access_token: str, + ): + state_jwt = generate_state_token({}, "SECRET") + async_method_mocker(oauth_client, "get_access_token", return_value=access_token) + async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) + async_method_mocker( + user_manager_oauth, "oauth_callback" + ).side_effect = exceptions.UserAlreadyExists + + response = await test_app_client.get( + "/oauth/callback", + params={"code": "CODE", "state": state_jwt}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.OAUTH_USER_ALREADY_EXISTS + async def test_active_user( self, async_method_mocker: AsyncMethodMocker,