From 586344577430a59dcd1f55585b5423838fe23ef7 Mon Sep 17 00:00:00 2001 From: Muhammad Daffa Dinaya Date: Thu, 2 Jan 2025 12:19:48 +0700 Subject: [PATCH] fix: add expired token error on oauth callback --- fastapi_users/router/common.py | 3 + fastapi_users/router/oauth.py | 36 +++++++++++- tests/test_router_oauth.py | 102 +++++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 2 deletions(-) diff --git a/fastapi_users/router/common.py b/fastapi_users/router/common.py index 0184758b..7200b9f5 100644 --- a/fastapi_users/router/common.py +++ b/fastapi_users/router/common.py @@ -25,3 +25,6 @@ class ErrorCode(str, Enum): VERIFY_USER_ALREADY_VERIFIED = "VERIFY_USER_ALREADY_VERIFIED" UPDATE_USER_EMAIL_ALREADY_EXISTS = "UPDATE_USER_EMAIL_ALREADY_EXISTS" UPDATE_USER_INVALID_PASSWORD = "UPDATE_USER_INVALID_PASSWORD" + ACCESS_TOKEN_ALREADY_EXPIRED = "ACCESS_TOKEN_ALREADY_EXPIRED" + ACCESS_TOKEN_DECODE_ERROR = "ACCESS_TOKEN_DECODE_ERROR" + diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index 3640aa65..a4deb0db 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -90,6 +90,14 @@ def get_oauth_router( "summary": "User is inactive.", "value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS}, }, + ErrorCode.ACCESS_TOKEN_DECODE_ERROR: { + "summary": "Access token is error.", + "value": {"detail": ErrorCode.ACCESS_TOKEN_DECODE_ERROR}, + }, + ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED: { + "summary": "Access token is already expired.", + "value": {"detail": ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED}, + }, } } }, @@ -118,7 +126,15 @@ def get_oauth_router( try: decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE]) except jwt.DecodeError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.ACCESS_TOKEN_DECODE_ERROR, + ) + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED, + ) try: user = await user_manager.oauth_callback( @@ -221,6 +237,14 @@ def get_oauth_associate_router( "summary": "Invalid state token.", "value": None, }, + ErrorCode.ACCESS_TOKEN_DECODE_ERROR: { + "summary": "Access token is error.", + "value": {"detail": ErrorCode.ACCESS_TOKEN_DECODE_ERROR}, + }, + ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED: { + "summary": "Access token is already expired.", + "value": {"detail": ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED}, + }, } } }, @@ -249,7 +273,15 @@ def get_oauth_associate_router( try: state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE]) except jwt.DecodeError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.ACCESS_TOKEN_DECODE_ERROR, + ) + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED, + ) if state_data["sub"] != str(user.id): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index d67de146..2612a3d1 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -310,6 +310,57 @@ class TestCallback: assert json["detail"] == ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL + async def test_callback_token_expired( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, + user_oauth: UserOAuthModel, + user_manager_oauth: UserManagerMock, + access_token: str, + ): + state_jwt = generate_state_token({}, "SECRET", lifetime_seconds=-1) + async_method_mocker(oauth_client, "get_access_token", return_value=access_token) + async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) + response = await test_app_client.get( + "/oauth/callback", + params={"code": "CODE", "state": state_jwt}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + data = cast(dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED + + assert user_manager_oauth.on_after_login.called is False + + async def test_callback_decode_token_error( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, + user_oauth: UserOAuthModel, + user_manager_oauth: UserManagerMock, + access_token: str, + ): + state_jwt = generate_state_token({}, "RANDOM") + async_method_mocker(oauth_client, "get_access_token", return_value=access_token) + async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) + response = await test_app_client.get( + "/oauth/callback", + params={"code": "CODE", "state": state_jwt}, + ) + + data = cast(dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.ACCESS_TOKEN_DECODE_ERROR + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert user_manager_oauth.on_after_login.called is False + @pytest.mark.router @pytest.mark.oauth @pytest.mark.asyncio @@ -551,6 +602,57 @@ class TestAssociateCallback: assert response.status_code == status.HTTP_400_BAD_REQUEST json = response.json() assert json["detail"] == ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL + + async def test_callback_token_expired( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, + user_oauth: UserOAuthModel, + user_manager_oauth: UserManagerMock, + access_token: str, + ): + state_jwt = generate_state_token({}, "SECRET", lifetime_seconds=-1) + async_method_mocker(oauth_client, "get_access_token", return_value=access_token) + async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) + response = await test_app_client.get( + "/oauth/callback", + params={"code": "CODE", "state": state_jwt}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + data = cast(dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED + + assert user_manager_oauth.on_after_login.called is False + + async def test_callback_decode_token_error( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, + user_oauth: UserOAuthModel, + user_manager_oauth: UserManagerMock, + access_token: str, + ): + state_jwt = generate_state_token({}, "RANDOM") + async_method_mocker(oauth_client, "get_access_token", return_value=access_token) + async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) + response = await test_app_client.get( + "/oauth/callback", + params={"code": "CODE", "state": state_jwt}, + ) + + data = cast(dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.ACCESS_TOKEN_DECODE_ERROR + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert user_manager_oauth.on_after_login.called is False @pytest.mark.asyncio