Revamp AccessToken DB strategy to adopt generic model approach

This commit is contained in:
François Voron
2022-04-29 15:45:14 +02:00
parent e271cc1352
commit 87ac51a7bd
7 changed files with 85 additions and 67 deletions

View File

@ -3,9 +3,9 @@ from fastapi_users.authentication.strategy.base import (
StrategyDestroyNotSupportedError, StrategyDestroyNotSupportedError,
) )
from fastapi_users.authentication.strategy.db import ( from fastapi_users.authentication.strategy.db import (
A, AP,
AccessTokenDatabase, AccessTokenDatabase,
BaseAccessToken, AccessTokenProtocol,
DatabaseStrategy, DatabaseStrategy,
) )
from fastapi_users.authentication.strategy.jwt import JWTStrategy from fastapi_users.authentication.strategy.jwt import JWTStrategy
@ -16,9 +16,9 @@ except ImportError: # pragma: no cover
pass pass
__all__ = [ __all__ = [
"A", "AP",
"AccessTokenDatabase", "AccessTokenDatabase",
"BaseAccessToken", "AccessTokenProtocol",
"DatabaseStrategy", "DatabaseStrategy",
"JWTStrategy", "JWTStrategy",
"Strategy", "Strategy",

View File

@ -1,5 +1,5 @@
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
from fastapi_users.authentication.strategy.db.models import A, BaseAccessToken from fastapi_users.authentication.strategy.db.models import AP, AccessTokenProtocol
from fastapi_users.authentication.strategy.db.strategy import DatabaseStrategy from fastapi_users.authentication.strategy.db.strategy import DatabaseStrategy
__all__ = ["A", "AccessTokenDatabase", "BaseAccessToken", "DatabaseStrategy"] __all__ = ["AP", "AccessTokenDatabase", "AccessTokenProtocol", "DatabaseStrategy"]

View File

@ -1,38 +1,32 @@
import sys import sys
from datetime import datetime from datetime import datetime
from typing import Generic, Optional, Type from typing import Any, Dict, Generic, Optional
if sys.version_info < (3, 8): if sys.version_info < (3, 8):
from typing_extensions import Protocol # pragma: no cover from typing_extensions import Protocol # pragma: no cover
else: else:
from typing import Protocol # pragma: no cover from typing import Protocol # pragma: no cover
from fastapi_users.authentication.strategy.db.models import A from fastapi_users.authentication.strategy.db.models import AP
class AccessTokenDatabase(Protocol, Generic[A]): class AccessTokenDatabase(Protocol, Generic[AP]):
""" """Protocol for retrieving, creating and updating access tokens from a database."""
Protocol for retrieving, creating and updating access tokens from a database.
:param access_token_model: Pydantic model of an access token.
"""
access_token_model: Type[A]
async def get_by_token( async def get_by_token(
self, token: str, max_age: Optional[datetime] = None self, token: str, max_age: Optional[datetime] = None
) -> Optional[A]: ) -> Optional[AP]:
"""Get a single access token by token.""" """Get a single access token by token."""
... # pragma: no cover ... # pragma: no cover
async def create(self, access_token: A) -> A: async def create(self, create_dict: Dict[str, Any]) -> AP:
"""Create an access token.""" """Create an access token."""
... # pragma: no cover ... # pragma: no cover
async def update(self, access_token: A) -> A: async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP:
"""Update an access token.""" """Update an access token."""
... # pragma: no cover ... # pragma: no cover
async def delete(self, access_token: A) -> None: async def delete(self, access_token: AP) -> None:
"""Delete an access token.""" """Delete an access token."""
... # pragma: no cover ... # pragma: no cover

View File

@ -1,22 +1,23 @@
from datetime import datetime, timezone import sys
import uuid
from datetime import datetime
from typing import TypeVar from typing import TypeVar
from pydantic import UUID4, BaseModel, Field if sys.version_info < (3, 8):
from typing_extensions import Protocol # pragma: no cover
else:
from typing import Protocol # pragma: no cover
def now_utc(): class AccessTokenProtocol(Protocol):
return datetime.now(timezone.utc) """Access token protocol that ORM model should follow."""
class BaseAccessToken(BaseModel):
"""Base access token model."""
token: str token: str
user_id: UUID4 user_id: uuid.UUID
created_at: datetime = Field(default_factory=now_utc) created_at: datetime
class Config: def __init__(self, *args, **kwargs) -> None:
orm_mode = True ... # pragma: no cover
A = TypeVar("A", bound=BaseAccessToken) AP = TypeVar("AP", bound=AccessTokenProtocol)

View File

@ -1,17 +1,17 @@
import secrets import secrets
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Generic, Optional from typing import Any, Dict, Generic, Optional
from fastapi_users import models from fastapi_users import models
from fastapi_users.authentication.strategy.base import Strategy from fastapi_users.authentication.strategy.base import Strategy
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
from fastapi_users.authentication.strategy.db.models import A from fastapi_users.authentication.strategy.db.models import AP
from fastapi_users.manager import BaseUserManager, UserNotExists from fastapi_users.manager import BaseUserManager, UserNotExists
class DatabaseStrategy(Strategy, Generic[models.UP, A]): class DatabaseStrategy(Strategy, Generic[models.UP, AP]):
def __init__( def __init__(
self, database: AccessTokenDatabase[A], lifetime_seconds: Optional[int] = None self, database: AccessTokenDatabase[AP], lifetime_seconds: Optional[int] = None
): ):
self.database = database self.database = database
self.lifetime_seconds = lifetime_seconds self.lifetime_seconds = lifetime_seconds
@ -39,8 +39,8 @@ class DatabaseStrategy(Strategy, Generic[models.UP, A]):
return None return None
async def write_token(self, user: models.UP) -> str: async def write_token(self, user: models.UP) -> str:
access_token = self._create_access_token(user) access_token_dict = self._create_access_token_dict(user)
await self.database.create(access_token) access_token = await self.database.create(access_token_dict)
return access_token.token return access_token.token
async def destroy_token(self, token: str, user: models.UP) -> None: async def destroy_token(self, token: str, user: models.UP) -> None:
@ -48,6 +48,6 @@ class DatabaseStrategy(Strategy, Generic[models.UP, A]):
if access_token is not None: if access_token is not None:
await self.database.delete(access_token) await self.database.delete(access_token)
def _create_access_token(self, user: models.UP) -> A: def _create_access_token_dict(self, user: models.UP) -> Dict[str, Any]:
token = secrets.token_urlsafe() token = secrets.token_urlsafe()
return self.database.access_token_model(token=token, user_id=user.id) return {"token": token, "user_id": user.id}

View File

@ -9,6 +9,8 @@ else:
class UserProtocol(Protocol): class UserProtocol(Protocol):
"""User protocol that ORM model should follow."""
id: uuid.UUID id: uuid.UUID
email: str email: str
hashed_password: str hashed_password: str
@ -21,6 +23,8 @@ class UserProtocol(Protocol):
class OAuthAccountProtocol(Protocol): class OAuthAccountProtocol(Protocol):
"""OAuth account protocol that ORM model should follow."""
id: uuid.UUID id: uuid.UUID
oauth_name: str oauth_name: str
access_token: str access_token: str
@ -38,6 +42,8 @@ OAP = TypeVar("OAP", bound=OAuthAccountProtocol)
class UserOAuthProtocol(UserProtocol, Generic[OAP]): class UserOAuthProtocol(UserProtocol, Generic[OAP]):
"""User protocol including a list of OAuth accounts."""
oauth_accounts: List[OAP] oauth_accounts: List[OAP]

View File

@ -1,30 +1,37 @@
import dataclasses
import uuid import uuid
from datetime import datetime from datetime import datetime, timezone
from typing import Dict, Optional from typing import Any, Dict, Optional
import pytest import pytest
from fastapi_users.authentication.strategy import ( from fastapi_users.authentication.strategy import (
AccessTokenDatabase, AccessTokenDatabase,
BaseAccessToken, AccessTokenProtocol,
DatabaseStrategy, DatabaseStrategy,
) )
from tests.conftest import UserModel
@dataclasses.dataclass
class AccessTokenModel(AccessTokenProtocol):
token: str
user_id: uuid.UUID
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
created_at: datetime = dataclasses.field(
default_factory=lambda: datetime.now(timezone.utc)
)
class AccessToken(BaseAccessToken): class AccessTokenDatabaseMock(AccessTokenDatabase[AccessTokenModel]):
pass store: Dict[str, AccessTokenModel]
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]):
store: Dict[str, AccessToken]
def __init__(self): def __init__(self):
self.access_token_model = AccessToken
self.store = {} self.store = {}
async def get_by_token( async def get_by_token(
self, token: str, max_age: Optional[datetime] = None self, token: str, max_age: Optional[datetime] = None
) -> Optional[AccessToken]: ) -> Optional[AccessTokenModel]:
try: try:
access_token = self.store[token] access_token = self.store[token]
if max_age is not None and access_token.created_at < max_age: if max_age is not None and access_token.created_at < max_age:
@ -33,15 +40,20 @@ class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]):
except KeyError: except KeyError:
return None return None
async def create(self, access_token: AccessToken) -> AccessToken: async def create(self, create_dict: Dict[str, Any]) -> AccessTokenModel:
access_token = AccessTokenModel(**create_dict)
self.store[access_token.token] = access_token self.store[access_token.token] = access_token
return access_token return access_token
async def update(self, access_token: AccessToken) -> AccessToken: async def update(
self, access_token: AccessTokenModel, update_dict: Dict[str, Any]
) -> AccessTokenModel:
for field, value in update_dict.items():
setattr(access_token, field, value)
self.store[access_token.token] = access_token self.store[access_token.token] = access_token
return access_token return access_token
async def delete(self, access_token: AccessToken) -> None: async def delete(self, access_token: AccessTokenModel) -> None:
try: try:
del self.store[access_token.token] del self.store[access_token.token]
except KeyError: except KeyError:
@ -62,14 +74,18 @@ def database_strategy(access_token_database: AccessTokenDatabaseMock):
class TestReadToken: class TestReadToken:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_missing_token( async def test_missing_token(
self, database_strategy: DatabaseStrategy, user_manager self,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
user_manager,
): ):
authenticated_user = await database_strategy.read_token(None, user_manager) authenticated_user = await database_strategy.read_token(None, user_manager)
assert authenticated_user is None assert authenticated_user is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_token( async def test_invalid_token(
self, database_strategy: DatabaseStrategy, user_manager self,
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
user_manager,
): ):
authenticated_user = await database_strategy.read_token("TOKEN", user_manager) authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is None assert authenticated_user is None
@ -77,14 +93,15 @@ class TestReadToken:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_valid_token_not_existing_user( async def test_valid_token_not_existing_user(
self, self,
database_strategy: DatabaseStrategy, database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock, access_token_database: AccessTokenDatabaseMock,
user_manager, user_manager,
): ):
await access_token_database.create( await access_token_database.create(
AccessToken( {
token="TOKEN", user_id=uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f") "token": "TOKEN",
) "user_id": uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"),
}
) )
authenticated_user = await database_strategy.read_token("TOKEN", user_manager) authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is None assert authenticated_user is None
@ -92,12 +109,12 @@ class TestReadToken:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_valid_token( async def test_valid_token(
self, self,
database_strategy: DatabaseStrategy, database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock, access_token_database: AccessTokenDatabaseMock,
user_manager, user_manager,
user, user: UserModel,
): ):
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id)) await access_token_database.create({"token": "TOKEN", "user_id": user.id})
authenticated_user = await database_strategy.read_token("TOKEN", user_manager) authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
assert authenticated_user is not None assert authenticated_user is not None
assert authenticated_user.id == user.id assert authenticated_user.id == user.id
@ -106,9 +123,9 @@ class TestReadToken:
@pytest.mark.authentication @pytest.mark.authentication
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_write_token( async def test_write_token(
database_strategy: DatabaseStrategy, database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock, access_token_database: AccessTokenDatabaseMock,
user, user: UserModel,
): ):
token = await database_strategy.write_token(user) token = await database_strategy.write_token(user)
@ -120,11 +137,11 @@ async def test_write_token(
@pytest.mark.authentication @pytest.mark.authentication
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_destroy_token( async def test_destroy_token(
database_strategy: DatabaseStrategy, database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
access_token_database: AccessTokenDatabaseMock, access_token_database: AccessTokenDatabaseMock,
user, user: UserModel,
): ):
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id)) await access_token_database.create({"token": "TOKEN", "user_id": user.id})
await database_strategy.destroy_token("TOKEN", user) await database_strategy.destroy_token("TOKEN", user)