From abfa9a1c47a786eddf84a297667ee96ca0f2aca9 Mon Sep 17 00:00:00 2001 From: Alexander Zinov <33320473+sashkent3@users.noreply.github.com> Date: Sun, 14 Jul 2024 17:04:13 +0400 Subject: [PATCH] Improve type hints (#1401) * Add type parameters to `AuthenticationBackend` * add more type-hints --- examples/beanie-oauth/app/users.py | 4 ++-- examples/sqlalchemy-oauth/app/users.py | 4 ++-- examples/sqlalchemy/app/users.py | 4 ++-- fastapi_users/authentication/authenticator.py | 20 +++++++++---------- fastapi_users/fastapi_users.py | 8 ++++---- fastapi_users/router/auth.py | 4 ++-- fastapi_users/router/oauth.py | 4 ++-- fastapi_users/router/users.py | 2 +- 8 files changed, 25 insertions(+), 25 deletions(-) diff --git a/examples/beanie-oauth/app/users.py b/examples/beanie-oauth/app/users.py index 81ad5ca4..cd7097c4 100644 --- a/examples/beanie-oauth/app/users.py +++ b/examples/beanie-oauth/app/users.py @@ -3,7 +3,7 @@ from typing import Optional from beanie import PydanticObjectId from fastapi import Depends, Request -from fastapi_users import BaseUserManager, FastAPIUsers +from fastapi_users import BaseUserManager, FastAPIUsers, models from fastapi_users.authentication import ( AuthenticationBackend, BearerTransport, @@ -47,7 +47,7 @@ async def get_user_manager(user_db: BeanieUserDatabase = Depends(get_user_db)): bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") -def get_jwt_strategy() -> JWTStrategy: +def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: return JWTStrategy(secret=SECRET, lifetime_seconds=3600) diff --git a/examples/sqlalchemy-oauth/app/users.py b/examples/sqlalchemy-oauth/app/users.py index 0b61b2a5..a7337e7f 100644 --- a/examples/sqlalchemy-oauth/app/users.py +++ b/examples/sqlalchemy-oauth/app/users.py @@ -3,7 +3,7 @@ import uuid from typing import Optional from fastapi import Depends, Request -from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin +from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models from fastapi_users.authentication import ( AuthenticationBackend, BearerTransport, @@ -47,7 +47,7 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") -def get_jwt_strategy() -> JWTStrategy: +def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: return JWTStrategy(secret=SECRET, lifetime_seconds=3600) diff --git a/examples/sqlalchemy/app/users.py b/examples/sqlalchemy/app/users.py index 479c49e2..f37f0ac3 100644 --- a/examples/sqlalchemy/app/users.py +++ b/examples/sqlalchemy/app/users.py @@ -2,7 +2,7 @@ import uuid from typing import Optional from fastapi import Depends, Request -from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin +from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models from fastapi_users.authentication import ( AuthenticationBackend, BearerTransport, @@ -40,7 +40,7 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") -def get_jwt_strategy() -> JWTStrategy: +def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: return JWTStrategy(secret=SECRET, lifetime_seconds=3600) diff --git a/fastapi_users/authentication/authenticator.py b/fastapi_users/authentication/authenticator.py index 7fab4b78..6117765f 100644 --- a/fastapi_users/authentication/authenticator.py +++ b/fastapi_users/authentication/authenticator.py @@ -1,6 +1,6 @@ import re from inspect import Parameter, Signature -from typing import Callable, List, Optional, Sequence, Tuple, cast +from typing import Any, Callable, Generic, List, Optional, Sequence, Tuple, cast from fastapi import Depends, HTTPException, status from makefun import with_signature @@ -31,10 +31,10 @@ class DuplicateBackendNamesError(Exception): pass -EnabledBackendsDependency = DependencyCallable[Sequence[AuthenticationBackend]] +EnabledBackendsDependency = DependencyCallable[Sequence[AuthenticationBackend[models.UP, models.ID]]] -class Authenticator: +class Authenticator(Generic[models.UP, models.ID]): """ Provides dependency callables to retrieve authenticated user. @@ -46,11 +46,11 @@ class Authenticator: :param get_user_manager: User manager dependency callable. """ - backends: Sequence[AuthenticationBackend] + backends: Sequence[AuthenticationBackend[models.UP, models.ID]] def __init__( self, - backends: Sequence[AuthenticationBackend], + backends: Sequence[AuthenticationBackend[models.UP, models.ID]], get_user_manager: UserManagerDependency[models.UP, models.ID], ): self.backends = backends @@ -62,7 +62,7 @@ class Authenticator: active: bool = False, verified: bool = False, superuser: bool = False, - get_enabled_backends: Optional[EnabledBackendsDependency] = None, + get_enabled_backends: Optional[EnabledBackendsDependency[models.UP, models.ID]] = None, ): """ Return a dependency callable to retrieve currently authenticated user and token. @@ -88,7 +88,7 @@ class Authenticator: signature = self._get_dependency_signature(get_enabled_backends) @with_signature(signature) - async def current_user_token_dependency(*args, **kwargs): + async def current_user_token_dependency(*args: Any, **kwargs: Any): return await self._authenticate( *args, optional=optional, @@ -106,7 +106,7 @@ class Authenticator: active: bool = False, verified: bool = False, superuser: bool = False, - get_enabled_backends: Optional[EnabledBackendsDependency] = None, + get_enabled_backends: Optional[EnabledBackendsDependency[models.UP, models.ID]] = None, ): """ Return a dependency callable to retrieve currently authenticated user. @@ -132,7 +132,7 @@ class Authenticator: signature = self._get_dependency_signature(get_enabled_backends) @with_signature(signature) - async def current_user_dependency(*args, **kwargs): + async def current_user_dependency(*args: Any, **kwargs: Any): user, _ = await self._authenticate( *args, optional=optional, @@ -157,7 +157,7 @@ class Authenticator: ) -> Tuple[Optional[models.UP], Optional[str]]: user: Optional[models.UP] = None token: Optional[str] = None - enabled_backends: Sequence[AuthenticationBackend] = kwargs.get( + enabled_backends: Sequence[AuthenticationBackend[models.UP, models.ID]] = kwargs.get( "enabled_backends", self.backends ) for backend in self.backends: diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index 10ca1465..85edb1bd 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -35,12 +35,12 @@ class FastAPIUsers(Generic[models.UP, models.ID]): with a specific set of parameters. """ - authenticator: Authenticator + authenticator: Authenticator[models.UP, models.ID] def __init__( self, get_user_manager: UserManagerDependency[models.UP, models.ID], - auth_backends: Sequence[AuthenticationBackend], + auth_backends: Sequence[AuthenticationBackend[models.UP, models.ID]], ): self.authenticator = Authenticator(auth_backends, get_user_manager) self.get_user_manager = get_user_manager @@ -72,7 +72,7 @@ class FastAPIUsers(Generic[models.UP, models.ID]): return get_reset_password_router(self.get_user_manager) def get_auth_router( - self, backend: AuthenticationBackend, requires_verification: bool = False + self, backend: AuthenticationBackend[models.UP, models.ID], requires_verification: bool = False ) -> APIRouter: """ Return an auth router for a given authentication backend. @@ -91,7 +91,7 @@ class FastAPIUsers(Generic[models.UP, models.ID]): def get_oauth_router( self, oauth_client: BaseOAuth2, - backend: AuthenticationBackend, + backend: AuthenticationBackend[models.UP, models.ID], state_secret: SecretType, redirect_url: Optional[str] = None, associate_by_email: bool = False, diff --git a/fastapi_users/router/auth.py b/fastapi_users/router/auth.py index c61770f0..57f397d0 100644 --- a/fastapi_users/router/auth.py +++ b/fastapi_users/router/auth.py @@ -11,9 +11,9 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_auth_router( - backend: AuthenticationBackend, + backend: AuthenticationBackend[models.UP, models.ID], get_user_manager: UserManagerDependency[models.UP, models.ID], - authenticator: Authenticator, + authenticator: Authenticator[models.UP, models.ID], requires_verification: bool = False, ) -> APIRouter: """Generate a router with login/logout routes for an authentication backend.""" diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index 9300c603..12cdf325 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -29,7 +29,7 @@ def generate_state_token( def get_oauth_router( oauth_client: BaseOAuth2, - backend: AuthenticationBackend, + backend: AuthenticationBackend[models.UP, models.ID], get_user_manager: UserManagerDependency[models.UP, models.ID], state_secret: SecretType, redirect_url: Optional[str] = None, @@ -156,7 +156,7 @@ def get_oauth_router( def get_oauth_associate_router( oauth_client: BaseOAuth2, - authenticator: Authenticator, + authenticator: Authenticator[models.UP, models.ID], get_user_manager: UserManagerDependency[models.UP, models.ID], user_schema: Type[schemas.U], state_secret: SecretType, diff --git a/fastapi_users/router/users.py b/fastapi_users/router/users.py index b3cc4351..179230aa 100644 --- a/fastapi_users/router/users.py +++ b/fastapi_users/router/users.py @@ -12,7 +12,7 @@ def get_users_router( get_user_manager: UserManagerDependency[models.UP, models.ID], user_schema: Type[schemas.U], user_update_schema: Type[schemas.UU], - authenticator: Authenticator, + authenticator: Authenticator[models.UP, models.ID], requires_verification: bool = False, ) -> APIRouter: """Generate a router with the authentication routes."""