mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-14 18:58:10 +08:00
Improve generic typing
This commit is contained in:
@ -5,11 +5,11 @@ from typing import Optional, Sequence
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from makefun import with_signature
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication.base import BaseAuthentication # noqa: F401
|
||||
from fastapi_users.authentication.cookie import CookieAuthentication # noqa: F401
|
||||
from fastapi_users.authentication.jwt import JWTAuthentication # noqa: F401
|
||||
from fastapi_users.manager import UserManager, UserManagerDependency
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
INVALID_CHARS_PATTERN = re.compile(r"[^0-9a-zA-Z_]")
|
||||
INVALID_LEADING_CHARS_PATTERN = re.compile(r"^[^a-zA-Z_]+")
|
||||
@ -43,7 +43,7 @@ class Authenticator:
|
||||
def __init__(
|
||||
self,
|
||||
backends: Sequence[BaseAuthentication],
|
||||
get_user_manager: UserManagerDependency,
|
||||
get_user_manager: UserManagerDependency[models.UD],
|
||||
):
|
||||
self.backends = backends
|
||||
self.get_user_manager = get_user_manager
|
||||
@ -108,14 +108,14 @@ class Authenticator:
|
||||
async def _authenticate(
|
||||
self,
|
||||
*args,
|
||||
user_manager: UserManager,
|
||||
user_manager: UserManager[models.UD],
|
||||
optional: bool = False,
|
||||
active: bool = False,
|
||||
verified: bool = False,
|
||||
superuser: bool = False,
|
||||
**kwargs
|
||||
) -> Optional[BaseUserDB]:
|
||||
user: Optional[BaseUserDB] = None
|
||||
) -> Optional[models.UD]:
|
||||
user: Optional[models.UD] = None
|
||||
for backend in self.backends:
|
||||
token: str = kwargs[name_to_variable_name(backend.name)]
|
||||
if token:
|
||||
|
@ -3,8 +3,8 @@ from typing import Any, Generic, Optional, TypeVar
|
||||
from fastapi import Response
|
||||
from fastapi.security.base import SecurityBase
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.manager import UserManager
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@ -28,12 +28,12 @@ class BaseAuthentication(Generic[T]):
|
||||
self.logout = logout
|
||||
|
||||
async def __call__(
|
||||
self, credentials: Optional[T], user_manager: UserManager
|
||||
) -> Optional[BaseUserDB]:
|
||||
self, credentials: Optional[T], user_manager: UserManager[models.UD]
|
||||
) -> Optional[models.UD]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||
async def get_login_response(self, user: models.UD, response: Response) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_logout_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||
async def get_logout_response(self, user: models.UD, response: Response) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
@ -5,10 +5,10 @@ from fastapi import Response
|
||||
from fastapi.security import APIKeyCookie
|
||||
from pydantic import UUID4
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.manager import UserManager, UserNotExists
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class CookieAuthentication(BaseAuthentication[str]):
|
||||
@ -67,8 +67,8 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
async def __call__(
|
||||
self,
|
||||
credentials: Optional[str],
|
||||
user_manager: UserManager,
|
||||
) -> Optional[BaseUserDB]:
|
||||
user_manager: UserManager[models.UD],
|
||||
) -> Optional[models.UD]:
|
||||
if credentials is None:
|
||||
return None
|
||||
|
||||
@ -88,7 +88,7 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
except UserNotExists:
|
||||
return None
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||
async def get_login_response(self, user: models.UD, response: Response) -> Any:
|
||||
token = await self._generate_token(user)
|
||||
response.set_cookie(
|
||||
self.cookie_name,
|
||||
@ -105,11 +105,11 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
# so that FastAPI can terminate it properly
|
||||
return None
|
||||
|
||||
async def get_logout_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||
async def get_logout_response(self, user: models.UD, response: Response) -> Any:
|
||||
response.delete_cookie(
|
||||
self.cookie_name, path=self.cookie_path, domain=self.cookie_domain
|
||||
)
|
||||
|
||||
async def _generate_token(self, user: BaseUserDB) -> str:
|
||||
async def _generate_token(self, user: models.UD) -> str:
|
||||
data = {"user_id": str(user.id), "aud": self.token_audience}
|
||||
return generate_jwt(data, self.secret, self.lifetime_seconds)
|
||||
|
@ -5,10 +5,10 @@ from fastapi import Response
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import UUID4
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication.base import BaseAuthentication
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.manager import UserManager, UserNotExists
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class JWTAuthentication(BaseAuthentication[str]):
|
||||
@ -44,8 +44,8 @@ class JWTAuthentication(BaseAuthentication[str]):
|
||||
async def __call__(
|
||||
self,
|
||||
credentials: Optional[str],
|
||||
user_manager: UserManager,
|
||||
) -> Optional[BaseUserDB]:
|
||||
user_manager: UserManager[models.UD],
|
||||
) -> Optional[models.UD]:
|
||||
if credentials is None:
|
||||
return None
|
||||
|
||||
@ -65,10 +65,10 @@ class JWTAuthentication(BaseAuthentication[str]):
|
||||
except UserNotExists:
|
||||
return None
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||
async def get_login_response(self, user: models.UD, response: Response) -> Any:
|
||||
token = await self._generate_token(user)
|
||||
return {"access_token": token, "token_type": "bearer"}
|
||||
|
||||
async def _generate_token(self, user: BaseUserDB) -> str:
|
||||
async def _generate_token(self, user: models.UD) -> str:
|
||||
data = {"user_id": str(user.id), "aud": self.token_audience}
|
||||
return generate_jwt(data, self.secret, self.lifetime_seconds)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Type
|
||||
from typing import Any, Callable, Dict, Generic, Optional, Sequence, Type
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
@ -24,7 +24,7 @@ except ModuleNotFoundError: # pragma: no cover
|
||||
BaseOAuth2 = Type # type: ignore
|
||||
|
||||
|
||||
class FastAPIUsers:
|
||||
class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]):
|
||||
"""
|
||||
Main object that ties together the component for users authentication.
|
||||
|
||||
@ -45,19 +45,19 @@ class FastAPIUsers:
|
||||
|
||||
authenticator: Authenticator
|
||||
validate_password: Optional[ValidatePasswordProtocol]
|
||||
_user_model: Type[models.BaseUser]
|
||||
_user_create_model: Type[models.BaseUserCreate]
|
||||
_user_update_model: Type[models.BaseUserUpdate]
|
||||
_user_db_model: Type[models.BaseUserDB]
|
||||
_user_model: Type[models.U]
|
||||
_user_create_model: Type[models.UC]
|
||||
_user_update_model: Type[models.UU]
|
||||
_user_db_model: Type[models.UD]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
get_db: UserDatabaseDependency,
|
||||
get_db: UserDatabaseDependency[models.UD],
|
||||
auth_backends: Sequence[BaseAuthentication],
|
||||
user_model: Type[models.BaseUser],
|
||||
user_create_model: Type[models.BaseUserCreate],
|
||||
user_update_model: Type[models.BaseUserUpdate],
|
||||
user_db_model: Type[models.BaseUserDB],
|
||||
user_model: Type[models.U],
|
||||
user_create_model: Type[models.UC],
|
||||
user_update_model: Type[models.UU],
|
||||
user_db_model: Type[models.UD],
|
||||
validate_password: Optional[ValidatePasswordProtocol] = None,
|
||||
):
|
||||
def get_user_manager(
|
||||
|
@ -37,7 +37,7 @@ class InvalidPasswordException(FastAPIUsersException):
|
||||
|
||||
class ValidatePasswordProtocol(Protocol): # pragma: no cover
|
||||
def __call__(
|
||||
self, password: str, user: Union[models.BaseUserCreate, models.BaseUserDB]
|
||||
self, password: str, user: Union[models.UC, models.UD]
|
||||
) -> Awaitable[None]:
|
||||
pass
|
||||
|
||||
@ -82,9 +82,7 @@ class UserManager(Generic[models.UD]):
|
||||
|
||||
return user
|
||||
|
||||
async def create(
|
||||
self, user: models.BaseUserCreate, safe: bool = False
|
||||
) -> models.UD:
|
||||
async def create(self, user: models.UC, safe: bool = False) -> models.UD:
|
||||
if self.validate_password:
|
||||
await self.validate_password(user.password, user)
|
||||
|
||||
@ -107,7 +105,7 @@ class UserManager(Generic[models.UD]):
|
||||
return await self.user_db.update(user)
|
||||
|
||||
async def update(
|
||||
self, updated_user: models.BaseUserUpdate, user: models.UD, safe: bool = False
|
||||
self, updated_user: models.UU, user: models.UD, safe: bool = False
|
||||
) -> models.UD:
|
||||
if safe:
|
||||
updated_user_data = updated_user.create_update_dict()
|
||||
|
@ -54,6 +54,9 @@ class BaseUserDB(BaseUser):
|
||||
orm_mode = True
|
||||
|
||||
|
||||
U = TypeVar("U", bound=BaseUser)
|
||||
UC = TypeVar("UC", bound=BaseUserCreate)
|
||||
UU = TypeVar("UU", bound=BaseUserUpdate)
|
||||
UD = TypeVar("UD", bound=BaseUserDB)
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@ from fastapi_users.router.common import ErrorCode
|
||||
|
||||
def get_auth_router(
|
||||
backend: BaseAuthentication,
|
||||
get_user_manager: UserManagerDependency[models.BaseUserDB],
|
||||
get_user_manager: UserManagerDependency[models.UD],
|
||||
authenticator: Authenticator,
|
||||
requires_verification: bool = False,
|
||||
) -> APIRouter:
|
||||
@ -23,7 +23,7 @@ def get_auth_router(
|
||||
async def login(
|
||||
response: Response,
|
||||
credentials: OAuth2PasswordRequestForm = Depends(),
|
||||
user_manager: UserManager[models.BaseUserDB] = Depends(get_user_manager),
|
||||
user_manager: UserManager[models.UD] = Depends(get_user_manager),
|
||||
):
|
||||
user = await user_manager.authenticate(credentials)
|
||||
|
||||
|
@ -24,8 +24,8 @@ def generate_state_token(
|
||||
|
||||
def get_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
get_user_manager: UserManagerDependency[models.BaseUserDB],
|
||||
user_db_model: Type[models.BaseUserDB],
|
||||
get_user_manager: UserManagerDependency[models.UD],
|
||||
user_db_model: Type[models.UD],
|
||||
authenticator: Authenticator,
|
||||
state_secret: SecretType,
|
||||
redirect_url: str = None,
|
||||
@ -83,7 +83,7 @@ def get_oauth_router(
|
||||
request: Request,
|
||||
response: Response,
|
||||
access_token_state=Depends(oauth2_authorize_callback),
|
||||
user_manager: UserManager[models.BaseUserDB] = Depends(get_user_manager),
|
||||
user_manager: UserManager[models.UD] = Depends(get_user_manager),
|
||||
):
|
||||
token, state = access_token_state
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Callable, Optional, Type, cast
|
||||
from typing import Callable, Optional, Type
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
@ -13,9 +13,9 @@ from fastapi_users.router.common import ErrorCode, run_handler
|
||||
|
||||
|
||||
def get_register_router(
|
||||
get_user_manager: UserManagerDependency[models.BaseUserDB],
|
||||
user_model: Type[models.BaseUser],
|
||||
user_create_model: Type[models.BaseUserCreate],
|
||||
get_user_manager: UserManagerDependency[models.UD],
|
||||
user_model: Type[models.U],
|
||||
user_create_model: Type[models.UC],
|
||||
after_register: Optional[Callable[[models.UD, Request], None]] = None,
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the register route."""
|
||||
@ -29,8 +29,6 @@ def get_register_router(
|
||||
user: user_create_model, # type: ignore
|
||||
user_manager: UserManager[models.UD] = Depends(get_user_manager),
|
||||
):
|
||||
user = cast(models.BaseUserCreate, user) # Prevent mypy complain
|
||||
|
||||
try:
|
||||
created_user = await user_manager.create(user, safe=True)
|
||||
except UserAlreadyExists:
|
||||
|
@ -20,7 +20,7 @@ RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset"
|
||||
|
||||
|
||||
def get_reset_password_router(
|
||||
get_user_manager: UserManagerDependency[models.BaseUserDB],
|
||||
get_user_manager: UserManagerDependency[models.UD],
|
||||
reset_password_token_secret: SecretType,
|
||||
reset_password_token_lifetime_seconds: int = 3600,
|
||||
after_forgot_password: Optional[Callable[[models.UD, str, Request], None]] = None,
|
||||
|
@ -17,9 +17,9 @@ from fastapi_users.router.common import ErrorCode, run_handler
|
||||
|
||||
def get_users_router(
|
||||
get_user_manager: UserManagerDependency[models.UD],
|
||||
user_model: Type[models.BaseUser],
|
||||
user_update_model: Type[models.BaseUserUpdate],
|
||||
user_db_model: Type[models.BaseUserDB],
|
||||
user_model: Type[models.U],
|
||||
user_update_model: Type[models.UU],
|
||||
user_db_model: Type[models.UD],
|
||||
authenticator: Authenticator,
|
||||
after_update: Optional[Callable[[models.UD, Dict[str, Any], Request], None]] = None,
|
||||
requires_verification: bool = False,
|
||||
@ -36,7 +36,7 @@ def get_users_router(
|
||||
|
||||
async def get_user_or_404(
|
||||
id: UUID4, user_manager: UserManager[models.UD] = Depends(get_user_manager)
|
||||
) -> models.BaseUserDB:
|
||||
) -> models.UD:
|
||||
try:
|
||||
return await user_manager.get(id)
|
||||
except UserNotExists:
|
||||
|
@ -19,7 +19,7 @@ VERIFY_USER_TOKEN_AUDIENCE = "fastapi-users:verify"
|
||||
|
||||
def get_verify_router(
|
||||
get_user_manager: UserManagerDependency[models.UD],
|
||||
user_model: Type[models.BaseUser],
|
||||
user_model: Type[models.U],
|
||||
verification_token_secret: SecretType,
|
||||
verification_token_lifetime_seconds: int = 3600,
|
||||
after_verification_request: Optional[
|
||||
|
@ -4,9 +4,9 @@ import pytest
|
||||
from fastapi import Request, status
|
||||
from fastapi.security.base import SecurityBase
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication import BaseAuthentication, DuplicateBackendNamesError
|
||||
from fastapi_users.manager import UserManager
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class MockSecurityScheme(SecurityBase):
|
||||
@ -20,20 +20,20 @@ class BackendNone(BaseAuthentication[str]):
|
||||
self.scheme = MockSecurityScheme()
|
||||
|
||||
async def __call__(
|
||||
self, credentials: Optional[str], user_manager: UserManager
|
||||
) -> Optional[BaseUserDB]:
|
||||
self, credentials: Optional[str], user_manager: UserManager[models.UD]
|
||||
) -> Optional[models.UD]:
|
||||
return None
|
||||
|
||||
|
||||
class BackendUser(BaseAuthentication[str]):
|
||||
def __init__(self, user: BaseUserDB, name="user"):
|
||||
def __init__(self, user: models.UD, name="user"):
|
||||
super().__init__(name, logout=False)
|
||||
self.scheme = MockSecurityScheme()
|
||||
self.user = user
|
||||
|
||||
async def __call__(
|
||||
self, credentials: Optional[str], user_manager: UserManager
|
||||
) -> Optional[BaseUserDB]:
|
||||
self, credentials: Optional[str], user_manager: UserManager[models.UD]
|
||||
) -> Optional[models.UD]:
|
||||
return self.user
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ async def test_app_client(
|
||||
get_test_client,
|
||||
validate_password,
|
||||
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
fastapi_users = FastAPIUsers(
|
||||
fastapi_users = FastAPIUsers[User, UserCreate, UserUpdate, UserDB](
|
||||
get_mock_user_db,
|
||||
[mock_authentication],
|
||||
User,
|
||||
|
Reference in New Issue
Block a user