diff --git a/fastapi_users/authentication/authenticator.py b/fastapi_users/authentication/authenticator.py index 41f297d5..7fab4b78 100644 --- a/fastapi_users/authentication/authenticator.py +++ b/fastapi_users/authentication/authenticator.py @@ -51,7 +51,7 @@ class Authenticator: def __init__( self, backends: Sequence[AuthenticationBackend], - get_user_manager: UserManagerDependency[models.UP], + get_user_manager: UserManagerDependency[models.UP, models.ID], ): self.backends = backends self.get_user_manager = get_user_manager @@ -148,7 +148,7 @@ class Authenticator: async def _authenticate( self, *args, - user_manager: BaseUserManager[models.UP], + user_manager: BaseUserManager[models.UP, models.ID], optional: bool = False, active: bool = False, verified: bool = False, @@ -163,7 +163,7 @@ class Authenticator: for backend in self.backends: if backend in enabled_backends: token = kwargs[name_to_variable_name(backend.name)] - strategy: Strategy[models.UP] = kwargs[ + strategy: Strategy[models.UP, models.ID] = kwargs[ name_to_strategy_variable_name(backend.name) ] if token is not None: diff --git a/fastapi_users/authentication/backend.py b/fastapi_users/authentication/backend.py index f4931c75..09861210 100644 --- a/fastapi_users/authentication/backend.py +++ b/fastapi_users/authentication/backend.py @@ -33,7 +33,7 @@ class AuthenticationBackend(Generic[models.UP]): self, name: str, transport: Transport, - get_strategy: DependencyCallable[Strategy[models.UP]], + get_strategy: DependencyCallable[Strategy[models.UP, models.ID]], ): self.name = name self.transport = transport @@ -41,7 +41,7 @@ class AuthenticationBackend(Generic[models.UP]): async def login( self, - strategy: Strategy[models.UP], + strategy: Strategy[models.UP, models.ID], user: models.UP, response: Response, ) -> Any: @@ -50,7 +50,7 @@ class AuthenticationBackend(Generic[models.UP]): async def logout( self, - strategy: Strategy[models.UP], + strategy: Strategy[models.UP, models.ID], user: models.UP, token: str, response: Response, diff --git a/fastapi_users/authentication/strategy/base.py b/fastapi_users/authentication/strategy/base.py index 6954cec0..ce60db13 100644 --- a/fastapi_users/authentication/strategy/base.py +++ b/fastapi_users/authentication/strategy/base.py @@ -14,9 +14,9 @@ class StrategyDestroyNotSupportedError(Exception): pass -class Strategy(Protocol, Generic[models.UP]): +class Strategy(Protocol, Generic[models.UP, models.ID]): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP] + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] ) -> Optional[models.UP]: ... # pragma: no cover diff --git a/fastapi_users/authentication/strategy/db/models.py b/fastapi_users/authentication/strategy/db/models.py index 810d0a5d..6c3b58be 100644 --- a/fastapi_users/authentication/strategy/db/models.py +++ b/fastapi_users/authentication/strategy/db/models.py @@ -1,5 +1,4 @@ import sys -import uuid from datetime import datetime from typing import TypeVar @@ -8,12 +7,14 @@ if sys.version_info < (3, 8): else: from typing import Protocol # pragma: no cover +from fastapi_users import models -class AccessTokenProtocol(Protocol): + +class AccessTokenProtocol(Protocol[models.ID]): """Access token protocol that ORM model should follow.""" token: str - user_id: uuid.UUID + user_id: models.ID created_at: datetime def __init__(self, *args, **kwargs) -> None: diff --git a/fastapi_users/authentication/strategy/db/strategy.py b/fastapi_users/authentication/strategy/db/strategy.py index ce4998c8..e3cb93df 100644 --- a/fastapi_users/authentication/strategy/db/strategy.py +++ b/fastapi_users/authentication/strategy/db/strategy.py @@ -6,7 +6,7 @@ from fastapi_users import models from fastapi_users.authentication.strategy.base import Strategy from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase from fastapi_users.authentication.strategy.db.models import AP -from fastapi_users.manager import BaseUserManager, UserNotExists +from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists class DatabaseStrategy(Strategy, Generic[models.UP, AP]): @@ -17,7 +17,7 @@ class DatabaseStrategy(Strategy, Generic[models.UP, AP]): self.lifetime_seconds = lifetime_seconds async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP] + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] ) -> Optional[models.UP]: if token is None: return None @@ -33,9 +33,9 @@ class DatabaseStrategy(Strategy, Generic[models.UP, AP]): return None try: - user_id = access_token.user_id - return await user_manager.get(user_id) - except UserNotExists: + parsed_id = user_manager.parse_id(access_token.user_id) + return await user_manager.get(parsed_id) + except (UserNotExists, InvalidID): return None async def write_token(self, user: models.UP) -> str: diff --git a/fastapi_users/authentication/strategy/jwt.py b/fastapi_users/authentication/strategy/jwt.py index 38ebe77f..a2fbf58e 100644 --- a/fastapi_users/authentication/strategy/jwt.py +++ b/fastapi_users/authentication/strategy/jwt.py @@ -1,7 +1,6 @@ from typing import Generic, List, Optional import jwt -from pydantic import UUID4 from fastapi_users import models from fastapi_users.authentication.strategy.base import ( @@ -9,7 +8,7 @@ from fastapi_users.authentication.strategy.base import ( StrategyDestroyNotSupportedError, ) from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt -from fastapi_users.manager import BaseUserManager, UserNotExists +from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists class JWTStrategy(Strategy, Generic[models.UP]): @@ -36,7 +35,7 @@ class JWTStrategy(Strategy, Generic[models.UP]): return self.public_key or self.secret async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP] + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] ) -> Optional[models.UP]: if token is None: return None @@ -52,11 +51,9 @@ class JWTStrategy(Strategy, Generic[models.UP]): return None try: - user_uiid = UUID4(user_id) - return await user_manager.get(user_uiid) - except ValueError: - return None - except UserNotExists: + parsed_id = user_manager.parse_id(user_id) + return await user_manager.get(parsed_id) + except (UserNotExists, InvalidID): return None async def write_token(self, user: models.UP) -> str: diff --git a/fastapi_users/authentication/strategy/redis.py b/fastapi_users/authentication/strategy/redis.py index ce6a6a21..d6527892 100644 --- a/fastapi_users/authentication/strategy/redis.py +++ b/fastapi_users/authentication/strategy/redis.py @@ -2,11 +2,10 @@ import secrets from typing import Generic, Optional import aioredis -from pydantic import UUID4 from fastapi_users import models from fastapi_users.authentication.strategy.base import Strategy -from fastapi_users.manager import BaseUserManager, UserNotExists +from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists class RedisStrategy(Strategy, Generic[models.UP]): @@ -15,7 +14,7 @@ class RedisStrategy(Strategy, Generic[models.UP]): self.lifetime_seconds = lifetime_seconds async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP] + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] ) -> Optional[models.UP]: if token is None: return None @@ -25,11 +24,9 @@ class RedisStrategy(Strategy, Generic[models.UP]): return None try: - user_uiid = UUID4(user_id) - return await user_manager.get(user_uiid) - except ValueError: - return None - except UserNotExists: + parsed_id = user_manager.parse_id(user_id) + return await user_manager.get(parsed_id) + except (UserNotExists, InvalidID): return None async def write_token(self, user: models.UP) -> str: diff --git a/fastapi_users/db/__init__.py b/fastapi_users/db/__init__.py index 1856e059..d29c354b 100644 --- a/fastapi_users/db/__init__.py +++ b/fastapi_users/db/__init__.py @@ -1,9 +1,6 @@ from fastapi_users.db.base import BaseUserDatabase, UserDatabaseDependency -__all__ = [ - "BaseUserDatabase", - "UserDatabaseDependency", -] +__all__ = ["BaseUserDatabase", "UserDatabaseDependency"] try: # pragma: no cover from fastapi_users_db_mongodb import MongoDBUserDatabase # noqa: F401 diff --git a/fastapi_users/db/base.py b/fastapi_users/db/base.py index 6ac78730..681bad92 100644 --- a/fastapi_users/db/base.py +++ b/fastapi_users/db/base.py @@ -1,15 +1,13 @@ from typing import Any, Dict, Generic, Optional -from pydantic import UUID4 - -from fastapi_users.models import OAP, UP +from fastapi_users.models import ID, OAP, UP from fastapi_users.types import DependencyCallable -class BaseUserDatabase(Generic[UP]): +class BaseUserDatabase(Generic[UP, ID]): """Base adapter for retrieving, creating and updating users from a database.""" - async def get(self, id: UUID4) -> Optional[UP]: + async def get(self, id: ID) -> Optional[UP]: """Get a single user by id.""" raise NotImplementedError() @@ -44,4 +42,4 @@ class BaseUserDatabase(Generic[UP]): raise NotImplementedError() -UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UP]] +UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UP, ID]] diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index 15517cdf..f784966d 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -44,7 +44,7 @@ class FastAPIUsers(Generic[models.UP, schemas.U, schemas.UC, schemas.UU]): def __init__( self, - get_user_manager: UserManagerDependency[models.UP], + get_user_manager: UserManagerDependency[models.UP, models.ID], auth_backends: Sequence[AuthenticationBackend], user_schema: Type[schemas.U], user_create_schema: Type[schemas.UC], diff --git a/fastapi_users/manager.py b/fastapi_users/manager.py index 77847561..116db23a 100644 --- a/fastapi_users/manager.py +++ b/fastapi_users/manager.py @@ -1,9 +1,9 @@ +import uuid from typing import Any, Dict, Generic, Optional, Union import jwt from fastapi import Request from fastapi.security import OAuth2PasswordRequestForm -from pydantic import UUID4 from fastapi_users import models, schemas from fastapi_users.db import BaseUserDatabase @@ -19,6 +19,10 @@ class FastAPIUsersException(Exception): pass +class InvalidID(FastAPIUsersException): + pass + + class UserAlreadyExists(FastAPIUsersException): pass @@ -48,7 +52,7 @@ class InvalidPasswordException(FastAPIUsersException): self.reason = reason -class BaseUserManager(Generic[models.UP]): +class BaseUserManager(Generic[models.UP, models.ID]): """ User management logic. @@ -70,12 +74,12 @@ class BaseUserManager(Generic[models.UP]): verification_token_lifetime_seconds: int = 3600 verification_token_audience: str = VERIFY_USER_TOKEN_AUDIENCE - user_db: BaseUserDatabase[models.UP] + user_db: BaseUserDatabase[models.UP, models.ID] password_helper: PasswordHelperProtocol def __init__( self, - user_db: BaseUserDatabase[models.UP], + user_db: BaseUserDatabase[models.UP, models.ID], password_helper: Optional[PasswordHelperProtocol] = None, ): self.user_db = user_db @@ -84,7 +88,17 @@ class BaseUserManager(Generic[models.UP]): else: self.password_helper = password_helper # pragma: no cover - async def get(self, id: UUID4) -> models.UP: + def parse_id(self, value: Any) -> models.ID: + """ + Parse a value into a correct models.ID instance. + + :param value: The value to parse. + :raises InvalidID: The models.ID value is invalid. + :return: An models.ID object. + """ + raise NotImplementedError() # pragma: no cover + + async def get(self, id: models.ID) -> models.UP: """ Get a user by id. @@ -170,7 +184,7 @@ class BaseUserManager(Generic[models.UP]): return created_user async def oauth_callback( - self: "BaseUserManager[models.UOAP]", + self: "BaseUserManager[models.UOAP, models.ID]", oauth_name: str, access_token: str, account_id: str, @@ -192,7 +206,7 @@ class BaseUserManager(Generic[models.UP]): :param oauth_name: Name of the OAuth client. :param access_token: Valid access token for the service provider. - :param account_id: ID of the user on the service provider. + :param account_id: models.ID of the user on the service provider. :param account_email: E-mail of the user on the service provider. :param expires_at: Optional timestamp at which the access token expires. :param refresh_token: Optional refresh token to get a @@ -307,11 +321,11 @@ class BaseUserManager(Generic[models.UP]): raise InvalidVerifyToken() try: - user_uuid = UUID4(user_id) - except ValueError: + parsed_id = self.parse_id(user_id) + except InvalidID: raise InvalidVerifyToken() - if user_uuid != user.id: + if parsed_id != user.id: raise InvalidVerifyToken() if user.is_verified: @@ -382,11 +396,11 @@ class BaseUserManager(Generic[models.UP]): raise InvalidResetPasswordToken() try: - user_uuid = UUID4(user_id) - except ValueError: + parsed_id = self.parse_id(user_id) + except InvalidID: raise InvalidResetPasswordToken() - user = await self.get(user_uuid) + user = await self.get(parsed_id) if not user.is_active: raise UserInactive() @@ -588,4 +602,14 @@ class BaseUserManager(Generic[models.UP]): return await self.user_db.update(user, validated_update_dict) -UserManagerDependency = DependencyCallable[BaseUserManager[models.UP]] +class UUIDIDMixin: + def parse_id(self, value: Any) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + try: + return uuid.UUID(value) + except ValueError as e: + raise InvalidID() from e + + +UserManagerDependency = DependencyCallable[BaseUserManager[models.UP, models.ID]] diff --git a/fastapi_users/models.py b/fastapi_users/models.py index 37a9ff64..f6bbe1ab 100644 --- a/fastapi_users/models.py +++ b/fastapi_users/models.py @@ -1,5 +1,4 @@ import sys -import uuid from typing import Generic, List, Optional, TypeVar if sys.version_info < (3, 8): @@ -7,11 +6,13 @@ if sys.version_info < (3, 8): else: from typing import Protocol # pragma: no cover +ID = TypeVar("ID") -class UserProtocol(Protocol): + +class UserProtocol(Protocol[ID]): """User protocol that ORM model should follow.""" - id: uuid.UUID + id: ID email: str hashed_password: str is_active: bool @@ -22,10 +23,10 @@ class UserProtocol(Protocol): ... # pragma: no cover -class OAuthAccountProtocol(Protocol): +class OAuthAccountProtocol(Protocol[ID]): """OAuth account protocol that ORM model should follow.""" - id: uuid.UUID + id: ID oauth_name: str access_token: str expires_at: Optional[int] @@ -41,7 +42,7 @@ UP = TypeVar("UP", bound=UserProtocol) OAP = TypeVar("OAP", bound=OAuthAccountProtocol) -class UserOAuthProtocol(UserProtocol, Generic[OAP]): +class UserOAuthProtocol(UserProtocol[ID], Generic[ID, OAP]): """User protocol including a list of OAuth accounts.""" oauth_accounts: List[OAP] diff --git a/fastapi_users/router/auth.py b/fastapi_users/router/auth.py index 3fc1275f..cda0fa3e 100644 --- a/fastapi_users/router/auth.py +++ b/fastapi_users/router/auth.py @@ -12,7 +12,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_auth_router( backend: AuthenticationBackend, - get_user_manager: UserManagerDependency[models.UP], + get_user_manager: UserManagerDependency[models.UP, models.ID], authenticator: Authenticator, requires_verification: bool = False, ) -> APIRouter: @@ -51,8 +51,8 @@ def get_auth_router( async def login( response: Response, credentials: OAuth2PasswordRequestForm = Depends(), - user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), - strategy: Strategy[models.UP] = Depends(backend.get_strategy), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), + strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), ): user = await user_manager.authenticate(credentials) @@ -83,7 +83,7 @@ def get_auth_router( async def logout( response: Response, user_token: Tuple[models.UP, str] = Depends(get_current_user_token), - strategy: Strategy[models.UP] = Depends(backend.get_strategy), + strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), ): user, token = user_token return await backend.logout(strategy, user, token, response) diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index 00499d2b..9ec1ac26 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -29,7 +29,7 @@ def generate_state_token( def get_oauth_router( oauth_client: BaseOAuth2, backend: AuthenticationBackend, - get_user_manager: UserManagerDependency[models.UP], + get_user_manager: UserManagerDependency[models.UP, models.ID], state_secret: SecretType, redirect_url: str = None, ) -> APIRouter: @@ -101,8 +101,8 @@ def get_oauth_router( access_token_state: Tuple[OAuth2Token, str] = Depends( oauth2_authorize_callback ), - user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), - strategy: Strategy[models.UP] = Depends(backend.get_strategy), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), + strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), ): token, state = access_token_state account_id, account_email = await oauth_client.get_id_email( diff --git a/fastapi_users/router/register.py b/fastapi_users/router/register.py index 862b0406..3eab139e 100644 --- a/fastapi_users/router/register.py +++ b/fastapi_users/router/register.py @@ -13,7 +13,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_register_router( - get_user_manager: UserManagerDependency[models.UP], + get_user_manager: UserManagerDependency[models.UP, models.ID], user_schema: Type[schemas.U], user_create_schema: Type[schemas.UC], ) -> APIRouter: @@ -56,7 +56,7 @@ def get_register_router( async def register( request: Request, user_create: user_create_schema, # type: ignore - user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: created_user = await user_manager.create( diff --git a/fastapi_users/router/reset.py b/fastapi_users/router/reset.py index e6458321..95e2e93d 100644 --- a/fastapi_users/router/reset.py +++ b/fastapi_users/router/reset.py @@ -40,7 +40,7 @@ RESET_PASSWORD_RESPONSES: OpenAPIResponseType = { def get_reset_password_router( - get_user_manager: UserManagerDependency[models.UP], + get_user_manager: UserManagerDependency[models.UP, models.ID], ) -> APIRouter: """Generate a router with the reset password routes.""" router = APIRouter() @@ -53,7 +53,7 @@ def get_reset_password_router( async def forgot_password( request: Request, email: EmailStr = Body(..., embed=True), - user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: user = await user_manager.get_by_email(email) @@ -76,7 +76,7 @@ def get_reset_password_router( request: Request, token: str = Body(...), password: str = Body(...), - user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: await user_manager.reset_password(token, password, request) diff --git a/fastapi_users/router/users.py b/fastapi_users/router/users.py index 9e744734..b6a48caf 100644 --- a/fastapi_users/router/users.py +++ b/fastapi_users/router/users.py @@ -1,12 +1,12 @@ -from typing import Type +from typing import Any, Type from fastapi import APIRouter, Depends, HTTPException, Request, Response, status -from pydantic import UUID4 from fastapi_users import models, schemas from fastapi_users.authentication import Authenticator from fastapi_users.manager import ( BaseUserManager, + InvalidID, InvalidPasswordException, UserAlreadyExists, UserManagerDependency, @@ -16,7 +16,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_users_router( - get_user_manager: UserManagerDependency[models.UP], + get_user_manager: UserManagerDependency[models.UP, models.ID], user_schema: Type[schemas.U], user_update_schema: Type[schemas.UU], authenticator: Authenticator, @@ -33,13 +33,14 @@ def get_users_router( ) async def get_user_or_404( - id: UUID4, - user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + id: Any, + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ) -> models.UP: try: - return await user_manager.get(id) - except UserNotExists: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + parsed_id = user_manager.parse_id(id) + return await user_manager.get(parsed_id) + except (UserNotExists, InvalidID) as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) from e @router.get( "/me", @@ -96,7 +97,7 @@ def get_users_router( request: Request, user_update: user_update_schema, # type: ignore user: models.UP = Depends(get_current_active_user), - user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: return await user_manager.update( @@ -117,7 +118,7 @@ def get_users_router( ) @router.get( - "/{id:uuid}", + "/{id}", response_model=user_schema, dependencies=[Depends(get_current_superuser)], name="users:user", @@ -137,7 +138,7 @@ def get_users_router( return user @router.patch( - "/{id:uuid}", + "/{id}", response_model=user_schema, dependencies=[Depends(get_current_superuser)], name="users:patch_user", @@ -182,7 +183,7 @@ def get_users_router( user_update: user_update_schema, # type: ignore request: Request, user=Depends(get_user_or_404), - user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: return await user_manager.update( @@ -203,7 +204,7 @@ def get_users_router( ) @router.delete( - "/{id:uuid}", + "/{id}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response, dependencies=[Depends(get_current_superuser)], @@ -222,7 +223,7 @@ def get_users_router( ) async def delete_user( user=Depends(get_user_or_404), - user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): await user_manager.delete(user) return None diff --git a/fastapi_users/router/verify.py b/fastapi_users/router/verify.py index 7860e18e..af6578f5 100644 --- a/fastapi_users/router/verify.py +++ b/fastapi_users/router/verify.py @@ -16,7 +16,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_verify_router( - get_user_manager: UserManagerDependency[models.UP], + get_user_manager: UserManagerDependency[models.UP, models.ID], user_schema: Type[schemas.U], ): router = APIRouter() @@ -29,7 +29,7 @@ def get_verify_router( async def request_verify_token( request: Request, email: EmailStr = Body(..., embed=True), - user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: user = await user_manager.get_by_email(email) @@ -69,7 +69,7 @@ def get_verify_router( async def verify( request: Request, token: str = Body(..., embed=True), - user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: return await user_manager.verify(token, request) diff --git a/fastapi_users/schemas.py b/fastapi_users/schemas.py index a361bd58..6ab2994d 100644 --- a/fastapi_users/schemas.py +++ b/fastapi_users/schemas.py @@ -1,7 +1,8 @@ -import uuid -from typing import List, Optional, TypeVar +from typing import Generic, List, Optional, TypeVar -from pydantic import UUID4, BaseModel, EmailStr, Field +from pydantic import BaseModel, EmailStr + +from fastapi_users import models class CreateUpdateDictModel(BaseModel): @@ -21,10 +22,10 @@ class CreateUpdateDictModel(BaseModel): return self.dict(exclude_unset=True, exclude={"id"}) -class BaseUser(CreateUpdateDictModel): +class BaseUser(Generic[models.ID], CreateUpdateDictModel): """Base User model.""" - id: UUID4 = Field(default_factory=uuid.uuid4) + id: models.ID email: EmailStr is_active: bool = True is_superuser: bool = False @@ -52,10 +53,10 @@ UC = TypeVar("UC", bound=BaseUserCreate) UU = TypeVar("UU", bound=BaseUserUpdate) -class BaseOAuthAccount(BaseModel): +class BaseOAuthAccount(Generic[models.ID], BaseModel): """Base OAuth account model.""" - id: UUID4 = Field(default_factory=uuid.uuid4) + id: models.ID oauth_name: str access_token: str expires_at: Optional[int] = None diff --git a/tests/conftest.py b/tests/conftest.py index c995b13b..c998fa78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,8 +29,10 @@ from fastapi_users.db import BaseUserDatabase from fastapi_users.jwt import SecretType from fastapi_users.manager import ( BaseUserManager, + InvalidID, InvalidPasswordException, UserNotExists, + UUIDIDMixin, ) from fastapi_users.openapi import OpenAPIResponseType from fastapi_users.password import PasswordHelper @@ -43,8 +45,11 @@ lancelot_password_hash = password_helper.hash("lancelot") excalibur_password_hash = password_helper.hash("excalibur") +IDType = uuid.UUID + + @dataclasses.dataclass -class UserModel(models.UserProtocol): +class UserModel(models.UserProtocol[IDType]): email: str hashed_password: str id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4) @@ -55,7 +60,7 @@ class UserModel(models.UserProtocol): @dataclasses.dataclass -class OAuthAccountModel(models.OAuthAccountProtocol): +class OAuthAccountModel(models.OAuthAccountProtocol[IDType]): oauth_name: str access_token: str account_id: str @@ -86,7 +91,9 @@ class UserOAuth(User, schemas.BaseOAuthAccountMixin): pass -class BaseTestUserManager(Generic[models.UP], BaseUserManager[models.UP]): +class BaseTestUserManager( + Generic[models.UP], UUIDIDMixin, BaseUserManager[models.UP, IDType] +): reset_password_token_secret = "SECRET" verification_token_secret = "SECRET" @@ -318,8 +325,8 @@ def mock_user_db( inactive_user: UserModel, superuser: UserModel, verified_superuser: UserModel, -) -> BaseUserDatabase[UserModel]: - class MockUserDatabase(BaseUserDatabase[UserModel]): +) -> BaseUserDatabase[UserModel, IDType]: + class MockUserDatabase(BaseUserDatabase[UserModel, IDType]): async def get(self, id: UUID4) -> Optional[UserModel]: if id == user.id: return user @@ -370,8 +377,8 @@ def mock_user_db_oauth( inactive_user_oauth: UserOAuthModel, superuser_oauth: UserOAuthModel, verified_superuser_oauth: UserOAuthModel, -) -> BaseUserDatabase[UserOAuthModel]: - class MockUserDatabase(BaseUserDatabase[UserOAuthModel]): +) -> BaseUserDatabase[UserOAuthModel, IDType]: + class MockUserDatabase(BaseUserDatabase[UserOAuthModel, IDType]): async def get(self, id: UUID4) -> Optional[UserOAuthModel]: if id == user_oauth.id: return user_oauth @@ -518,17 +525,15 @@ class MockTransport(BearerTransport): return {} -class MockStrategy(Strategy[UserModel]): +class MockStrategy(Strategy[UserModel, IDType]): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[UserModel] + self, token: Optional[str], user_manager: BaseUserManager[UserModel, IDType] ) -> Optional[UserModel]: if token is not None: try: - token_uuid = UUID4(token) - return await user_manager.get(token_uuid) - except ValueError: - return None - except UserNotExists: + parsed_id = user_manager.parse_id(token) + return await user_manager.get(parsed_id) + except (InvalidID, UserNotExists): return None return None diff --git a/tests/test_authentication_authenticator.py b/tests/test_authentication_authenticator.py index 7ad8c1bc..d2ee48e2 100644 --- a/tests/test_authentication_authenticator.py +++ b/tests/test_authentication_authenticator.py @@ -29,7 +29,7 @@ class MockTransport(Transport): class NoneStrategy(Strategy): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP] + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] ) -> Optional[models.UP]: return None @@ -39,7 +39,7 @@ class UserStrategy(Strategy, Generic[models.UP]): self.user = user async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP] + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] ) -> Optional[models.UP]: return self.user diff --git a/tests/test_authentication_backend.py b/tests/test_authentication_backend.py index 8aad2756..b7d99795 100644 --- a/tests/test_authentication_backend.py +++ b/tests/test_authentication_backend.py @@ -21,7 +21,7 @@ class MockTransportLogoutNotSupported(BearerTransport): class MockStrategyDestroyNotSupported(Strategy, Generic[models.UP]): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP] + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] ) -> Optional[models.UP]: return None diff --git a/tests/test_db_base.py b/tests/test_db_base.py index 82501d2f..23e9e44b 100644 --- a/tests/test_db_base.py +++ b/tests/test_db_base.py @@ -3,7 +3,7 @@ import uuid import pytest from fastapi_users.db import BaseUserDatabase -from tests.conftest import OAuthAccountModel, UserModel +from tests.conftest import IDType, OAuthAccountModel, UserModel @pytest.mark.asyncio @@ -11,7 +11,7 @@ from tests.conftest import OAuthAccountModel, UserModel async def test_not_implemented_methods( user: UserModel, oauth_account1: OAuthAccountModel ): - base_user_db = BaseUserDatabase[UserModel]() + base_user_db = BaseUserDatabase[UserModel, IDType]() with pytest.raises(NotImplementedError): await base_user_db.get(uuid.uuid4()) diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index 8ca1af4d..ed929038 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -32,13 +32,14 @@ async def test_app_client( app.include_router( fastapi_users.get_oauth_router(oauth_client, mock_authentication, secret) ) - app.include_router(fastapi_users.get_users_router(), prefix="/users") - app.include_router(fastapi_users.get_verify_router()) @app.delete("/users/me") def custom_users_route(): return None + app.include_router(fastapi_users.get_users_router(), prefix="/users") + app.include_router(fastapi_users.get_verify_router()) + @app.get("/current-user", response_model=User) def current_user(user: UserModel = Depends(fastapi_users.current_user())): return user diff --git a/tests/test_manager.py b/tests/test_manager.py index cc1360d9..6f423aab 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,5 +1,3 @@ -import copy -import uuid from typing import Callable import pytest @@ -18,7 +16,6 @@ from fastapi_users.manager import ( UserNotExists, ) from tests.conftest import ( - OAuthAccountModel, UserCreate, UserManagerMock, UserModel,