fix: add expired token error on oauth callback

This commit is contained in:
Muhammad Daffa Dinaya
2025-01-02 12:19:48 +07:00
committed by François Voron
parent 576683cccd
commit 5863445774
3 changed files with 139 additions and 2 deletions

View File

@ -25,3 +25,6 @@ class ErrorCode(str, Enum):
VERIFY_USER_ALREADY_VERIFIED = "VERIFY_USER_ALREADY_VERIFIED"
UPDATE_USER_EMAIL_ALREADY_EXISTS = "UPDATE_USER_EMAIL_ALREADY_EXISTS"
UPDATE_USER_INVALID_PASSWORD = "UPDATE_USER_INVALID_PASSWORD"
ACCESS_TOKEN_ALREADY_EXPIRED = "ACCESS_TOKEN_ALREADY_EXPIRED"
ACCESS_TOKEN_DECODE_ERROR = "ACCESS_TOKEN_DECODE_ERROR"

View File

@ -90,6 +90,14 @@ def get_oauth_router(
"summary": "User is inactive.",
"value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS},
},
ErrorCode.ACCESS_TOKEN_DECODE_ERROR: {
"summary": "Access token is error.",
"value": {"detail": ErrorCode.ACCESS_TOKEN_DECODE_ERROR},
},
ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED: {
"summary": "Access token is already expired.",
"value": {"detail": ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED},
},
}
}
},
@ -118,7 +126,15 @@ def get_oauth_router(
try:
decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
except jwt.DecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.ACCESS_TOKEN_DECODE_ERROR,
)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED,
)
try:
user = await user_manager.oauth_callback(
@ -221,6 +237,14 @@ def get_oauth_associate_router(
"summary": "Invalid state token.",
"value": None,
},
ErrorCode.ACCESS_TOKEN_DECODE_ERROR: {
"summary": "Access token is error.",
"value": {"detail": ErrorCode.ACCESS_TOKEN_DECODE_ERROR},
},
ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED: {
"summary": "Access token is already expired.",
"value": {"detail": ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED},
},
}
}
},
@ -249,7 +273,15 @@ def get_oauth_associate_router(
try:
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
except jwt.DecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.ACCESS_TOKEN_DECODE_ERROR,
)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED,
)
if state_data["sub"] != str(user.id):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)

View File

@ -310,6 +310,57 @@ class TestCallback:
assert json["detail"] == ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL
async def test_callback_token_expired(
self,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
oauth_client: BaseOAuth2,
user_oauth: UserOAuthModel,
user_manager_oauth: UserManagerMock,
access_token: str,
):
state_jwt = generate_state_token({}, "SECRET", lifetime_seconds=-1)
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
async_method_mocker(
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
)
response = await test_app_client.get(
"/oauth/callback",
params={"code": "CODE", "state": state_jwt},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = cast(dict[str, Any], response.json())
assert data["detail"] == ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED
assert user_manager_oauth.on_after_login.called is False
async def test_callback_decode_token_error(
self,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
oauth_client: BaseOAuth2,
user_oauth: UserOAuthModel,
user_manager_oauth: UserManagerMock,
access_token: str,
):
state_jwt = generate_state_token({}, "RANDOM")
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
async_method_mocker(
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
)
response = await test_app_client.get(
"/oauth/callback",
params={"code": "CODE", "state": state_jwt},
)
data = cast(dict[str, Any], response.json())
assert data["detail"] == ErrorCode.ACCESS_TOKEN_DECODE_ERROR
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert user_manager_oauth.on_after_login.called is False
@pytest.mark.router
@pytest.mark.oauth
@pytest.mark.asyncio
@ -551,6 +602,57 @@ class TestAssociateCallback:
assert response.status_code == status.HTTP_400_BAD_REQUEST
json = response.json()
assert json["detail"] == ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL
async def test_callback_token_expired(
self,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
oauth_client: BaseOAuth2,
user_oauth: UserOAuthModel,
user_manager_oauth: UserManagerMock,
access_token: str,
):
state_jwt = generate_state_token({}, "SECRET", lifetime_seconds=-1)
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
async_method_mocker(
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
)
response = await test_app_client.get(
"/oauth/callback",
params={"code": "CODE", "state": state_jwt},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = cast(dict[str, Any], response.json())
assert data["detail"] == ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED
assert user_manager_oauth.on_after_login.called is False
async def test_callback_decode_token_error(
self,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
oauth_client: BaseOAuth2,
user_oauth: UserOAuthModel,
user_manager_oauth: UserManagerMock,
access_token: str,
):
state_jwt = generate_state_token({}, "RANDOM")
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
async_method_mocker(
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
)
response = await test_app_client.get(
"/oauth/callback",
params={"code": "CODE", "state": state_jwt},
)
data = cast(dict[str, Any], response.json())
assert data["detail"] == ErrorCode.ACCESS_TOKEN_DECODE_ERROR
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert user_manager_oauth.on_after_login.called is False
@pytest.mark.asyncio