Files
fastapi-users/tests/test_authentication_strategy_db.py
2024-11-03 12:51:32 +00:00

150 lines
4.6 KiB
Python

import dataclasses
import uuid
from datetime import datetime, timezone
from typing import Any, Optional
import pytest
from fastapi_users.authentication.strategy import (
AccessTokenDatabase,
AccessTokenProtocol,
DatabaseStrategy,
)
from tests.conftest import IDType, UserModel
@dataclasses.dataclass
class AccessTokenModel(AccessTokenProtocol[IDType]):
token: str
user_id: uuid.UUID
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
created_at: datetime = dataclasses.field(
default_factory=lambda: datetime.now(timezone.utc)
)
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessTokenModel]):
store: dict[str, AccessTokenModel]
def __init__(self):
self.store = {}
async def get_by_token(
self, token: str, max_age: Optional[datetime] = None
) -> Optional[AccessTokenModel]:
try:
access_token = self.store[token]
if max_age is not None and access_token.created_at < max_age:
return None
except KeyError:
return None
else:
return access_token
async def create(self, create_dict: dict[str, Any]) -> AccessTokenModel:
access_token = AccessTokenModel(**create_dict)
self.store[access_token.token] = access_token
return access_token
async def update(
self, access_token: AccessTokenModel, update_dict: dict[str, Any]
) -> AccessTokenModel:
for field, value in update_dict.items():
setattr(access_token, field, value)
self.store[access_token.token] = access_token
return access_token
async def delete(self, access_token: AccessTokenModel) -> None:
try:
del self.store[access_token.token]
except KeyError:
pass
@pytest.fixture
def access_token_database() -> AccessTokenDatabaseMock:
return AccessTokenDatabaseMock()
@pytest.fixture
def database_strategy(access_token_database: AccessTokenDatabaseMock):
return DatabaseStrategy(access_token_database, 3600)
@pytest.mark.authentication
class TestReadToken:
@pytest.mark.asyncio
async def test_missing_token(
self,
database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
user_manager,
):
authenticated_user = await database_strategy.read_token(None, user_manager)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_invalid_token(
self,
database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
user_manager,
):
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token_not_existing_user(
self,
database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user_manager,
):
await access_token_database.create(
{
"token": "TOKEN",
"user_id": uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"),
}
)
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token(
self,
database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user_manager,
user: UserModel,
):
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is not None
assert authenticated_user.id == user.id
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_write_token(
database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user: UserModel,
):
token = await database_strategy.write_token(user)
access_token = await access_token_database.get_by_token(token)
assert access_token is not None
assert access_token.user_id == user.id
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_destroy_token(
database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user: UserModel,
):
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
await database_strategy.destroy_token("TOKEN", user)
assert await access_token_database.get_by_token("TOKEN") is None