mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-26 04:25:46 +08:00
Revamp AccessToken DB strategy to adopt generic model approach
This commit is contained in:
@ -1,30 +1,37 @@
|
||||
import dataclasses
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi_users.authentication.strategy import (
|
||||
AccessTokenDatabase,
|
||||
BaseAccessToken,
|
||||
AccessTokenProtocol,
|
||||
DatabaseStrategy,
|
||||
)
|
||||
from tests.conftest import UserModel
|
||||
|
||||
|
||||
class AccessToken(BaseAccessToken):
|
||||
pass
|
||||
@dataclasses.dataclass
|
||||
class AccessTokenModel(AccessTokenProtocol):
|
||||
token: str
|
||||
user_id: uuid.UUID
|
||||
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
|
||||
created_at: datetime = dataclasses.field(
|
||||
default_factory=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]):
|
||||
store: Dict[str, AccessToken]
|
||||
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessTokenModel]):
|
||||
store: Dict[str, AccessTokenModel]
|
||||
|
||||
def __init__(self):
|
||||
self.access_token_model = AccessToken
|
||||
self.store = {}
|
||||
|
||||
async def get_by_token(
|
||||
self, token: str, max_age: Optional[datetime] = None
|
||||
) -> Optional[AccessToken]:
|
||||
) -> Optional[AccessTokenModel]:
|
||||
try:
|
||||
access_token = self.store[token]
|
||||
if max_age is not None and access_token.created_at < max_age:
|
||||
@ -33,15 +40,20 @@ class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]):
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
async def create(self, access_token: AccessToken) -> AccessToken:
|
||||
async def create(self, create_dict: Dict[str, Any]) -> AccessTokenModel:
|
||||
access_token = AccessTokenModel(**create_dict)
|
||||
self.store[access_token.token] = access_token
|
||||
return access_token
|
||||
|
||||
async def update(self, access_token: AccessToken) -> AccessToken:
|
||||
async def update(
|
||||
self, access_token: AccessTokenModel, update_dict: Dict[str, Any]
|
||||
) -> AccessTokenModel:
|
||||
for field, value in update_dict.items():
|
||||
setattr(access_token, field, value)
|
||||
self.store[access_token.token] = access_token
|
||||
return access_token
|
||||
|
||||
async def delete(self, access_token: AccessToken) -> None:
|
||||
async def delete(self, access_token: AccessTokenModel) -> None:
|
||||
try:
|
||||
del self.store[access_token.token]
|
||||
except KeyError:
|
||||
@ -62,14 +74,18 @@ def database_strategy(access_token_database: AccessTokenDatabaseMock):
|
||||
class TestReadToken:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_token(
|
||||
self, database_strategy: DatabaseStrategy, user_manager
|
||||
self,
|
||||
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||
user_manager,
|
||||
):
|
||||
authenticated_user = await database_strategy.read_token(None, user_manager)
|
||||
assert authenticated_user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_token(
|
||||
self, database_strategy: DatabaseStrategy, user_manager
|
||||
self,
|
||||
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||
user_manager,
|
||||
):
|
||||
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
|
||||
assert authenticated_user is None
|
||||
@ -77,14 +93,15 @@ class TestReadToken:
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_token_not_existing_user(
|
||||
self,
|
||||
database_strategy: DatabaseStrategy,
|
||||
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||
access_token_database: AccessTokenDatabaseMock,
|
||||
user_manager,
|
||||
):
|
||||
await access_token_database.create(
|
||||
AccessToken(
|
||||
token="TOKEN", user_id=uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f")
|
||||
)
|
||||
{
|
||||
"token": "TOKEN",
|
||||
"user_id": uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"),
|
||||
}
|
||||
)
|
||||
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
|
||||
assert authenticated_user is None
|
||||
@ -92,12 +109,12 @@ class TestReadToken:
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_token(
|
||||
self,
|
||||
database_strategy: DatabaseStrategy,
|
||||
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||
access_token_database: AccessTokenDatabaseMock,
|
||||
user_manager,
|
||||
user,
|
||||
user: UserModel,
|
||||
):
|
||||
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id))
|
||||
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
|
||||
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
|
||||
assert authenticated_user is not None
|
||||
assert authenticated_user.id == user.id
|
||||
@ -106,9 +123,9 @@ class TestReadToken:
|
||||
@pytest.mark.authentication
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_token(
|
||||
database_strategy: DatabaseStrategy,
|
||||
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||
access_token_database: AccessTokenDatabaseMock,
|
||||
user,
|
||||
user: UserModel,
|
||||
):
|
||||
token = await database_strategy.write_token(user)
|
||||
|
||||
@ -120,11 +137,11 @@ async def test_write_token(
|
||||
@pytest.mark.authentication
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_token(
|
||||
database_strategy: DatabaseStrategy,
|
||||
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||
access_token_database: AccessTokenDatabaseMock,
|
||||
user,
|
||||
user: UserModel,
|
||||
):
|
||||
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id))
|
||||
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
|
||||
|
||||
await database_strategy.destroy_token("TOKEN", user)
|
||||
|
||||
|
Reference in New Issue
Block a user