From fdc8e542530faddea89799a58bf811adfb00ffe3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 14 Sep 2021 11:53:43 +0200 Subject: [PATCH] Improve generic typing --- fastapi_users/authentication/__init__.py | 10 +++++----- fastapi_users/authentication/base.py | 10 +++++----- fastapi_users/authentication/cookie.py | 12 ++++++------ fastapi_users/authentication/jwt.py | 10 +++++----- fastapi_users/fastapi_users.py | 22 +++++++++++----------- fastapi_users/manager.py | 8 +++----- fastapi_users/models.py | 3 +++ fastapi_users/router/auth.py | 4 ++-- fastapi_users/router/oauth.py | 6 +++--- fastapi_users/router/register.py | 10 ++++------ fastapi_users/router/reset.py | 2 +- fastapi_users/router/users.py | 8 ++++---- fastapi_users/router/verify.py | 2 +- tests/test_authentication.py | 12 ++++++------ tests/test_fastapi_users.py | 2 +- 15 files changed, 60 insertions(+), 61 deletions(-) diff --git a/fastapi_users/authentication/__init__.py b/fastapi_users/authentication/__init__.py index d95fb970..f80827a9 100644 --- a/fastapi_users/authentication/__init__.py +++ b/fastapi_users/authentication/__init__.py @@ -5,11 +5,11 @@ from typing import Optional, Sequence from fastapi import Depends, HTTPException, status from makefun import with_signature +from fastapi_users import models from fastapi_users.authentication.base import BaseAuthentication # noqa: F401 from fastapi_users.authentication.cookie import CookieAuthentication # noqa: F401 from fastapi_users.authentication.jwt import JWTAuthentication # noqa: F401 from fastapi_users.manager import UserManager, UserManagerDependency -from fastapi_users.models import BaseUserDB INVALID_CHARS_PATTERN = re.compile(r"[^0-9a-zA-Z_]") INVALID_LEADING_CHARS_PATTERN = re.compile(r"^[^a-zA-Z_]+") @@ -43,7 +43,7 @@ class Authenticator: def __init__( self, backends: Sequence[BaseAuthentication], - get_user_manager: UserManagerDependency, + get_user_manager: UserManagerDependency[models.UD], ): self.backends = backends self.get_user_manager = get_user_manager @@ -108,14 +108,14 @@ class Authenticator: async def _authenticate( self, *args, - user_manager: UserManager, + user_manager: UserManager[models.UD], optional: bool = False, active: bool = False, verified: bool = False, superuser: bool = False, **kwargs - ) -> Optional[BaseUserDB]: - user: Optional[BaseUserDB] = None + ) -> Optional[models.UD]: + user: Optional[models.UD] = None for backend in self.backends: token: str = kwargs[name_to_variable_name(backend.name)] if token: diff --git a/fastapi_users/authentication/base.py b/fastapi_users/authentication/base.py index 04912b42..2d780ee1 100644 --- a/fastapi_users/authentication/base.py +++ b/fastapi_users/authentication/base.py @@ -3,8 +3,8 @@ from typing import Any, Generic, Optional, TypeVar from fastapi import Response from fastapi.security.base import SecurityBase +from fastapi_users import models from fastapi_users.manager import UserManager -from fastapi_users.models import BaseUserDB T = TypeVar("T") @@ -28,12 +28,12 @@ class BaseAuthentication(Generic[T]): self.logout = logout async def __call__( - self, credentials: Optional[T], user_manager: UserManager - ) -> Optional[BaseUserDB]: + self, credentials: Optional[T], user_manager: UserManager[models.UD] + ) -> Optional[models.UD]: raise NotImplementedError() - async def get_login_response(self, user: BaseUserDB, response: Response) -> Any: + async def get_login_response(self, user: models.UD, response: Response) -> Any: raise NotImplementedError() - async def get_logout_response(self, user: BaseUserDB, response: Response) -> Any: + async def get_logout_response(self, user: models.UD, response: Response) -> Any: raise NotImplementedError() diff --git a/fastapi_users/authentication/cookie.py b/fastapi_users/authentication/cookie.py index 79ae205f..e2174dca 100644 --- a/fastapi_users/authentication/cookie.py +++ b/fastapi_users/authentication/cookie.py @@ -5,10 +5,10 @@ from fastapi import Response from fastapi.security import APIKeyCookie from pydantic import UUID4 +from fastapi_users import models from fastapi_users.authentication import BaseAuthentication from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.manager import UserManager, UserNotExists -from fastapi_users.models import BaseUserDB class CookieAuthentication(BaseAuthentication[str]): @@ -67,8 +67,8 @@ class CookieAuthentication(BaseAuthentication[str]): async def __call__( self, credentials: Optional[str], - user_manager: UserManager, - ) -> Optional[BaseUserDB]: + user_manager: UserManager[models.UD], + ) -> Optional[models.UD]: if credentials is None: return None @@ -88,7 +88,7 @@ class CookieAuthentication(BaseAuthentication[str]): except UserNotExists: return None - async def get_login_response(self, user: BaseUserDB, response: Response) -> Any: + async def get_login_response(self, user: models.UD, response: Response) -> Any: token = await self._generate_token(user) response.set_cookie( self.cookie_name, @@ -105,11 +105,11 @@ class CookieAuthentication(BaseAuthentication[str]): # so that FastAPI can terminate it properly return None - async def get_logout_response(self, user: BaseUserDB, response: Response) -> Any: + async def get_logout_response(self, user: models.UD, response: Response) -> Any: response.delete_cookie( self.cookie_name, path=self.cookie_path, domain=self.cookie_domain ) - async def _generate_token(self, user: BaseUserDB) -> str: + async def _generate_token(self, user: models.UD) -> str: data = {"user_id": str(user.id), "aud": self.token_audience} return generate_jwt(data, self.secret, self.lifetime_seconds) diff --git a/fastapi_users/authentication/jwt.py b/fastapi_users/authentication/jwt.py index b6a95297..7315431e 100644 --- a/fastapi_users/authentication/jwt.py +++ b/fastapi_users/authentication/jwt.py @@ -5,10 +5,10 @@ from fastapi import Response from fastapi.security import OAuth2PasswordBearer from pydantic import UUID4 +from fastapi_users import models from fastapi_users.authentication.base import BaseAuthentication from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.manager import UserManager, UserNotExists -from fastapi_users.models import BaseUserDB class JWTAuthentication(BaseAuthentication[str]): @@ -44,8 +44,8 @@ class JWTAuthentication(BaseAuthentication[str]): async def __call__( self, credentials: Optional[str], - user_manager: UserManager, - ) -> Optional[BaseUserDB]: + user_manager: UserManager[models.UD], + ) -> Optional[models.UD]: if credentials is None: return None @@ -65,10 +65,10 @@ class JWTAuthentication(BaseAuthentication[str]): except UserNotExists: return None - async def get_login_response(self, user: BaseUserDB, response: Response) -> Any: + async def get_login_response(self, user: models.UD, response: Response) -> Any: token = await self._generate_token(user) return {"access_token": token, "token_type": "bearer"} - async def _generate_token(self, user: BaseUserDB) -> str: + async def _generate_token(self, user: models.UD) -> str: data = {"user_id": str(user.id), "aud": self.token_audience} return generate_jwt(data, self.secret, self.lifetime_seconds) diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index f0325ee8..a13fd913 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Optional, Sequence, Type +from typing import Any, Callable, Dict, Generic, Optional, Sequence, Type from fastapi import APIRouter, Depends, Request @@ -24,7 +24,7 @@ except ModuleNotFoundError: # pragma: no cover BaseOAuth2 = Type # type: ignore -class FastAPIUsers: +class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]): """ Main object that ties together the component for users authentication. @@ -45,19 +45,19 @@ class FastAPIUsers: authenticator: Authenticator validate_password: Optional[ValidatePasswordProtocol] - _user_model: Type[models.BaseUser] - _user_create_model: Type[models.BaseUserCreate] - _user_update_model: Type[models.BaseUserUpdate] - _user_db_model: Type[models.BaseUserDB] + _user_model: Type[models.U] + _user_create_model: Type[models.UC] + _user_update_model: Type[models.UU] + _user_db_model: Type[models.UD] def __init__( self, - get_db: UserDatabaseDependency, + get_db: UserDatabaseDependency[models.UD], auth_backends: Sequence[BaseAuthentication], - user_model: Type[models.BaseUser], - user_create_model: Type[models.BaseUserCreate], - user_update_model: Type[models.BaseUserUpdate], - user_db_model: Type[models.BaseUserDB], + user_model: Type[models.U], + user_create_model: Type[models.UC], + user_update_model: Type[models.UU], + user_db_model: Type[models.UD], validate_password: Optional[ValidatePasswordProtocol] = None, ): def get_user_manager( diff --git a/fastapi_users/manager.py b/fastapi_users/manager.py index 3a61066a..257beda4 100644 --- a/fastapi_users/manager.py +++ b/fastapi_users/manager.py @@ -37,7 +37,7 @@ class InvalidPasswordException(FastAPIUsersException): class ValidatePasswordProtocol(Protocol): # pragma: no cover def __call__( - self, password: str, user: Union[models.BaseUserCreate, models.BaseUserDB] + self, password: str, user: Union[models.UC, models.UD] ) -> Awaitable[None]: pass @@ -82,9 +82,7 @@ class UserManager(Generic[models.UD]): return user - async def create( - self, user: models.BaseUserCreate, safe: bool = False - ) -> models.UD: + async def create(self, user: models.UC, safe: bool = False) -> models.UD: if self.validate_password: await self.validate_password(user.password, user) @@ -107,7 +105,7 @@ class UserManager(Generic[models.UD]): return await self.user_db.update(user) async def update( - self, updated_user: models.BaseUserUpdate, user: models.UD, safe: bool = False + self, updated_user: models.UU, user: models.UD, safe: bool = False ) -> models.UD: if safe: updated_user_data = updated_user.create_update_dict() diff --git a/fastapi_users/models.py b/fastapi_users/models.py index 5693803b..6529a3f5 100644 --- a/fastapi_users/models.py +++ b/fastapi_users/models.py @@ -54,6 +54,9 @@ class BaseUserDB(BaseUser): orm_mode = True +U = TypeVar("U", bound=BaseUser) +UC = TypeVar("UC", bound=BaseUserCreate) +UU = TypeVar("UU", bound=BaseUserUpdate) UD = TypeVar("UD", bound=BaseUserDB) diff --git a/fastapi_users/router/auth.py b/fastapi_users/router/auth.py index f508e798..321856b1 100644 --- a/fastapi_users/router/auth.py +++ b/fastapi_users/router/auth.py @@ -9,7 +9,7 @@ from fastapi_users.router.common import ErrorCode def get_auth_router( backend: BaseAuthentication, - get_user_manager: UserManagerDependency[models.BaseUserDB], + get_user_manager: UserManagerDependency[models.UD], authenticator: Authenticator, requires_verification: bool = False, ) -> APIRouter: @@ -23,7 +23,7 @@ def get_auth_router( async def login( response: Response, credentials: OAuth2PasswordRequestForm = Depends(), - user_manager: UserManager[models.BaseUserDB] = Depends(get_user_manager), + user_manager: UserManager[models.UD] = Depends(get_user_manager), ): user = await user_manager.authenticate(credentials) diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index 753dfd2f..0b7f66f7 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -24,8 +24,8 @@ def generate_state_token( def get_oauth_router( oauth_client: BaseOAuth2, - get_user_manager: UserManagerDependency[models.BaseUserDB], - user_db_model: Type[models.BaseUserDB], + get_user_manager: UserManagerDependency[models.UD], + user_db_model: Type[models.UD], authenticator: Authenticator, state_secret: SecretType, redirect_url: str = None, @@ -83,7 +83,7 @@ def get_oauth_router( request: Request, response: Response, access_token_state=Depends(oauth2_authorize_callback), - user_manager: UserManager[models.BaseUserDB] = Depends(get_user_manager), + user_manager: UserManager[models.UD] = Depends(get_user_manager), ): token, state = access_token_state account_id, account_email = await oauth_client.get_id_email( diff --git a/fastapi_users/router/register.py b/fastapi_users/router/register.py index c4d87631..51ecc5a2 100644 --- a/fastapi_users/router/register.py +++ b/fastapi_users/router/register.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Type, cast +from typing import Callable, Optional, Type from fastapi import APIRouter, Depends, HTTPException, Request, status @@ -13,9 +13,9 @@ from fastapi_users.router.common import ErrorCode, run_handler def get_register_router( - get_user_manager: UserManagerDependency[models.BaseUserDB], - user_model: Type[models.BaseUser], - user_create_model: Type[models.BaseUserCreate], + get_user_manager: UserManagerDependency[models.UD], + user_model: Type[models.U], + user_create_model: Type[models.UC], after_register: Optional[Callable[[models.UD, Request], None]] = None, ) -> APIRouter: """Generate a router with the register route.""" @@ -29,8 +29,6 @@ def get_register_router( user: user_create_model, # type: ignore user_manager: UserManager[models.UD] = Depends(get_user_manager), ): - user = cast(models.BaseUserCreate, user) # Prevent mypy complain - try: created_user = await user_manager.create(user, safe=True) except UserAlreadyExists: diff --git a/fastapi_users/router/reset.py b/fastapi_users/router/reset.py index 21bec218..a148e3cc 100644 --- a/fastapi_users/router/reset.py +++ b/fastapi_users/router/reset.py @@ -20,7 +20,7 @@ RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset" def get_reset_password_router( - get_user_manager: UserManagerDependency[models.BaseUserDB], + get_user_manager: UserManagerDependency[models.UD], reset_password_token_secret: SecretType, reset_password_token_lifetime_seconds: int = 3600, after_forgot_password: Optional[Callable[[models.UD, str, Request], None]] = None, diff --git a/fastapi_users/router/users.py b/fastapi_users/router/users.py index 6415acbf..827d7681 100644 --- a/fastapi_users/router/users.py +++ b/fastapi_users/router/users.py @@ -17,9 +17,9 @@ from fastapi_users.router.common import ErrorCode, run_handler def get_users_router( get_user_manager: UserManagerDependency[models.UD], - user_model: Type[models.BaseUser], - user_update_model: Type[models.BaseUserUpdate], - user_db_model: Type[models.BaseUserDB], + user_model: Type[models.U], + user_update_model: Type[models.UU], + user_db_model: Type[models.UD], authenticator: Authenticator, after_update: Optional[Callable[[models.UD, Dict[str, Any], Request], None]] = None, requires_verification: bool = False, @@ -36,7 +36,7 @@ def get_users_router( async def get_user_or_404( id: UUID4, user_manager: UserManager[models.UD] = Depends(get_user_manager) - ) -> models.BaseUserDB: + ) -> models.UD: try: return await user_manager.get(id) except UserNotExists: diff --git a/fastapi_users/router/verify.py b/fastapi_users/router/verify.py index 3573c91a..0827177a 100644 --- a/fastapi_users/router/verify.py +++ b/fastapi_users/router/verify.py @@ -19,7 +19,7 @@ VERIFY_USER_TOKEN_AUDIENCE = "fastapi-users:verify" def get_verify_router( get_user_manager: UserManagerDependency[models.UD], - user_model: Type[models.BaseUser], + user_model: Type[models.U], verification_token_secret: SecretType, verification_token_lifetime_seconds: int = 3600, after_verification_request: Optional[ diff --git a/tests/test_authentication.py b/tests/test_authentication.py index b994a1cb..a075f17e 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -4,9 +4,9 @@ import pytest from fastapi import Request, status from fastapi.security.base import SecurityBase +from fastapi_users import models from fastapi_users.authentication import BaseAuthentication, DuplicateBackendNamesError from fastapi_users.manager import UserManager -from fastapi_users.models import BaseUserDB class MockSecurityScheme(SecurityBase): @@ -20,20 +20,20 @@ class BackendNone(BaseAuthentication[str]): self.scheme = MockSecurityScheme() async def __call__( - self, credentials: Optional[str], user_manager: UserManager - ) -> Optional[BaseUserDB]: + self, credentials: Optional[str], user_manager: UserManager[models.UD] + ) -> Optional[models.UD]: return None class BackendUser(BaseAuthentication[str]): - def __init__(self, user: BaseUserDB, name="user"): + def __init__(self, user: models.UD, name="user"): super().__init__(name, logout=False) self.scheme = MockSecurityScheme() self.user = user async def __call__( - self, credentials: Optional[str], user_manager: UserManager - ) -> Optional[BaseUserDB]: + self, credentials: Optional[str], user_manager: UserManager[models.UD] + ) -> Optional[models.UD]: return self.user diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index 4e5ce393..d97b6cbb 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -18,7 +18,7 @@ async def test_app_client( get_test_client, validate_password, ) -> AsyncGenerator[httpx.AsyncClient, None]: - fastapi_users = FastAPIUsers( + fastapi_users = FastAPIUsers[User, UserCreate, UserUpdate, UserDB]( get_mock_user_db, [mock_authentication], User,