mirror of
				https://github.com/fastapi-users/fastapi-users.git
				synced 2025-10-31 09:28:45 +08:00 
			
		
		
		
	fix: add expired token error on oauth callback
This commit is contained in:
		 Muhammad Daffa Dinaya
					Muhammad Daffa Dinaya
				
			
				
					committed by
					
						 François Voron
						François Voron
					
				
			
			
				
	
			
			
			 François Voron
						François Voron
					
				
			
						parent
						
							576683cccd
						
					
				
				
					commit
					5863445774
				
			| @ -25,3 +25,6 @@ class ErrorCode(str, Enum): | |||||||
|     VERIFY_USER_ALREADY_VERIFIED = "VERIFY_USER_ALREADY_VERIFIED" |     VERIFY_USER_ALREADY_VERIFIED = "VERIFY_USER_ALREADY_VERIFIED" | ||||||
|     UPDATE_USER_EMAIL_ALREADY_EXISTS = "UPDATE_USER_EMAIL_ALREADY_EXISTS" |     UPDATE_USER_EMAIL_ALREADY_EXISTS = "UPDATE_USER_EMAIL_ALREADY_EXISTS" | ||||||
|     UPDATE_USER_INVALID_PASSWORD = "UPDATE_USER_INVALID_PASSWORD" |     UPDATE_USER_INVALID_PASSWORD = "UPDATE_USER_INVALID_PASSWORD" | ||||||
|  |     ACCESS_TOKEN_ALREADY_EXPIRED = "ACCESS_TOKEN_ALREADY_EXPIRED" | ||||||
|  |     ACCESS_TOKEN_DECODE_ERROR = "ACCESS_TOKEN_DECODE_ERROR" | ||||||
|  |  | ||||||
|  | |||||||
| @ -90,6 +90,14 @@ def get_oauth_router( | |||||||
|                                 "summary": "User is inactive.", |                                 "summary": "User is inactive.", | ||||||
|                                 "value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS}, |                                 "value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS}, | ||||||
|                             }, |                             }, | ||||||
|  |                             ErrorCode.ACCESS_TOKEN_DECODE_ERROR: { | ||||||
|  |                                 "summary": "Access token is error.", | ||||||
|  |                                 "value": {"detail": ErrorCode.ACCESS_TOKEN_DECODE_ERROR}, | ||||||
|  |                             }, | ||||||
|  |                             ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED: { | ||||||
|  |                                 "summary": "Access token is already expired.", | ||||||
|  |                                 "value": {"detail": ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED}, | ||||||
|  |                             }, | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                 }, |                 }, | ||||||
| @ -118,7 +126,15 @@ def get_oauth_router( | |||||||
|         try: |         try: | ||||||
|             decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE]) |             decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE]) | ||||||
|         except jwt.DecodeError: |         except jwt.DecodeError: | ||||||
|             raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) |             raise HTTPException( | ||||||
|  |                 status_code=status.HTTP_400_BAD_REQUEST, | ||||||
|  |                 detail=ErrorCode.ACCESS_TOKEN_DECODE_ERROR, | ||||||
|  |             ) | ||||||
|  |         except jwt.ExpiredSignatureError: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=status.HTTP_400_BAD_REQUEST, | ||||||
|  |                 detail=ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             user = await user_manager.oauth_callback( |             user = await user_manager.oauth_callback( | ||||||
| @ -221,6 +237,14 @@ def get_oauth_associate_router( | |||||||
|                                 "summary": "Invalid state token.", |                                 "summary": "Invalid state token.", | ||||||
|                                 "value": None, |                                 "value": None, | ||||||
|                             }, |                             }, | ||||||
|  |                             ErrorCode.ACCESS_TOKEN_DECODE_ERROR: { | ||||||
|  |                                 "summary": "Access token is error.", | ||||||
|  |                                 "value": {"detail": ErrorCode.ACCESS_TOKEN_DECODE_ERROR}, | ||||||
|  |                             }, | ||||||
|  |                             ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED: { | ||||||
|  |                                 "summary": "Access token is already expired.", | ||||||
|  |                                 "value": {"detail": ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED}, | ||||||
|  |                             }, | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                 }, |                 }, | ||||||
| @ -249,7 +273,15 @@ def get_oauth_associate_router( | |||||||
|         try: |         try: | ||||||
|             state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE]) |             state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE]) | ||||||
|         except jwt.DecodeError: |         except jwt.DecodeError: | ||||||
|             raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) |             raise HTTPException( | ||||||
|  |                 status_code=status.HTTP_400_BAD_REQUEST, | ||||||
|  |                 detail=ErrorCode.ACCESS_TOKEN_DECODE_ERROR, | ||||||
|  |             ) | ||||||
|  |         except jwt.ExpiredSignatureError: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=status.HTTP_400_BAD_REQUEST, | ||||||
|  |                 detail=ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         if state_data["sub"] != str(user.id): |         if state_data["sub"] != str(user.id): | ||||||
|             raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) |             raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) | ||||||
|  | |||||||
| @ -310,6 +310,57 @@ class TestCallback: | |||||||
|         assert json["detail"] == ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL |         assert json["detail"] == ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     async def test_callback_token_expired( | ||||||
|  |         self, | ||||||
|  |         async_method_mocker: AsyncMethodMocker, | ||||||
|  |         test_app_client: httpx.AsyncClient, | ||||||
|  |         oauth_client: BaseOAuth2, | ||||||
|  |         user_oauth: UserOAuthModel, | ||||||
|  |         user_manager_oauth: UserManagerMock, | ||||||
|  |         access_token: str, | ||||||
|  |     ): | ||||||
|  |         state_jwt = generate_state_token({}, "SECRET", lifetime_seconds=-1) | ||||||
|  |         async_method_mocker(oauth_client, "get_access_token", return_value=access_token) | ||||||
|  |         async_method_mocker( | ||||||
|  |             oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) | ||||||
|  |         ) | ||||||
|  |         response = await test_app_client.get( | ||||||
|  |             "/oauth/callback", | ||||||
|  |             params={"code": "CODE", "state": state_jwt}, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         assert response.status_code == status.HTTP_400_BAD_REQUEST | ||||||
|  |  | ||||||
|  |         data = cast(dict[str, Any], response.json()) | ||||||
|  |         assert data["detail"] == ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED | ||||||
|  |  | ||||||
|  |         assert user_manager_oauth.on_after_login.called is False | ||||||
|  |  | ||||||
|  |     async def test_callback_decode_token_error( | ||||||
|  |         self, | ||||||
|  |         async_method_mocker: AsyncMethodMocker, | ||||||
|  |         test_app_client: httpx.AsyncClient, | ||||||
|  |         oauth_client: BaseOAuth2, | ||||||
|  |         user_oauth: UserOAuthModel, | ||||||
|  |         user_manager_oauth: UserManagerMock, | ||||||
|  |         access_token: str, | ||||||
|  |     ): | ||||||
|  |         state_jwt = generate_state_token({}, "RANDOM") | ||||||
|  |         async_method_mocker(oauth_client, "get_access_token", return_value=access_token) | ||||||
|  |         async_method_mocker( | ||||||
|  |             oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) | ||||||
|  |         ) | ||||||
|  |         response = await test_app_client.get( | ||||||
|  |             "/oauth/callback", | ||||||
|  |             params={"code": "CODE", "state": state_jwt}, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         data = cast(dict[str, Any], response.json()) | ||||||
|  |         assert data["detail"] == ErrorCode.ACCESS_TOKEN_DECODE_ERROR | ||||||
|  |  | ||||||
|  |         assert response.status_code == status.HTTP_400_BAD_REQUEST | ||||||
|  |         assert user_manager_oauth.on_after_login.called is False | ||||||
|  |  | ||||||
| @pytest.mark.router | @pytest.mark.router | ||||||
| @pytest.mark.oauth | @pytest.mark.oauth | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| @ -552,6 +603,57 @@ class TestAssociateCallback: | |||||||
|         json = response.json() |         json = response.json() | ||||||
|         assert json["detail"] == ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL |         assert json["detail"] == ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL | ||||||
|      |      | ||||||
|  |     async def test_callback_token_expired( | ||||||
|  |         self, | ||||||
|  |         async_method_mocker: AsyncMethodMocker, | ||||||
|  |         test_app_client: httpx.AsyncClient, | ||||||
|  |         oauth_client: BaseOAuth2, | ||||||
|  |         user_oauth: UserOAuthModel, | ||||||
|  |         user_manager_oauth: UserManagerMock, | ||||||
|  |         access_token: str, | ||||||
|  |     ): | ||||||
|  |         state_jwt = generate_state_token({}, "SECRET", lifetime_seconds=-1) | ||||||
|  |         async_method_mocker(oauth_client, "get_access_token", return_value=access_token) | ||||||
|  |         async_method_mocker( | ||||||
|  |             oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) | ||||||
|  |         ) | ||||||
|  |         response = await test_app_client.get( | ||||||
|  |             "/oauth/callback", | ||||||
|  |             params={"code": "CODE", "state": state_jwt}, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         assert response.status_code == status.HTTP_400_BAD_REQUEST | ||||||
|  |  | ||||||
|  |         data = cast(dict[str, Any], response.json()) | ||||||
|  |         assert data["detail"] == ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED | ||||||
|  |  | ||||||
|  |         assert user_manager_oauth.on_after_login.called is False | ||||||
|  |  | ||||||
|  |     async def test_callback_decode_token_error( | ||||||
|  |         self, | ||||||
|  |         async_method_mocker: AsyncMethodMocker, | ||||||
|  |         test_app_client: httpx.AsyncClient, | ||||||
|  |         oauth_client: BaseOAuth2, | ||||||
|  |         user_oauth: UserOAuthModel, | ||||||
|  |         user_manager_oauth: UserManagerMock, | ||||||
|  |         access_token: str, | ||||||
|  |     ): | ||||||
|  |         state_jwt = generate_state_token({}, "RANDOM") | ||||||
|  |         async_method_mocker(oauth_client, "get_access_token", return_value=access_token) | ||||||
|  |         async_method_mocker( | ||||||
|  |             oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) | ||||||
|  |         ) | ||||||
|  |         response = await test_app_client.get( | ||||||
|  |             "/oauth/callback", | ||||||
|  |             params={"code": "CODE", "state": state_jwt}, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         data = cast(dict[str, Any], response.json()) | ||||||
|  |         assert data["detail"] == ErrorCode.ACCESS_TOKEN_DECODE_ERROR | ||||||
|  |  | ||||||
|  |         assert response.status_code == status.HTTP_400_BAD_REQUEST | ||||||
|  |         assert user_manager_oauth.on_after_login.called is False | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| @pytest.mark.oauth | @pytest.mark.oauth | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user