diff --git a/fastapi_users/__init__.py b/fastapi_users/__init__.py index 651e81a5..6d177186 100644 --- a/fastapi_users/__init__.py +++ b/fastapi_users/__init__.py @@ -2,7 +2,7 @@ __version__ = "9.3.1" -from fastapi_users import models # noqa: F401 +from fastapi_users import models, schemas # noqa: F401 from fastapi_users.fastapi_users import FastAPIUsers # noqa: F401 from fastapi_users.manager import ( # noqa: F401 BaseUserManager, @@ -10,7 +10,7 @@ from fastapi_users.manager import ( # noqa: F401 ) __all__ = [ - "models", + "schemas", "FastAPIUsers", "BaseUserManager", "InvalidPasswordException", diff --git a/fastapi_users/authentication/authenticator.py b/fastapi_users/authentication/authenticator.py index 4d876713..41f297d5 100644 --- a/fastapi_users/authentication/authenticator.py +++ b/fastapi_users/authentication/authenticator.py @@ -51,7 +51,7 @@ class Authenticator: def __init__( self, backends: Sequence[AuthenticationBackend], - get_user_manager: UserManagerDependency[models.UC, models.UD], + get_user_manager: UserManagerDependency[models.UP], ): self.backends = backends self.get_user_manager = get_user_manager @@ -148,14 +148,14 @@ class Authenticator: async def _authenticate( self, *args, - user_manager: BaseUserManager[models.UC, models.UD], + user_manager: BaseUserManager[models.UP], optional: bool = False, active: bool = False, verified: bool = False, superuser: bool = False, **kwargs, - ) -> Tuple[Optional[models.UD], Optional[str]]: - user: Optional[models.UD] = None + ) -> Tuple[Optional[models.UP], Optional[str]]: + user: Optional[models.UP] = None token: Optional[str] = None enabled_backends: Sequence[AuthenticationBackend] = kwargs.get( "enabled_backends", self.backends @@ -163,7 +163,7 @@ class Authenticator: for backend in self.backends: if backend in enabled_backends: token = kwargs[name_to_variable_name(backend.name)] - strategy: Strategy[models.UC, models.UD] = kwargs[ + strategy: Strategy[models.UP] = kwargs[ name_to_strategy_variable_name(backend.name) ] if token is not None: diff --git a/fastapi_users/authentication/backend.py b/fastapi_users/authentication/backend.py index 95f44345..f4931c75 100644 --- a/fastapi_users/authentication/backend.py +++ b/fastapi_users/authentication/backend.py @@ -14,7 +14,7 @@ from fastapi_users.authentication.transport import ( from fastapi_users.types import DependencyCallable -class AuthenticationBackend(Generic[models.UC, models.UD]): +class AuthenticationBackend(Generic[models.UP]): """ Combination of an authentication transport and strategy. @@ -33,7 +33,7 @@ class AuthenticationBackend(Generic[models.UC, models.UD]): self, name: str, transport: Transport, - get_strategy: DependencyCallable[Strategy[models.UC, models.UD]], + get_strategy: DependencyCallable[Strategy[models.UP]], ): self.name = name self.transport = transport @@ -41,8 +41,8 @@ class AuthenticationBackend(Generic[models.UC, models.UD]): async def login( self, - strategy: Strategy[models.UC, models.UD], - user: models.UD, + strategy: Strategy[models.UP], + user: models.UP, response: Response, ) -> Any: token = await strategy.write_token(user) @@ -50,8 +50,8 @@ class AuthenticationBackend(Generic[models.UC, models.UD]): async def logout( self, - strategy: Strategy[models.UC, models.UD], - user: models.UD, + strategy: Strategy[models.UP], + user: models.UP, token: str, response: Response, ) -> Any: diff --git a/fastapi_users/authentication/strategy/base.py b/fastapi_users/authentication/strategy/base.py index 367b5204..6954cec0 100644 --- a/fastapi_users/authentication/strategy/base.py +++ b/fastapi_users/authentication/strategy/base.py @@ -14,14 +14,14 @@ class StrategyDestroyNotSupportedError(Exception): pass -class Strategy(Protocol, Generic[models.UC, models.UD]): +class Strategy(Protocol, Generic[models.UP]): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP] + ) -> Optional[models.UP]: ... # pragma: no cover - async def write_token(self, user: models.UD) -> str: + async def write_token(self, user: models.UP) -> str: ... # pragma: no cover - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: models.UP) -> None: ... # pragma: no cover diff --git a/fastapi_users/authentication/strategy/db/strategy.py b/fastapi_users/authentication/strategy/db/strategy.py index d7aef0e9..e09da2b3 100644 --- a/fastapi_users/authentication/strategy/db/strategy.py +++ b/fastapi_users/authentication/strategy/db/strategy.py @@ -9,7 +9,7 @@ from fastapi_users.authentication.strategy.db.models import A from fastapi_users.manager import BaseUserManager, UserNotExists -class DatabaseStrategy(Strategy, Generic[models.UC, models.UD, A]): +class DatabaseStrategy(Strategy, Generic[models.UP, A]): def __init__( self, database: AccessTokenDatabase[A], lifetime_seconds: Optional[int] = None ): @@ -17,8 +17,8 @@ class DatabaseStrategy(Strategy, Generic[models.UC, models.UD, A]): self.lifetime_seconds = lifetime_seconds async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP] + ) -> Optional[models.UP]: if token is None: return None @@ -38,16 +38,16 @@ class DatabaseStrategy(Strategy, Generic[models.UC, models.UD, A]): except UserNotExists: return None - async def write_token(self, user: models.UD) -> str: + async def write_token(self, user: models.UP) -> str: access_token = self._create_access_token(user) await self.database.create(access_token) return access_token.token - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: models.UP) -> None: access_token = await self.database.get_by_token(token) if access_token is not None: await self.database.delete(access_token) - def _create_access_token(self, user: models.UD) -> A: + def _create_access_token(self, user: models.UP) -> A: token = secrets.token_urlsafe() return self.database.access_token_model(token=token, user_id=user.id) diff --git a/fastapi_users/authentication/strategy/jwt.py b/fastapi_users/authentication/strategy/jwt.py index 8d761ad7..38ebe77f 100644 --- a/fastapi_users/authentication/strategy/jwt.py +++ b/fastapi_users/authentication/strategy/jwt.py @@ -12,7 +12,7 @@ from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.manager import BaseUserManager, UserNotExists -class JWTStrategy(Strategy, Generic[models.UC, models.UD]): +class JWTStrategy(Strategy, Generic[models.UP]): def __init__( self, secret: SecretType, @@ -36,8 +36,8 @@ class JWTStrategy(Strategy, Generic[models.UC, models.UD]): return self.public_key or self.secret async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP] + ) -> Optional[models.UP]: if token is None: return None @@ -59,13 +59,13 @@ class JWTStrategy(Strategy, Generic[models.UC, models.UD]): except UserNotExists: return None - async def write_token(self, user: models.UD) -> str: + async def write_token(self, user: models.UP) -> str: data = {"user_id": str(user.id), "aud": self.token_audience} return generate_jwt( data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm ) - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: models.UP) -> None: raise StrategyDestroyNotSupportedError( "A JWT can't be invalidated: it's valid until it expires." ) diff --git a/fastapi_users/authentication/strategy/redis.py b/fastapi_users/authentication/strategy/redis.py index 3f32e22e..ce6a6a21 100644 --- a/fastapi_users/authentication/strategy/redis.py +++ b/fastapi_users/authentication/strategy/redis.py @@ -9,14 +9,14 @@ from fastapi_users.authentication.strategy.base import Strategy from fastapi_users.manager import BaseUserManager, UserNotExists -class RedisStrategy(Strategy, Generic[models.UC, models.UD]): +class RedisStrategy(Strategy, Generic[models.UP]): def __init__(self, redis: aioredis.Redis, lifetime_seconds: Optional[int] = None): self.redis = redis self.lifetime_seconds = lifetime_seconds async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP] + ) -> Optional[models.UP]: if token is None: return None @@ -32,10 +32,10 @@ class RedisStrategy(Strategy, Generic[models.UC, models.UD]): except UserNotExists: return None - async def write_token(self, user: models.UD) -> str: + async def write_token(self, user: models.UP) -> str: token = secrets.token_urlsafe() await self.redis.set(token, str(user.id), ex=self.lifetime_seconds) return token - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: models.UP) -> None: await self.redis.delete(token) diff --git a/fastapi_users/db/base.py b/fastapi_users/db/base.py index b6e92006..9c31844e 100644 --- a/fastapi_users/db/base.py +++ b/fastapi_users/db/base.py @@ -1,46 +1,37 @@ -from typing import Generic, Optional, Type +from typing import Any, Dict, Generic, Optional from pydantic import UUID4 -from fastapi_users.models import UD +from fastapi_users.models import UP from fastapi_users.types import DependencyCallable -class BaseUserDatabase(Generic[UD]): - """ - Base adapter for retrieving, creating and updating users from a database. +class BaseUserDatabase(Generic[UP]): + """Base adapter for retrieving, creating and updating users from a database.""" - :param user_db_model: Pydantic model of a DB representation of a user. - """ - - user_db_model: Type[UD] - - def __init__(self, user_db_model: Type[UD]): - self.user_db_model = user_db_model - - async def get(self, id: UUID4) -> Optional[UD]: + async def get(self, id: UUID4) -> Optional[UP]: """Get a single user by id.""" raise NotImplementedError() - async def get_by_email(self, email: str) -> Optional[UD]: + async def get_by_email(self, email: str) -> Optional[UP]: """Get a single user by email.""" raise NotImplementedError() - async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: + async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]: """Get a single user by OAuth account id.""" raise NotImplementedError() - async def create(self, user: UD) -> UD: + async def create(self, create_dict: Dict[str, Any]) -> UP: """Create a user.""" raise NotImplementedError() - async def update(self, user: UD) -> UD: + async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP: """Update a user.""" raise NotImplementedError() - async def delete(self, user: UD) -> None: + async def delete(self, user: UP) -> None: """Delete a user.""" raise NotImplementedError() -UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UD]] +UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UP]] diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index 674b5062..e33d7966 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -2,7 +2,7 @@ from typing import Generic, Sequence, Type from fastapi import APIRouter -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.authentication import AuthenticationBackend, Authenticator from fastapi_users.jwt import SecretType from fastapi_users.manager import UserManagerDependency @@ -22,43 +22,39 @@ except ModuleNotFoundError: # pragma: no cover BaseOAuth2 = Type # type: ignore -class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]): +class FastAPIUsers(Generic[models.UP, schemas.U, schemas.UC, schemas.UU, schemas.UD]): """ Main object that ties together the component for users authentication. :param get_user_manager: Dependency callable getter to inject the user manager class instance. :param auth_backends: List of authentication backends. - :param user_model: Pydantic model of a user. - :param user_create_model: Pydantic model for creating a user. - :param user_update_model: Pydantic model for updating a user. - :param user_db_model: Pydantic model of a DB representation of a user. + :param user_schema: Pydantic schema of a public user. + :param user_create_schema: Pydantic schema for creating a user. + :param user_update_schema: Pydantic schema for updating a user. :attribute current_user: Dependency callable getter to inject authenticated user with a specific set of parameters. """ authenticator: Authenticator - _user_model: Type[models.U] - _user_create_model: Type[models.UC] - _user_update_model: Type[models.UU] - _user_db_model: Type[models.UD] + _user_schema: Type[schemas.U] + _user_create_schema: Type[schemas.UC] + _user_update_schema: Type[schemas.UU] def __init__( self, - get_user_manager: UserManagerDependency[models.UC, models.UD], + get_user_manager: UserManagerDependency[models.UP], auth_backends: Sequence[AuthenticationBackend], - user_model: Type[models.U], - user_create_model: Type[models.UC], - user_update_model: Type[models.UU], - user_db_model: Type[models.UD], + user_schema: Type[schemas.U], + user_create_schema: Type[schemas.UC], + user_update_model: Type[schemas.UU], ): self.authenticator = Authenticator(auth_backends, get_user_manager) - self._user_model = user_model - self._user_db_model = user_db_model - self._user_create_model = user_create_model - self._user_update_model = user_update_model + self._user_schema = user_schema + self._user_create_schema = user_create_schema + self._user_update_schema = user_update_model self.get_user_manager = get_user_manager self.current_user = self.authenticator.current_user @@ -67,13 +63,13 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]): """Return a router with a register route.""" return get_register_router( self.get_user_manager, - self._user_model, - self._user_create_model, + self._user_schema, + self._user_create_schema, ) def get_verify_router(self) -> APIRouter: """Return a router with e-mail verification routes.""" - return get_verify_router(self.get_user_manager, self._user_model) + return get_verify_router(self.get_user_manager, self._user_schema) def get_reset_password_router(self) -> APIRouter: """Return a reset password process router.""" @@ -132,9 +128,8 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]): """ return get_users_router( self.get_user_manager, - self._user_model, - self._user_update_model, - self._user_db_model, + self._user_schema, + self._user_update_schema, self.authenticator, requires_verification, ) diff --git a/fastapi_users/manager.py b/fastapi_users/manager.py index 3693ece3..44acbef6 100644 --- a/fastapi_users/manager.py +++ b/fastapi_users/manager.py @@ -5,7 +5,7 @@ from fastapi import Request from fastapi.security import OAuth2PasswordRequestForm from pydantic import UUID4 -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.db import BaseUserDatabase from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.password import PasswordHelper, PasswordHelperProtocol @@ -48,11 +48,11 @@ class InvalidPasswordException(FastAPIUsersException): self.reason = reason -class BaseUserManager(Generic[models.UC, models.UD]): +class BaseUserManager(Generic[models.UP]): """ User management logic. - :attribute user_db_model: Pydantic model of a DB representation of a user. + :attribute user_model: Model of a user. :attribute reset_password_token_secret: Secret to encode reset password token. :attribute reset_password_token_lifetime_seconds: Lifetime of reset password token. :attribute reset_password_token_audience: JWT audience of reset password token. @@ -63,7 +63,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): :param user_db: Database adapter instance. """ - user_db_model: Type[models.UD] + user_model: Type[models.UP] reset_password_token_secret: SecretType reset_password_token_lifetime_seconds: int = 3600 reset_password_token_audience: str = RESET_PASSWORD_TOKEN_AUDIENCE @@ -72,12 +72,12 @@ class BaseUserManager(Generic[models.UC, models.UD]): verification_token_lifetime_seconds: int = 3600 verification_token_audience: str = VERIFY_USER_TOKEN_AUDIENCE - user_db: BaseUserDatabase[models.UD] + user_db: BaseUserDatabase[models.UP] password_helper: PasswordHelperProtocol def __init__( self, - user_db: BaseUserDatabase[models.UD], + user_db: BaseUserDatabase[models.UP], password_helper: Optional[PasswordHelperProtocol] = None, ): self.user_db = user_db @@ -86,7 +86,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): else: self.password_helper = password_helper # pragma: no cover - async def get(self, id: UUID4) -> models.UD: + async def get(self, id: UUID4) -> models.UP: """ Get a user by id. @@ -101,7 +101,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return user - async def get_by_email(self, user_email: str) -> models.UD: + async def get_by_email(self, user_email: str) -> models.UP: """ Get a user by e-mail. @@ -116,7 +116,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return user - async def get_by_oauth_account(self, oauth: str, account_id: str) -> models.UD: + async def get_by_oauth_account(self, oauth: str, account_id: str) -> models.UP: """ Get a user by OAuth account. @@ -133,14 +133,17 @@ class BaseUserManager(Generic[models.UC, models.UD]): return user async def create( - self, user: models.UC, safe: bool = False, request: Optional[Request] = None - ) -> models.UD: + self, + user_create: schemas.UC, + safe: bool = False, + request: Optional[Request] = None, + ) -> models.UP: """ Create a user in database. Triggers the on_after_register handler on success. - :param user: The UserCreate model to create. + :param user_create: The UserCreate model to create. :param safe: If True, sensitive values like is_superuser or is_verified will be ignored during the creation, defaults to False. :param request: Optional FastAPI request that @@ -148,27 +151,31 @@ class BaseUserManager(Generic[models.UC, models.UD]): :raises UserAlreadyExists: A user already exists with the same e-mail. :return: A new user. """ - await self.validate_password(user.password, user) + await self.validate_password(user_create.password, user_create) - existing_user = await self.user_db.get_by_email(user.email) + existing_user = await self.user_db.get_by_email(user_create.email) if existing_user is not None: raise UserAlreadyExists() - hashed_password = self.password_helper.hash(user.password) user_dict = ( - user.create_update_dict() if safe else user.create_update_dict_superuser() + user_create.create_update_dict() + if safe + else user_create.create_update_dict_superuser() ) - db_user = self.user_db_model(**user_dict, hashed_password=hashed_password) + password = user_dict.pop("password") + user_dict["hashed_password"] = self.password_helper.hash(password) - created_user = await self.user_db.create(db_user) + created_user = await self.user_db.create(user_dict) await self.on_after_register(created_user, request) return created_user async def oauth_callback( - self, oauth_account: models.BaseOAuthAccount, request: Optional[Request] = None - ) -> models.UD: + self: "BaseUserManager[models.UOAP]", + oauth_account: models.OAP, + request: Optional[Request] = None, + ) -> models.UOAP: """ Handle the callback after a successful OAuth authentication. @@ -193,17 +200,17 @@ class BaseUserManager(Generic[models.UC, models.UD]): try: # Link account user = await self.get_by_email(oauth_account.account_email) - user.oauth_accounts.append(oauth_account) # type: ignore - await self.user_db.update(user) + oauth_accounts = [*user.oauth_accounts, oauth_account] + await self.user_db.update(user, {"oauth_accounts": oauth_accounts}) except UserNotExists: # Create account password = self.password_helper.generate() - user = self.user_db_model( - email=oauth_account.account_email, - hashed_password=self.password_helper.hash(password), - oauth_accounts=[oauth_account], - ) - await self.user_db.create(user) + user_dict = { + "email": oauth_account.account_email, + "hashed_password": self.password_helper.hash(password), + "oauth_accounts": [oauth_account], + } + user = await self.user_db.create(user_dict) await self.on_after_register(user, request) else: # Update oauth @@ -217,13 +224,12 @@ class BaseUserManager(Generic[models.UC, models.UD]): updated_oauth_accounts.append(oauth_account) else: updated_oauth_accounts.append(existing_oauth_account) - user.oauth_accounts = updated_oauth_accounts # type: ignore - await self.user_db.update(user) + await self.user_db.update(user, {"oauth_accounts": updated_oauth_accounts}) return user async def request_verify( - self, user: models.UD, request: Optional[Request] = None + self, user: models.UP, request: Optional[Request] = None ) -> None: """ Start a verification request. @@ -253,7 +259,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): ) await self.on_after_request_verify(user, token, request) - async def verify(self, token: str, request: Optional[Request] = None) -> models.UD: + async def verify(self, token: str, request: Optional[Request] = None) -> models.UP: """ Validate a verification request. @@ -306,7 +312,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return verified_user async def forgot_password( - self, user: models.UD, request: Optional[Request] = None + self, user: models.UP, request: Optional[Request] = None ) -> None: """ Start a forgot password request. @@ -334,7 +340,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): async def reset_password( self, token: str, password: str, request: Optional[Request] = None - ) -> models.UD: + ) -> models.UP: """ Reset the password of a user. @@ -381,11 +387,11 @@ class BaseUserManager(Generic[models.UC, models.UD]): async def update( self, - user_update: models.UU, - user: models.UD, + user_update: schemas.UU, + user: models.UP, safe: bool = False, request: Optional[Request] = None, - ) -> models.UD: + ) -> models.UP: """ Update a user. @@ -408,7 +414,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): await self.on_after_update(updated_user, updated_user_data, request) return updated_user - async def delete(self, user: models.UD) -> None: + async def delete(self, user: models.UP) -> None: """ Delete a user. @@ -417,7 +423,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): await self.user_db.delete(user) async def validate_password( - self, password: str, user: Union[models.UC, models.UD] + self, password: str, user: Union[schemas.UC, models.UP] ) -> None: """ Validate a password. @@ -432,7 +438,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return # pragma: no cover async def on_after_register( - self, user: models.UD, request: Optional[Request] = None + self, user: models.UP, request: Optional[Request] = None ) -> None: """ Perform logic after successful user registration. @@ -447,7 +453,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): async def on_after_update( self, - user: models.UD, + user: models.UP, update_dict: Dict[str, Any], request: Optional[Request] = None, ) -> None: @@ -464,7 +470,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return # pragma: no cover async def on_after_request_verify( - self, user: models.UD, token: str, request: Optional[Request] = None + self, user: models.UP, token: str, request: Optional[Request] = None ) -> None: """ Perform logic after successful verification request. @@ -479,7 +485,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return # pragma: no cover async def on_after_verify( - self, user: models.UD, request: Optional[Request] = None + self, user: models.UP, request: Optional[Request] = None ) -> None: """ Perform logic after successful user verification. @@ -493,7 +499,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return # pragma: no cover async def on_after_forgot_password( - self, user: models.UD, token: str, request: Optional[Request] = None + self, user: models.UP, token: str, request: Optional[Request] = None ) -> None: """ Perform logic after successful forgot password request. @@ -508,7 +514,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return # pragma: no cover async def on_after_reset_password( - self, user: models.UD, request: Optional[Request] = None + self, user: models.UP, request: Optional[Request] = None ) -> None: """ Perform logic after successful password reset. @@ -523,7 +529,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): async def authenticate( self, credentials: OAuth2PasswordRequestForm - ) -> Optional[models.UD]: + ) -> Optional[models.UP]: """ Authenticate and return a user following an email and a password. @@ -546,27 +552,28 @@ class BaseUserManager(Generic[models.UC, models.UD]): return None # Update password hash to a more robust one if needed if updated_password_hash is not None: - user.hashed_password = updated_password_hash - await self.user_db.update(user) + await self.user_db.update(user, {"hashed_password": updated_password_hash}) return user - async def _update(self, user: models.UD, update_dict: Dict[str, Any]) -> models.UD: + async def _update(self, user: models.UP, update_dict: Dict[str, Any]) -> models.UP: + validated_update_dict = {} for field, value in update_dict.items(): if field == "email" and value != user.email: try: await self.get_by_email(value) raise UserAlreadyExists() except UserNotExists: - user.email = value - user.is_verified = False + validated_update_dict["email"] = value + validated_update_dict["is_verified"] = False elif field == "password": await self.validate_password(value, user) - hashed_password = self.password_helper.hash(value) - user.hashed_password = hashed_password + validated_update_dict["hashed_password"] = self.password_helper.hash( + value + ) else: - setattr(user, field, value) - return await self.user_db.update(user) + validated_update_dict[field] = value + return await self.user_db.update(user, validated_update_dict) -UserManagerDependency = DependencyCallable[BaseUserManager[models.UC, models.UD]] +UserManagerDependency = DependencyCallable[BaseUserManager[models.UP]] diff --git a/fastapi_users/models.py b/fastapi_users/models.py index 6529a3f5..0d311c83 100644 --- a/fastapi_users/models.py +++ b/fastapi_users/models.py @@ -1,81 +1,44 @@ +import sys import uuid -from typing import List, Optional, TypeVar +from typing import Generic, List, Optional, TypeVar -from pydantic import UUID4, BaseModel, EmailStr, Field +if sys.version_info < (3, 8): + from typing_extensions import Protocol # pragma: no cover +else: + from typing import Protocol # pragma: no cover -class CreateUpdateDictModel(BaseModel): - def create_update_dict(self): - return self.dict( - exclude_unset=True, - exclude={ - "id", - "is_superuser", - "is_active", - "is_verified", - "oauth_accounts", - }, - ) - - def create_update_dict_superuser(self): - return self.dict(exclude_unset=True, exclude={"id"}) - - -class BaseUser(CreateUpdateDictModel): - """Base User model.""" - - id: UUID4 = Field(default_factory=uuid.uuid4) - email: EmailStr - is_active: bool = True - is_superuser: bool = False - is_verified: bool = False - - -class BaseUserCreate(CreateUpdateDictModel): - email: EmailStr - password: str - is_active: Optional[bool] = True - is_superuser: Optional[bool] = False - is_verified: Optional[bool] = False - - -class BaseUserUpdate(CreateUpdateDictModel): - password: Optional[str] - email: Optional[EmailStr] - is_active: Optional[bool] - is_superuser: Optional[bool] - is_verified: Optional[bool] - - -class BaseUserDB(BaseUser): +class UserProtocol(Protocol): + id: uuid.UUID + email: str hashed_password: str + is_active: bool + is_superuser: bool + is_verified: bool - class Config: - orm_mode = True + def __init__(self, *args, **kwargs) -> None: + ... # pragma: no cover -U = TypeVar("U", bound=BaseUser) -UC = TypeVar("UC", bound=BaseUserCreate) -UU = TypeVar("UU", bound=BaseUserUpdate) -UD = TypeVar("UD", bound=BaseUserDB) - - -class BaseOAuthAccount(BaseModel): - """Base OAuth account model.""" - - id: UUID4 = Field(default_factory=uuid.uuid4) +class OAuthAccountProtocol(Protocol): + id: uuid.UUID oauth_name: str access_token: str - expires_at: Optional[int] = None - refresh_token: Optional[str] = None + expires_at: Optional[int] + refresh_token: Optional[str] account_id: str account_email: str - class Config: - orm_mode = True + def __init__(self, *args, **kwargs) -> None: + ... # pragma: no cover -class BaseOAuthAccountMixin(BaseModel): - """Adds OAuth accounts list to a User model.""" +UP = TypeVar("UP", bound=UserProtocol) +OAP = TypeVar("OAP", bound=OAuthAccountProtocol) - oauth_accounts: List[BaseOAuthAccount] = [] + +class UserOAuthProtocol(UserProtocol, Generic[OAP]): + oauth_accounts: List[OAP] + + +UOAP = TypeVar("UOAP", bound=UserOAuthProtocol) diff --git a/fastapi_users/router/auth.py b/fastapi_users/router/auth.py index 16ef5e07..3fc1275f 100644 --- a/fastapi_users/router/auth.py +++ b/fastapi_users/router/auth.py @@ -12,7 +12,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_auth_router( backend: AuthenticationBackend, - get_user_manager: UserManagerDependency[models.UC, models.UD], + get_user_manager: UserManagerDependency[models.UP], authenticator: Authenticator, requires_verification: bool = False, ) -> APIRouter: @@ -51,8 +51,8 @@ def get_auth_router( async def login( response: Response, credentials: OAuth2PasswordRequestForm = Depends(), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), - strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy), + user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + strategy: Strategy[models.UP] = Depends(backend.get_strategy), ): user = await user_manager.authenticate(credentials) @@ -82,8 +82,8 @@ def get_auth_router( ) async def logout( response: Response, - user_token: Tuple[models.UD, str] = Depends(get_current_user_token), - strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy), + user_token: Tuple[models.UP, str] = Depends(get_current_user_token), + strategy: Strategy[models.UP] = Depends(backend.get_strategy), ): user, token = user_token return await backend.logout(strategy, user, token, response) diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index 2232be55..74bb3fb9 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -6,7 +6,7 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token from pydantic import BaseModel -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.authentication import AuthenticationBackend, Strategy from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.manager import BaseUserManager, UserManagerDependency @@ -29,7 +29,7 @@ def generate_state_token( def get_oauth_router( oauth_client: BaseOAuth2, backend: AuthenticationBackend, - get_user_manager: UserManagerDependency[models.UC, models.UD], + get_user_manager: UserManagerDependency[models.UP], state_secret: SecretType, redirect_url: str = None, ) -> APIRouter: @@ -101,8 +101,8 @@ def get_oauth_router( access_token_state: Tuple[OAuth2Token, str] = Depends( oauth2_authorize_callback ), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), - strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy), + user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + strategy: Strategy[models.UP] = Depends(backend.get_strategy), ): token, state = access_token_state account_id, account_email = await oauth_client.get_id_email( @@ -114,7 +114,7 @@ def get_oauth_router( except jwt.DecodeError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - new_oauth_account = models.BaseOAuthAccount( + new_oauth_account = schemas.BaseOAuthAccount( oauth_name=oauth_client.name, access_token=token["access_token"], expires_at=token.get("expires_at"), diff --git a/fastapi_users/router/register.py b/fastapi_users/router/register.py index 567d1a3c..4778c8a6 100644 --- a/fastapi_users/router/register.py +++ b/fastapi_users/router/register.py @@ -2,7 +2,7 @@ from typing import Type from fastapi import APIRouter, Depends, HTTPException, Request, status -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.manager import ( BaseUserManager, InvalidPasswordException, @@ -13,9 +13,9 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_register_router( - get_user_manager: UserManagerDependency[models.UC, models.UD], - user_model: Type[models.U], - user_create_model: Type[models.UC], + get_user_manager: UserManagerDependency[models.UP], + user_model: Type[schemas.U], + user_create_model: Type[schemas.UC], ) -> APIRouter: """Generate a router with the register route.""" router = APIRouter() @@ -56,7 +56,7 @@ def get_register_router( async def register( request: Request, user: user_create_model, # type: ignore - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), ): try: created_user = await user_manager.create(user, safe=True, request=request) diff --git a/fastapi_users/router/reset.py b/fastapi_users/router/reset.py index 46d6e924..e6458321 100644 --- a/fastapi_users/router/reset.py +++ b/fastapi_users/router/reset.py @@ -40,7 +40,7 @@ RESET_PASSWORD_RESPONSES: OpenAPIResponseType = { def get_reset_password_router( - get_user_manager: UserManagerDependency[models.UC, models.UD] + get_user_manager: UserManagerDependency[models.UP], ) -> APIRouter: """Generate a router with the reset password routes.""" router = APIRouter() @@ -53,7 +53,7 @@ def get_reset_password_router( async def forgot_password( request: Request, email: EmailStr = Body(..., embed=True), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), ): try: user = await user_manager.get_by_email(email) @@ -76,7 +76,7 @@ def get_reset_password_router( request: Request, token: str = Body(...), password: str = Body(...), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), ): try: await user_manager.reset_password(token, password, request) diff --git a/fastapi_users/router/users.py b/fastapi_users/router/users.py index aed2eaad..d23d87d8 100644 --- a/fastapi_users/router/users.py +++ b/fastapi_users/router/users.py @@ -3,7 +3,7 @@ from typing import Type from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from pydantic import UUID4 -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.authentication import Authenticator from fastapi_users.manager import ( BaseUserManager, @@ -16,10 +16,9 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_users_router( - get_user_manager: UserManagerDependency[models.UC, models.UD], - user_model: Type[models.U], - user_update_model: Type[models.UU], - user_db_model: Type[models.UD], + get_user_manager: UserManagerDependency[models.UP], + user_model: Type[schemas.U], + user_update_model: Type[schemas.UU], authenticator: Authenticator, requires_verification: bool = False, ) -> APIRouter: @@ -35,8 +34,8 @@ def get_users_router( async def get_user_or_404( id: UUID4, - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), - ) -> models.UD: + user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), + ) -> models.UP: try: return await user_manager.get(id) except UserNotExists: @@ -53,7 +52,7 @@ def get_users_router( }, ) async def me( - user: user_db_model = Depends(get_current_active_user), # type: ignore + user: models.UP = Depends(get_current_active_user), ): return user @@ -96,8 +95,8 @@ def get_users_router( async def update_me( request: Request, user_update: user_update_model, # type: ignore - user: user_db_model = Depends(get_current_active_user), # type: ignore - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user: models.UP = Depends(get_current_active_user), + user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), ): try: return await user_manager.update( @@ -183,7 +182,7 @@ def get_users_router( user_update: user_update_model, # type: ignore request: Request, user=Depends(get_user_or_404), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), ): try: return await user_manager.update( @@ -223,7 +222,7 @@ def get_users_router( ) async def delete_user( user=Depends(get_user_or_404), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), ): await user_manager.delete(user) return None diff --git a/fastapi_users/router/verify.py b/fastapi_users/router/verify.py index d4046040..2e783e30 100644 --- a/fastapi_users/router/verify.py +++ b/fastapi_users/router/verify.py @@ -3,7 +3,7 @@ from typing import Type from fastapi import APIRouter, Body, Depends, HTTPException, Request, status from pydantic import EmailStr -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.manager import ( BaseUserManager, InvalidVerifyToken, @@ -16,8 +16,8 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_verify_router( - get_user_manager: UserManagerDependency[models.UC, models.UD], - user_model: Type[models.U], + get_user_manager: UserManagerDependency[models.UP], + user_model: Type[schemas.U], ): router = APIRouter() @@ -29,7 +29,7 @@ def get_verify_router( async def request_verify_token( request: Request, email: EmailStr = Body(..., embed=True), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), ): try: user = await user_manager.get_by_email(email) @@ -69,7 +69,7 @@ def get_verify_router( async def verify( request: Request, token: str = Body(..., embed=True), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP] = Depends(get_user_manager), ): try: return await user_manager.verify(token, request) diff --git a/fastapi_users/schemas.py b/fastapi_users/schemas.py new file mode 100644 index 00000000..6529a3f5 --- /dev/null +++ b/fastapi_users/schemas.py @@ -0,0 +1,81 @@ +import uuid +from typing import List, Optional, TypeVar + +from pydantic import UUID4, BaseModel, EmailStr, Field + + +class CreateUpdateDictModel(BaseModel): + def create_update_dict(self): + return self.dict( + exclude_unset=True, + exclude={ + "id", + "is_superuser", + "is_active", + "is_verified", + "oauth_accounts", + }, + ) + + def create_update_dict_superuser(self): + return self.dict(exclude_unset=True, exclude={"id"}) + + +class BaseUser(CreateUpdateDictModel): + """Base User model.""" + + id: UUID4 = Field(default_factory=uuid.uuid4) + email: EmailStr + is_active: bool = True + is_superuser: bool = False + is_verified: bool = False + + +class BaseUserCreate(CreateUpdateDictModel): + email: EmailStr + password: str + is_active: Optional[bool] = True + is_superuser: Optional[bool] = False + is_verified: Optional[bool] = False + + +class BaseUserUpdate(CreateUpdateDictModel): + password: Optional[str] + email: Optional[EmailStr] + is_active: Optional[bool] + is_superuser: Optional[bool] + is_verified: Optional[bool] + + +class BaseUserDB(BaseUser): + hashed_password: str + + class Config: + orm_mode = True + + +U = TypeVar("U", bound=BaseUser) +UC = TypeVar("UC", bound=BaseUserCreate) +UU = TypeVar("UU", bound=BaseUserUpdate) +UD = TypeVar("UD", bound=BaseUserDB) + + +class BaseOAuthAccount(BaseModel): + """Base OAuth account model.""" + + id: UUID4 = Field(default_factory=uuid.uuid4) + oauth_name: str + access_token: str + expires_at: Optional[int] = None + refresh_token: Optional[str] = None + account_id: str + account_email: str + + class Config: + orm_mode = True + + +class BaseOAuthAccountMixin(BaseModel): + """Adds OAuth accounts list to a User model.""" + + oauth_accounts: List[BaseOAuthAccount] = [] diff --git a/tests/conftest.py b/tests/conftest.py index d94bc94e..fc98da0f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,17 @@ import asyncio -from typing import Any, AsyncGenerator, Callable, Generic, Optional, Type, Union +import dataclasses +import uuid +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + Generic, + List, + Optional, + Type, + Union, +) from unittest.mock import MagicMock import httpx @@ -10,7 +22,7 @@ from httpx_oauth.oauth2 import OAuth2 from pydantic import UUID4, SecretStr from pytest_mock import MockerFixture -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.authentication import AuthenticationBackend, BearerTransport from fastapi_users.authentication.strategy import Strategy from fastapi_users.db import BaseUserDatabase @@ -20,7 +32,6 @@ from fastapi_users.manager import ( InvalidPasswordException, UserNotExists, ) -from fastapi_users.models import BaseOAuthAccount, BaseOAuthAccountMixin from fastapi_users.openapi import OpenAPIResponseType from fastapi_users.password import PasswordHelper @@ -32,23 +43,50 @@ lancelot_password_hash = password_helper.hash("lancelot") excalibur_password_hash = password_helper.hash("excalibur") -class User(models.BaseUser): +@dataclasses.dataclass +class UserModel(models.UserProtocol): + email: str + hashed_password: str + id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4) + is_active: bool = True + is_superuser: bool = False + is_verified: bool = False + first_name: Optional[str] = None + + +@dataclasses.dataclass +class OAuthAccountModel(models.OAuthAccountProtocol): + oauth_name: str + access_token: str + account_id: str + account_email: str + id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4) + expires_at: Optional[int] = None + refresh_token: Optional[str] = None + + +@dataclasses.dataclass +class UserOAuthModel(UserModel): + oauth_accounts: List[OAuthAccountModel] = dataclasses.field(default_factory=list) + + +class User(schemas.BaseUser): first_name: Optional[str] -class UserCreate(models.BaseUserCreate): +class UserCreate(schemas.BaseUserCreate): first_name: Optional[str] -class UserUpdate(models.BaseUserUpdate): +class UserUpdate(schemas.BaseUserUpdate): first_name: Optional[str] -class UserDB(User, models.BaseUserDB): +class UserDB(User, schemas.BaseUserDB): pass -class UserOAuth(User, BaseOAuthAccountMixin): +class UserOAuth(User, schemas.BaseOAuthAccountMixin): pass @@ -56,14 +94,12 @@ class UserDBOAuth(UserOAuth, UserDB): pass -class BaseTestUserManager( - Generic[models.UC, models.UD], BaseUserManager[models.UC, models.UD] -): +class BaseTestUserManager(Generic[models.UP], BaseUserManager[models.UP]): reset_password_token_secret = "SECRET" verification_token_secret = "SECRET" async def validate_password( - self, password: str, user: Union[models.UC, models.UD] + self, password: str, user: Union[schemas.UC, models.UP] ) -> None: if len(password) < 3: raise InvalidPasswordException( @@ -71,15 +107,15 @@ class BaseTestUserManager( ) -class UserManager(BaseTestUserManager[UserCreate, UserDB]): - user_db_model = UserDB +class UserManager(BaseTestUserManager[UserModel]): + user_model = UserModel -class UserManagerOAuth(BaseTestUserManager[UserCreate, UserDBOAuth]): - user_db_model = UserDBOAuth +class UserManagerOAuth(BaseTestUserManager[UserOAuthModel]): + user_model = UserOAuthModel -class UserManagerMock(UserManager): +class UserManagerMock(BaseTestUserManager[models.UP]): get_by_email: MagicMock request_verify: MagicMock verify: MagicMock @@ -131,16 +167,18 @@ def secret(request) -> SecretType: @pytest.fixture -def user() -> UserDB: - return UserDB( +def user() -> UserModel: + return UserModel( email="king.arthur@camelot.bt", hashed_password=guinevere_password_hash, ) @pytest.fixture -def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth: - return UserDBOAuth( +def user_oauth( + oauth_account1: OAuthAccountModel, oauth_account2: OAuthAccountModel +) -> UserOAuthModel: + return UserOAuthModel( email="king.arthur@camelot.bt", hashed_password=guinevere_password_hash, oauth_accounts=[oauth_account1, oauth_account2], @@ -148,8 +186,8 @@ def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth: @pytest.fixture -def inactive_user() -> UserDB: - return UserDB( +def inactive_user() -> UserModel: + return UserModel( email="percival@camelot.bt", hashed_password=angharad_password_hash, is_active=False, @@ -157,8 +195,8 @@ def inactive_user() -> UserDB: @pytest.fixture -def inactive_user_oauth(oauth_account3) -> UserDBOAuth: - return UserDBOAuth( +def inactive_user_oauth(oauth_account3: OAuthAccountModel) -> UserOAuthModel: + return UserOAuthModel( email="percival@camelot.bt", hashed_password=angharad_password_hash, is_active=False, @@ -167,8 +205,8 @@ def inactive_user_oauth(oauth_account3) -> UserDBOAuth: @pytest.fixture -def verified_user() -> UserDB: - return UserDB( +def verified_user() -> UserModel: + return UserModel( email="lake.lady@camelot.bt", hashed_password=excalibur_password_hash, is_active=True, @@ -177,8 +215,8 @@ def verified_user() -> UserDB: @pytest.fixture -def verified_user_oauth(oauth_account4) -> UserDBOAuth: - return UserDBOAuth( +def verified_user_oauth(oauth_account4: OAuthAccountModel) -> UserOAuthModel: + return UserOAuthModel( email="lake.lady@camelot.bt", hashed_password=excalibur_password_hash, is_active=False, @@ -187,8 +225,8 @@ def verified_user_oauth(oauth_account4) -> UserDBOAuth: @pytest.fixture -def superuser() -> UserDB: - return UserDB( +def superuser() -> UserModel: + return UserModel( email="merlin@camelot.bt", hashed_password=viviane_password_hash, is_superuser=True, @@ -196,8 +234,8 @@ def superuser() -> UserDB: @pytest.fixture -def superuser_oauth() -> UserDBOAuth: - return UserDBOAuth( +def superuser_oauth() -> UserOAuthModel: + return UserOAuthModel( email="merlin@camelot.bt", hashed_password=viviane_password_hash, is_superuser=True, @@ -206,8 +244,8 @@ def superuser_oauth() -> UserDBOAuth: @pytest.fixture -def verified_superuser() -> UserDB: - return UserDB( +def verified_superuser() -> UserModel: + return UserModel( email="the.real.merlin@camelot.bt", hashed_password=viviane_password_hash, is_superuser=True, @@ -216,8 +254,8 @@ def verified_superuser() -> UserDB: @pytest.fixture -def verified_superuser_oauth() -> UserDBOAuth: - return UserDBOAuth( +def verified_superuser_oauth() -> UserOAuthModel: + return UserOAuthModel( email="the.real.merlin@camelot.bt", hashed_password=viviane_password_hash, is_superuser=True, @@ -227,8 +265,8 @@ def verified_superuser_oauth() -> UserDBOAuth: @pytest.fixture -def oauth_account1() -> BaseOAuthAccount: - return BaseOAuthAccount( +def oauth_account1() -> OAuthAccountModel: + return OAuthAccountModel( oauth_name="service1", access_token="TOKEN", expires_at=1579000751, @@ -238,8 +276,8 @@ def oauth_account1() -> BaseOAuthAccount: @pytest.fixture -def oauth_account2() -> BaseOAuthAccount: - return BaseOAuthAccount( +def oauth_account2() -> OAuthAccountModel: + return OAuthAccountModel( oauth_name="service2", access_token="TOKEN", expires_at=1579000751, @@ -249,8 +287,8 @@ def oauth_account2() -> BaseOAuthAccount: @pytest.fixture -def oauth_account3() -> BaseOAuthAccount: - return BaseOAuthAccount( +def oauth_account3() -> OAuthAccountModel: + return OAuthAccountModel( oauth_name="service3", access_token="TOKEN", expires_at=1579000751, @@ -260,8 +298,8 @@ def oauth_account3() -> BaseOAuthAccount: @pytest.fixture -def oauth_account4() -> BaseOAuthAccount: - return BaseOAuthAccount( +def oauth_account4() -> OAuthAccountModel: + return OAuthAccountModel( oauth_name="service4", access_token="TOKEN", expires_at=1579000751, @@ -271,8 +309,8 @@ def oauth_account4() -> BaseOAuthAccount: @pytest.fixture -def oauth_account5() -> BaseOAuthAccount: - return BaseOAuthAccount( +def oauth_account5() -> OAuthAccountModel: + return OAuthAccountModel( oauth_name="service5", access_token="TOKEN", expires_at=1579000751, @@ -283,10 +321,14 @@ def oauth_account5() -> BaseOAuthAccount: @pytest.fixture def mock_user_db( - user, verified_user, inactive_user, superuser, verified_superuser -) -> BaseUserDatabase: - class MockUserDatabase(BaseUserDatabase[UserDB]): - async def get(self, id: UUID4) -> Optional[UserDB]: + user: UserModel, + verified_user: UserModel, + inactive_user: UserModel, + superuser: UserModel, + verified_superuser: UserModel, +) -> BaseUserDatabase[UserModel]: + class MockUserDatabase(BaseUserDatabase[UserModel]): + async def get(self, id: UUID4) -> Optional[UserModel]: if id == user.id: return user if id == verified_user.id: @@ -299,7 +341,7 @@ def mock_user_db( return verified_superuser return None - async def get_by_email(self, email: str) -> Optional[UserDB]: + async def get_by_email(self, email: str) -> Optional[UserModel]: lower_email = email.lower() if lower_email == user.email.lower(): return user @@ -313,28 +355,32 @@ def mock_user_db( return verified_superuser return None - async def create(self, user: UserDB) -> UserDB: + async def create(self, create_dict: Dict[str, Any]) -> UserModel: + return UserModel(**create_dict) + + async def update( + self, user: UserModel, update_dict: Dict[str, Any] + ) -> UserModel: + for field, value in update_dict.items(): + setattr(user, field, value) return user - async def update(self, user: UserDB) -> UserDB: - return user - - async def delete(self, user: UserDB) -> None: + async def delete(self, user: UserModel) -> None: pass - return MockUserDatabase(UserDB) + return MockUserDatabase() @pytest.fixture def mock_user_db_oauth( - user_oauth, - verified_user_oauth, - inactive_user_oauth, - superuser_oauth, - verified_superuser_oauth, -) -> BaseUserDatabase: - class MockUserDatabase(BaseUserDatabase[UserDBOAuth]): - async def get(self, id: UUID4) -> Optional[UserDBOAuth]: + user_oauth: UserOAuthModel, + verified_user_oauth: UserOAuthModel, + inactive_user_oauth: UserOAuthModel, + superuser_oauth: UserOAuthModel, + verified_superuser_oauth: UserOAuthModel, +) -> BaseUserDatabase[UserOAuthModel]: + class MockUserDatabase(BaseUserDatabase[UserOAuthModel]): + async def get(self, id: UUID4) -> Optional[UserOAuthModel]: if id == user_oauth.id: return user_oauth if id == verified_user_oauth.id: @@ -347,7 +393,7 @@ def mock_user_db_oauth( return verified_superuser_oauth return None - async def get_by_email(self, email: str) -> Optional[UserDBOAuth]: + async def get_by_email(self, email: str) -> Optional[UserOAuthModel]: lower_email = email.lower() if lower_email == user_oauth.email.lower(): return user_oauth @@ -363,7 +409,7 @@ def mock_user_db_oauth( async def get_by_oauth_account( self, oauth: str, account_id: str - ) -> Optional[UserDBOAuth]: + ) -> Optional[UserOAuthModel]: user_oauth_account = user_oauth.oauth_accounts[0] if ( user_oauth_account.oauth_name == oauth @@ -379,16 +425,20 @@ def mock_user_db_oauth( return inactive_user_oauth return None - async def create(self, user: UserDBOAuth) -> UserDBOAuth: - return user_oauth + async def create(self, create_dict: Dict[str, Any]) -> UserOAuthModel: + return UserOAuthModel(**create_dict) - async def update(self, user: UserDBOAuth) -> UserDBOAuth: - return user_oauth + async def update( + self, user: UserOAuthModel, update_dict: Dict[str, Any] + ) -> UserOAuthModel: + for field, value in update_dict.items(): + setattr(user, field, value) + return user - async def delete(self, user: UserDBOAuth) -> None: + async def delete(self, user: UserOAuthModel) -> None: pass - return MockUserDatabase(UserDBOAuth) + return MockUserDatabase() @pytest.fixture @@ -450,10 +500,10 @@ class MockTransport(BearerTransport): return {} -class MockStrategy(Strategy[UserCreate, UserDB]): +class MockStrategy(Strategy[UserModel]): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[UserCreate, UserDB] - ) -> Optional[UserDB]: + self, token: Optional[str], user_manager: BaseUserManager[UserModel] + ) -> Optional[UserModel]: if token is not None: try: token_uuid = UUID4(token) @@ -464,10 +514,10 @@ class MockStrategy(Strategy[UserCreate, UserDB]): return None return None - async def write_token(self, user: models.UD) -> str: + async def write_token(self, user: UserModel) -> str: return str(user.id) - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: UserModel) -> None: return None diff --git a/tests/test_authentication_authenticator.py b/tests/test_authentication_authenticator.py index 39982ead..7ad8c1bc 100644 --- a/tests/test_authentication_authenticator.py +++ b/tests/test_authentication_authenticator.py @@ -12,7 +12,7 @@ from fastapi_users.authentication.strategy import Strategy from fastapi_users.authentication.transport import Transport from fastapi_users.manager import BaseUserManager from fastapi_users.types import DependencyCallable -from tests.conftest import UserDB +from tests.conftest import User, UserModel class MockSecurityScheme(SecurityBase): @@ -29,18 +29,18 @@ class MockTransport(Transport): class NoneStrategy(Strategy): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP] + ) -> Optional[models.UP]: return None -class UserStrategy(Strategy, Generic[models.UC, models.UD]): - def __init__(self, user: models.UD): +class UserStrategy(Strategy, Generic[models.UP]): + def __init__(self, user: models.UP): self.user = user async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP] + ) -> Optional[models.UP]: return self.user @@ -55,7 +55,7 @@ def get_backend_none(): @pytest.fixture -def get_backend_user(user: UserDB): +def get_backend_user(user: UserModel): def _get_backend_user(name: str = "user"): return AuthenticationBackend( name=name, @@ -78,17 +78,17 @@ def get_test_auth_client(get_user_manager, get_test_client): app = FastAPI() authenticator = Authenticator(backends, get_user_manager) - @app.get("/test-current-user") + @app.get("/test-current-user", response_model=User) def test_current_user( - user: UserDB = Depends( + user: UserModel = Depends( authenticator.current_user(get_enabled_backends=get_enabled_backends) ), ): return user - @app.get("/test-current-active-user") + @app.get("/test-current-active-user", response_model=User) def test_current_active_user( - user: UserDB = Depends( + user: UserModel = Depends( authenticator.current_user( active=True, get_enabled_backends=get_enabled_backends ) @@ -96,9 +96,9 @@ def get_test_auth_client(get_user_manager, get_test_client): ): return user - @app.get("/test-current-superuser") + @app.get("/test-current-superuser", response_model=User) def test_current_superuser( - user: UserDB = Depends( + user: UserModel = Depends( authenticator.current_user( active=True, superuser=True, diff --git a/tests/test_authentication_backend.py b/tests/test_authentication_backend.py index ddd4cec9..7b960c85 100644 --- a/tests/test_authentication_backend.py +++ b/tests/test_authentication_backend.py @@ -3,7 +3,7 @@ from typing import Callable, Generic, Optional, Type, cast import pytest from fastapi import Response -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.authentication import ( AuthenticationBackend, BearerTransport, @@ -19,16 +19,16 @@ class MockTransportLogoutNotSupported(BearerTransport): pass -class MockStrategyDestroyNotSupported(Strategy, Generic[models.UC, models.UD]): +class MockStrategyDestroyNotSupported(Strategy, Generic[schemas.UC, schemas.UD]): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP] + ) -> Optional[schemas.UD]: return None - async def write_token(self, user: models.UD) -> str: + async def write_token(self, user: schemas.UD) -> str: return "TOKEN" - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: schemas.UD) -> None: raise StrategyDestroyNotSupportedError diff --git a/tests/test_db_base.py b/tests/test_db_base.py index 9fab9fa8..5cfea0d2 100644 --- a/tests/test_db_base.py +++ b/tests/test_db_base.py @@ -1,13 +1,12 @@ import pytest from fastapi_users.db import BaseUserDatabase -from tests.conftest import UserDB @pytest.mark.asyncio @pytest.mark.db async def test_not_implemented_methods(user): - base_user_db = BaseUserDatabase(UserDB) + base_user_db = BaseUserDatabase() with pytest.raises(NotImplementedError): await base_user_db.get("aaa") @@ -19,10 +18,10 @@ async def test_not_implemented_methods(user): await base_user_db.get_by_oauth_account("google", "user_oauth1") with pytest.raises(NotImplementedError): - await base_user_db.create(user) + await base_user_db.create({}) with pytest.raises(NotImplementedError): - await base_user_db.update(user) + await base_user_db.update(user, {}) with pytest.raises(NotImplementedError): await base_user_db.delete(user) diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index 179ee7d7..e723436e 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -5,7 +5,7 @@ import pytest from fastapi import Depends, FastAPI, status from fastapi_users import FastAPIUsers -from tests.conftest import User, UserCreate, UserDB, UserUpdate +from tests.conftest import User, UserCreate, UserDB, UserModel, UserUpdate @pytest.fixture @@ -17,13 +17,12 @@ async def test_app_client( oauth_client, get_test_client, ) -> AsyncGenerator[httpx.AsyncClient, None]: - fastapi_users = FastAPIUsers[User, UserCreate, UserUpdate, UserDB]( + fastapi_users = FastAPIUsers[UserModel, User, UserCreate, UserUpdate, UserDB]( get_user_manager, [mock_authentication], User, UserCreate, UserUpdate, - UserDB, ) app = FastAPI() @@ -40,59 +39,71 @@ async def test_app_client( def custom_users_route(): return None - @app.get("/current-user") - def current_user(user=Depends(fastapi_users.current_user())): + @app.get("/current-user", response_model=User) + def current_user(user: UserModel = Depends(fastapi_users.current_user())): return user - @app.get("/current-active-user") - def current_active_user(user=Depends(fastapi_users.current_user(active=True))): - return user - - @app.get("/current-verified-user") - def current_verified_user(user=Depends(fastapi_users.current_user(verified=True))): - return user - - @app.get("/current-superuser") - def current_superuser( - user=Depends(fastapi_users.current_user(active=True, superuser=True)) + @app.get("/current-active-user", response_model=User) + def current_active_user( + user: UserModel = Depends(fastapi_users.current_user(active=True)), ): return user - @app.get("/current-verified-superuser") + @app.get("/current-verified-user", response_model=User) + def current_verified_user( + user: UserModel = Depends(fastapi_users.current_user(verified=True)), + ): + return user + + @app.get("/current-superuser", response_model=User) + def current_superuser( + user: UserModel = Depends( + fastapi_users.current_user(active=True, superuser=True) + ) + ): + return user + + @app.get("/current-verified-superuser", response_model=User) def current_verified_superuser( - user=Depends( + user: UserModel = Depends( fastapi_users.current_user(active=True, verified=True, superuser=True) ), ): return user - @app.get("/optional-current-user") - def optional_current_user(user=Depends(fastapi_users.current_user(optional=True))): + @app.get("/optional-current-user", response_model=User) + def optional_current_user( + user: UserModel = Depends(fastapi_users.current_user(optional=True)), + ): return user - @app.get("/optional-current-active-user") + @app.get("/optional-current-active-user", response_model=User) def optional_current_active_user( - user=Depends(fastapi_users.current_user(optional=True, active=True)), + user: UserModel = Depends( + fastapi_users.current_user(optional=True, active=True) + ), ): return user - @app.get("/optional-current-verified-user") + @app.get("/optional-current-verified-user", response_model=User) def optional_current_verified_user( - user=Depends(fastapi_users.current_user(optional=True, verified=True)), + user: UserModel = Depends( + fastapi_users.current_user(optional=True, verified=True) + ), ): return user - @app.get("/optional-current-superuser") + @app.get("/optional-current-superuser", response_model=User) def optional_current_superuser( - user=Depends( + user: UserModel = Depends( fastapi_users.current_user(optional=True, active=True, superuser=True) ), ): return user - @app.get("/optional-current-verified-superuser") + @app.get("/optional-current-verified-superuser", response_model=User) def optional_current_verified_superuser( - user=Depends( + user: UserModel = Depends( fastapi_users.current_user( optional=True, active=True, verified=True, superuser=True ) @@ -150,7 +161,9 @@ class TestGetCurrentUser: ) assert response.status_code == status.HTTP_401_UNAUTHORIZED - async def test_valid_token(self, test_app_client: httpx.AsyncClient, user: UserDB): + async def test_valid_token( + self, test_app_client: httpx.AsyncClient, user: UserModel + ): response = await test_app_client.get( "/current-user", headers={"Authorization": f"Bearer {user.id}"} ) @@ -171,7 +184,7 @@ class TestGetCurrentActiveUser: assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_valid_token_inactive_user( - self, test_app_client: httpx.AsyncClient, inactive_user: UserDB + self, test_app_client: httpx.AsyncClient, inactive_user: UserModel ): response = await test_app_client.get( "/current-active-user", @@ -179,7 +192,9 @@ class TestGetCurrentActiveUser: ) assert response.status_code == status.HTTP_401_UNAUTHORIZED - async def test_valid_token(self, test_app_client: httpx.AsyncClient, user: UserDB): + async def test_valid_token( + self, test_app_client: httpx.AsyncClient, user: UserModel + ): response = await test_app_client.get( "/current-active-user", headers={"Authorization": f"Bearer {user.id}"} ) @@ -200,7 +215,7 @@ class TestGetCurrentVerifiedUser: assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_valid_token_unverified_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/current-verified-user", @@ -209,7 +224,7 @@ class TestGetCurrentVerifiedUser: assert response.status_code == status.HTTP_403_FORBIDDEN async def test_valid_token_verified_user( - self, test_app_client: httpx.AsyncClient, verified_user: UserDB + self, test_app_client: httpx.AsyncClient, verified_user: UserModel ): response = await test_app_client.get( "/current-verified-user", @@ -232,7 +247,7 @@ class TestGetCurrentSuperuser: assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_valid_token_regular_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/current-superuser", headers={"Authorization": f"Bearer {user.id}"} @@ -240,7 +255,7 @@ class TestGetCurrentSuperuser: assert response.status_code == status.HTTP_403_FORBIDDEN async def test_valid_token_superuser( - self, test_app_client: httpx.AsyncClient, superuser: UserDB + self, test_app_client: httpx.AsyncClient, superuser: UserModel ): response = await test_app_client.get( "/current-superuser", headers={"Authorization": f"Bearer {superuser.id}"} @@ -262,7 +277,7 @@ class TestGetCurrentVerifiedSuperuser: assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_valid_token_regular_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/current-verified-superuser", @@ -271,7 +286,7 @@ class TestGetCurrentVerifiedSuperuser: assert response.status_code == status.HTTP_403_FORBIDDEN async def test_valid_token_verified_user( - self, test_app_client: httpx.AsyncClient, verified_user: UserDB + self, test_app_client: httpx.AsyncClient, verified_user: UserModel ): response = await test_app_client.get( "/current-verified-superuser", @@ -280,7 +295,7 @@ class TestGetCurrentVerifiedSuperuser: assert response.status_code == status.HTTP_403_FORBIDDEN async def test_valid_token_superuser( - self, test_app_client: httpx.AsyncClient, superuser: UserDB + self, test_app_client: httpx.AsyncClient, superuser: UserModel ): response = await test_app_client.get( "/current-verified-superuser", @@ -289,7 +304,7 @@ class TestGetCurrentVerifiedSuperuser: assert response.status_code == status.HTTP_403_FORBIDDEN async def test_valid_token_verified_superuser( - self, test_app_client: httpx.AsyncClient, verified_superuser: UserDB + self, test_app_client: httpx.AsyncClient, verified_superuser: UserModel ): response = await test_app_client.get( "/current-verified-superuser", @@ -313,7 +328,9 @@ class TestOptionalGetCurrentUser: assert response.status_code == status.HTTP_200_OK assert response.json() is None - async def test_valid_token(self, test_app_client: httpx.AsyncClient, user: UserDB): + async def test_valid_token( + self, test_app_client: httpx.AsyncClient, user: UserModel + ): response = await test_app_client.get( "/optional-current-user", headers={"Authorization": f"Bearer {user.id}"} ) @@ -337,7 +354,7 @@ class TestOptionalGetCurrentVerifiedUser: assert response.json() is None async def test_valid_token_unverified_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/optional-current-verified-user", @@ -347,7 +364,7 @@ class TestOptionalGetCurrentVerifiedUser: assert response.json() is None async def test_valid_token_verified_user( - self, test_app_client: httpx.AsyncClient, verified_user: UserDB + self, test_app_client: httpx.AsyncClient, verified_user: UserModel ): response = await test_app_client.get( "/optional-current-verified-user", @@ -373,7 +390,7 @@ class TestOptionalGetCurrentActiveUser: assert response.json() is None async def test_valid_token_inactive_user( - self, test_app_client: httpx.AsyncClient, inactive_user: UserDB + self, test_app_client: httpx.AsyncClient, inactive_user: UserModel ): response = await test_app_client.get( "/optional-current-active-user", @@ -382,7 +399,9 @@ class TestOptionalGetCurrentActiveUser: assert response.status_code == status.HTTP_200_OK assert response.json() is None - async def test_valid_token(self, test_app_client: httpx.AsyncClient, user: UserDB): + async def test_valid_token( + self, test_app_client: httpx.AsyncClient, user: UserModel + ): response = await test_app_client.get( "/optional-current-active-user", headers={"Authorization": f"Bearer {user.id}"}, @@ -407,7 +426,7 @@ class TestOptionalGetCurrentSuperuser: assert response.json() is None async def test_valid_token_regular_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/optional-current-superuser", @@ -417,7 +436,7 @@ class TestOptionalGetCurrentSuperuser: assert response.json() is None async def test_valid_token_superuser( - self, test_app_client: httpx.AsyncClient, superuser: UserDB + self, test_app_client: httpx.AsyncClient, superuser: UserModel ): response = await test_app_client.get( "/optional-current-superuser", @@ -444,7 +463,7 @@ class TestOptionalGetCurrentVerifiedSuperuser: assert response.json() is None async def test_valid_token_regular_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/optional-current-verified-superuser", @@ -454,7 +473,7 @@ class TestOptionalGetCurrentVerifiedSuperuser: assert response.json() is None async def test_valid_token_verified_user( - self, test_app_client: httpx.AsyncClient, verified_user: UserDB + self, test_app_client: httpx.AsyncClient, verified_user: UserModel ): response = await test_app_client.get( "/optional-current-verified-superuser", @@ -464,7 +483,7 @@ class TestOptionalGetCurrentVerifiedSuperuser: assert response.json() is None async def test_valid_token_superuser( - self, test_app_client: httpx.AsyncClient, superuser: UserDB + self, test_app_client: httpx.AsyncClient, superuser: UserModel ): response = await test_app_client.get( "/optional-current-verified-superuser", @@ -474,7 +493,7 @@ class TestOptionalGetCurrentVerifiedSuperuser: assert response.json() is None async def test_valid_token_verified_superuser( - self, test_app_client: httpx.AsyncClient, verified_superuser: UserDB + self, test_app_client: httpx.AsyncClient, verified_superuser: UserModel ): response = await test_app_client.get( "/optional-current-verified-superuser", diff --git a/tests/test_manager.py b/tests/test_manager.py index 97f8cd1b..6c9238f6 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,11 +1,12 @@ -from typing import Callable, cast +import copy +import uuid +from typing import Callable import pytest from fastapi.security import OAuth2PasswordRequestForm from pydantic import UUID4 from pytest_mock import MockerFixture -from fastapi_users import models from fastapi_users.jwt import decode_jwt, generate_jwt from fastapi_users.manager import ( InvalidPasswordException, @@ -16,11 +17,19 @@ from fastapi_users.manager import ( UserInactive, UserNotExists, ) -from tests.conftest import UserCreate, UserDB, UserDBOAuth, UserManagerMock, UserUpdate +from tests.conftest import ( + OAuthAccountModel, + UserCreate, + UserDBOAuth, + UserManagerMock, + UserModel, + UserOAuthModel, + UserUpdate, +) @pytest.fixture -def verify_token(user_manager: UserManagerMock): +def verify_token(user_manager: UserManagerMock[UserModel]): def _verify_token( user_id=None, email=None, @@ -37,7 +46,7 @@ def verify_token(user_manager: UserManagerMock): @pytest.fixture -def forgot_password_token(user_manager: UserManagerMock): +def forgot_password_token(user_manager: UserManagerMock[UserModel]): def _forgot_password_token( user_id=None, lifetime=user_manager.reset_password_token_lifetime_seconds ): @@ -62,11 +71,13 @@ def create_oauth2_password_request_form() -> Callable[ @pytest.mark.asyncio @pytest.mark.manager class TestGet: - async def test_not_existing_user(self, user_manager: UserManagerMock): + async def test_not_existing_user(self, user_manager: UserManagerMock[UserModel]): with pytest.raises(UserNotExists): await user_manager.get(UUID4("d35d213e-f3d8-4f08-954a-7e0d1bea286f")) - async def test_existing_user(self, user_manager: UserManagerMock, user: UserDB): + async def test_existing_user( + self, user_manager: UserManagerMock[UserModel], user: UserModel + ): retrieved_user = await user_manager.get(user.id) assert retrieved_user.id == user.id @@ -74,11 +85,13 @@ class TestGet: @pytest.mark.asyncio @pytest.mark.manager class TestGetByEmail: - async def test_not_existing_user(self, user_manager: UserManagerMock): + async def test_not_existing_user(self, user_manager: UserManagerMock[UserModel]): with pytest.raises(UserNotExists): await user_manager.get_by_email("lancelot@camelot.bt") - async def test_existing_user(self, user_manager: UserManagerMock, user: UserDB): + async def test_existing_user( + self, user_manager: UserManagerMock[UserModel], user: UserModel + ): retrieved_user = await user_manager.get_by_email(user.email) assert retrieved_user.id == user.id @@ -86,12 +99,14 @@ class TestGetByEmail: @pytest.mark.asyncio @pytest.mark.manager class TestGetByOAuthAccount: - async def test_not_existing_user(self, user_manager_oauth: UserManagerMock): + async def test_not_existing_user( + self, user_manager_oauth: UserManagerMock[UserModel] + ): with pytest.raises(UserNotExists): await user_manager_oauth.get_by_oauth_account("service1", "foo") async def test_existing_user( - self, user_manager_oauth: UserManagerMock, user_oauth: UserDBOAuth + self, user_manager_oauth: UserManagerMock[UserModel], user_oauth: UserDBOAuth ): oauth_account = user_oauth.oauth_accounts[0] retrieved_user = await user_manager_oauth.get_by_oauth_account( @@ -106,42 +121,46 @@ class TestCreateUser: @pytest.mark.parametrize( "email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"] ) - async def test_existing_user(self, email: str, user_manager: UserManagerMock): + async def test_existing_user( + self, email: str, user_manager: UserManagerMock[UserModel] + ): user = UserCreate(email=email, password="guinevere") with pytest.raises(UserAlreadyExists): await user_manager.create(user) assert user_manager.on_after_register.called is False @pytest.mark.parametrize("email", ["lancelot@camelot.bt", "Lancelot@camelot.bt"]) - async def test_regular_user(self, email: str, user_manager: UserManagerMock): + async def test_regular_user( + self, email: str, user_manager: UserManagerMock[UserModel] + ): user = UserCreate(email=email, password="guinevere") created_user = await user_manager.create(user) - assert type(created_user) == UserDB + assert type(created_user) == UserModel assert user_manager.on_after_register.called is True @pytest.mark.parametrize("safe,result", [(True, False), (False, True)]) async def test_superuser( - self, user_manager: UserManagerMock, safe: bool, result: bool + self, user_manager: UserManagerMock[UserModel], safe: bool, result: bool ): user = UserCreate( email="lancelot@camelot.b", password="guinevere", is_superuser=True ) created_user = await user_manager.create(user, safe) - assert type(created_user) == UserDB + assert type(created_user) == UserModel assert created_user.is_superuser is result assert user_manager.on_after_register.called is True @pytest.mark.parametrize("safe,result", [(True, True), (False, False)]) async def test_is_active( - self, user_manager: UserManagerMock, safe: bool, result: bool + self, user_manager: UserManagerMock[UserModel], safe: bool, result: bool ): user = UserCreate( email="lancelot@camelot.b", password="guinevere", is_active=False ) created_user = await user_manager.create(user, safe) - assert type(created_user) == UserDB + assert type(created_user) == UserModel assert created_user.is_active is result assert user_manager.on_after_register.called is True @@ -151,13 +170,15 @@ class TestCreateUser: @pytest.mark.manager class TestOAuthCallback: async def test_existing_user_with_oauth( - self, user_manager_oauth: UserManagerMock, user_oauth: UserDBOAuth + self, + user_manager_oauth: UserManagerMock[UserOAuthModel], + user_oauth: UserOAuthModel, ): - oauth_account = models.BaseOAuthAccount( - **user_oauth.oauth_accounts[0].dict(exclude={"id", "access_token"}), - access_token="UPDATED_TOKEN" - ) - user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account)) + oauth_account = copy.deepcopy(user_oauth.oauth_accounts[0]) + oauth_account.id = uuid.uuid4() + oauth_account.access_token = "UPDATED_TOKEN" + + user = await user_manager_oauth.oauth_callback(oauth_account) assert user.id == user_oauth.id assert len(user.oauth_accounts) == 2 @@ -169,16 +190,19 @@ class TestOAuthCallback: assert user_manager_oauth.on_after_register.called is False async def test_existing_user_without_oauth( - self, user_manager_oauth: UserManagerMock, superuser_oauth: UserDBOAuth + self, + user_manager_oauth: UserManagerMock[UserOAuthModel], + superuser_oauth: UserDBOAuth, ): - oauth_account = models.BaseOAuthAccount( + oauth_account = OAuthAccountModel( oauth_name="service1", access_token="TOKEN", expires_at=1579000751, account_id="superuser_oauth1", account_email=superuser_oauth.email, ) - user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account)) + + user = await user_manager_oauth.oauth_callback(oauth_account) assert user.id == superuser_oauth.id assert len(user.oauth_accounts) == 1 @@ -186,15 +210,16 @@ class TestOAuthCallback: assert user_manager_oauth.on_after_register.called is False - async def test_new_user(self, user_manager_oauth: UserManagerMock): - oauth_account = models.BaseOAuthAccount( + async def test_new_user(self, user_manager_oauth: UserManagerMock[UserOAuthModel]): + oauth_account = OAuthAccountModel( oauth_name="service1", access_token="TOKEN", expires_at=1579000751, account_id="new_user_oauth1", account_email="galahad@camelot.bt", ) - user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account)) + + user = await user_manager_oauth.oauth_callback(oauth_account) assert user.email == "galahad@camelot.bt" assert len(user.oauth_accounts) == 1 @@ -207,19 +232,19 @@ class TestOAuthCallback: @pytest.mark.manager class TestRequestVerifyUser: async def test_user_inactive( - self, user_manager: UserManagerMock, inactive_user: UserDB + self, user_manager: UserManagerMock[UserModel], inactive_user: UserModel ): with pytest.raises(UserInactive): await user_manager.request_verify(inactive_user) async def test_user_verified( - self, user_manager: UserManagerMock, verified_user: UserDB + self, user_manager: UserManagerMock[UserModel], verified_user: UserModel ): with pytest.raises(UserAlreadyVerified): await user_manager.request_verify(verified_user) async def test_user_active_not_verified( - self, user_manager: UserManagerMock, user: UserDB + self, user_manager: UserManagerMock[UserModel], user: UserModel ): await user_manager.request_verify(user) assert user_manager.on_after_request_verify.called is True @@ -240,40 +265,40 @@ class TestRequestVerifyUser: @pytest.mark.asyncio @pytest.mark.manager class TestVerifyUser: - async def test_invalid_token(self, user_manager: UserManagerMock): + async def test_invalid_token(self, user_manager: UserManagerMock[UserModel]): with pytest.raises(InvalidVerifyToken): await user_manager.verify("foo") async def test_token_expired( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): with pytest.raises(InvalidVerifyToken): token = verify_token(user_id=user.id, email=user.email, lifetime=-1) await user_manager.verify(token) async def test_missing_user_id( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): with pytest.raises(InvalidVerifyToken): token = verify_token(email=user.email) await user_manager.verify(token) async def test_missing_user_email( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): with pytest.raises(InvalidVerifyToken): token = verify_token(user_id=user.id) await user_manager.verify(token) async def test_invalid_user_id( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): with pytest.raises(InvalidVerifyToken): token = verify_token(user_id="foo", email=user.email) await user_manager.verify(token) async def test_invalid_email( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): with pytest.raises(InvalidVerifyToken): token = verify_token(user_id=user.id, email="foo") @@ -281,9 +306,9 @@ class TestVerifyUser: async def test_email_id_mismatch( self, - user_manager: UserManagerMock, - user: UserDB, - inactive_user: UserDB, + user_manager: UserManagerMock[UserModel], + user: UserModel, + inactive_user: UserModel, verify_token, ): with pytest.raises(InvalidVerifyToken): @@ -291,14 +316,20 @@ class TestVerifyUser: await user_manager.verify(token) async def test_verified_user( - self, user_manager: UserManagerMock, verified_user: UserDB, verify_token + self, + user_manager: UserManagerMock[UserModel], + verified_user: UserModel, + verify_token, ): with pytest.raises(UserAlreadyVerified): token = verify_token(user_id=verified_user.id, email=verified_user.email) await user_manager.verify(token) async def test_inactive_user( - self, user_manager: UserManagerMock, inactive_user: UserDB, verify_token + self, + user_manager: UserManagerMock[UserModel], + inactive_user: UserModel, + verify_token, ): token = verify_token(user_id=inactive_user.id, email=inactive_user.email) verified_user = await user_manager.verify(token) @@ -307,7 +338,7 @@ class TestVerifyUser: assert verified_user.is_active is False async def test_active_user( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): token = verify_token(user_id=user.id, email=user.email) verified_user = await user_manager.verify(token) @@ -320,13 +351,15 @@ class TestVerifyUser: @pytest.mark.manager class TestForgotPassword: async def test_user_inactive( - self, user_manager: UserManagerMock, inactive_user: UserDB + self, user_manager: UserManagerMock[UserModel], inactive_user: UserModel ): with pytest.raises(UserInactive): await user_manager.forgot_password(inactive_user) assert user_manager.on_after_forgot_password.called is False - async def test_user_active(self, user_manager: UserManagerMock, user: UserDB): + async def test_user_active( + self, user_manager: UserManagerMock[UserModel], user: UserModel + ): await user_manager.forgot_password(user) assert user_manager.on_after_forgot_password.called is True @@ -345,14 +378,17 @@ class TestForgotPassword: @pytest.mark.asyncio @pytest.mark.manager class TestResetPassword: - async def test_invalid_token(self, user_manager: UserManagerMock): + async def test_invalid_token(self, user_manager: UserManagerMock[UserModel]): with pytest.raises(InvalidResetPasswordToken): await user_manager.reset_password("foo", "guinevere") assert user_manager._update.called is False assert user_manager.on_after_reset_password.called is False async def test_token_expired( - self, user_manager: UserManagerMock, user: UserDB, forgot_password_token + self, + user_manager: UserManagerMock[UserModel], + user: UserModel, + forgot_password_token, ): with pytest.raises(InvalidResetPasswordToken): await user_manager.reset_password( @@ -363,7 +399,10 @@ class TestResetPassword: @pytest.mark.parametrize("user_id", [None, "foo"]) async def test_valid_token_bad_payload( - self, user_id: str, user_manager: UserManagerMock, forgot_password_token + self, + user_id: str, + user_manager: UserManagerMock[UserModel], + forgot_password_token, ): with pytest.raises(InvalidResetPasswordToken): await user_manager.reset_password( @@ -373,7 +412,7 @@ class TestResetPassword: assert user_manager.on_after_reset_password.called is False async def test_not_existing_user( - self, user_manager: UserManagerMock, forgot_password_token + self, user_manager: UserManagerMock[UserModel], forgot_password_token ): with pytest.raises(UserNotExists): await user_manager.reset_password( @@ -385,8 +424,8 @@ class TestResetPassword: async def test_inactive_user( self, - inactive_user: UserDB, - user_manager: UserManagerMock, + inactive_user: UserModel, + user_manager: UserManagerMock[UserModel], forgot_password_token, ): with pytest.raises(UserInactive): @@ -398,7 +437,10 @@ class TestResetPassword: assert user_manager.on_after_reset_password.called is False async def test_invalid_password( - self, user: UserDB, user_manager: UserManagerMock, forgot_password_token + self, + user: UserModel, + user_manager: UserManagerMock[UserModel], + forgot_password_token, ): with pytest.raises(InvalidPasswordException): await user_manager.reset_password( @@ -408,7 +450,10 @@ class TestResetPassword: assert user_manager.on_after_reset_password.called is False async def test_valid_user_password( - self, user: UserDB, user_manager: UserManagerMock, forgot_password_token + self, + user: UserModel, + user_manager: UserManagerMock[UserModel], + forgot_password_token, ): await user_manager.reset_password(forgot_password_token(user.id), "holygrail") @@ -424,7 +469,9 @@ class TestResetPassword: @pytest.mark.asyncio @pytest.mark.manager class TestUpdateUser: - async def test_safe_update(self, user: UserDB, user_manager: UserManagerMock): + async def test_safe_update( + self, user: UserModel, user_manager: UserManagerMock[UserModel] + ): user_update = UserUpdate(first_name="Arthur", is_superuser=True) updated_user = await user_manager.update(user_update, user, safe=True) @@ -433,7 +480,9 @@ class TestUpdateUser: assert user_manager.on_after_update.called is True - async def test_unsafe_update(self, user: UserDB, user_manager: UserManagerMock): + async def test_unsafe_update( + self, user: UserModel, user_manager: UserManagerMock[UserModel] + ): user_update = UserUpdate(first_name="Arthur", is_superuser=True) updated_user = await user_manager.update(user_update, user, safe=False) @@ -443,7 +492,7 @@ class TestUpdateUser: assert user_manager.on_after_update.called is True async def test_password_update_invalid( - self, user: UserDB, user_manager: UserManagerMock + self, user: UserModel, user_manager: UserManagerMock[UserModel] ): user_update = UserUpdate(password="h") with pytest.raises(InvalidPasswordException): @@ -452,7 +501,7 @@ class TestUpdateUser: assert user_manager.on_after_update.called is False async def test_password_update_valid( - self, user: UserDB, user_manager: UserManagerMock + self, user: UserModel, user_manager: UserManagerMock[UserModel] ): old_hashed_password = user.hashed_password user_update = UserUpdate(password="holygrail") @@ -463,7 +512,10 @@ class TestUpdateUser: assert user_manager.on_after_update.called is True async def test_email_update_already_existing( - self, user: UserDB, superuser: UserDB, user_manager: UserManagerMock + self, + user: UserModel, + superuser: UserModel, + user_manager: UserManagerMock[UserModel], ): user_update = UserUpdate(email=superuser.email) with pytest.raises(UserAlreadyExists): @@ -472,7 +524,7 @@ class TestUpdateUser: assert user_manager.on_after_update.called is False async def test_email_update_with_same_email( - self, user: UserDB, user_manager: UserManagerMock + self, user: UserModel, user_manager: UserManagerMock[UserModel] ): user_update = UserUpdate(email=user.email) updated_user = await user_manager.update(user_update, user, safe=True) @@ -485,7 +537,9 @@ class TestUpdateUser: @pytest.mark.asyncio @pytest.mark.manager class TestDelete: - async def test_delete(self, user: UserDB, user_manager: UserManagerMock): + async def test_delete( + self, user: UserModel, user_manager: UserManagerMock[UserModel] + ): await user_manager.delete(user) @@ -497,7 +551,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManagerMock, + user_manager: UserManagerMock[UserModel], ): form = create_oauth2_password_request_form("lancelot@camelot.bt", "guinevere") user = await user_manager.authenticate(form) @@ -508,7 +562,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManagerMock, + user_manager: UserManagerMock[UserModel], ): form = create_oauth2_password_request_form("king.arthur@camelot.bt", "percival") user = await user_manager.authenticate(form) @@ -519,7 +573,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManagerMock, + user_manager: UserManagerMock[UserModel], ): form = create_oauth2_password_request_form( "king.arthur@camelot.bt", "guinevere" @@ -534,7 +588,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManagerMock, + user_manager: UserManagerMock[UserModel], ): verify_and_update_password_patch = mocker.patch.object( user_manager.password_helper, "verify_and_update" diff --git a/tests/test_openapi.py b/tests/test_openapi.py index 205cf746..197aa9a1 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -3,18 +3,13 @@ import pytest from fastapi import FastAPI, status from fastapi_users.fastapi_users import FastAPIUsers -from tests.conftest import User, UserCreate, UserDB, UserUpdate +from tests.conftest import User, UserCreate, UserDB, UserModel, UserUpdate @pytest.fixture def fastapi_users(get_user_manager, mock_authentication) -> FastAPIUsers: - return FastAPIUsers[User, UserCreate, UserUpdate, UserDB]( - get_user_manager, - [mock_authentication], - User, - UserCreate, - UserUpdate, - UserDB, + return FastAPIUsers[UserModel, User, UserCreate, UserUpdate, UserDB]( + get_user_manager, [mock_authentication], User, UserCreate, UserUpdate ) diff --git a/tests/test_router_auth.py b/tests/test_router_auth.py index 067c3743..1bbc8026 100644 --- a/tests/test_router_auth.py +++ b/tests/test_router_auth.py @@ -6,7 +6,7 @@ from fastapi import FastAPI, status from fastapi_users.authentication import Authenticator from fastapi_users.router import ErrorCode, get_auth_router -from tests.conftest import UserDB, get_mock_authentication +from tests.conftest import UserModel, get_mock_authentication @pytest.fixture @@ -118,7 +118,7 @@ class TestLogin: path, email, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client data = {"username": email, "password": "guinevere"} @@ -140,7 +140,7 @@ class TestLogin: path, email, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client data = {"username": email, "password": "excalibur"} @@ -182,7 +182,7 @@ class TestLogout: mocker, path, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.post( @@ -198,7 +198,7 @@ class TestLogout: mocker, path, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.post( diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index 89194b3f..bc30648b 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -7,7 +7,7 @@ from httpx_oauth.oauth2 import BaseOAuth2, OAuth2 from fastapi_users.authentication import AuthenticationBackend from fastapi_users.router.oauth import generate_state_token, get_oauth_router -from tests.conftest import AsyncMethodMocker, UserDB, UserManagerMock +from tests.conftest import AsyncMethodMocker, UserManagerMock, UserOAuthModel @pytest.fixture @@ -112,7 +112,7 @@ class TestCallback: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, oauth_client: BaseOAuth2, - user_oauth: UserDB, + user_oauth: UserOAuthModel, access_token: str, ): async_method_mocker(oauth_client, "get_access_token", return_value=access_token) @@ -133,7 +133,7 @@ class TestCallback: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, oauth_client: BaseOAuth2, - user_oauth: UserDB, + user_oauth: UserOAuthModel, user_manager_oauth: UserManagerMock, access_token: str, ): @@ -161,7 +161,7 @@ class TestCallback: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, oauth_client: BaseOAuth2, - inactive_user_oauth: UserDB, + inactive_user_oauth: UserOAuthModel, user_manager_oauth: UserManagerMock, access_token: str, ): @@ -188,7 +188,7 @@ class TestCallback: async_method_mocker: AsyncMethodMocker, test_app_client_redirect_url: httpx.AsyncClient, oauth_client: BaseOAuth2, - user_oauth: UserDB, + user_oauth: UserOAuthModel, user_manager_oauth: UserManagerMock, access_token: str, ): diff --git a/tests/test_router_users.py b/tests/test_router_users.py index 1d4752c9..cc4364a9 100644 --- a/tests/test_router_users.py +++ b/tests/test_router_users.py @@ -6,7 +6,7 @@ from fastapi import FastAPI, status from fastapi_users.authentication import Authenticator from fastapi_users.router import ErrorCode, get_users_router -from tests.conftest import User, UserDB, UserUpdate, get_mock_authentication +from tests.conftest import User, UserModel, UserUpdate, get_mock_authentication @pytest.fixture @@ -21,7 +21,6 @@ def app_factory(get_user_manager, mock_authentication): get_user_manager, User, UserUpdate, - UserDB, authenticator, requires_verification=requires_verification, ) @@ -59,7 +58,7 @@ class TestMe: async def test_inactive_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - inactive_user: UserDB, + inactive_user: UserModel, ): client, _ = test_app_client response = await client.get( @@ -70,7 +69,7 @@ class TestMe: async def test_active_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.get( @@ -87,7 +86,7 @@ class TestMe: async def test_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.get( @@ -116,7 +115,7 @@ class TestUpdateMe: async def test_inactive_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - inactive_user: UserDB, + inactive_user: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -127,8 +126,8 @@ class TestUpdateMe: async def test_existing_email( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_user: UserDB, + user: UserModel, + verified_user: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -146,7 +145,7 @@ class TestUpdateMe: async def test_invalid_password( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -167,7 +166,7 @@ class TestUpdateMe: async def test_empty_body( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -184,7 +183,7 @@ class TestUpdateMe: async def test_valid_body( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client json = {"email": "king.arthur@tintagel.bt"} @@ -202,7 +201,7 @@ class TestUpdateMe: async def test_unverified_after_email_change( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client json = {"email": "king.arthur@tintagel.bt"} @@ -217,7 +216,7 @@ class TestUpdateMe: async def test_valid_body_is_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client json = {"is_superuser": True} @@ -235,7 +234,7 @@ class TestUpdateMe: async def test_valid_body_is_active( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client json = {"is_active": False} @@ -253,7 +252,7 @@ class TestUpdateMe: async def test_valid_body_is_verified( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client json = {"is_verified": True} @@ -273,7 +272,7 @@ class TestUpdateMe: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client mocker.spy(mock_user_db, "update") @@ -295,7 +294,7 @@ class TestUpdateMe: async def test_empty_body_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -309,7 +308,7 @@ class TestUpdateMe: async def test_valid_body_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client json = {"email": "king.arthur@tintagel.bt"} @@ -324,7 +323,7 @@ class TestUpdateMe: async def test_valid_body_is_superuser_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client json = {"is_superuser": True} @@ -339,7 +338,7 @@ class TestUpdateMe: async def test_valid_body_is_active_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client json = {"is_active": False} @@ -354,7 +353,7 @@ class TestUpdateMe: async def test_valid_body_is_verified_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client json = {"is_verified": False} @@ -371,7 +370,7 @@ class TestUpdateMe: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client mocker.spy(mock_user_db, "update") @@ -399,7 +398,7 @@ class TestGetUser: async def test_regular_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.get( @@ -412,7 +411,7 @@ class TestGetUser: async def test_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.get( @@ -424,7 +423,7 @@ class TestGetUser: async def test_not_existing_user_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - superuser: UserDB, + superuser: UserModel, ): client, requires_verification = test_app_client response = await client.get( @@ -439,7 +438,7 @@ class TestGetUser: async def test_not_existing_user_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_superuser: UserDB, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.get( @@ -451,8 +450,8 @@ class TestGetUser: async def test_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client response = await client.get( @@ -470,8 +469,8 @@ class TestGetUser: async def test_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.get( @@ -483,7 +482,7 @@ class TestGetUser: assert data["id"] == str(user.id) assert "hashed_password" not in data - async def test_get_user_namespace(self, app_factory, user: UserDB): + async def test_get_user_namespace(self, app_factory, user: UserModel): assert app_factory(True).url_path_for("users:user", id=user.id) == f"/{user.id}" @@ -498,7 +497,7 @@ class TestUpdateUser: async def test_regular_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -511,7 +510,7 @@ class TestUpdateUser: async def test_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -523,7 +522,7 @@ class TestUpdateUser: async def test_not_existing_user_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - superuser: UserDB, + superuser: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -539,7 +538,7 @@ class TestUpdateUser: async def test_not_existing_user_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_superuser: UserDB, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -552,8 +551,8 @@ class TestUpdateUser: async def test_empty_body_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -570,8 +569,8 @@ class TestUpdateUser: async def test_empty_body_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -587,8 +586,8 @@ class TestUpdateUser: async def test_valid_body_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client json = {"email": "king.arthur@tintagel.bt"} @@ -608,9 +607,9 @@ class TestUpdateUser: async def test_existing_email_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -625,8 +624,8 @@ class TestUpdateUser: async def test_invalid_password_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -644,8 +643,8 @@ class TestUpdateUser: async def test_valid_body_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client json = {"email": "king.arthur@tintagel.bt"} @@ -662,8 +661,8 @@ class TestUpdateUser: async def test_valid_body_is_superuser_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client json = {"is_superuser": True} @@ -683,8 +682,8 @@ class TestUpdateUser: async def test_valid_body_is_superuser_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client json = {"is_superuser": True} @@ -701,8 +700,8 @@ class TestUpdateUser: async def test_valid_body_is_active_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client json = {"is_active": False} @@ -722,8 +721,8 @@ class TestUpdateUser: async def test_valid_body_is_active_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client json = {"is_active": False} @@ -740,8 +739,8 @@ class TestUpdateUser: async def test_valid_body_is_verified_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client json = {"is_verified": True} @@ -761,8 +760,8 @@ class TestUpdateUser: async def test_valid_body_is_verified_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client json = {"is_verified": True} @@ -781,8 +780,8 @@ class TestUpdateUser: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client mocker.spy(mock_user_db, "update") @@ -808,8 +807,8 @@ class TestUpdateUser: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client mocker.spy(mock_user_db, "update") @@ -839,7 +838,7 @@ class TestDeleteUser: async def test_regular_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.delete( @@ -852,7 +851,7 @@ class TestDeleteUser: async def test_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.delete( @@ -864,7 +863,7 @@ class TestDeleteUser: async def test_not_existing_user_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - superuser: UserDB, + superuser: UserModel, ): client, requires_verification = test_app_client response = await client.delete( @@ -879,7 +878,7 @@ class TestDeleteUser: async def test_not_existing_user_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_superuser: UserDB, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.delete( @@ -893,8 +892,8 @@ class TestDeleteUser: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client mocker.spy(mock_user_db, "delete") @@ -917,8 +916,8 @@ class TestDeleteUser: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client mocker.spy(mock_user_db, "delete") diff --git a/tests/test_router_verify.py b/tests/test_router_verify.py index c954a0fd..f7619c45 100644 --- a/tests/test_router_verify.py +++ b/tests/test_router_verify.py @@ -11,7 +11,7 @@ from fastapi_users.manager import ( UserNotExists, ) from fastapi_users.router import ErrorCode, get_verify_router -from tests.conftest import AsyncMethodMocker, User, UserDB, UserManagerMock +from tests.conftest import AsyncMethodMocker, User, UserManagerMock, UserModel @pytest.fixture @@ -20,10 +20,7 @@ async def test_app_client( get_user_manager, get_test_client, ) -> AsyncGenerator[httpx.AsyncClient, None]: - verify_router = get_verify_router( - get_user_manager, - User, - ) + verify_router = get_verify_router(get_user_manager, User) app = FastAPI() app.include_router(verify_router) @@ -70,7 +67,7 @@ class TestVerifyTokenRequest: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock, - user: UserDB, + user: UserModel, ): async_method_mocker(user_manager, "get_by_email", return_value=user) user_manager.request_verify.side_effect = UserInactive() @@ -83,7 +80,7 @@ class TestVerifyTokenRequest: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock, - user: UserDB, + user: UserModel, ): async_method_mocker(user_manager, "get_by_email", return_value=user) user_manager.request_verify.side_effect = UserAlreadyVerified() @@ -96,7 +93,7 @@ class TestVerifyTokenRequest: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock, - user: UserDB, + user: UserModel, ): async_method_mocker(user_manager, "get_by_email", return_value=user) async_method_mocker(user_manager, "request_verify", return_value=None) @@ -171,7 +168,7 @@ class TestVerify: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock, - user: UserDB, + user: UserModel, ): async_method_mocker(user_manager, "verify", return_value=user) response = await test_app_client.post("/verify", json={"token": "foo"})