Make ID a generic instead of forcing UUIDs

This commit is contained in:
François Voron
2022-05-01 11:18:27 +02:00
parent 87ac51a7bd
commit 7093c9e38a
25 changed files with 143 additions and 123 deletions

View File

@ -51,7 +51,7 @@ class Authenticator:
def __init__(
self,
backends: Sequence[AuthenticationBackend],
get_user_manager: UserManagerDependency[models.UP],
get_user_manager: UserManagerDependency[models.UP, models.ID],
):
self.backends = backends
self.get_user_manager = get_user_manager
@ -148,7 +148,7 @@ class Authenticator:
async def _authenticate(
self,
*args,
user_manager: BaseUserManager[models.UP],
user_manager: BaseUserManager[models.UP, models.ID],
optional: bool = False,
active: bool = False,
verified: bool = False,
@ -163,7 +163,7 @@ class Authenticator:
for backend in self.backends:
if backend in enabled_backends:
token = kwargs[name_to_variable_name(backend.name)]
strategy: Strategy[models.UP] = kwargs[
strategy: Strategy[models.UP, models.ID] = kwargs[
name_to_strategy_variable_name(backend.name)
]
if token is not None:

View File

@ -33,7 +33,7 @@ class AuthenticationBackend(Generic[models.UP]):
self,
name: str,
transport: Transport,
get_strategy: DependencyCallable[Strategy[models.UP]],
get_strategy: DependencyCallable[Strategy[models.UP, models.ID]],
):
self.name = name
self.transport = transport
@ -41,7 +41,7 @@ class AuthenticationBackend(Generic[models.UP]):
async def login(
self,
strategy: Strategy[models.UP],
strategy: Strategy[models.UP, models.ID],
user: models.UP,
response: Response,
) -> Any:
@ -50,7 +50,7 @@ class AuthenticationBackend(Generic[models.UP]):
async def logout(
self,
strategy: Strategy[models.UP],
strategy: Strategy[models.UP, models.ID],
user: models.UP,
token: str,
response: Response,

View File

@ -14,9 +14,9 @@ class StrategyDestroyNotSupportedError(Exception):
pass
class Strategy(Protocol, Generic[models.UP]):
class Strategy(Protocol, Generic[models.UP, models.ID]):
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
... # pragma: no cover

View File

@ -1,5 +1,4 @@
import sys
import uuid
from datetime import datetime
from typing import TypeVar
@ -8,12 +7,14 @@ if sys.version_info < (3, 8):
else:
from typing import Protocol # pragma: no cover
from fastapi_users import models
class AccessTokenProtocol(Protocol):
class AccessTokenProtocol(Protocol[models.ID]):
"""Access token protocol that ORM model should follow."""
token: str
user_id: uuid.UUID
user_id: models.ID
created_at: datetime
def __init__(self, *args, **kwargs) -> None:

View File

@ -6,7 +6,7 @@ from fastapi_users import models
from fastapi_users.authentication.strategy.base import Strategy
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
from fastapi_users.authentication.strategy.db.models import AP
from fastapi_users.manager import BaseUserManager, UserNotExists
from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists
class DatabaseStrategy(Strategy, Generic[models.UP, AP]):
@ -17,7 +17,7 @@ class DatabaseStrategy(Strategy, Generic[models.UP, AP]):
self.lifetime_seconds = lifetime_seconds
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
if token is None:
return None
@ -33,9 +33,9 @@ class DatabaseStrategy(Strategy, Generic[models.UP, AP]):
return None
try:
user_id = access_token.user_id
return await user_manager.get(user_id)
except UserNotExists:
parsed_id = user_manager.parse_id(access_token.user_id)
return await user_manager.get(parsed_id)
except (UserNotExists, InvalidID):
return None
async def write_token(self, user: models.UP) -> str:

View File

@ -1,7 +1,6 @@
from typing import Generic, List, Optional
import jwt
from pydantic import UUID4
from fastapi_users import models
from fastapi_users.authentication.strategy.base import (
@ -9,7 +8,7 @@ from fastapi_users.authentication.strategy.base import (
StrategyDestroyNotSupportedError,
)
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.manager import BaseUserManager, UserNotExists
from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists
class JWTStrategy(Strategy, Generic[models.UP]):
@ -36,7 +35,7 @@ class JWTStrategy(Strategy, Generic[models.UP]):
return self.public_key or self.secret
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
if token is None:
return None
@ -52,11 +51,9 @@ class JWTStrategy(Strategy, Generic[models.UP]):
return None
try:
user_uiid = UUID4(user_id)
return await user_manager.get(user_uiid)
except ValueError:
return None
except UserNotExists:
parsed_id = user_manager.parse_id(user_id)
return await user_manager.get(parsed_id)
except (UserNotExists, InvalidID):
return None
async def write_token(self, user: models.UP) -> str:

View File

@ -2,11 +2,10 @@ import secrets
from typing import Generic, Optional
import aioredis
from pydantic import UUID4
from fastapi_users import models
from fastapi_users.authentication.strategy.base import Strategy
from fastapi_users.manager import BaseUserManager, UserNotExists
from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists
class RedisStrategy(Strategy, Generic[models.UP]):
@ -15,7 +14,7 @@ class RedisStrategy(Strategy, Generic[models.UP]):
self.lifetime_seconds = lifetime_seconds
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
if token is None:
return None
@ -25,11 +24,9 @@ class RedisStrategy(Strategy, Generic[models.UP]):
return None
try:
user_uiid = UUID4(user_id)
return await user_manager.get(user_uiid)
except ValueError:
return None
except UserNotExists:
parsed_id = user_manager.parse_id(user_id)
return await user_manager.get(parsed_id)
except (UserNotExists, InvalidID):
return None
async def write_token(self, user: models.UP) -> str:

View File

@ -1,9 +1,6 @@
from fastapi_users.db.base import BaseUserDatabase, UserDatabaseDependency
__all__ = [
"BaseUserDatabase",
"UserDatabaseDependency",
]
__all__ = ["BaseUserDatabase", "UserDatabaseDependency"]
try: # pragma: no cover
from fastapi_users_db_mongodb import MongoDBUserDatabase # noqa: F401

View File

@ -1,15 +1,13 @@
from typing import Any, Dict, Generic, Optional
from pydantic import UUID4
from fastapi_users.models import OAP, UP
from fastapi_users.models import ID, OAP, UP
from fastapi_users.types import DependencyCallable
class BaseUserDatabase(Generic[UP]):
class BaseUserDatabase(Generic[UP, ID]):
"""Base adapter for retrieving, creating and updating users from a database."""
async def get(self, id: UUID4) -> Optional[UP]:
async def get(self, id: ID) -> Optional[UP]:
"""Get a single user by id."""
raise NotImplementedError()
@ -44,4 +42,4 @@ class BaseUserDatabase(Generic[UP]):
raise NotImplementedError()
UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UP]]
UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UP, ID]]

View File

@ -44,7 +44,7 @@ class FastAPIUsers(Generic[models.UP, schemas.U, schemas.UC, schemas.UU]):
def __init__(
self,
get_user_manager: UserManagerDependency[models.UP],
get_user_manager: UserManagerDependency[models.UP, models.ID],
auth_backends: Sequence[AuthenticationBackend],
user_schema: Type[schemas.U],
user_create_schema: Type[schemas.UC],

View File

@ -1,9 +1,9 @@
import uuid
from typing import Any, Dict, Generic, Optional, Union
import jwt
from fastapi import Request
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import UUID4
from fastapi_users import models, schemas
from fastapi_users.db import BaseUserDatabase
@ -19,6 +19,10 @@ class FastAPIUsersException(Exception):
pass
class InvalidID(FastAPIUsersException):
pass
class UserAlreadyExists(FastAPIUsersException):
pass
@ -48,7 +52,7 @@ class InvalidPasswordException(FastAPIUsersException):
self.reason = reason
class BaseUserManager(Generic[models.UP]):
class BaseUserManager(Generic[models.UP, models.ID]):
"""
User management logic.
@ -70,12 +74,12 @@ class BaseUserManager(Generic[models.UP]):
verification_token_lifetime_seconds: int = 3600
verification_token_audience: str = VERIFY_USER_TOKEN_AUDIENCE
user_db: BaseUserDatabase[models.UP]
user_db: BaseUserDatabase[models.UP, models.ID]
password_helper: PasswordHelperProtocol
def __init__(
self,
user_db: BaseUserDatabase[models.UP],
user_db: BaseUserDatabase[models.UP, models.ID],
password_helper: Optional[PasswordHelperProtocol] = None,
):
self.user_db = user_db
@ -84,7 +88,17 @@ class BaseUserManager(Generic[models.UP]):
else:
self.password_helper = password_helper # pragma: no cover
async def get(self, id: UUID4) -> models.UP:
def parse_id(self, value: Any) -> models.ID:
"""
Parse a value into a correct models.ID instance.
:param value: The value to parse.
:raises InvalidID: The models.ID value is invalid.
:return: An models.ID object.
"""
raise NotImplementedError() # pragma: no cover
async def get(self, id: models.ID) -> models.UP:
"""
Get a user by id.
@ -170,7 +184,7 @@ class BaseUserManager(Generic[models.UP]):
return created_user
async def oauth_callback(
self: "BaseUserManager[models.UOAP]",
self: "BaseUserManager[models.UOAP, models.ID]",
oauth_name: str,
access_token: str,
account_id: str,
@ -192,7 +206,7 @@ class BaseUserManager(Generic[models.UP]):
:param oauth_name: Name of the OAuth client.
:param access_token: Valid access token for the service provider.
:param account_id: ID of the user on the service provider.
:param account_id: models.ID of the user on the service provider.
:param account_email: E-mail of the user on the service provider.
:param expires_at: Optional timestamp at which the access token expires.
:param refresh_token: Optional refresh token to get a
@ -307,11 +321,11 @@ class BaseUserManager(Generic[models.UP]):
raise InvalidVerifyToken()
try:
user_uuid = UUID4(user_id)
except ValueError:
parsed_id = self.parse_id(user_id)
except InvalidID:
raise InvalidVerifyToken()
if user_uuid != user.id:
if parsed_id != user.id:
raise InvalidVerifyToken()
if user.is_verified:
@ -382,11 +396,11 @@ class BaseUserManager(Generic[models.UP]):
raise InvalidResetPasswordToken()
try:
user_uuid = UUID4(user_id)
except ValueError:
parsed_id = self.parse_id(user_id)
except InvalidID:
raise InvalidResetPasswordToken()
user = await self.get(user_uuid)
user = await self.get(parsed_id)
if not user.is_active:
raise UserInactive()
@ -588,4 +602,14 @@ class BaseUserManager(Generic[models.UP]):
return await self.user_db.update(user, validated_update_dict)
UserManagerDependency = DependencyCallable[BaseUserManager[models.UP]]
class UUIDIDMixin:
def parse_id(self, value: Any) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
try:
return uuid.UUID(value)
except ValueError as e:
raise InvalidID() from e
UserManagerDependency = DependencyCallable[BaseUserManager[models.UP, models.ID]]

View File

@ -1,5 +1,4 @@
import sys
import uuid
from typing import Generic, List, Optional, TypeVar
if sys.version_info < (3, 8):
@ -7,11 +6,13 @@ if sys.version_info < (3, 8):
else:
from typing import Protocol # pragma: no cover
ID = TypeVar("ID")
class UserProtocol(Protocol):
class UserProtocol(Protocol[ID]):
"""User protocol that ORM model should follow."""
id: uuid.UUID
id: ID
email: str
hashed_password: str
is_active: bool
@ -22,10 +23,10 @@ class UserProtocol(Protocol):
... # pragma: no cover
class OAuthAccountProtocol(Protocol):
class OAuthAccountProtocol(Protocol[ID]):
"""OAuth account protocol that ORM model should follow."""
id: uuid.UUID
id: ID
oauth_name: str
access_token: str
expires_at: Optional[int]
@ -41,7 +42,7 @@ UP = TypeVar("UP", bound=UserProtocol)
OAP = TypeVar("OAP", bound=OAuthAccountProtocol)
class UserOAuthProtocol(UserProtocol, Generic[OAP]):
class UserOAuthProtocol(UserProtocol[ID], Generic[ID, OAP]):
"""User protocol including a list of OAuth accounts."""
oauth_accounts: List[OAP]

View File

@ -12,7 +12,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
def get_auth_router(
backend: AuthenticationBackend,
get_user_manager: UserManagerDependency[models.UP],
get_user_manager: UserManagerDependency[models.UP, models.ID],
authenticator: Authenticator,
requires_verification: bool = False,
) -> APIRouter:
@ -51,8 +51,8 @@ def get_auth_router(
async def login(
response: Response,
credentials: OAuth2PasswordRequestForm = Depends(),
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
strategy: Strategy[models.UP] = Depends(backend.get_strategy),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
):
user = await user_manager.authenticate(credentials)
@ -83,7 +83,7 @@ def get_auth_router(
async def logout(
response: Response,
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP] = Depends(backend.get_strategy),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
):
user, token = user_token
return await backend.logout(strategy, user, token, response)

View File

@ -29,7 +29,7 @@ def generate_state_token(
def get_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
get_user_manager: UserManagerDependency[models.UP],
get_user_manager: UserManagerDependency[models.UP, models.ID],
state_secret: SecretType,
redirect_url: str = None,
) -> APIRouter:
@ -101,8 +101,8 @@ def get_oauth_router(
access_token_state: Tuple[OAuth2Token, str] = Depends(
oauth2_authorize_callback
),
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
strategy: Strategy[models.UP] = Depends(backend.get_strategy),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
):
token, state = access_token_state
account_id, account_email = await oauth_client.get_id_email(

View File

@ -13,7 +13,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
def get_register_router(
get_user_manager: UserManagerDependency[models.UP],
get_user_manager: UserManagerDependency[models.UP, models.ID],
user_schema: Type[schemas.U],
user_create_schema: Type[schemas.UC],
) -> APIRouter:
@ -56,7 +56,7 @@ def get_register_router(
async def register(
request: Request,
user_create: user_create_schema, # type: ignore
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
created_user = await user_manager.create(

View File

@ -40,7 +40,7 @@ RESET_PASSWORD_RESPONSES: OpenAPIResponseType = {
def get_reset_password_router(
get_user_manager: UserManagerDependency[models.UP],
get_user_manager: UserManagerDependency[models.UP, models.ID],
) -> APIRouter:
"""Generate a router with the reset password routes."""
router = APIRouter()
@ -53,7 +53,7 @@ def get_reset_password_router(
async def forgot_password(
request: Request,
email: EmailStr = Body(..., embed=True),
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
user = await user_manager.get_by_email(email)
@ -76,7 +76,7 @@ def get_reset_password_router(
request: Request,
token: str = Body(...),
password: str = Body(...),
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
await user_manager.reset_password(token, password, request)

View File

@ -1,12 +1,12 @@
from typing import Type
from typing import Any, Type
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import UUID4
from fastapi_users import models, schemas
from fastapi_users.authentication import Authenticator
from fastapi_users.manager import (
BaseUserManager,
InvalidID,
InvalidPasswordException,
UserAlreadyExists,
UserManagerDependency,
@ -16,7 +16,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
def get_users_router(
get_user_manager: UserManagerDependency[models.UP],
get_user_manager: UserManagerDependency[models.UP, models.ID],
user_schema: Type[schemas.U],
user_update_schema: Type[schemas.UU],
authenticator: Authenticator,
@ -33,13 +33,14 @@ def get_users_router(
)
async def get_user_or_404(
id: UUID4,
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
id: Any,
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
) -> models.UP:
try:
return await user_manager.get(id)
except UserNotExists:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
parsed_id = user_manager.parse_id(id)
return await user_manager.get(parsed_id)
except (UserNotExists, InvalidID) as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) from e
@router.get(
"/me",
@ -96,7 +97,7 @@ def get_users_router(
request: Request,
user_update: user_update_schema, # type: ignore
user: models.UP = Depends(get_current_active_user),
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
return await user_manager.update(
@ -117,7 +118,7 @@ def get_users_router(
)
@router.get(
"/{id:uuid}",
"/{id}",
response_model=user_schema,
dependencies=[Depends(get_current_superuser)],
name="users:user",
@ -137,7 +138,7 @@ def get_users_router(
return user
@router.patch(
"/{id:uuid}",
"/{id}",
response_model=user_schema,
dependencies=[Depends(get_current_superuser)],
name="users:patch_user",
@ -182,7 +183,7 @@ def get_users_router(
user_update: user_update_schema, # type: ignore
request: Request,
user=Depends(get_user_or_404),
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
return await user_manager.update(
@ -203,7 +204,7 @@ def get_users_router(
)
@router.delete(
"/{id:uuid}",
"/{id}",
status_code=status.HTTP_204_NO_CONTENT,
response_class=Response,
dependencies=[Depends(get_current_superuser)],
@ -222,7 +223,7 @@ def get_users_router(
)
async def delete_user(
user=Depends(get_user_or_404),
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
await user_manager.delete(user)
return None

View File

@ -16,7 +16,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
def get_verify_router(
get_user_manager: UserManagerDependency[models.UP],
get_user_manager: UserManagerDependency[models.UP, models.ID],
user_schema: Type[schemas.U],
):
router = APIRouter()
@ -29,7 +29,7 @@ def get_verify_router(
async def request_verify_token(
request: Request,
email: EmailStr = Body(..., embed=True),
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
user = await user_manager.get_by_email(email)
@ -69,7 +69,7 @@ def get_verify_router(
async def verify(
request: Request,
token: str = Body(..., embed=True),
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
return await user_manager.verify(token, request)

View File

@ -1,7 +1,8 @@
import uuid
from typing import List, Optional, TypeVar
from typing import Generic, List, Optional, TypeVar
from pydantic import UUID4, BaseModel, EmailStr, Field
from pydantic import BaseModel, EmailStr
from fastapi_users import models
class CreateUpdateDictModel(BaseModel):
@ -21,10 +22,10 @@ class CreateUpdateDictModel(BaseModel):
return self.dict(exclude_unset=True, exclude={"id"})
class BaseUser(CreateUpdateDictModel):
class BaseUser(Generic[models.ID], CreateUpdateDictModel):
"""Base User model."""
id: UUID4 = Field(default_factory=uuid.uuid4)
id: models.ID
email: EmailStr
is_active: bool = True
is_superuser: bool = False
@ -52,10 +53,10 @@ UC = TypeVar("UC", bound=BaseUserCreate)
UU = TypeVar("UU", bound=BaseUserUpdate)
class BaseOAuthAccount(BaseModel):
class BaseOAuthAccount(Generic[models.ID], BaseModel):
"""Base OAuth account model."""
id: UUID4 = Field(default_factory=uuid.uuid4)
id: models.ID
oauth_name: str
access_token: str
expires_at: Optional[int] = None

View File

@ -29,8 +29,10 @@ from fastapi_users.db import BaseUserDatabase
from fastapi_users.jwt import SecretType
from fastapi_users.manager import (
BaseUserManager,
InvalidID,
InvalidPasswordException,
UserNotExists,
UUIDIDMixin,
)
from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users.password import PasswordHelper
@ -43,8 +45,11 @@ lancelot_password_hash = password_helper.hash("lancelot")
excalibur_password_hash = password_helper.hash("excalibur")
IDType = uuid.UUID
@dataclasses.dataclass
class UserModel(models.UserProtocol):
class UserModel(models.UserProtocol[IDType]):
email: str
hashed_password: str
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
@ -55,7 +60,7 @@ class UserModel(models.UserProtocol):
@dataclasses.dataclass
class OAuthAccountModel(models.OAuthAccountProtocol):
class OAuthAccountModel(models.OAuthAccountProtocol[IDType]):
oauth_name: str
access_token: str
account_id: str
@ -86,7 +91,9 @@ class UserOAuth(User, schemas.BaseOAuthAccountMixin):
pass
class BaseTestUserManager(Generic[models.UP], BaseUserManager[models.UP]):
class BaseTestUserManager(
Generic[models.UP], UUIDIDMixin, BaseUserManager[models.UP, IDType]
):
reset_password_token_secret = "SECRET"
verification_token_secret = "SECRET"
@ -318,8 +325,8 @@ def mock_user_db(
inactive_user: UserModel,
superuser: UserModel,
verified_superuser: UserModel,
) -> BaseUserDatabase[UserModel]:
class MockUserDatabase(BaseUserDatabase[UserModel]):
) -> BaseUserDatabase[UserModel, IDType]:
class MockUserDatabase(BaseUserDatabase[UserModel, IDType]):
async def get(self, id: UUID4) -> Optional[UserModel]:
if id == user.id:
return user
@ -370,8 +377,8 @@ def mock_user_db_oauth(
inactive_user_oauth: UserOAuthModel,
superuser_oauth: UserOAuthModel,
verified_superuser_oauth: UserOAuthModel,
) -> BaseUserDatabase[UserOAuthModel]:
class MockUserDatabase(BaseUserDatabase[UserOAuthModel]):
) -> BaseUserDatabase[UserOAuthModel, IDType]:
class MockUserDatabase(BaseUserDatabase[UserOAuthModel, IDType]):
async def get(self, id: UUID4) -> Optional[UserOAuthModel]:
if id == user_oauth.id:
return user_oauth
@ -518,17 +525,15 @@ class MockTransport(BearerTransport):
return {}
class MockStrategy(Strategy[UserModel]):
class MockStrategy(Strategy[UserModel, IDType]):
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[UserModel]
self, token: Optional[str], user_manager: BaseUserManager[UserModel, IDType]
) -> Optional[UserModel]:
if token is not None:
try:
token_uuid = UUID4(token)
return await user_manager.get(token_uuid)
except ValueError:
return None
except UserNotExists:
parsed_id = user_manager.parse_id(token)
return await user_manager.get(parsed_id)
except (InvalidID, UserNotExists):
return None
return None

View File

@ -29,7 +29,7 @@ class MockTransport(Transport):
class NoneStrategy(Strategy):
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
return None
@ -39,7 +39,7 @@ class UserStrategy(Strategy, Generic[models.UP]):
self.user = user
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
return self.user

View File

@ -21,7 +21,7 @@ class MockTransportLogoutNotSupported(BearerTransport):
class MockStrategyDestroyNotSupported(Strategy, Generic[models.UP]):
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
return None

View File

@ -3,7 +3,7 @@ import uuid
import pytest
from fastapi_users.db import BaseUserDatabase
from tests.conftest import OAuthAccountModel, UserModel
from tests.conftest import IDType, OAuthAccountModel, UserModel
@pytest.mark.asyncio
@ -11,7 +11,7 @@ from tests.conftest import OAuthAccountModel, UserModel
async def test_not_implemented_methods(
user: UserModel, oauth_account1: OAuthAccountModel
):
base_user_db = BaseUserDatabase[UserModel]()
base_user_db = BaseUserDatabase[UserModel, IDType]()
with pytest.raises(NotImplementedError):
await base_user_db.get(uuid.uuid4())

View File

@ -32,13 +32,14 @@ async def test_app_client(
app.include_router(
fastapi_users.get_oauth_router(oauth_client, mock_authentication, secret)
)
app.include_router(fastapi_users.get_users_router(), prefix="/users")
app.include_router(fastapi_users.get_verify_router())
@app.delete("/users/me")
def custom_users_route():
return None
app.include_router(fastapi_users.get_users_router(), prefix="/users")
app.include_router(fastapi_users.get_verify_router())
@app.get("/current-user", response_model=User)
def current_user(user: UserModel = Depends(fastapi_users.current_user())):
return user

View File

@ -1,5 +1,3 @@
import copy
import uuid
from typing import Callable
import pytest
@ -18,7 +16,6 @@ from fastapi_users.manager import (
UserNotExists,
)
from tests.conftest import (
OAuthAccountModel,
UserCreate,
UserManagerMock,
UserModel,