Revamp AccessToken DB strategy to adopt generic model approach

This commit is contained in:
François Voron
2022-04-29 15:45:14 +02:00
parent e271cc1352
commit 87ac51a7bd
7 changed files with 85 additions and 67 deletions

View File

@ -1,30 +1,37 @@
import dataclasses
import uuid
from datetime import datetime
from typing import Dict, Optional
from datetime import datetime, timezone
from typing import Any, Dict, Optional
import pytest
from fastapi_users.authentication.strategy import (
AccessTokenDatabase,
BaseAccessToken,
AccessTokenProtocol,
DatabaseStrategy,
)
from tests.conftest import UserModel
class AccessToken(BaseAccessToken):
pass
@dataclasses.dataclass
class AccessTokenModel(AccessTokenProtocol):
token: str
user_id: uuid.UUID
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
created_at: datetime = dataclasses.field(
default_factory=lambda: datetime.now(timezone.utc)
)
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]):
store: Dict[str, AccessToken]
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessTokenModel]):
store: Dict[str, AccessTokenModel]
def __init__(self):
self.access_token_model = AccessToken
self.store = {}
async def get_by_token(
self, token: str, max_age: Optional[datetime] = None
) -> Optional[AccessToken]:
) -> Optional[AccessTokenModel]:
try:
access_token = self.store[token]
if max_age is not None and access_token.created_at < max_age:
@ -33,15 +40,20 @@ class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]):
except KeyError:
return None
async def create(self, access_token: AccessToken) -> AccessToken:
async def create(self, create_dict: Dict[str, Any]) -> AccessTokenModel:
access_token = AccessTokenModel(**create_dict)
self.store[access_token.token] = access_token
return access_token
async def update(self, access_token: AccessToken) -> AccessToken:
async def update(
self, access_token: AccessTokenModel, update_dict: Dict[str, Any]
) -> AccessTokenModel:
for field, value in update_dict.items():
setattr(access_token, field, value)
self.store[access_token.token] = access_token
return access_token
async def delete(self, access_token: AccessToken) -> None:
async def delete(self, access_token: AccessTokenModel) -> None:
try:
del self.store[access_token.token]
except KeyError:
@ -62,14 +74,18 @@ def database_strategy(access_token_database: AccessTokenDatabaseMock):
class TestReadToken:
@pytest.mark.asyncio
async def test_missing_token(
self, database_strategy: DatabaseStrategy, user_manager
self,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
user_manager,
):
authenticated_user = await database_strategy.read_token(None, user_manager)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_invalid_token(
self, database_strategy: DatabaseStrategy, user_manager
self,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
user_manager,
):
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is None
@ -77,14 +93,15 @@ class TestReadToken:
@pytest.mark.asyncio
async def test_valid_token_not_existing_user(
self,
database_strategy: DatabaseStrategy,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user_manager,
):
await access_token_database.create(
AccessToken(
token="TOKEN", user_id=uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f")
)
{
"token": "TOKEN",
"user_id": uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"),
}
)
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is None
@ -92,12 +109,12 @@ class TestReadToken:
@pytest.mark.asyncio
async def test_valid_token(
self,
database_strategy: DatabaseStrategy,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user_manager,
user,
user: UserModel,
):
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id))
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is not None
assert authenticated_user.id == user.id
@ -106,9 +123,9 @@ class TestReadToken:
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_write_token(
database_strategy: DatabaseStrategy,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user,
user: UserModel,
):
token = await database_strategy.write_token(user)
@ -120,11 +137,11 @@ async def test_write_token(
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_destroy_token(
database_strategy: DatabaseStrategy,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user,
user: UserModel,
):
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id))
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
await database_strategy.destroy_token("TOKEN", user)