Native model and generic ID (#971)

* Use a generic Protocol model for User instead of Pydantic

* Remove UserDB Pydantic schema

* Harmonize schema variable naming to avoid confusions

* Revamp OAuth account model management

* Revamp AccessToken DB strategy to adopt generic model approach

* Make ID a generic instead of forcing UUIDs

* Improve generic typing

* Improve Strategy typing

* Tweak base DB typing

* Don't set Pydantic schemas on FastAPIUsers class: pass it directly on router creation

* Add IntegerIdMixin and export related classes

* Start to revamp doc for V10

* Revamp OAuth documentation

* Fix code highlights

* Write the 9.x.x ➡️ 10.x.x migration doc

* Fix pyproject.toml
This commit is contained in:
François Voron
2022-05-05 14:51:19 +02:00
committed by GitHub
parent b7734fc8b0
commit 72aa68c462
124 changed files with 2144 additions and 2114 deletions

View File

@ -2,16 +2,22 @@
__version__ = "9.3.2"
from fastapi_users import models # noqa: F401
from fastapi_users import models, schemas # noqa: F401
from fastapi_users.fastapi_users import FastAPIUsers # noqa: F401
from fastapi_users.manager import ( # noqa: F401
BaseUserManager,
IntegerIDMixin,
InvalidID,
InvalidPasswordException,
UUIDIDMixin,
)
__all__ = [
"models",
"schemas",
"FastAPIUsers",
"BaseUserManager",
"InvalidPasswordException",
"InvalidID",
"UUIDIDMixin",
"IntegerIDMixin",
]

View File

@ -51,7 +51,7 @@ class Authenticator:
def __init__(
self,
backends: Sequence[AuthenticationBackend],
get_user_manager: UserManagerDependency[models.UC, models.UD],
get_user_manager: UserManagerDependency[models.UP, models.ID],
):
self.backends = backends
self.get_user_manager = get_user_manager
@ -148,14 +148,14 @@ class Authenticator:
async def _authenticate(
self,
*args,
user_manager: BaseUserManager[models.UC, models.UD],
user_manager: BaseUserManager[models.UP, models.ID],
optional: bool = False,
active: bool = False,
verified: bool = False,
superuser: bool = False,
**kwargs,
) -> Tuple[Optional[models.UD], Optional[str]]:
user: Optional[models.UD] = None
) -> Tuple[Optional[models.UP], Optional[str]]:
user: Optional[models.UP] = None
token: Optional[str] = None
enabled_backends: Sequence[AuthenticationBackend] = kwargs.get(
"enabled_backends", self.backends
@ -163,7 +163,7 @@ class Authenticator:
for backend in self.backends:
if backend in enabled_backends:
token = kwargs[name_to_variable_name(backend.name)]
strategy: Strategy[models.UC, models.UD] = kwargs[
strategy: Strategy[models.UP, models.ID] = kwargs[
name_to_strategy_variable_name(backend.name)
]
if token is not None:

View File

@ -14,7 +14,7 @@ from fastapi_users.authentication.transport import (
from fastapi_users.types import DependencyCallable
class AuthenticationBackend(Generic[models.UC, models.UD]):
class AuthenticationBackend(Generic[models.UP]):
"""
Combination of an authentication transport and strategy.
@ -33,7 +33,7 @@ class AuthenticationBackend(Generic[models.UC, models.UD]):
self,
name: str,
transport: Transport,
get_strategy: DependencyCallable[Strategy[models.UC, models.UD]],
get_strategy: DependencyCallable[Strategy[models.UP, models.ID]],
):
self.name = name
self.transport = transport
@ -41,8 +41,8 @@ class AuthenticationBackend(Generic[models.UC, models.UD]):
async def login(
self,
strategy: Strategy[models.UC, models.UD],
user: models.UD,
strategy: Strategy[models.UP, models.ID],
user: models.UP,
response: Response,
) -> Any:
token = await strategy.write_token(user)
@ -50,8 +50,8 @@ class AuthenticationBackend(Generic[models.UC, models.UD]):
async def logout(
self,
strategy: Strategy[models.UC, models.UD],
user: models.UD,
strategy: Strategy[models.UP, models.ID],
user: models.UP,
token: str,
response: Response,
) -> Any:

View File

@ -3,9 +3,9 @@ from fastapi_users.authentication.strategy.base import (
StrategyDestroyNotSupportedError,
)
from fastapi_users.authentication.strategy.db import (
A,
AP,
AccessTokenDatabase,
BaseAccessToken,
AccessTokenProtocol,
DatabaseStrategy,
)
from fastapi_users.authentication.strategy.jwt import JWTStrategy
@ -16,9 +16,9 @@ except ImportError: # pragma: no cover
pass
__all__ = [
"A",
"AP",
"AccessTokenDatabase",
"BaseAccessToken",
"AccessTokenProtocol",
"DatabaseStrategy",
"JWTStrategy",
"Strategy",

View File

@ -14,14 +14,14 @@ class StrategyDestroyNotSupportedError(Exception):
pass
class Strategy(Protocol, Generic[models.UC, models.UD]):
class Strategy(Protocol, Generic[models.UP, models.ID]):
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD]
) -> Optional[models.UD]:
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
... # pragma: no cover
async def write_token(self, user: models.UD) -> str:
async def write_token(self, user: models.UP) -> str:
... # pragma: no cover
async def destroy_token(self, token: str, user: models.UD) -> None:
async def destroy_token(self, token: str, user: models.UP) -> None:
... # pragma: no cover

View File

@ -1,5 +1,5 @@
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
from fastapi_users.authentication.strategy.db.models import A, BaseAccessToken
from fastapi_users.authentication.strategy.db.models import AP, AccessTokenProtocol
from fastapi_users.authentication.strategy.db.strategy import DatabaseStrategy
__all__ = ["A", "AccessTokenDatabase", "BaseAccessToken", "DatabaseStrategy"]
__all__ = ["AP", "AccessTokenDatabase", "AccessTokenProtocol", "DatabaseStrategy"]

View File

@ -1,38 +1,32 @@
import sys
from datetime import datetime
from typing import Generic, Optional, Type
from typing import Any, Dict, Generic, Optional
if sys.version_info < (3, 8):
from typing_extensions import Protocol # pragma: no cover
else:
from typing import Protocol # pragma: no cover
from fastapi_users.authentication.strategy.db.models import A
from fastapi_users.authentication.strategy.db.models import AP
class AccessTokenDatabase(Protocol, Generic[A]):
"""
Protocol for retrieving, creating and updating access tokens from a database.
:param access_token_model: Pydantic model of an access token.
"""
access_token_model: Type[A]
class AccessTokenDatabase(Protocol, Generic[AP]):
"""Protocol for retrieving, creating and updating access tokens from a database."""
async def get_by_token(
self, token: str, max_age: Optional[datetime] = None
) -> Optional[A]:
) -> Optional[AP]:
"""Get a single access token by token."""
... # pragma: no cover
async def create(self, access_token: A) -> A:
async def create(self, create_dict: Dict[str, Any]) -> AP:
"""Create an access token."""
... # pragma: no cover
async def update(self, access_token: A) -> A:
async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP:
"""Update an access token."""
... # pragma: no cover
async def delete(self, access_token: A) -> None:
async def delete(self, access_token: AP) -> None:
"""Delete an access token."""
... # pragma: no cover

View File

@ -1,22 +1,24 @@
from datetime import datetime, timezone
import sys
from datetime import datetime
from typing import TypeVar
from pydantic import UUID4, BaseModel, Field
if sys.version_info < (3, 8):
from typing_extensions import Protocol # pragma: no cover
else:
from typing import Protocol # pragma: no cover
from fastapi_users import models
def now_utc():
return datetime.now(timezone.utc)
class BaseAccessToken(BaseModel):
"""Base access token model."""
class AccessTokenProtocol(Protocol[models.ID]):
"""Access token protocol that ORM model should follow."""
token: str
user_id: UUID4
created_at: datetime = Field(default_factory=now_utc)
user_id: models.ID
created_at: datetime
class Config:
orm_mode = True
def __init__(self, *args, **kwargs) -> None:
... # pragma: no cover
A = TypeVar("A", bound=BaseAccessToken)
AP = TypeVar("AP", bound=AccessTokenProtocol)

View File

@ -1,24 +1,26 @@
import secrets
from datetime import datetime, timedelta, timezone
from typing import Generic, Optional
from typing import Any, Dict, Generic, Optional
from fastapi_users import models
from fastapi_users.authentication.strategy.base import Strategy
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
from fastapi_users.authentication.strategy.db.models import A
from fastapi_users.manager import BaseUserManager, UserNotExists
from fastapi_users.authentication.strategy.db.models import AP
from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists
class DatabaseStrategy(Strategy, Generic[models.UC, models.UD, A]):
class DatabaseStrategy(
Strategy[models.UP, models.ID], Generic[models.UP, models.ID, AP]
):
def __init__(
self, database: AccessTokenDatabase[A], lifetime_seconds: Optional[int] = None
self, database: AccessTokenDatabase[AP], lifetime_seconds: Optional[int] = None
):
self.database = database
self.lifetime_seconds = lifetime_seconds
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD]
) -> Optional[models.UD]:
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
if token is None:
return None
@ -33,21 +35,21 @@ class DatabaseStrategy(Strategy, Generic[models.UC, models.UD, A]):
return None
try:
user_id = access_token.user_id
return await user_manager.get(user_id)
except UserNotExists:
parsed_id = user_manager.parse_id(access_token.user_id)
return await user_manager.get(parsed_id)
except (UserNotExists, InvalidID):
return None
async def write_token(self, user: models.UD) -> str:
access_token = self._create_access_token(user)
await self.database.create(access_token)
async def write_token(self, user: models.UP) -> str:
access_token_dict = self._create_access_token_dict(user)
access_token = await self.database.create(access_token_dict)
return access_token.token
async def destroy_token(self, token: str, user: models.UD) -> None:
async def destroy_token(self, token: str, user: models.UP) -> None:
access_token = await self.database.get_by_token(token)
if access_token is not None:
await self.database.delete(access_token)
def _create_access_token(self, user: models.UD) -> A:
def _create_access_token_dict(self, user: models.UP) -> Dict[str, Any]:
token = secrets.token_urlsafe()
return self.database.access_token_model(token=token, user_id=user.id)
return {"token": token, "user_id": user.id}

View File

@ -1,7 +1,6 @@
from typing import Generic, List, Optional
import jwt
from pydantic import UUID4
from fastapi_users import models
from fastapi_users.authentication.strategy.base import (
@ -9,10 +8,10 @@ from fastapi_users.authentication.strategy.base import (
StrategyDestroyNotSupportedError,
)
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.manager import BaseUserManager, UserNotExists
from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists
class JWTStrategy(Strategy, Generic[models.UC, models.UD]):
class JWTStrategy(Strategy[models.UP, models.ID], Generic[models.UP, models.ID]):
def __init__(
self,
secret: SecretType,
@ -36,8 +35,8 @@ class JWTStrategy(Strategy, Generic[models.UC, models.UD]):
return self.public_key or self.secret
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD]
) -> Optional[models.UD]:
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
if token is None:
return None
@ -52,20 +51,18 @@ class JWTStrategy(Strategy, Generic[models.UC, models.UD]):
return None
try:
user_uiid = UUID4(user_id)
return await user_manager.get(user_uiid)
except ValueError:
return None
except UserNotExists:
parsed_id = user_manager.parse_id(user_id)
return await user_manager.get(parsed_id)
except (UserNotExists, InvalidID):
return None
async def write_token(self, user: models.UD) -> str:
async def write_token(self, user: models.UP) -> str:
data = {"user_id": str(user.id), "aud": self.token_audience}
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
)
async def destroy_token(self, token: str, user: models.UD) -> None:
async def destroy_token(self, token: str, user: models.UP) -> None:
raise StrategyDestroyNotSupportedError(
"A JWT can't be invalidated: it's valid until it expires."
)

View File

@ -2,21 +2,20 @@ import secrets
from typing import Generic, Optional
import aioredis
from pydantic import UUID4
from fastapi_users import models
from fastapi_users.authentication.strategy.base import Strategy
from fastapi_users.manager import BaseUserManager, UserNotExists
from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists
class RedisStrategy(Strategy, Generic[models.UC, models.UD]):
class RedisStrategy(Strategy[models.UP, models.ID], Generic[models.UP, models.ID]):
def __init__(self, redis: aioredis.Redis, lifetime_seconds: Optional[int] = None):
self.redis = redis
self.lifetime_seconds = lifetime_seconds
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD]
) -> Optional[models.UD]:
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
if token is None:
return None
@ -25,17 +24,15 @@ class RedisStrategy(Strategy, Generic[models.UC, models.UD]):
return None
try:
user_uiid = UUID4(user_id)
return await user_manager.get(user_uiid)
except ValueError:
return None
except UserNotExists:
parsed_id = user_manager.parse_id(user_id)
return await user_manager.get(parsed_id)
except (UserNotExists, InvalidID):
return None
async def write_token(self, user: models.UD) -> str:
async def write_token(self, user: models.UP) -> str:
token = secrets.token_urlsafe()
await self.redis.set(token, str(user.id), ex=self.lifetime_seconds)
return token
async def destroy_token(self, token: str, user: models.UD) -> None:
async def destroy_token(self, token: str, user: models.UP) -> None:
await self.redis.delete(token)

View File

@ -1,52 +1,36 @@
from fastapi_users.db.base import BaseUserDatabase, UserDatabaseDependency
__all__ = [
"BaseUserDatabase",
"UserDatabaseDependency",
]
__all__ = ["BaseUserDatabase", "UserDatabaseDependency"]
try: # pragma: no cover
from fastapi_users_db_mongodb import MongoDBUserDatabase # noqa: F401
__all__.append("MongoDBUserDatabase")
except ImportError: # pragma: no cover
pass
try: # pragma: no cover
from fastapi_users_db_sqlalchemy import ( # noqa: F401
SQLAlchemyBaseOAuthAccountTable,
SQLAlchemyBaseOAuthAccountTableUUID,
SQLAlchemyBaseUserTable,
SQLAlchemyBaseUserTableUUID,
SQLAlchemyUserDatabase,
)
__all__.append("SQLAlchemyBaseOAuthAccountTable")
__all__.append("SQLAlchemyBaseUserTable")
__all__.append("SQLAlchemyBaseUserTableUUID")
__all__.append("SQLAlchemyBaseOAuthAccountTable")
__all__.append("SQLAlchemyBaseOAuthAccountTableUUID")
__all__.append("SQLAlchemyUserDatabase")
except ImportError: # pragma: no cover
pass
try: # pragma: no cover
from fastapi_users_db_tortoise import ( # noqa: F401
TortoiseBaseOAuthAccountModel,
TortoiseBaseUserModel,
TortoiseUserDatabase,
from fastapi_users_db_beanie import ( # noqa: F401
BaseOAuthAccount,
BeanieBaseUser,
BeanieUserDatabase,
ObjectIDIDMixin,
)
__all__.append("TortoiseBaseOAuthAccountModel")
__all__.append("TortoiseBaseUserModel")
__all__.append("TortoiseUserDatabase")
except ImportError: # pragma: no cover
pass
try: # pragma: no cover
from fastapi_users_db_ormar import ( # noqa: F401
OrmarBaseOAuthAccountModel,
OrmarBaseUserModel,
OrmarUserDatabase,
)
__all__.append("OrmarBaseOAuthAccountModel")
__all__.append("OrmarBaseUserModel")
__all__.append("OrmarUserDatabase")
__all__.append("BeanieBaseUser")
__all__.append("BaseOAuthAccount")
__all__.append("BeanieUserDatabase")
__all__.append("ObjectIDIDMixin")
except ImportError: # pragma: no cover
pass

View File

@ -1,46 +1,50 @@
from typing import Generic, Optional, Type
from typing import Any, Dict, Generic, Optional
from pydantic import UUID4
from fastapi_users.models import UD
from fastapi_users.models import ID, OAP, UOAP, UP
from fastapi_users.types import DependencyCallable
class BaseUserDatabase(Generic[UD]):
"""
Base adapter for retrieving, creating and updating users from a database.
class BaseUserDatabase(Generic[UP, ID]):
"""Base adapter for retrieving, creating and updating users from a database."""
:param user_db_model: Pydantic model of a DB representation of a user.
"""
user_db_model: Type[UD]
def __init__(self, user_db_model: Type[UD]):
self.user_db_model = user_db_model
async def get(self, id: UUID4) -> Optional[UD]:
async def get(self, id: ID) -> Optional[UP]:
"""Get a single user by id."""
raise NotImplementedError()
async def get_by_email(self, email: str) -> Optional[UD]:
async def get_by_email(self, email: str) -> Optional[UP]:
"""Get a single user by email."""
raise NotImplementedError()
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]:
"""Get a single user by OAuth account id."""
raise NotImplementedError()
async def create(self, user: UD) -> UD:
async def create(self, create_dict: Dict[str, Any]) -> UP:
"""Create a user."""
raise NotImplementedError()
async def update(self, user: UD) -> UD:
async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP:
"""Update a user."""
raise NotImplementedError()
async def delete(self, user: UD) -> None:
async def delete(self, user: UP) -> None:
"""Delete a user."""
raise NotImplementedError()
async def add_oauth_account(
self: "BaseUserDatabase[UOAP, ID]", user: UOAP, create_dict: Dict[str, Any]
) -> UOAP:
"""Create an OAuth account and add it to the user."""
raise NotImplementedError()
UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UD]]
async def update_oauth_account(
self: "BaseUserDatabase[UOAP, ID]",
user: UOAP,
oauth_account: OAP,
update_dict: Dict[str, Any],
) -> UOAP:
"""Update an OAuth account on a user."""
raise NotImplementedError()
UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UP, ID]]

View File

@ -2,7 +2,7 @@ from typing import Generic, Sequence, Type
from fastapi import APIRouter
from fastapi_users import models
from fastapi_users import models, schemas
from fastapi_users.authentication import AuthenticationBackend, Authenticator
from fastapi_users.jwt import SecretType
from fastapi_users.manager import UserManagerDependency
@ -22,58 +22,49 @@ except ModuleNotFoundError: # pragma: no cover
BaseOAuth2 = Type # type: ignore
class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]):
class FastAPIUsers(Generic[models.UP, models.ID]):
"""
Main object that ties together the component for users authentication.
:param get_user_manager: Dependency callable getter to inject the
user manager class instance.
:param auth_backends: List of authentication backends.
:param user_model: Pydantic model of a user.
:param user_create_model: Pydantic model for creating a user.
:param user_update_model: Pydantic model for updating a user.
:param user_db_model: Pydantic model of a DB representation of a user.
:attribute current_user: Dependency callable getter to inject authenticated user
with a specific set of parameters.
"""
authenticator: Authenticator
_user_model: Type[models.U]
_user_create_model: Type[models.UC]
_user_update_model: Type[models.UU]
_user_db_model: Type[models.UD]
def __init__(
self,
get_user_manager: UserManagerDependency[models.UC, models.UD],
get_user_manager: UserManagerDependency[models.UP, models.ID],
auth_backends: Sequence[AuthenticationBackend],
user_model: Type[models.U],
user_create_model: Type[models.UC],
user_update_model: Type[models.UU],
user_db_model: Type[models.UD],
):
self.authenticator = Authenticator(auth_backends, get_user_manager)
self._user_model = user_model
self._user_db_model = user_db_model
self._user_create_model = user_create_model
self._user_update_model = user_update_model
self.get_user_manager = get_user_manager
self.current_user = self.authenticator.current_user
def get_register_router(self) -> APIRouter:
"""Return a router with a register route."""
def get_register_router(
self, user_schema: Type[schemas.U], user_create_schema: Type[schemas.UC]
) -> APIRouter:
"""
Return a router with a register route.
:param user_schema: Pydantic schema of a public user.
:param user_create_schema: Pydantic schema for creating a user.
"""
return get_register_router(
self.get_user_manager,
self._user_model,
self._user_create_model,
self.get_user_manager, user_schema, user_create_schema
)
def get_verify_router(self) -> APIRouter:
"""Return a router with e-mail verification routes."""
return get_verify_router(self.get_user_manager, self._user_model)
def get_verify_router(self, user_schema: Type[schemas.U]) -> APIRouter:
"""
Return a router with e-mail verification routes.
:param user_schema: Pydantic schema of a public user.
"""
return get_verify_router(self.get_user_manager, user_schema)
def get_reset_password_router(self) -> APIRouter:
"""Return a reset password process router."""
@ -122,19 +113,22 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]):
def get_users_router(
self,
user_schema: Type[schemas.U],
user_update_schema: Type[schemas.UU],
requires_verification: bool = False,
) -> APIRouter:
"""
Return a router with routes to manage users.
:param user_schema: Pydantic schema of a public user.
:param user_update_schema: Pydantic schema for updating a user.
:param requires_verification: Whether the endpoints
require the users to be verified or not.
"""
return get_users_router(
self.get_user_manager,
self._user_model,
self._user_update_model,
self._user_db_model,
user_schema,
user_update_schema,
self.authenticator,
requires_verification,
)

View File

@ -1,11 +1,11 @@
from typing import Any, Dict, Generic, Optional, Type, Union
import uuid
from typing import Any, Dict, Generic, Optional, Union
import jwt
from fastapi import Request
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import UUID4
from fastapi_users import models
from fastapi_users import models, schemas
from fastapi_users.db import BaseUserDatabase
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.password import PasswordHelper, PasswordHelperProtocol
@ -19,6 +19,10 @@ class FastAPIUsersException(Exception):
pass
class InvalidID(FastAPIUsersException):
pass
class UserAlreadyExists(FastAPIUsersException):
pass
@ -48,11 +52,10 @@ class InvalidPasswordException(FastAPIUsersException):
self.reason = reason
class BaseUserManager(Generic[models.UC, models.UD]):
class BaseUserManager(Generic[models.UP, models.ID]):
"""
User management logic.
:attribute user_db_model: Pydantic model of a DB representation of a user.
:attribute reset_password_token_secret: Secret to encode reset password token.
:attribute reset_password_token_lifetime_seconds: Lifetime of reset password token.
:attribute reset_password_token_audience: JWT audience of reset password token.
@ -63,7 +66,6 @@ class BaseUserManager(Generic[models.UC, models.UD]):
:param user_db: Database adapter instance.
"""
user_db_model: Type[models.UD]
reset_password_token_secret: SecretType
reset_password_token_lifetime_seconds: int = 3600
reset_password_token_audience: str = RESET_PASSWORD_TOKEN_AUDIENCE
@ -72,12 +74,12 @@ class BaseUserManager(Generic[models.UC, models.UD]):
verification_token_lifetime_seconds: int = 3600
verification_token_audience: str = VERIFY_USER_TOKEN_AUDIENCE
user_db: BaseUserDatabase[models.UD]
user_db: BaseUserDatabase[models.UP, models.ID]
password_helper: PasswordHelperProtocol
def __init__(
self,
user_db: BaseUserDatabase[models.UD],
user_db: BaseUserDatabase[models.UP, models.ID],
password_helper: Optional[PasswordHelperProtocol] = None,
):
self.user_db = user_db
@ -86,7 +88,17 @@ class BaseUserManager(Generic[models.UC, models.UD]):
else:
self.password_helper = password_helper # pragma: no cover
async def get(self, id: UUID4) -> models.UD:
def parse_id(self, value: Any) -> models.ID:
"""
Parse a value into a correct models.ID instance.
:param value: The value to parse.
:raises InvalidID: The models.ID value is invalid.
:return: An models.ID object.
"""
raise NotImplementedError() # pragma: no cover
async def get(self, id: models.ID) -> models.UP:
"""
Get a user by id.
@ -101,7 +113,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
return user
async def get_by_email(self, user_email: str) -> models.UD:
async def get_by_email(self, user_email: str) -> models.UP:
"""
Get a user by e-mail.
@ -116,7 +128,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
return user
async def get_by_oauth_account(self, oauth: str, account_id: str) -> models.UD:
async def get_by_oauth_account(self, oauth: str, account_id: str) -> models.UP:
"""
Get a user by OAuth account.
@ -133,14 +145,17 @@ class BaseUserManager(Generic[models.UC, models.UD]):
return user
async def create(
self, user: models.UC, safe: bool = False, request: Optional[Request] = None
) -> models.UD:
self,
user_create: schemas.UC,
safe: bool = False,
request: Optional[Request] = None,
) -> models.UP:
"""
Create a user in database.
Triggers the on_after_register handler on success.
:param user: The UserCreate model to create.
:param user_create: The UserCreate model to create.
:param safe: If True, sensitive values like is_superuser or is_verified
will be ignored during the creation, defaults to False.
:param request: Optional FastAPI request that
@ -148,27 +163,36 @@ class BaseUserManager(Generic[models.UC, models.UD]):
:raises UserAlreadyExists: A user already exists with the same e-mail.
:return: A new user.
"""
await self.validate_password(user.password, user)
await self.validate_password(user_create.password, user_create)
existing_user = await self.user_db.get_by_email(user.email)
existing_user = await self.user_db.get_by_email(user_create.email)
if existing_user is not None:
raise UserAlreadyExists()
hashed_password = self.password_helper.hash(user.password)
user_dict = (
user.create_update_dict() if safe else user.create_update_dict_superuser()
user_create.create_update_dict()
if safe
else user_create.create_update_dict_superuser()
)
db_user = self.user_db_model(**user_dict, hashed_password=hashed_password)
password = user_dict.pop("password")
user_dict["hashed_password"] = self.password_helper.hash(password)
created_user = await self.user_db.create(db_user)
created_user = await self.user_db.create(user_dict)
await self.on_after_register(created_user, request)
return created_user
async def oauth_callback(
self, oauth_account: models.BaseOAuthAccount, request: Optional[Request] = None
) -> models.UD:
self: "BaseUserManager[models.UOAP, models.ID]",
oauth_name: str,
access_token: str,
account_id: str,
account_email: str,
expires_at: Optional[int] = None,
refresh_token: Optional[str] = None,
request: Optional[Request] = None,
) -> models.UOAP:
"""
Handle the callback after a successful OAuth authentication.
@ -180,50 +204,58 @@ class BaseUserManager(Generic[models.UC, models.UD]):
If the user does not exist, it is created and the on_after_register handler
is triggered.
:param oauth_account: The new OAuth account to create.
:param oauth_name: Name of the OAuth client.
:param access_token: Valid access token for the service provider.
:param account_id: models.ID of the user on the service provider.
:param account_email: E-mail of the user on the service provider.
:param expires_at: Optional timestamp at which the access token expires.
:param refresh_token: Optional refresh token to get a
fresh access token from the service provider.
:param request: Optional FastAPI request that
triggered the operation, defaults to None
:return: A user.
"""
oauth_account_dict = {
"oauth_name": oauth_name,
"access_token": access_token,
"account_id": account_id,
"account_email": account_email,
"expires_at": expires_at,
"refresh_token": refresh_token,
}
try:
user = await self.get_by_oauth_account(
oauth_account.oauth_name, oauth_account.account_id
)
user = await self.get_by_oauth_account(oauth_name, account_id)
except UserNotExists:
try:
# Link account
user = await self.get_by_email(oauth_account.account_email)
user.oauth_accounts.append(oauth_account) # type: ignore
await self.user_db.update(user)
user = await self.get_by_email(account_email)
user = await self.user_db.add_oauth_account(user, oauth_account_dict)
except UserNotExists:
# Create account
password = self.password_helper.generate()
user = self.user_db_model(
email=oauth_account.account_email,
hashed_password=self.password_helper.hash(password),
oauth_accounts=[oauth_account],
)
await self.user_db.create(user)
user_dict = {
"email": account_email,
"hashed_password": self.password_helper.hash(password),
}
user = await self.user_db.create(user_dict)
user = await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
# Update oauth
updated_oauth_accounts = []
for existing_oauth_account in user.oauth_accounts: # type: ignore
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == oauth_account.account_id
and existing_oauth_account.oauth_name == oauth_account.oauth_name
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
oauth_account.id = existing_oauth_account.id
updated_oauth_accounts.append(oauth_account)
else:
updated_oauth_accounts.append(existing_oauth_account)
user.oauth_accounts = updated_oauth_accounts # type: ignore
await self.user_db.update(user)
user = await self.user_db.update_oauth_account(
user, existing_oauth_account, oauth_account_dict
)
return user
async def request_verify(
self, user: models.UD, request: Optional[Request] = None
self, user: models.UP, request: Optional[Request] = None
) -> None:
"""
Start a verification request.
@ -253,7 +285,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
)
await self.on_after_request_verify(user, token, request)
async def verify(self, token: str, request: Optional[Request] = None) -> models.UD:
async def verify(self, token: str, request: Optional[Request] = None) -> models.UP:
"""
Validate a verification request.
@ -289,11 +321,11 @@ class BaseUserManager(Generic[models.UC, models.UD]):
raise InvalidVerifyToken()
try:
user_uuid = UUID4(user_id)
except ValueError:
parsed_id = self.parse_id(user_id)
except InvalidID:
raise InvalidVerifyToken()
if user_uuid != user.id:
if parsed_id != user.id:
raise InvalidVerifyToken()
if user.is_verified:
@ -306,7 +338,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
return verified_user
async def forgot_password(
self, user: models.UD, request: Optional[Request] = None
self, user: models.UP, request: Optional[Request] = None
) -> None:
"""
Start a forgot password request.
@ -334,7 +366,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
async def reset_password(
self, token: str, password: str, request: Optional[Request] = None
) -> models.UD:
) -> models.UP:
"""
Reset the password of a user.
@ -364,11 +396,11 @@ class BaseUserManager(Generic[models.UC, models.UD]):
raise InvalidResetPasswordToken()
try:
user_uuid = UUID4(user_id)
except ValueError:
parsed_id = self.parse_id(user_id)
except InvalidID:
raise InvalidResetPasswordToken()
user = await self.get(user_uuid)
user = await self.get(parsed_id)
if not user.is_active:
raise UserInactive()
@ -381,11 +413,11 @@ class BaseUserManager(Generic[models.UC, models.UD]):
async def update(
self,
user_update: models.UU,
user: models.UD,
user_update: schemas.UU,
user: models.UP,
safe: bool = False,
request: Optional[Request] = None,
) -> models.UD:
) -> models.UP:
"""
Update a user.
@ -408,7 +440,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
await self.on_after_update(updated_user, updated_user_data, request)
return updated_user
async def delete(self, user: models.UD) -> None:
async def delete(self, user: models.UP) -> None:
"""
Delete a user.
@ -417,7 +449,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
await self.user_db.delete(user)
async def validate_password(
self, password: str, user: Union[models.UC, models.UD]
self, password: str, user: Union[schemas.UC, models.UP]
) -> None:
"""
Validate a password.
@ -432,7 +464,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
return # pragma: no cover
async def on_after_register(
self, user: models.UD, request: Optional[Request] = None
self, user: models.UP, request: Optional[Request] = None
) -> None:
"""
Perform logic after successful user registration.
@ -447,7 +479,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
async def on_after_update(
self,
user: models.UD,
user: models.UP,
update_dict: Dict[str, Any],
request: Optional[Request] = None,
) -> None:
@ -464,7 +496,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
return # pragma: no cover
async def on_after_request_verify(
self, user: models.UD, token: str, request: Optional[Request] = None
self, user: models.UP, token: str, request: Optional[Request] = None
) -> None:
"""
Perform logic after successful verification request.
@ -479,7 +511,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
return # pragma: no cover
async def on_after_verify(
self, user: models.UD, request: Optional[Request] = None
self, user: models.UP, request: Optional[Request] = None
) -> None:
"""
Perform logic after successful user verification.
@ -493,7 +525,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
return # pragma: no cover
async def on_after_forgot_password(
self, user: models.UD, token: str, request: Optional[Request] = None
self, user: models.UP, token: str, request: Optional[Request] = None
) -> None:
"""
Perform logic after successful forgot password request.
@ -508,7 +540,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
return # pragma: no cover
async def on_after_reset_password(
self, user: models.UD, request: Optional[Request] = None
self, user: models.UP, request: Optional[Request] = None
) -> None:
"""
Perform logic after successful password reset.
@ -523,7 +555,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[models.UD]:
) -> Optional[models.UP]:
"""
Authenticate and return a user following an email and a password.
@ -546,27 +578,48 @@ class BaseUserManager(Generic[models.UC, models.UD]):
return None
# Update password hash to a more robust one if needed
if updated_password_hash is not None:
user.hashed_password = updated_password_hash
await self.user_db.update(user)
await self.user_db.update(user, {"hashed_password": updated_password_hash})
return user
async def _update(self, user: models.UD, update_dict: Dict[str, Any]) -> models.UD:
async def _update(self, user: models.UP, update_dict: Dict[str, Any]) -> models.UP:
validated_update_dict = {}
for field, value in update_dict.items():
if field == "email" and value != user.email:
try:
await self.get_by_email(value)
raise UserAlreadyExists()
except UserNotExists:
user.email = value
user.is_verified = False
validated_update_dict["email"] = value
validated_update_dict["is_verified"] = False
elif field == "password":
await self.validate_password(value, user)
hashed_password = self.password_helper.hash(value)
user.hashed_password = hashed_password
validated_update_dict["hashed_password"] = self.password_helper.hash(
value
)
else:
setattr(user, field, value)
return await self.user_db.update(user)
validated_update_dict[field] = value
return await self.user_db.update(user, validated_update_dict)
UserManagerDependency = DependencyCallable[BaseUserManager[models.UC, models.UD]]
class UUIDIDMixin:
def parse_id(self, value: Any) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
try:
return uuid.UUID(value)
except ValueError as e:
raise InvalidID() from e
class IntegerIDMixin:
def parse_id(self, value: Any) -> int:
if isinstance(value, float):
raise InvalidID()
try:
return int(value)
except ValueError as e:
raise InvalidID() from e
UserManagerDependency = DependencyCallable[BaseUserManager[models.UP, models.ID]]

View File

@ -1,81 +1,51 @@
import uuid
from typing import List, Optional, TypeVar
import sys
from typing import Generic, List, Optional, TypeVar
from pydantic import UUID4, BaseModel, EmailStr, Field
if sys.version_info < (3, 8):
from typing_extensions import Protocol # pragma: no cover
else:
from typing import Protocol # pragma: no cover
ID = TypeVar("ID")
class CreateUpdateDictModel(BaseModel):
def create_update_dict(self):
return self.dict(
exclude_unset=True,
exclude={
"id",
"is_superuser",
"is_active",
"is_verified",
"oauth_accounts",
},
)
class UserProtocol(Protocol[ID]):
"""User protocol that ORM model should follow."""
def create_update_dict_superuser(self):
return self.dict(exclude_unset=True, exclude={"id"})
class BaseUser(CreateUpdateDictModel):
"""Base User model."""
id: UUID4 = Field(default_factory=uuid.uuid4)
email: EmailStr
is_active: bool = True
is_superuser: bool = False
is_verified: bool = False
class BaseUserCreate(CreateUpdateDictModel):
email: EmailStr
password: str
is_active: Optional[bool] = True
is_superuser: Optional[bool] = False
is_verified: Optional[bool] = False
class BaseUserUpdate(CreateUpdateDictModel):
password: Optional[str]
email: Optional[EmailStr]
is_active: Optional[bool]
is_superuser: Optional[bool]
is_verified: Optional[bool]
class BaseUserDB(BaseUser):
id: ID
email: str
hashed_password: str
is_active: bool
is_superuser: bool
is_verified: bool
class Config:
orm_mode = True
def __init__(self, *args, **kwargs) -> None:
... # pragma: no cover
U = TypeVar("U", bound=BaseUser)
UC = TypeVar("UC", bound=BaseUserCreate)
UU = TypeVar("UU", bound=BaseUserUpdate)
UD = TypeVar("UD", bound=BaseUserDB)
class OAuthAccountProtocol(Protocol[ID]):
"""OAuth account protocol that ORM model should follow."""
class BaseOAuthAccount(BaseModel):
"""Base OAuth account model."""
id: UUID4 = Field(default_factory=uuid.uuid4)
id: ID
oauth_name: str
access_token: str
expires_at: Optional[int] = None
refresh_token: Optional[str] = None
expires_at: Optional[int]
refresh_token: Optional[str]
account_id: str
account_email: str
class Config:
orm_mode = True
def __init__(self, *args, **kwargs) -> None:
... # pragma: no cover
class BaseOAuthAccountMixin(BaseModel):
"""Adds OAuth accounts list to a User model."""
UP = TypeVar("UP", bound=UserProtocol)
OAP = TypeVar("OAP", bound=OAuthAccountProtocol)
oauth_accounts: List[BaseOAuthAccount] = []
class UserOAuthProtocol(UserProtocol[ID], Generic[ID, OAP]):
"""User protocol including a list of OAuth accounts."""
oauth_accounts: List[OAP]
UOAP = TypeVar("UOAP", bound=UserOAuthProtocol)

View File

@ -12,7 +12,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
def get_auth_router(
backend: AuthenticationBackend,
get_user_manager: UserManagerDependency[models.UC, models.UD],
get_user_manager: UserManagerDependency[models.UP, models.ID],
authenticator: Authenticator,
requires_verification: bool = False,
) -> APIRouter:
@ -51,8 +51,8 @@ def get_auth_router(
async def login(
response: Response,
credentials: OAuth2PasswordRequestForm = Depends(),
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
):
user = await user_manager.authenticate(credentials)
@ -82,8 +82,8 @@ def get_auth_router(
)
async def logout(
response: Response,
user_token: Tuple[models.UD, str] = Depends(get_current_user_token),
strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy),
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
):
user, token = user_token
return await backend.logout(strategy, user, token, response)

View File

@ -29,7 +29,7 @@ def generate_state_token(
def get_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
get_user_manager: UserManagerDependency[models.UC, models.UD],
get_user_manager: UserManagerDependency[models.UP, models.ID],
state_secret: SecretType,
redirect_url: str = None,
) -> APIRouter:
@ -101,8 +101,8 @@ def get_oauth_router(
access_token_state: Tuple[OAuth2Token, str] = Depends(
oauth2_authorize_callback
),
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
):
token, state = access_token_state
account_id, account_email = await oauth_client.get_id_email(
@ -114,17 +114,16 @@ def get_oauth_router(
except jwt.DecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
new_oauth_account = models.BaseOAuthAccount(
oauth_name=oauth_client.name,
access_token=token["access_token"],
expires_at=token.get("expires_at"),
refresh_token=token.get("refresh_token"),
account_id=account_id,
account_email=account_email,
user = await user_manager.oauth_callback(
oauth_client.name,
token["access_token"],
account_id,
account_email,
token.get("expires_at"),
token.get("refresh_token"),
request,
)
user = await user_manager.oauth_callback(new_oauth_account, request)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -2,7 +2,7 @@ from typing import Type
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi_users import models
from fastapi_users import models, schemas
from fastapi_users.manager import (
BaseUserManager,
InvalidPasswordException,
@ -13,16 +13,16 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
def get_register_router(
get_user_manager: UserManagerDependency[models.UC, models.UD],
user_model: Type[models.U],
user_create_model: Type[models.UC],
get_user_manager: UserManagerDependency[models.UP, models.ID],
user_schema: Type[schemas.U],
user_create_schema: Type[schemas.UC],
) -> APIRouter:
"""Generate a router with the register route."""
router = APIRouter()
@router.post(
"/register",
response_model=user_model,
response_model=user_schema,
status_code=status.HTTP_201_CREATED,
name="register:register",
responses={
@ -55,11 +55,13 @@ def get_register_router(
)
async def register(
request: Request,
user: user_create_model, # type: ignore
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
user_create: user_create_schema, # type: ignore
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
created_user = await user_manager.create(user, safe=True, request=request)
created_user = await user_manager.create(
user_create, safe=True, request=request
)
except UserAlreadyExists:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -40,7 +40,7 @@ RESET_PASSWORD_RESPONSES: OpenAPIResponseType = {
def get_reset_password_router(
get_user_manager: UserManagerDependency[models.UC, models.UD]
get_user_manager: UserManagerDependency[models.UP, models.ID],
) -> APIRouter:
"""Generate a router with the reset password routes."""
router = APIRouter()
@ -53,7 +53,7 @@ def get_reset_password_router(
async def forgot_password(
request: Request,
email: EmailStr = Body(..., embed=True),
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
user = await user_manager.get_by_email(email)
@ -76,7 +76,7 @@ def get_reset_password_router(
request: Request,
token: str = Body(...),
password: str = Body(...),
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
await user_manager.reset_password(token, password, request)

View File

@ -1,12 +1,12 @@
from typing import Type
from typing import Any, Type
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import UUID4
from fastapi_users import models
from fastapi_users import models, schemas
from fastapi_users.authentication import Authenticator
from fastapi_users.manager import (
BaseUserManager,
InvalidID,
InvalidPasswordException,
UserAlreadyExists,
UserManagerDependency,
@ -16,10 +16,9 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
def get_users_router(
get_user_manager: UserManagerDependency[models.UC, models.UD],
user_model: Type[models.U],
user_update_model: Type[models.UU],
user_db_model: Type[models.UD],
get_user_manager: UserManagerDependency[models.UP, models.ID],
user_schema: Type[schemas.U],
user_update_schema: Type[schemas.UU],
authenticator: Authenticator,
requires_verification: bool = False,
) -> APIRouter:
@ -34,17 +33,18 @@ def get_users_router(
)
async def get_user_or_404(
id: UUID4,
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
) -> models.UD:
id: Any,
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
) -> models.UP:
try:
return await user_manager.get(id)
except UserNotExists:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
parsed_id = user_manager.parse_id(id)
return await user_manager.get(parsed_id)
except (UserNotExists, InvalidID) as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) from e
@router.get(
"/me",
response_model=user_model,
response_model=user_schema,
name="users:current_user",
responses={
status.HTTP_401_UNAUTHORIZED: {
@ -53,13 +53,13 @@ def get_users_router(
},
)
async def me(
user: user_db_model = Depends(get_current_active_user), # type: ignore
user: models.UP = Depends(get_current_active_user),
):
return user
@router.patch(
"/me",
response_model=user_model,
response_model=user_schema,
dependencies=[Depends(get_current_active_user)],
name="users:patch_current_user",
responses={
@ -95,9 +95,9 @@ def get_users_router(
)
async def update_me(
request: Request,
user_update: user_update_model, # type: ignore
user: user_db_model = Depends(get_current_active_user), # type: ignore
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
user_update: user_update_schema, # type: ignore
user: models.UP = Depends(get_current_active_user),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
return await user_manager.update(
@ -118,8 +118,8 @@ def get_users_router(
)
@router.get(
"/{id:uuid}",
response_model=user_model,
"/{id}",
response_model=user_schema,
dependencies=[Depends(get_current_superuser)],
name="users:user",
responses={
@ -138,8 +138,8 @@ def get_users_router(
return user
@router.patch(
"/{id:uuid}",
response_model=user_model,
"/{id}",
response_model=user_schema,
dependencies=[Depends(get_current_superuser)],
name="users:patch_user",
responses={
@ -180,10 +180,10 @@ def get_users_router(
},
)
async def update_user(
user_update: user_update_model, # type: ignore
user_update: user_update_schema, # type: ignore
request: Request,
user=Depends(get_user_or_404),
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
return await user_manager.update(
@ -204,7 +204,7 @@ def get_users_router(
)
@router.delete(
"/{id:uuid}",
"/{id}",
status_code=status.HTTP_204_NO_CONTENT,
response_class=Response,
dependencies=[Depends(get_current_superuser)],
@ -223,7 +223,7 @@ def get_users_router(
)
async def delete_user(
user=Depends(get_user_or_404),
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
await user_manager.delete(user)
return None

View File

@ -3,7 +3,7 @@ from typing import Type
from fastapi import APIRouter, Body, Depends, HTTPException, Request, status
from pydantic import EmailStr
from fastapi_users import models
from fastapi_users import models, schemas
from fastapi_users.manager import (
BaseUserManager,
InvalidVerifyToken,
@ -16,8 +16,8 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
def get_verify_router(
get_user_manager: UserManagerDependency[models.UC, models.UD],
user_model: Type[models.U],
get_user_manager: UserManagerDependency[models.UP, models.ID],
user_schema: Type[schemas.U],
):
router = APIRouter()
@ -29,7 +29,7 @@ def get_verify_router(
async def request_verify_token(
request: Request,
email: EmailStr = Body(..., embed=True),
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
user = await user_manager.get_by_email(email)
@ -41,7 +41,7 @@ def get_verify_router(
@router.post(
"/verify",
response_model=user_model,
response_model=user_schema,
name="verify:verify",
responses={
status.HTTP_400_BAD_REQUEST: {
@ -69,7 +69,7 @@ def get_verify_router(
async def verify(
request: Request,
token: str = Body(..., embed=True),
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
):
try:
return await user_manager.verify(token, request)

74
fastapi_users/schemas.py Normal file
View File

@ -0,0 +1,74 @@
from typing import Generic, List, Optional, TypeVar
from pydantic import BaseModel, EmailStr
from fastapi_users import models
class CreateUpdateDictModel(BaseModel):
def create_update_dict(self):
return self.dict(
exclude_unset=True,
exclude={
"id",
"is_superuser",
"is_active",
"is_verified",
"oauth_accounts",
},
)
def create_update_dict_superuser(self):
return self.dict(exclude_unset=True, exclude={"id"})
class BaseUser(Generic[models.ID], CreateUpdateDictModel):
"""Base User model."""
id: models.ID
email: EmailStr
is_active: bool = True
is_superuser: bool = False
is_verified: bool = False
class BaseUserCreate(CreateUpdateDictModel):
email: EmailStr
password: str
is_active: Optional[bool] = True
is_superuser: Optional[bool] = False
is_verified: Optional[bool] = False
class BaseUserUpdate(CreateUpdateDictModel):
password: Optional[str]
email: Optional[EmailStr]
is_active: Optional[bool]
is_superuser: Optional[bool]
is_verified: Optional[bool]
U = TypeVar("U", bound=BaseUser)
UC = TypeVar("UC", bound=BaseUserCreate)
UU = TypeVar("UU", bound=BaseUserUpdate)
class BaseOAuthAccount(Generic[models.ID], BaseModel):
"""Base OAuth account model."""
id: models.ID
oauth_name: str
access_token: str
expires_at: Optional[int] = None
refresh_token: Optional[str] = None
account_id: str
account_email: str
class Config:
orm_mode = True
class BaseOAuthAccountMixin(BaseModel):
"""Adds OAuth accounts list to a User model."""
oauth_accounts: List[BaseOAuthAccount] = []