From 7cf7154e2758fe1e8f1feffe9e3f4c778d4c7854 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 28 Dec 2020 08:53:31 +0100 Subject: [PATCH] Fix #431: make OAuth expires_at optional in model and DB schemas --- docs/configuration/oauth.md | 2 +- fastapi_users/db/sqlalchemy.py | 2 +- fastapi_users/db/tortoise.py | 2 +- fastapi_users/models.py | 2 +- fastapi_users/router/oauth.py | 2 +- tests/test_router_oauth.py | 43 +++++++++++++++------------------- 6 files changed, 24 insertions(+), 29 deletions(-) diff --git a/docs/configuration/oauth.md b/docs/configuration/oauth.md index 58d97f39..c9e3b96f 100644 --- a/docs/configuration/oauth.md +++ b/docs/configuration/oauth.md @@ -59,7 +59,7 @@ Notice that we inherit from the `BaseOAuthAccountMixin`, which adds a `List` of * `id` (`UUID4`) – Unique identifier of the OAuth account information. Default to a **UUID4**. * `oauth_name` (`str`) – Name of the OAuth service. It corresponds to the `name` property of the OAuth client. * `access_token` (`str`) – Access token. -* `expires_at` (`int`) - Timestamp at which the access token is expired. +* `expires_at` (`Optional[int]`) - Timestamp at which the access token is expired. * `refresh_token` (`Optional[str]`) – On services that support it, a token to get a fresh access token. * `account_id` (`str`) - Identifier of the OAuth account on the corresponding service. * `account_email` (`str`) - Email address of the OAuth account on the corresponding service. diff --git a/fastapi_users/db/sqlalchemy.py b/fastapi_users/db/sqlalchemy.py index 76fc26ba..1dfba7a5 100644 --- a/fastapi_users/db/sqlalchemy.py +++ b/fastapi_users/db/sqlalchemy.py @@ -67,7 +67,7 @@ class SQLAlchemyBaseOAuthAccountTable: id = Column(GUID, primary_key=True) oauth_name = Column(String(length=100), index=True, nullable=False) access_token = Column(String(length=1024), nullable=False) - expires_at = Column(Integer, nullable=False) + expires_at = Column(Integer, nullable=True) refresh_token = Column(String(length=1024), nullable=True) account_id = Column(String(length=320), index=True, nullable=False) account_email = Column(String(length=320), nullable=False) diff --git a/fastapi_users/db/tortoise.py b/fastapi_users/db/tortoise.py index 7902242e..09e9151e 100644 --- a/fastapi_users/db/tortoise.py +++ b/fastapi_users/db/tortoise.py @@ -31,7 +31,7 @@ class TortoiseBaseOAuthAccountModel(models.Model): id = fields.UUIDField(pk=True, generated=False, max_length=255) oauth_name = fields.CharField(null=False, max_length=255) access_token = fields.CharField(null=False, max_length=255) - expires_at = fields.IntField(null=False) + expires_at = fields.IntField(null=True) refresh_token = fields.CharField(null=True, max_length=255) account_id = fields.CharField(index=True, null=False, max_length=255) account_email = fields.CharField(null=False, max_length=255) diff --git a/fastapi_users/models.py b/fastapi_users/models.py index 6f843c10..f23912bf 100644 --- a/fastapi_users/models.py +++ b/fastapi_users/models.py @@ -56,7 +56,7 @@ class BaseOAuthAccount(BaseModel): id: Optional[UUID4] = None oauth_name: str access_token: str - expires_at: int + expires_at: Optional[int] = None refresh_token: Optional[str] = None account_id: str account_email: str diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index ae0646b0..e6aae2b7 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -108,7 +108,7 @@ def get_oauth_router( new_oauth_account = models.BaseOAuthAccount( oauth_name=oauth_client.name, access_token=token["access_token"], - expires_at=token["expires_at"], + expires_at=token.get("expires_at"), refresh_token=token.get("refresh_token"), account_id=account_id, account_email=account_email, diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index 64f83af9..58f6cb55 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -147,10 +147,18 @@ class TestAuthorize: @pytest.mark.router @pytest.mark.oauth @pytest.mark.asyncio +@pytest.mark.parametrize( + "access_token", + [ + ({"access_token": "TOKEN", "expires_at": 1579179542}), + ({"access_token": "TOKEN"}), + ], +) class TestCallback: async def test_invalid_state( self, test_app_client: httpx.AsyncClient, + access_token, oauth_client, user_oauth, after_register, @@ -158,10 +166,7 @@ class TestCallback: with asynctest.patch.object( oauth_client, "get_access_token" ) as get_access_token_mock: - get_access_token_mock.return_value = { - "access_token": "TOKEN", - "expires_at": 1579179542, - } + get_access_token_mock.return_value = access_token with asynctest.patch.object( oauth_client, "get_id_email" ) as get_id_email_mock: @@ -180,6 +185,7 @@ class TestCallback: self, mock_user_db_oauth, test_app_client: httpx.AsyncClient, + access_token, oauth_client, user_oauth, after_register, @@ -188,10 +194,7 @@ class TestCallback: with asynctest.patch.object( oauth_client, "get_access_token" ) as get_access_token_mock: - get_access_token_mock.return_value = { - "access_token": "TOKEN", - "expires_at": 1579179542, - } + get_access_token_mock.return_value = access_token with asynctest.patch.object( oauth_client, "get_id_email" ) as get_id_email_mock: @@ -216,6 +219,7 @@ class TestCallback: self, mock_user_db_oauth, test_app_client: httpx.AsyncClient, + access_token, oauth_client, superuser_oauth, after_register, @@ -224,10 +228,7 @@ class TestCallback: with asynctest.patch.object( oauth_client, "get_access_token" ) as get_access_token_mock: - get_access_token_mock.return_value = { - "access_token": "TOKEN", - "expires_at": 1579179542, - } + get_access_token_mock.return_value = access_token with asynctest.patch.object( oauth_client, "get_id_email" ) as get_id_email_mock: @@ -255,6 +256,7 @@ class TestCallback: self, mock_user_db_oauth, test_app_client: httpx.AsyncClient, + access_token, oauth_client, after_register, ): @@ -262,10 +264,7 @@ class TestCallback: with asynctest.patch.object( oauth_client, "get_access_token" ) as get_access_token_mock: - get_access_token_mock.return_value = { - "access_token": "TOKEN", - "expires_at": 1579179542, - } + get_access_token_mock.return_value = access_token with asynctest.patch.object( oauth_client, "get_id_email" ) as get_id_email_mock: @@ -297,6 +296,7 @@ class TestCallback: self, mock_user_db_oauth, test_app_client: httpx.AsyncClient, + access_token, oauth_client, inactive_user_oauth, after_register, @@ -305,10 +305,7 @@ class TestCallback: with asynctest.patch.object( oauth_client, "get_access_token" ) as get_access_token_mock: - get_access_token_mock.return_value = { - "access_token": "TOKEN", - "expires_at": 1579179542, - } + get_access_token_mock.return_value = access_token with asynctest.patch.object( oauth_client, "get_id_email" ) as get_id_email_mock: @@ -331,6 +328,7 @@ class TestCallback: self, mock_user_db_oauth, test_app_client_redirect_url: httpx.AsyncClient, + access_token, oauth_client, user_oauth, ): @@ -338,10 +336,7 @@ class TestCallback: with asynctest.patch.object( oauth_client, "get_access_token" ) as get_access_token_mock: - get_access_token_mock.return_value = { - "access_token": "TOKEN", - "expires_at": 1579179542, - } + get_access_token_mock.return_value = access_token with asynctest.patch.object( oauth_client, "get_id_email" ) as get_id_email_mock: