Improve generic typing

This commit is contained in:
François Voron
2021-09-14 11:53:43 +02:00
parent 90aee2d487
commit fdc8e54253
15 changed files with 60 additions and 61 deletions

View File

@ -5,11 +5,11 @@ from typing import Optional, Sequence
from fastapi import Depends, HTTPException, status
from makefun import with_signature
from fastapi_users import models
from fastapi_users.authentication.base import BaseAuthentication # noqa: F401
from fastapi_users.authentication.cookie import CookieAuthentication # noqa: F401
from fastapi_users.authentication.jwt import JWTAuthentication # noqa: F401
from fastapi_users.manager import UserManager, UserManagerDependency
from fastapi_users.models import BaseUserDB
INVALID_CHARS_PATTERN = re.compile(r"[^0-9a-zA-Z_]")
INVALID_LEADING_CHARS_PATTERN = re.compile(r"^[^a-zA-Z_]+")
@ -43,7 +43,7 @@ class Authenticator:
def __init__(
self,
backends: Sequence[BaseAuthentication],
get_user_manager: UserManagerDependency,
get_user_manager: UserManagerDependency[models.UD],
):
self.backends = backends
self.get_user_manager = get_user_manager
@ -108,14 +108,14 @@ class Authenticator:
async def _authenticate(
self,
*args,
user_manager: UserManager,
user_manager: UserManager[models.UD],
optional: bool = False,
active: bool = False,
verified: bool = False,
superuser: bool = False,
**kwargs
) -> Optional[BaseUserDB]:
user: Optional[BaseUserDB] = None
) -> Optional[models.UD]:
user: Optional[models.UD] = None
for backend in self.backends:
token: str = kwargs[name_to_variable_name(backend.name)]
if token:

View File

@ -3,8 +3,8 @@ from typing import Any, Generic, Optional, TypeVar
from fastapi import Response
from fastapi.security.base import SecurityBase
from fastapi_users import models
from fastapi_users.manager import UserManager
from fastapi_users.models import BaseUserDB
T = TypeVar("T")
@ -28,12 +28,12 @@ class BaseAuthentication(Generic[T]):
self.logout = logout
async def __call__(
self, credentials: Optional[T], user_manager: UserManager
) -> Optional[BaseUserDB]:
self, credentials: Optional[T], user_manager: UserManager[models.UD]
) -> Optional[models.UD]:
raise NotImplementedError()
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
async def get_login_response(self, user: models.UD, response: Response) -> Any:
raise NotImplementedError()
async def get_logout_response(self, user: BaseUserDB, response: Response) -> Any:
async def get_logout_response(self, user: models.UD, response: Response) -> Any:
raise NotImplementedError()

View File

@ -5,10 +5,10 @@ from fastapi import Response
from fastapi.security import APIKeyCookie
from pydantic import UUID4
from fastapi_users import models
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.manager import UserManager, UserNotExists
from fastapi_users.models import BaseUserDB
class CookieAuthentication(BaseAuthentication[str]):
@ -67,8 +67,8 @@ class CookieAuthentication(BaseAuthentication[str]):
async def __call__(
self,
credentials: Optional[str],
user_manager: UserManager,
) -> Optional[BaseUserDB]:
user_manager: UserManager[models.UD],
) -> Optional[models.UD]:
if credentials is None:
return None
@ -88,7 +88,7 @@ class CookieAuthentication(BaseAuthentication[str]):
except UserNotExists:
return None
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
async def get_login_response(self, user: models.UD, response: Response) -> Any:
token = await self._generate_token(user)
response.set_cookie(
self.cookie_name,
@ -105,11 +105,11 @@ class CookieAuthentication(BaseAuthentication[str]):
# so that FastAPI can terminate it properly
return None
async def get_logout_response(self, user: BaseUserDB, response: Response) -> Any:
async def get_logout_response(self, user: models.UD, response: Response) -> Any:
response.delete_cookie(
self.cookie_name, path=self.cookie_path, domain=self.cookie_domain
)
async def _generate_token(self, user: BaseUserDB) -> str:
async def _generate_token(self, user: models.UD) -> str:
data = {"user_id": str(user.id), "aud": self.token_audience}
return generate_jwt(data, self.secret, self.lifetime_seconds)

View File

@ -5,10 +5,10 @@ from fastapi import Response
from fastapi.security import OAuth2PasswordBearer
from pydantic import UUID4
from fastapi_users import models
from fastapi_users.authentication.base import BaseAuthentication
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.manager import UserManager, UserNotExists
from fastapi_users.models import BaseUserDB
class JWTAuthentication(BaseAuthentication[str]):
@ -44,8 +44,8 @@ class JWTAuthentication(BaseAuthentication[str]):
async def __call__(
self,
credentials: Optional[str],
user_manager: UserManager,
) -> Optional[BaseUserDB]:
user_manager: UserManager[models.UD],
) -> Optional[models.UD]:
if credentials is None:
return None
@ -65,10 +65,10 @@ class JWTAuthentication(BaseAuthentication[str]):
except UserNotExists:
return None
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
async def get_login_response(self, user: models.UD, response: Response) -> Any:
token = await self._generate_token(user)
return {"access_token": token, "token_type": "bearer"}
async def _generate_token(self, user: BaseUserDB) -> str:
async def _generate_token(self, user: models.UD) -> str:
data = {"user_id": str(user.id), "aud": self.token_audience}
return generate_jwt(data, self.secret, self.lifetime_seconds)

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Optional, Sequence, Type
from typing import Any, Callable, Dict, Generic, Optional, Sequence, Type
from fastapi import APIRouter, Depends, Request
@ -24,7 +24,7 @@ except ModuleNotFoundError: # pragma: no cover
BaseOAuth2 = Type # type: ignore
class FastAPIUsers:
class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]):
"""
Main object that ties together the component for users authentication.
@ -45,19 +45,19 @@ class FastAPIUsers:
authenticator: Authenticator
validate_password: Optional[ValidatePasswordProtocol]
_user_model: Type[models.BaseUser]
_user_create_model: Type[models.BaseUserCreate]
_user_update_model: Type[models.BaseUserUpdate]
_user_db_model: Type[models.BaseUserDB]
_user_model: Type[models.U]
_user_create_model: Type[models.UC]
_user_update_model: Type[models.UU]
_user_db_model: Type[models.UD]
def __init__(
self,
get_db: UserDatabaseDependency,
get_db: UserDatabaseDependency[models.UD],
auth_backends: Sequence[BaseAuthentication],
user_model: Type[models.BaseUser],
user_create_model: Type[models.BaseUserCreate],
user_update_model: Type[models.BaseUserUpdate],
user_db_model: Type[models.BaseUserDB],
user_model: Type[models.U],
user_create_model: Type[models.UC],
user_update_model: Type[models.UU],
user_db_model: Type[models.UD],
validate_password: Optional[ValidatePasswordProtocol] = None,
):
def get_user_manager(

View File

@ -37,7 +37,7 @@ class InvalidPasswordException(FastAPIUsersException):
class ValidatePasswordProtocol(Protocol): # pragma: no cover
def __call__(
self, password: str, user: Union[models.BaseUserCreate, models.BaseUserDB]
self, password: str, user: Union[models.UC, models.UD]
) -> Awaitable[None]:
pass
@ -82,9 +82,7 @@ class UserManager(Generic[models.UD]):
return user
async def create(
self, user: models.BaseUserCreate, safe: bool = False
) -> models.UD:
async def create(self, user: models.UC, safe: bool = False) -> models.UD:
if self.validate_password:
await self.validate_password(user.password, user)
@ -107,7 +105,7 @@ class UserManager(Generic[models.UD]):
return await self.user_db.update(user)
async def update(
self, updated_user: models.BaseUserUpdate, user: models.UD, safe: bool = False
self, updated_user: models.UU, user: models.UD, safe: bool = False
) -> models.UD:
if safe:
updated_user_data = updated_user.create_update_dict()

View File

@ -54,6 +54,9 @@ class BaseUserDB(BaseUser):
orm_mode = True
U = TypeVar("U", bound=BaseUser)
UC = TypeVar("UC", bound=BaseUserCreate)
UU = TypeVar("UU", bound=BaseUserUpdate)
UD = TypeVar("UD", bound=BaseUserDB)

View File

@ -9,7 +9,7 @@ from fastapi_users.router.common import ErrorCode
def get_auth_router(
backend: BaseAuthentication,
get_user_manager: UserManagerDependency[models.BaseUserDB],
get_user_manager: UserManagerDependency[models.UD],
authenticator: Authenticator,
requires_verification: bool = False,
) -> APIRouter:
@ -23,7 +23,7 @@ def get_auth_router(
async def login(
response: Response,
credentials: OAuth2PasswordRequestForm = Depends(),
user_manager: UserManager[models.BaseUserDB] = Depends(get_user_manager),
user_manager: UserManager[models.UD] = Depends(get_user_manager),
):
user = await user_manager.authenticate(credentials)

View File

@ -24,8 +24,8 @@ def generate_state_token(
def get_oauth_router(
oauth_client: BaseOAuth2,
get_user_manager: UserManagerDependency[models.BaseUserDB],
user_db_model: Type[models.BaseUserDB],
get_user_manager: UserManagerDependency[models.UD],
user_db_model: Type[models.UD],
authenticator: Authenticator,
state_secret: SecretType,
redirect_url: str = None,
@ -83,7 +83,7 @@ def get_oauth_router(
request: Request,
response: Response,
access_token_state=Depends(oauth2_authorize_callback),
user_manager: UserManager[models.BaseUserDB] = Depends(get_user_manager),
user_manager: UserManager[models.UD] = Depends(get_user_manager),
):
token, state = access_token_state
account_id, account_email = await oauth_client.get_id_email(

View File

@ -1,4 +1,4 @@
from typing import Callable, Optional, Type, cast
from typing import Callable, Optional, Type
from fastapi import APIRouter, Depends, HTTPException, Request, status
@ -13,9 +13,9 @@ from fastapi_users.router.common import ErrorCode, run_handler
def get_register_router(
get_user_manager: UserManagerDependency[models.BaseUserDB],
user_model: Type[models.BaseUser],
user_create_model: Type[models.BaseUserCreate],
get_user_manager: UserManagerDependency[models.UD],
user_model: Type[models.U],
user_create_model: Type[models.UC],
after_register: Optional[Callable[[models.UD, Request], None]] = None,
) -> APIRouter:
"""Generate a router with the register route."""
@ -29,8 +29,6 @@ def get_register_router(
user: user_create_model, # type: ignore
user_manager: UserManager[models.UD] = Depends(get_user_manager),
):
user = cast(models.BaseUserCreate, user) # Prevent mypy complain
try:
created_user = await user_manager.create(user, safe=True)
except UserAlreadyExists:

View File

@ -20,7 +20,7 @@ RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset"
def get_reset_password_router(
get_user_manager: UserManagerDependency[models.BaseUserDB],
get_user_manager: UserManagerDependency[models.UD],
reset_password_token_secret: SecretType,
reset_password_token_lifetime_seconds: int = 3600,
after_forgot_password: Optional[Callable[[models.UD, str, Request], None]] = None,

View File

@ -17,9 +17,9 @@ from fastapi_users.router.common import ErrorCode, run_handler
def get_users_router(
get_user_manager: UserManagerDependency[models.UD],
user_model: Type[models.BaseUser],
user_update_model: Type[models.BaseUserUpdate],
user_db_model: Type[models.BaseUserDB],
user_model: Type[models.U],
user_update_model: Type[models.UU],
user_db_model: Type[models.UD],
authenticator: Authenticator,
after_update: Optional[Callable[[models.UD, Dict[str, Any], Request], None]] = None,
requires_verification: bool = False,
@ -36,7 +36,7 @@ def get_users_router(
async def get_user_or_404(
id: UUID4, user_manager: UserManager[models.UD] = Depends(get_user_manager)
) -> models.BaseUserDB:
) -> models.UD:
try:
return await user_manager.get(id)
except UserNotExists:

View File

@ -19,7 +19,7 @@ VERIFY_USER_TOKEN_AUDIENCE = "fastapi-users:verify"
def get_verify_router(
get_user_manager: UserManagerDependency[models.UD],
user_model: Type[models.BaseUser],
user_model: Type[models.U],
verification_token_secret: SecretType,
verification_token_lifetime_seconds: int = 3600,
after_verification_request: Optional[

View File

@ -4,9 +4,9 @@ import pytest
from fastapi import Request, status
from fastapi.security.base import SecurityBase
from fastapi_users import models
from fastapi_users.authentication import BaseAuthentication, DuplicateBackendNamesError
from fastapi_users.manager import UserManager
from fastapi_users.models import BaseUserDB
class MockSecurityScheme(SecurityBase):
@ -20,20 +20,20 @@ class BackendNone(BaseAuthentication[str]):
self.scheme = MockSecurityScheme()
async def __call__(
self, credentials: Optional[str], user_manager: UserManager
) -> Optional[BaseUserDB]:
self, credentials: Optional[str], user_manager: UserManager[models.UD]
) -> Optional[models.UD]:
return None
class BackendUser(BaseAuthentication[str]):
def __init__(self, user: BaseUserDB, name="user"):
def __init__(self, user: models.UD, name="user"):
super().__init__(name, logout=False)
self.scheme = MockSecurityScheme()
self.user = user
async def __call__(
self, credentials: Optional[str], user_manager: UserManager
) -> Optional[BaseUserDB]:
self, credentials: Optional[str], user_manager: UserManager[models.UD]
) -> Optional[models.UD]:
return self.user

View File

@ -18,7 +18,7 @@ async def test_app_client(
get_test_client,
validate_password,
) -> AsyncGenerator[httpx.AsyncClient, None]:
fastapi_users = FastAPIUsers(
fastapi_users = FastAPIUsers[User, UserCreate, UserUpdate, UserDB](
get_mock_user_db,
[mock_authentication],
User,