Revamp AccessToken DB strategy to adopt generic model approach

This commit is contained in:
François Voron
2022-04-29 15:45:14 +02:00
parent e271cc1352
commit 87ac51a7bd
7 changed files with 85 additions and 67 deletions

View File

@ -3,9 +3,9 @@ from fastapi_users.authentication.strategy.base import (
StrategyDestroyNotSupportedError,
)
from fastapi_users.authentication.strategy.db import (
A,
AP,
AccessTokenDatabase,
BaseAccessToken,
AccessTokenProtocol,
DatabaseStrategy,
)
from fastapi_users.authentication.strategy.jwt import JWTStrategy
@ -16,9 +16,9 @@ except ImportError: # pragma: no cover
pass
__all__ = [
"A",
"AP",
"AccessTokenDatabase",
"BaseAccessToken",
"AccessTokenProtocol",
"DatabaseStrategy",
"JWTStrategy",
"Strategy",

View File

@ -1,5 +1,5 @@
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
from fastapi_users.authentication.strategy.db.models import A, BaseAccessToken
from fastapi_users.authentication.strategy.db.models import AP, AccessTokenProtocol
from fastapi_users.authentication.strategy.db.strategy import DatabaseStrategy
__all__ = ["A", "AccessTokenDatabase", "BaseAccessToken", "DatabaseStrategy"]
__all__ = ["AP", "AccessTokenDatabase", "AccessTokenProtocol", "DatabaseStrategy"]

View File

@ -1,38 +1,32 @@
import sys
from datetime import datetime
from typing import Generic, Optional, Type
from typing import Any, Dict, Generic, Optional
if sys.version_info < (3, 8):
from typing_extensions import Protocol # pragma: no cover
else:
from typing import Protocol # pragma: no cover
from fastapi_users.authentication.strategy.db.models import A
from fastapi_users.authentication.strategy.db.models import AP
class AccessTokenDatabase(Protocol, Generic[A]):
"""
Protocol for retrieving, creating and updating access tokens from a database.
:param access_token_model: Pydantic model of an access token.
"""
access_token_model: Type[A]
class AccessTokenDatabase(Protocol, Generic[AP]):
"""Protocol for retrieving, creating and updating access tokens from a database."""
async def get_by_token(
self, token: str, max_age: Optional[datetime] = None
) -> Optional[A]:
) -> Optional[AP]:
"""Get a single access token by token."""
... # pragma: no cover
async def create(self, access_token: A) -> A:
async def create(self, create_dict: Dict[str, Any]) -> AP:
"""Create an access token."""
... # pragma: no cover
async def update(self, access_token: A) -> A:
async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP:
"""Update an access token."""
... # pragma: no cover
async def delete(self, access_token: A) -> None:
async def delete(self, access_token: AP) -> None:
"""Delete an access token."""
... # pragma: no cover

View File

@ -1,22 +1,23 @@
from datetime import datetime, timezone
import sys
import uuid
from datetime import datetime
from typing import TypeVar
from pydantic import UUID4, BaseModel, Field
if sys.version_info < (3, 8):
from typing_extensions import Protocol # pragma: no cover
else:
from typing import Protocol # pragma: no cover
def now_utc():
return datetime.now(timezone.utc)
class BaseAccessToken(BaseModel):
"""Base access token model."""
class AccessTokenProtocol(Protocol):
"""Access token protocol that ORM model should follow."""
token: str
user_id: UUID4
created_at: datetime = Field(default_factory=now_utc)
user_id: uuid.UUID
created_at: datetime
class Config:
orm_mode = True
def __init__(self, *args, **kwargs) -> None:
... # pragma: no cover
A = TypeVar("A", bound=BaseAccessToken)
AP = TypeVar("AP", bound=AccessTokenProtocol)

View File

@ -1,17 +1,17 @@
import secrets
from datetime import datetime, timedelta, timezone
from typing import Generic, Optional
from typing import Any, Dict, Generic, Optional
from fastapi_users import models
from fastapi_users.authentication.strategy.base import Strategy
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
from fastapi_users.authentication.strategy.db.models import A
from fastapi_users.authentication.strategy.db.models import AP
from fastapi_users.manager import BaseUserManager, UserNotExists
class DatabaseStrategy(Strategy, Generic[models.UP, A]):
class DatabaseStrategy(Strategy, Generic[models.UP, AP]):
def __init__(
self, database: AccessTokenDatabase[A], lifetime_seconds: Optional[int] = None
self, database: AccessTokenDatabase[AP], lifetime_seconds: Optional[int] = None
):
self.database = database
self.lifetime_seconds = lifetime_seconds
@ -39,8 +39,8 @@ class DatabaseStrategy(Strategy, Generic[models.UP, A]):
return None
async def write_token(self, user: models.UP) -> str:
access_token = self._create_access_token(user)
await self.database.create(access_token)
access_token_dict = self._create_access_token_dict(user)
access_token = await self.database.create(access_token_dict)
return access_token.token
async def destroy_token(self, token: str, user: models.UP) -> None:
@ -48,6 +48,6 @@ class DatabaseStrategy(Strategy, Generic[models.UP, A]):
if access_token is not None:
await self.database.delete(access_token)
def _create_access_token(self, user: models.UP) -> A:
def _create_access_token_dict(self, user: models.UP) -> Dict[str, Any]:
token = secrets.token_urlsafe()
return self.database.access_token_model(token=token, user_id=user.id)
return {"token": token, "user_id": user.id}

View File

@ -9,6 +9,8 @@ else:
class UserProtocol(Protocol):
"""User protocol that ORM model should follow."""
id: uuid.UUID
email: str
hashed_password: str
@ -21,6 +23,8 @@ class UserProtocol(Protocol):
class OAuthAccountProtocol(Protocol):
"""OAuth account protocol that ORM model should follow."""
id: uuid.UUID
oauth_name: str
access_token: str
@ -38,6 +42,8 @@ OAP = TypeVar("OAP", bound=OAuthAccountProtocol)
class UserOAuthProtocol(UserProtocol, Generic[OAP]):
"""User protocol including a list of OAuth accounts."""
oauth_accounts: List[OAP]

View File

@ -1,30 +1,37 @@
import dataclasses
import uuid
from datetime import datetime
from typing import Dict, Optional
from datetime import datetime, timezone
from typing import Any, Dict, Optional
import pytest
from fastapi_users.authentication.strategy import (
AccessTokenDatabase,
BaseAccessToken,
AccessTokenProtocol,
DatabaseStrategy,
)
from tests.conftest import UserModel
@dataclasses.dataclass
class AccessTokenModel(AccessTokenProtocol):
token: str
user_id: uuid.UUID
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
created_at: datetime = dataclasses.field(
default_factory=lambda: datetime.now(timezone.utc)
)
class AccessToken(BaseAccessToken):
pass
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]):
store: Dict[str, AccessToken]
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessTokenModel]):
store: Dict[str, AccessTokenModel]
def __init__(self):
self.access_token_model = AccessToken
self.store = {}
async def get_by_token(
self, token: str, max_age: Optional[datetime] = None
) -> Optional[AccessToken]:
) -> Optional[AccessTokenModel]:
try:
access_token = self.store[token]
if max_age is not None and access_token.created_at < max_age:
@ -33,15 +40,20 @@ class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]):
except KeyError:
return None
async def create(self, access_token: AccessToken) -> AccessToken:
async def create(self, create_dict: Dict[str, Any]) -> AccessTokenModel:
access_token = AccessTokenModel(**create_dict)
self.store[access_token.token] = access_token
return access_token
async def update(self, access_token: AccessToken) -> AccessToken:
async def update(
self, access_token: AccessTokenModel, update_dict: Dict[str, Any]
) -> AccessTokenModel:
for field, value in update_dict.items():
setattr(access_token, field, value)
self.store[access_token.token] = access_token
return access_token
async def delete(self, access_token: AccessToken) -> None:
async def delete(self, access_token: AccessTokenModel) -> None:
try:
del self.store[access_token.token]
except KeyError:
@ -62,14 +74,18 @@ def database_strategy(access_token_database: AccessTokenDatabaseMock):
class TestReadToken:
@pytest.mark.asyncio
async def test_missing_token(
self, database_strategy: DatabaseStrategy, user_manager
self,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
user_manager,
):
authenticated_user = await database_strategy.read_token(None, user_manager)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_invalid_token(
self, database_strategy: DatabaseStrategy, user_manager
self,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
user_manager,
):
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is None
@ -77,14 +93,15 @@ class TestReadToken:
@pytest.mark.asyncio
async def test_valid_token_not_existing_user(
self,
database_strategy: DatabaseStrategy,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user_manager,
):
await access_token_database.create(
AccessToken(
token="TOKEN", user_id=uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f")
)
{
"token": "TOKEN",
"user_id": uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"),
}
)
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is None
@ -92,12 +109,12 @@ class TestReadToken:
@pytest.mark.asyncio
async def test_valid_token(
self,
database_strategy: DatabaseStrategy,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user_manager,
user,
user: UserModel,
):
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id))
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is not None
assert authenticated_user.id == user.id
@ -106,9 +123,9 @@ class TestReadToken:
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_write_token(
database_strategy: DatabaseStrategy,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user,
user: UserModel,
):
token = await database_strategy.write_token(user)
@ -120,11 +137,11 @@ async def test_write_token(
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_destroy_token(
database_strategy: DatabaseStrategy,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock,
user,
user: UserModel,
):
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id))
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
await database_strategy.destroy_token("TOKEN", user)