mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-26 12:31:25 +08:00
Revamp AccessToken DB strategy to adopt generic model approach
This commit is contained in:
@ -3,9 +3,9 @@ from fastapi_users.authentication.strategy.base import (
|
|||||||
StrategyDestroyNotSupportedError,
|
StrategyDestroyNotSupportedError,
|
||||||
)
|
)
|
||||||
from fastapi_users.authentication.strategy.db import (
|
from fastapi_users.authentication.strategy.db import (
|
||||||
A,
|
AP,
|
||||||
AccessTokenDatabase,
|
AccessTokenDatabase,
|
||||||
BaseAccessToken,
|
AccessTokenProtocol,
|
||||||
DatabaseStrategy,
|
DatabaseStrategy,
|
||||||
)
|
)
|
||||||
from fastapi_users.authentication.strategy.jwt import JWTStrategy
|
from fastapi_users.authentication.strategy.jwt import JWTStrategy
|
||||||
@ -16,9 +16,9 @@ except ImportError: # pragma: no cover
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"A",
|
"AP",
|
||||||
"AccessTokenDatabase",
|
"AccessTokenDatabase",
|
||||||
"BaseAccessToken",
|
"AccessTokenProtocol",
|
||||||
"DatabaseStrategy",
|
"DatabaseStrategy",
|
||||||
"JWTStrategy",
|
"JWTStrategy",
|
||||||
"Strategy",
|
"Strategy",
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
|
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
|
||||||
from fastapi_users.authentication.strategy.db.models import A, BaseAccessToken
|
from fastapi_users.authentication.strategy.db.models import AP, AccessTokenProtocol
|
||||||
from fastapi_users.authentication.strategy.db.strategy import DatabaseStrategy
|
from fastapi_users.authentication.strategy.db.strategy import DatabaseStrategy
|
||||||
|
|
||||||
__all__ = ["A", "AccessTokenDatabase", "BaseAccessToken", "DatabaseStrategy"]
|
__all__ = ["AP", "AccessTokenDatabase", "AccessTokenProtocol", "DatabaseStrategy"]
|
||||||
|
@ -1,38 +1,32 @@
|
|||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Generic, Optional, Type
|
from typing import Any, Dict, Generic, Optional
|
||||||
|
|
||||||
if sys.version_info < (3, 8):
|
if sys.version_info < (3, 8):
|
||||||
from typing_extensions import Protocol # pragma: no cover
|
from typing_extensions import Protocol # pragma: no cover
|
||||||
else:
|
else:
|
||||||
from typing import Protocol # pragma: no cover
|
from typing import Protocol # pragma: no cover
|
||||||
|
|
||||||
from fastapi_users.authentication.strategy.db.models import A
|
from fastapi_users.authentication.strategy.db.models import AP
|
||||||
|
|
||||||
|
|
||||||
class AccessTokenDatabase(Protocol, Generic[A]):
|
class AccessTokenDatabase(Protocol, Generic[AP]):
|
||||||
"""
|
"""Protocol for retrieving, creating and updating access tokens from a database."""
|
||||||
Protocol for retrieving, creating and updating access tokens from a database.
|
|
||||||
|
|
||||||
:param access_token_model: Pydantic model of an access token.
|
|
||||||
"""
|
|
||||||
|
|
||||||
access_token_model: Type[A]
|
|
||||||
|
|
||||||
async def get_by_token(
|
async def get_by_token(
|
||||||
self, token: str, max_age: Optional[datetime] = None
|
self, token: str, max_age: Optional[datetime] = None
|
||||||
) -> Optional[A]:
|
) -> Optional[AP]:
|
||||||
"""Get a single access token by token."""
|
"""Get a single access token by token."""
|
||||||
... # pragma: no cover
|
... # pragma: no cover
|
||||||
|
|
||||||
async def create(self, access_token: A) -> A:
|
async def create(self, create_dict: Dict[str, Any]) -> AP:
|
||||||
"""Create an access token."""
|
"""Create an access token."""
|
||||||
... # pragma: no cover
|
... # pragma: no cover
|
||||||
|
|
||||||
async def update(self, access_token: A) -> A:
|
async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP:
|
||||||
"""Update an access token."""
|
"""Update an access token."""
|
||||||
... # pragma: no cover
|
... # pragma: no cover
|
||||||
|
|
||||||
async def delete(self, access_token: A) -> None:
|
async def delete(self, access_token: AP) -> None:
|
||||||
"""Delete an access token."""
|
"""Delete an access token."""
|
||||||
... # pragma: no cover
|
... # pragma: no cover
|
||||||
|
@ -1,22 +1,23 @@
|
|||||||
from datetime import datetime, timezone
|
import sys
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from pydantic import UUID4, BaseModel, Field
|
if sys.version_info < (3, 8):
|
||||||
|
from typing_extensions import Protocol # pragma: no cover
|
||||||
|
else:
|
||||||
|
from typing import Protocol # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
def now_utc():
|
class AccessTokenProtocol(Protocol):
|
||||||
return datetime.now(timezone.utc)
|
"""Access token protocol that ORM model should follow."""
|
||||||
|
|
||||||
|
|
||||||
class BaseAccessToken(BaseModel):
|
|
||||||
"""Base access token model."""
|
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
user_id: UUID4
|
user_id: uuid.UUID
|
||||||
created_at: datetime = Field(default_factory=now_utc)
|
created_at: datetime
|
||||||
|
|
||||||
class Config:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
orm_mode = True
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
A = TypeVar("A", bound=BaseAccessToken)
|
AP = TypeVar("AP", bound=AccessTokenProtocol)
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Generic, Optional
|
from typing import Any, Dict, Generic, Optional
|
||||||
|
|
||||||
from fastapi_users import models
|
from fastapi_users import models
|
||||||
from fastapi_users.authentication.strategy.base import Strategy
|
from fastapi_users.authentication.strategy.base import Strategy
|
||||||
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
|
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
|
||||||
from fastapi_users.authentication.strategy.db.models import A
|
from fastapi_users.authentication.strategy.db.models import AP
|
||||||
from fastapi_users.manager import BaseUserManager, UserNotExists
|
from fastapi_users.manager import BaseUserManager, UserNotExists
|
||||||
|
|
||||||
|
|
||||||
class DatabaseStrategy(Strategy, Generic[models.UP, A]):
|
class DatabaseStrategy(Strategy, Generic[models.UP, AP]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, database: AccessTokenDatabase[A], lifetime_seconds: Optional[int] = None
|
self, database: AccessTokenDatabase[AP], lifetime_seconds: Optional[int] = None
|
||||||
):
|
):
|
||||||
self.database = database
|
self.database = database
|
||||||
self.lifetime_seconds = lifetime_seconds
|
self.lifetime_seconds = lifetime_seconds
|
||||||
@ -39,8 +39,8 @@ class DatabaseStrategy(Strategy, Generic[models.UP, A]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def write_token(self, user: models.UP) -> str:
|
async def write_token(self, user: models.UP) -> str:
|
||||||
access_token = self._create_access_token(user)
|
access_token_dict = self._create_access_token_dict(user)
|
||||||
await self.database.create(access_token)
|
access_token = await self.database.create(access_token_dict)
|
||||||
return access_token.token
|
return access_token.token
|
||||||
|
|
||||||
async def destroy_token(self, token: str, user: models.UP) -> None:
|
async def destroy_token(self, token: str, user: models.UP) -> None:
|
||||||
@ -48,6 +48,6 @@ class DatabaseStrategy(Strategy, Generic[models.UP, A]):
|
|||||||
if access_token is not None:
|
if access_token is not None:
|
||||||
await self.database.delete(access_token)
|
await self.database.delete(access_token)
|
||||||
|
|
||||||
def _create_access_token(self, user: models.UP) -> A:
|
def _create_access_token_dict(self, user: models.UP) -> Dict[str, Any]:
|
||||||
token = secrets.token_urlsafe()
|
token = secrets.token_urlsafe()
|
||||||
return self.database.access_token_model(token=token, user_id=user.id)
|
return {"token": token, "user_id": user.id}
|
||||||
|
@ -9,6 +9,8 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
class UserProtocol(Protocol):
|
class UserProtocol(Protocol):
|
||||||
|
"""User protocol that ORM model should follow."""
|
||||||
|
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
email: str
|
email: str
|
||||||
hashed_password: str
|
hashed_password: str
|
||||||
@ -21,6 +23,8 @@ class UserProtocol(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class OAuthAccountProtocol(Protocol):
|
class OAuthAccountProtocol(Protocol):
|
||||||
|
"""OAuth account protocol that ORM model should follow."""
|
||||||
|
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
oauth_name: str
|
oauth_name: str
|
||||||
access_token: str
|
access_token: str
|
||||||
@ -38,6 +42,8 @@ OAP = TypeVar("OAP", bound=OAuthAccountProtocol)
|
|||||||
|
|
||||||
|
|
||||||
class UserOAuthProtocol(UserProtocol, Generic[OAP]):
|
class UserOAuthProtocol(UserProtocol, Generic[OAP]):
|
||||||
|
"""User protocol including a list of OAuth accounts."""
|
||||||
|
|
||||||
oauth_accounts: List[OAP]
|
oauth_accounts: List[OAP]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,30 +1,37 @@
|
|||||||
|
import dataclasses
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from typing import Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from fastapi_users.authentication.strategy import (
|
from fastapi_users.authentication.strategy import (
|
||||||
AccessTokenDatabase,
|
AccessTokenDatabase,
|
||||||
BaseAccessToken,
|
AccessTokenProtocol,
|
||||||
DatabaseStrategy,
|
DatabaseStrategy,
|
||||||
)
|
)
|
||||||
|
from tests.conftest import UserModel
|
||||||
|
|
||||||
|
|
||||||
class AccessToken(BaseAccessToken):
|
@dataclasses.dataclass
|
||||||
pass
|
class AccessTokenModel(AccessTokenProtocol):
|
||||||
|
token: str
|
||||||
|
user_id: uuid.UUID
|
||||||
|
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
|
||||||
|
created_at: datetime = dataclasses.field(
|
||||||
|
default_factory=lambda: datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]):
|
class AccessTokenDatabaseMock(AccessTokenDatabase[AccessTokenModel]):
|
||||||
store: Dict[str, AccessToken]
|
store: Dict[str, AccessTokenModel]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.access_token_model = AccessToken
|
|
||||||
self.store = {}
|
self.store = {}
|
||||||
|
|
||||||
async def get_by_token(
|
async def get_by_token(
|
||||||
self, token: str, max_age: Optional[datetime] = None
|
self, token: str, max_age: Optional[datetime] = None
|
||||||
) -> Optional[AccessToken]:
|
) -> Optional[AccessTokenModel]:
|
||||||
try:
|
try:
|
||||||
access_token = self.store[token]
|
access_token = self.store[token]
|
||||||
if max_age is not None and access_token.created_at < max_age:
|
if max_age is not None and access_token.created_at < max_age:
|
||||||
@ -33,15 +40,20 @@ class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def create(self, access_token: AccessToken) -> AccessToken:
|
async def create(self, create_dict: Dict[str, Any]) -> AccessTokenModel:
|
||||||
|
access_token = AccessTokenModel(**create_dict)
|
||||||
self.store[access_token.token] = access_token
|
self.store[access_token.token] = access_token
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
async def update(self, access_token: AccessToken) -> AccessToken:
|
async def update(
|
||||||
|
self, access_token: AccessTokenModel, update_dict: Dict[str, Any]
|
||||||
|
) -> AccessTokenModel:
|
||||||
|
for field, value in update_dict.items():
|
||||||
|
setattr(access_token, field, value)
|
||||||
self.store[access_token.token] = access_token
|
self.store[access_token.token] = access_token
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
async def delete(self, access_token: AccessToken) -> None:
|
async def delete(self, access_token: AccessTokenModel) -> None:
|
||||||
try:
|
try:
|
||||||
del self.store[access_token.token]
|
del self.store[access_token.token]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -62,14 +74,18 @@ def database_strategy(access_token_database: AccessTokenDatabaseMock):
|
|||||||
class TestReadToken:
|
class TestReadToken:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_missing_token(
|
async def test_missing_token(
|
||||||
self, database_strategy: DatabaseStrategy, user_manager
|
self,
|
||||||
|
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||||
|
user_manager,
|
||||||
):
|
):
|
||||||
authenticated_user = await database_strategy.read_token(None, user_manager)
|
authenticated_user = await database_strategy.read_token(None, user_manager)
|
||||||
assert authenticated_user is None
|
assert authenticated_user is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_token(
|
async def test_invalid_token(
|
||||||
self, database_strategy: DatabaseStrategy, user_manager
|
self,
|
||||||
|
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||||
|
user_manager,
|
||||||
):
|
):
|
||||||
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
|
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
|
||||||
assert authenticated_user is None
|
assert authenticated_user is None
|
||||||
@ -77,14 +93,15 @@ class TestReadToken:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_token_not_existing_user(
|
async def test_valid_token_not_existing_user(
|
||||||
self,
|
self,
|
||||||
database_strategy: DatabaseStrategy,
|
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||||
access_token_database: AccessTokenDatabaseMock,
|
access_token_database: AccessTokenDatabaseMock,
|
||||||
user_manager,
|
user_manager,
|
||||||
):
|
):
|
||||||
await access_token_database.create(
|
await access_token_database.create(
|
||||||
AccessToken(
|
{
|
||||||
token="TOKEN", user_id=uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f")
|
"token": "TOKEN",
|
||||||
)
|
"user_id": uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
|
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
|
||||||
assert authenticated_user is None
|
assert authenticated_user is None
|
||||||
@ -92,12 +109,12 @@ class TestReadToken:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_token(
|
async def test_valid_token(
|
||||||
self,
|
self,
|
||||||
database_strategy: DatabaseStrategy,
|
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||||
access_token_database: AccessTokenDatabaseMock,
|
access_token_database: AccessTokenDatabaseMock,
|
||||||
user_manager,
|
user_manager,
|
||||||
user,
|
user: UserModel,
|
||||||
):
|
):
|
||||||
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id))
|
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
|
||||||
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
|
authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
|
||||||
assert authenticated_user is not None
|
assert authenticated_user is not None
|
||||||
assert authenticated_user.id == user.id
|
assert authenticated_user.id == user.id
|
||||||
@ -106,9 +123,9 @@ class TestReadToken:
|
|||||||
@pytest.mark.authentication
|
@pytest.mark.authentication
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_write_token(
|
async def test_write_token(
|
||||||
database_strategy: DatabaseStrategy,
|
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||||
access_token_database: AccessTokenDatabaseMock,
|
access_token_database: AccessTokenDatabaseMock,
|
||||||
user,
|
user: UserModel,
|
||||||
):
|
):
|
||||||
token = await database_strategy.write_token(user)
|
token = await database_strategy.write_token(user)
|
||||||
|
|
||||||
@ -120,11 +137,11 @@ async def test_write_token(
|
|||||||
@pytest.mark.authentication
|
@pytest.mark.authentication
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_destroy_token(
|
async def test_destroy_token(
|
||||||
database_strategy: DatabaseStrategy,
|
database_strategy: DatabaseStrategy[UserModel, AccessTokenModel],
|
||||||
access_token_database: AccessTokenDatabaseMock,
|
access_token_database: AccessTokenDatabaseMock,
|
||||||
user,
|
user: UserModel,
|
||||||
):
|
):
|
||||||
await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id))
|
await access_token_database.create({"token": "TOKEN", "user_id": user.id})
|
||||||
|
|
||||||
await database_strategy.destroy_token("TOKEN", user)
|
await database_strategy.destroy_token("TOKEN", user)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user