mirror of
				https://github.com/fastapi-users/fastapi-users.git
				synced 2025-11-01 01:48:46 +08:00 
			
		
		
		
	Native model and generic ID (#971)
* Use a generic Protocol model for User instead of Pydantic
* Remove UserDB Pydantic schema
* Harmonize schema variable naming to avoid confusions
* Revamp OAuth account model management
* Revamp AccessToken DB strategy to adopt generic model approach
* Make ID a generic instead of forcing UUIDs
* Improve generic typing
* Improve Strategy typing
* Tweak base DB typing
* Don't set Pydantic schemas on FastAPIUsers class: pass it directly on router creation
* Add IntegerIdMixin and export related classes
* Start to revamp doc for V10
* Revamp OAuth documentation
* Fix code highlights
* Write the 9.x.x ➡️ 10.x.x migration doc
* Fix pyproject.toml
			
			
This commit is contained in:
		| @ -2,16 +2,22 @@ | ||||
|  | ||||
| __version__ = "9.3.2" | ||||
|  | ||||
| from fastapi_users import models  # noqa: F401 | ||||
| from fastapi_users import models, schemas  # noqa: F401 | ||||
| from fastapi_users.fastapi_users import FastAPIUsers  # noqa: F401 | ||||
| from fastapi_users.manager import (  # noqa: F401 | ||||
|     BaseUserManager, | ||||
|     IntegerIDMixin, | ||||
|     InvalidID, | ||||
|     InvalidPasswordException, | ||||
|     UUIDIDMixin, | ||||
| ) | ||||
|  | ||||
| __all__ = [ | ||||
|     "models", | ||||
|     "schemas", | ||||
|     "FastAPIUsers", | ||||
|     "BaseUserManager", | ||||
|     "InvalidPasswordException", | ||||
|     "InvalidID", | ||||
|     "UUIDIDMixin", | ||||
|     "IntegerIDMixin", | ||||
| ] | ||||
|  | ||||
| @ -51,7 +51,7 @@ class Authenticator: | ||||
|     def __init__( | ||||
|         self, | ||||
|         backends: Sequence[AuthenticationBackend], | ||||
|         get_user_manager: UserManagerDependency[models.UC, models.UD], | ||||
|         get_user_manager: UserManagerDependency[models.UP, models.ID], | ||||
|     ): | ||||
|         self.backends = backends | ||||
|         self.get_user_manager = get_user_manager | ||||
| @ -148,14 +148,14 @@ class Authenticator: | ||||
|     async def _authenticate( | ||||
|         self, | ||||
|         *args, | ||||
|         user_manager: BaseUserManager[models.UC, models.UD], | ||||
|         user_manager: BaseUserManager[models.UP, models.ID], | ||||
|         optional: bool = False, | ||||
|         active: bool = False, | ||||
|         verified: bool = False, | ||||
|         superuser: bool = False, | ||||
|         **kwargs, | ||||
|     ) -> Tuple[Optional[models.UD], Optional[str]]: | ||||
|         user: Optional[models.UD] = None | ||||
|     ) -> Tuple[Optional[models.UP], Optional[str]]: | ||||
|         user: Optional[models.UP] = None | ||||
|         token: Optional[str] = None | ||||
|         enabled_backends: Sequence[AuthenticationBackend] = kwargs.get( | ||||
|             "enabled_backends", self.backends | ||||
| @ -163,7 +163,7 @@ class Authenticator: | ||||
|         for backend in self.backends: | ||||
|             if backend in enabled_backends: | ||||
|                 token = kwargs[name_to_variable_name(backend.name)] | ||||
|                 strategy: Strategy[models.UC, models.UD] = kwargs[ | ||||
|                 strategy: Strategy[models.UP, models.ID] = kwargs[ | ||||
|                     name_to_strategy_variable_name(backend.name) | ||||
|                 ] | ||||
|                 if token is not None: | ||||
|  | ||||
| @ -14,7 +14,7 @@ from fastapi_users.authentication.transport import ( | ||||
| from fastapi_users.types import DependencyCallable | ||||
|  | ||||
|  | ||||
| class AuthenticationBackend(Generic[models.UC, models.UD]): | ||||
| class AuthenticationBackend(Generic[models.UP]): | ||||
|     """ | ||||
|     Combination of an authentication transport and strategy. | ||||
|  | ||||
| @ -33,7 +33,7 @@ class AuthenticationBackend(Generic[models.UC, models.UD]): | ||||
|         self, | ||||
|         name: str, | ||||
|         transport: Transport, | ||||
|         get_strategy: DependencyCallable[Strategy[models.UC, models.UD]], | ||||
|         get_strategy: DependencyCallable[Strategy[models.UP, models.ID]], | ||||
|     ): | ||||
|         self.name = name | ||||
|         self.transport = transport | ||||
| @ -41,8 +41,8 @@ class AuthenticationBackend(Generic[models.UC, models.UD]): | ||||
|  | ||||
|     async def login( | ||||
|         self, | ||||
|         strategy: Strategy[models.UC, models.UD], | ||||
|         user: models.UD, | ||||
|         strategy: Strategy[models.UP, models.ID], | ||||
|         user: models.UP, | ||||
|         response: Response, | ||||
|     ) -> Any: | ||||
|         token = await strategy.write_token(user) | ||||
| @ -50,8 +50,8 @@ class AuthenticationBackend(Generic[models.UC, models.UD]): | ||||
|  | ||||
|     async def logout( | ||||
|         self, | ||||
|         strategy: Strategy[models.UC, models.UD], | ||||
|         user: models.UD, | ||||
|         strategy: Strategy[models.UP, models.ID], | ||||
|         user: models.UP, | ||||
|         token: str, | ||||
|         response: Response, | ||||
|     ) -> Any: | ||||
|  | ||||
| @ -3,9 +3,9 @@ from fastapi_users.authentication.strategy.base import ( | ||||
|     StrategyDestroyNotSupportedError, | ||||
| ) | ||||
| from fastapi_users.authentication.strategy.db import ( | ||||
|     A, | ||||
|     AP, | ||||
|     AccessTokenDatabase, | ||||
|     BaseAccessToken, | ||||
|     AccessTokenProtocol, | ||||
|     DatabaseStrategy, | ||||
| ) | ||||
| from fastapi_users.authentication.strategy.jwt import JWTStrategy | ||||
| @ -16,9 +16,9 @@ except ImportError:  # pragma: no cover | ||||
|     pass | ||||
|  | ||||
| __all__ = [ | ||||
|     "A", | ||||
|     "AP", | ||||
|     "AccessTokenDatabase", | ||||
|     "BaseAccessToken", | ||||
|     "AccessTokenProtocol", | ||||
|     "DatabaseStrategy", | ||||
|     "JWTStrategy", | ||||
|     "Strategy", | ||||
|  | ||||
| @ -14,14 +14,14 @@ class StrategyDestroyNotSupportedError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class Strategy(Protocol, Generic[models.UC, models.UD]): | ||||
| class Strategy(Protocol, Generic[models.UP, models.ID]): | ||||
|     async def read_token( | ||||
|         self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] | ||||
|     ) -> Optional[models.UD]: | ||||
|         self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] | ||||
|     ) -> Optional[models.UP]: | ||||
|         ...  # pragma: no cover | ||||
|  | ||||
|     async def write_token(self, user: models.UD) -> str: | ||||
|     async def write_token(self, user: models.UP) -> str: | ||||
|         ...  # pragma: no cover | ||||
|  | ||||
|     async def destroy_token(self, token: str, user: models.UD) -> None: | ||||
|     async def destroy_token(self, token: str, user: models.UP) -> None: | ||||
|         ...  # pragma: no cover | ||||
|  | ||||
| @ -1,5 +1,5 @@ | ||||
| from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase | ||||
| from fastapi_users.authentication.strategy.db.models import A, BaseAccessToken | ||||
| from fastapi_users.authentication.strategy.db.models import AP, AccessTokenProtocol | ||||
| from fastapi_users.authentication.strategy.db.strategy import DatabaseStrategy | ||||
|  | ||||
| __all__ = ["A", "AccessTokenDatabase", "BaseAccessToken", "DatabaseStrategy"] | ||||
| __all__ = ["AP", "AccessTokenDatabase", "AccessTokenProtocol", "DatabaseStrategy"] | ||||
|  | ||||
| @ -1,38 +1,32 @@ | ||||
| import sys | ||||
| from datetime import datetime | ||||
| from typing import Generic, Optional, Type | ||||
| from typing import Any, Dict, Generic, Optional | ||||
|  | ||||
| if sys.version_info < (3, 8): | ||||
|     from typing_extensions import Protocol  # pragma: no cover | ||||
| else: | ||||
|     from typing import Protocol  # pragma: no cover | ||||
|  | ||||
| from fastapi_users.authentication.strategy.db.models import A | ||||
| from fastapi_users.authentication.strategy.db.models import AP | ||||
|  | ||||
|  | ||||
| class AccessTokenDatabase(Protocol, Generic[A]): | ||||
|     """ | ||||
|     Protocol for retrieving, creating and updating access tokens from a database. | ||||
|  | ||||
|     :param access_token_model: Pydantic model of an access token. | ||||
|     """ | ||||
|  | ||||
|     access_token_model: Type[A] | ||||
| class AccessTokenDatabase(Protocol, Generic[AP]): | ||||
|     """Protocol for retrieving, creating and updating access tokens from a database.""" | ||||
|  | ||||
|     async def get_by_token( | ||||
|         self, token: str, max_age: Optional[datetime] = None | ||||
|     ) -> Optional[A]: | ||||
|     ) -> Optional[AP]: | ||||
|         """Get a single access token by token.""" | ||||
|         ...  # pragma: no cover | ||||
|  | ||||
|     async def create(self, access_token: A) -> A: | ||||
|     async def create(self, create_dict: Dict[str, Any]) -> AP: | ||||
|         """Create an access token.""" | ||||
|         ...  # pragma: no cover | ||||
|  | ||||
|     async def update(self, access_token: A) -> A: | ||||
|     async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP: | ||||
|         """Update an access token.""" | ||||
|         ...  # pragma: no cover | ||||
|  | ||||
|     async def delete(self, access_token: A) -> None: | ||||
|     async def delete(self, access_token: AP) -> None: | ||||
|         """Delete an access token.""" | ||||
|         ...  # pragma: no cover | ||||
|  | ||||
| @ -1,22 +1,24 @@ | ||||
| from datetime import datetime, timezone | ||||
| import sys | ||||
| from datetime import datetime | ||||
| from typing import TypeVar | ||||
|  | ||||
| from pydantic import UUID4, BaseModel, Field | ||||
| if sys.version_info < (3, 8): | ||||
|     from typing_extensions import Protocol  # pragma: no cover | ||||
| else: | ||||
|     from typing import Protocol  # pragma: no cover | ||||
|  | ||||
| from fastapi_users import models | ||||
|  | ||||
|  | ||||
| def now_utc(): | ||||
|     return datetime.now(timezone.utc) | ||||
|  | ||||
|  | ||||
| class BaseAccessToken(BaseModel): | ||||
|     """Base access token model.""" | ||||
| class AccessTokenProtocol(Protocol[models.ID]): | ||||
|     """Access token protocol that ORM model should follow.""" | ||||
|  | ||||
|     token: str | ||||
|     user_id: UUID4 | ||||
|     created_at: datetime = Field(default_factory=now_utc) | ||||
|     user_id: models.ID | ||||
|     created_at: datetime | ||||
|  | ||||
|     class Config: | ||||
|         orm_mode = True | ||||
|     def __init__(self, *args, **kwargs) -> None: | ||||
|         ...  # pragma: no cover | ||||
|  | ||||
|  | ||||
| A = TypeVar("A", bound=BaseAccessToken) | ||||
| AP = TypeVar("AP", bound=AccessTokenProtocol) | ||||
|  | ||||
| @ -1,24 +1,26 @@ | ||||
| import secrets | ||||
| from datetime import datetime, timedelta, timezone | ||||
| from typing import Generic, Optional | ||||
| from typing import Any, Dict, Generic, Optional | ||||
|  | ||||
| from fastapi_users import models | ||||
| from fastapi_users.authentication.strategy.base import Strategy | ||||
| from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase | ||||
| from fastapi_users.authentication.strategy.db.models import A | ||||
| from fastapi_users.manager import BaseUserManager, UserNotExists | ||||
| from fastapi_users.authentication.strategy.db.models import AP | ||||
| from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists | ||||
|  | ||||
|  | ||||
| class DatabaseStrategy(Strategy, Generic[models.UC, models.UD, A]): | ||||
| class DatabaseStrategy( | ||||
|     Strategy[models.UP, models.ID], Generic[models.UP, models.ID, AP] | ||||
| ): | ||||
|     def __init__( | ||||
|         self, database: AccessTokenDatabase[A], lifetime_seconds: Optional[int] = None | ||||
|         self, database: AccessTokenDatabase[AP], lifetime_seconds: Optional[int] = None | ||||
|     ): | ||||
|         self.database = database | ||||
|         self.lifetime_seconds = lifetime_seconds | ||||
|  | ||||
|     async def read_token( | ||||
|         self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] | ||||
|     ) -> Optional[models.UD]: | ||||
|         self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] | ||||
|     ) -> Optional[models.UP]: | ||||
|         if token is None: | ||||
|             return None | ||||
|  | ||||
| @ -33,21 +35,21 @@ class DatabaseStrategy(Strategy, Generic[models.UC, models.UD, A]): | ||||
|             return None | ||||
|  | ||||
|         try: | ||||
|             user_id = access_token.user_id | ||||
|             return await user_manager.get(user_id) | ||||
|         except UserNotExists: | ||||
|             parsed_id = user_manager.parse_id(access_token.user_id) | ||||
|             return await user_manager.get(parsed_id) | ||||
|         except (UserNotExists, InvalidID): | ||||
|             return None | ||||
|  | ||||
|     async def write_token(self, user: models.UD) -> str: | ||||
|         access_token = self._create_access_token(user) | ||||
|         await self.database.create(access_token) | ||||
|     async def write_token(self, user: models.UP) -> str: | ||||
|         access_token_dict = self._create_access_token_dict(user) | ||||
|         access_token = await self.database.create(access_token_dict) | ||||
|         return access_token.token | ||||
|  | ||||
|     async def destroy_token(self, token: str, user: models.UD) -> None: | ||||
|     async def destroy_token(self, token: str, user: models.UP) -> None: | ||||
|         access_token = await self.database.get_by_token(token) | ||||
|         if access_token is not None: | ||||
|             await self.database.delete(access_token) | ||||
|  | ||||
|     def _create_access_token(self, user: models.UD) -> A: | ||||
|     def _create_access_token_dict(self, user: models.UP) -> Dict[str, Any]: | ||||
|         token = secrets.token_urlsafe() | ||||
|         return self.database.access_token_model(token=token, user_id=user.id) | ||||
|         return {"token": token, "user_id": user.id} | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| from typing import Generic, List, Optional | ||||
|  | ||||
| import jwt | ||||
| from pydantic import UUID4 | ||||
|  | ||||
| from fastapi_users import models | ||||
| from fastapi_users.authentication.strategy.base import ( | ||||
| @ -9,10 +8,10 @@ from fastapi_users.authentication.strategy.base import ( | ||||
|     StrategyDestroyNotSupportedError, | ||||
| ) | ||||
| from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt | ||||
| from fastapi_users.manager import BaseUserManager, UserNotExists | ||||
| from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists | ||||
|  | ||||
|  | ||||
| class JWTStrategy(Strategy, Generic[models.UC, models.UD]): | ||||
| class JWTStrategy(Strategy[models.UP, models.ID], Generic[models.UP, models.ID]): | ||||
|     def __init__( | ||||
|         self, | ||||
|         secret: SecretType, | ||||
| @ -36,8 +35,8 @@ class JWTStrategy(Strategy, Generic[models.UC, models.UD]): | ||||
|         return self.public_key or self.secret | ||||
|  | ||||
|     async def read_token( | ||||
|         self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] | ||||
|     ) -> Optional[models.UD]: | ||||
|         self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] | ||||
|     ) -> Optional[models.UP]: | ||||
|         if token is None: | ||||
|             return None | ||||
|  | ||||
| @ -52,20 +51,18 @@ class JWTStrategy(Strategy, Generic[models.UC, models.UD]): | ||||
|             return None | ||||
|  | ||||
|         try: | ||||
|             user_uiid = UUID4(user_id) | ||||
|             return await user_manager.get(user_uiid) | ||||
|         except ValueError: | ||||
|             return None | ||||
|         except UserNotExists: | ||||
|             parsed_id = user_manager.parse_id(user_id) | ||||
|             return await user_manager.get(parsed_id) | ||||
|         except (UserNotExists, InvalidID): | ||||
|             return None | ||||
|  | ||||
|     async def write_token(self, user: models.UD) -> str: | ||||
|     async def write_token(self, user: models.UP) -> str: | ||||
|         data = {"user_id": str(user.id), "aud": self.token_audience} | ||||
|         return generate_jwt( | ||||
|             data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm | ||||
|         ) | ||||
|  | ||||
|     async def destroy_token(self, token: str, user: models.UD) -> None: | ||||
|     async def destroy_token(self, token: str, user: models.UP) -> None: | ||||
|         raise StrategyDestroyNotSupportedError( | ||||
|             "A JWT can't be invalidated: it's valid until it expires." | ||||
|         ) | ||||
|  | ||||
| @ -2,21 +2,20 @@ import secrets | ||||
| from typing import Generic, Optional | ||||
|  | ||||
| import aioredis | ||||
| from pydantic import UUID4 | ||||
|  | ||||
| from fastapi_users import models | ||||
| from fastapi_users.authentication.strategy.base import Strategy | ||||
| from fastapi_users.manager import BaseUserManager, UserNotExists | ||||
| from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists | ||||
|  | ||||
|  | ||||
| class RedisStrategy(Strategy, Generic[models.UC, models.UD]): | ||||
| class RedisStrategy(Strategy[models.UP, models.ID], Generic[models.UP, models.ID]): | ||||
|     def __init__(self, redis: aioredis.Redis, lifetime_seconds: Optional[int] = None): | ||||
|         self.redis = redis | ||||
|         self.lifetime_seconds = lifetime_seconds | ||||
|  | ||||
|     async def read_token( | ||||
|         self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] | ||||
|     ) -> Optional[models.UD]: | ||||
|         self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] | ||||
|     ) -> Optional[models.UP]: | ||||
|         if token is None: | ||||
|             return None | ||||
|  | ||||
| @ -25,17 +24,15 @@ class RedisStrategy(Strategy, Generic[models.UC, models.UD]): | ||||
|             return None | ||||
|  | ||||
|         try: | ||||
|             user_uiid = UUID4(user_id) | ||||
|             return await user_manager.get(user_uiid) | ||||
|         except ValueError: | ||||
|             return None | ||||
|         except UserNotExists: | ||||
|             parsed_id = user_manager.parse_id(user_id) | ||||
|             return await user_manager.get(parsed_id) | ||||
|         except (UserNotExists, InvalidID): | ||||
|             return None | ||||
|  | ||||
|     async def write_token(self, user: models.UD) -> str: | ||||
|     async def write_token(self, user: models.UP) -> str: | ||||
|         token = secrets.token_urlsafe() | ||||
|         await self.redis.set(token, str(user.id), ex=self.lifetime_seconds) | ||||
|         return token | ||||
|  | ||||
|     async def destroy_token(self, token: str, user: models.UD) -> None: | ||||
|     async def destroy_token(self, token: str, user: models.UP) -> None: | ||||
|         await self.redis.delete(token) | ||||
|  | ||||
| @ -1,52 +1,36 @@ | ||||
| from fastapi_users.db.base import BaseUserDatabase, UserDatabaseDependency | ||||
|  | ||||
| __all__ = [ | ||||
|     "BaseUserDatabase", | ||||
|     "UserDatabaseDependency", | ||||
| ] | ||||
| __all__ = ["BaseUserDatabase", "UserDatabaseDependency"] | ||||
|  | ||||
| try:  # pragma: no cover | ||||
|     from fastapi_users_db_mongodb import MongoDBUserDatabase  # noqa: F401 | ||||
|  | ||||
|     __all__.append("MongoDBUserDatabase") | ||||
| except ImportError:  # pragma: no cover | ||||
|     pass | ||||
|  | ||||
| try:  # pragma: no cover | ||||
|     from fastapi_users_db_sqlalchemy import (  # noqa: F401 | ||||
|         SQLAlchemyBaseOAuthAccountTable, | ||||
|         SQLAlchemyBaseOAuthAccountTableUUID, | ||||
|         SQLAlchemyBaseUserTable, | ||||
|         SQLAlchemyBaseUserTableUUID, | ||||
|         SQLAlchemyUserDatabase, | ||||
|     ) | ||||
|  | ||||
|     __all__.append("SQLAlchemyBaseOAuthAccountTable") | ||||
|     __all__.append("SQLAlchemyBaseUserTable") | ||||
|     __all__.append("SQLAlchemyBaseUserTableUUID") | ||||
|     __all__.append("SQLAlchemyBaseOAuthAccountTable") | ||||
|     __all__.append("SQLAlchemyBaseOAuthAccountTableUUID") | ||||
|     __all__.append("SQLAlchemyUserDatabase") | ||||
| except ImportError:  # pragma: no cover | ||||
|     pass | ||||
|  | ||||
| try:  # pragma: no cover | ||||
|     from fastapi_users_db_tortoise import (  # noqa: F401 | ||||
|         TortoiseBaseOAuthAccountModel, | ||||
|         TortoiseBaseUserModel, | ||||
|         TortoiseUserDatabase, | ||||
|     from fastapi_users_db_beanie import (  # noqa: F401 | ||||
|         BaseOAuthAccount, | ||||
|         BeanieBaseUser, | ||||
|         BeanieUserDatabase, | ||||
|         ObjectIDIDMixin, | ||||
|     ) | ||||
|  | ||||
|     __all__.append("TortoiseBaseOAuthAccountModel") | ||||
|     __all__.append("TortoiseBaseUserModel") | ||||
|     __all__.append("TortoiseUserDatabase") | ||||
| except ImportError:  # pragma: no cover | ||||
|     pass | ||||
|  | ||||
| try:  # pragma: no cover | ||||
|     from fastapi_users_db_ormar import (  # noqa: F401 | ||||
|         OrmarBaseOAuthAccountModel, | ||||
|         OrmarBaseUserModel, | ||||
|         OrmarUserDatabase, | ||||
|     ) | ||||
|  | ||||
|     __all__.append("OrmarBaseOAuthAccountModel") | ||||
|     __all__.append("OrmarBaseUserModel") | ||||
|     __all__.append("OrmarUserDatabase") | ||||
|     __all__.append("BeanieBaseUser") | ||||
|     __all__.append("BaseOAuthAccount") | ||||
|     __all__.append("BeanieUserDatabase") | ||||
|     __all__.append("ObjectIDIDMixin") | ||||
| except ImportError:  # pragma: no cover | ||||
|     pass | ||||
|  | ||||
| @ -1,46 +1,50 @@ | ||||
| from typing import Generic, Optional, Type | ||||
| from typing import Any, Dict, Generic, Optional | ||||
|  | ||||
| from pydantic import UUID4 | ||||
|  | ||||
| from fastapi_users.models import UD | ||||
| from fastapi_users.models import ID, OAP, UOAP, UP | ||||
| from fastapi_users.types import DependencyCallable | ||||
|  | ||||
|  | ||||
| class BaseUserDatabase(Generic[UD]): | ||||
|     """ | ||||
|     Base adapter for retrieving, creating and updating users from a database. | ||||
| class BaseUserDatabase(Generic[UP, ID]): | ||||
|     """Base adapter for retrieving, creating and updating users from a database.""" | ||||
|  | ||||
|     :param user_db_model: Pydantic model of a DB representation of a user. | ||||
|     """ | ||||
|  | ||||
|     user_db_model: Type[UD] | ||||
|  | ||||
|     def __init__(self, user_db_model: Type[UD]): | ||||
|         self.user_db_model = user_db_model | ||||
|  | ||||
|     async def get(self, id: UUID4) -> Optional[UD]: | ||||
|     async def get(self, id: ID) -> Optional[UP]: | ||||
|         """Get a single user by id.""" | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     async def get_by_email(self, email: str) -> Optional[UD]: | ||||
|     async def get_by_email(self, email: str) -> Optional[UP]: | ||||
|         """Get a single user by email.""" | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: | ||||
|     async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]: | ||||
|         """Get a single user by OAuth account id.""" | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     async def create(self, user: UD) -> UD: | ||||
|     async def create(self, create_dict: Dict[str, Any]) -> UP: | ||||
|         """Create a user.""" | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     async def update(self, user: UD) -> UD: | ||||
|     async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP: | ||||
|         """Update a user.""" | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     async def delete(self, user: UD) -> None: | ||||
|     async def delete(self, user: UP) -> None: | ||||
|         """Delete a user.""" | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     async def add_oauth_account( | ||||
|         self: "BaseUserDatabase[UOAP, ID]", user: UOAP, create_dict: Dict[str, Any] | ||||
|     ) -> UOAP: | ||||
|         """Create an OAuth account and add it to the user.""" | ||||
|         raise NotImplementedError() | ||||
|  | ||||
| UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UD]] | ||||
|     async def update_oauth_account( | ||||
|         self: "BaseUserDatabase[UOAP, ID]", | ||||
|         user: UOAP, | ||||
|         oauth_account: OAP, | ||||
|         update_dict: Dict[str, Any], | ||||
|     ) -> UOAP: | ||||
|         """Update an OAuth account on a user.""" | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|  | ||||
| UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UP, ID]] | ||||
|  | ||||
| @ -2,7 +2,7 @@ from typing import Generic, Sequence, Type | ||||
|  | ||||
| from fastapi import APIRouter | ||||
|  | ||||
| from fastapi_users import models | ||||
| from fastapi_users import models, schemas | ||||
| from fastapi_users.authentication import AuthenticationBackend, Authenticator | ||||
| from fastapi_users.jwt import SecretType | ||||
| from fastapi_users.manager import UserManagerDependency | ||||
| @ -22,58 +22,49 @@ except ModuleNotFoundError:  # pragma: no cover | ||||
|     BaseOAuth2 = Type  # type: ignore | ||||
|  | ||||
|  | ||||
| class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]): | ||||
| class FastAPIUsers(Generic[models.UP, models.ID]): | ||||
|     """ | ||||
|     Main object that ties together the component for users authentication. | ||||
|  | ||||
|     :param get_user_manager: Dependency callable getter to inject the | ||||
|     user manager class instance. | ||||
|     :param auth_backends: List of authentication backends. | ||||
|     :param user_model: Pydantic model of a user. | ||||
|     :param user_create_model: Pydantic model for creating a user. | ||||
|     :param user_update_model: Pydantic model for updating a user. | ||||
|     :param user_db_model: Pydantic model of a DB representation of a user. | ||||
|  | ||||
|     :attribute current_user: Dependency callable getter to inject authenticated user | ||||
|     with a specific set of parameters. | ||||
|     """ | ||||
|  | ||||
|     authenticator: Authenticator | ||||
|     _user_model: Type[models.U] | ||||
|     _user_create_model: Type[models.UC] | ||||
|     _user_update_model: Type[models.UU] | ||||
|     _user_db_model: Type[models.UD] | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         get_user_manager: UserManagerDependency[models.UC, models.UD], | ||||
|         get_user_manager: UserManagerDependency[models.UP, models.ID], | ||||
|         auth_backends: Sequence[AuthenticationBackend], | ||||
|         user_model: Type[models.U], | ||||
|         user_create_model: Type[models.UC], | ||||
|         user_update_model: Type[models.UU], | ||||
|         user_db_model: Type[models.UD], | ||||
|     ): | ||||
|         self.authenticator = Authenticator(auth_backends, get_user_manager) | ||||
|  | ||||
|         self._user_model = user_model | ||||
|         self._user_db_model = user_db_model | ||||
|         self._user_create_model = user_create_model | ||||
|         self._user_update_model = user_update_model | ||||
|  | ||||
|         self.get_user_manager = get_user_manager | ||||
|         self.current_user = self.authenticator.current_user | ||||
|  | ||||
|     def get_register_router(self) -> APIRouter: | ||||
|         """Return a router with a register route.""" | ||||
|     def get_register_router( | ||||
|         self, user_schema: Type[schemas.U], user_create_schema: Type[schemas.UC] | ||||
|     ) -> APIRouter: | ||||
|         """ | ||||
|         Return a router with a register route. | ||||
|  | ||||
|         :param user_schema: Pydantic schema of a public user. | ||||
|         :param user_create_schema: Pydantic schema for creating a user. | ||||
|         """ | ||||
|         return get_register_router( | ||||
|             self.get_user_manager, | ||||
|             self._user_model, | ||||
|             self._user_create_model, | ||||
|             self.get_user_manager, user_schema, user_create_schema | ||||
|         ) | ||||
|  | ||||
|     def get_verify_router(self) -> APIRouter: | ||||
|         """Return a router with e-mail verification routes.""" | ||||
|         return get_verify_router(self.get_user_manager, self._user_model) | ||||
|     def get_verify_router(self, user_schema: Type[schemas.U]) -> APIRouter: | ||||
|         """ | ||||
|         Return a router with e-mail verification routes. | ||||
|  | ||||
|         :param user_schema: Pydantic schema of a public user. | ||||
|         """ | ||||
|         return get_verify_router(self.get_user_manager, user_schema) | ||||
|  | ||||
|     def get_reset_password_router(self) -> APIRouter: | ||||
|         """Return a reset password process router.""" | ||||
| @ -122,19 +113,22 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]): | ||||
|  | ||||
|     def get_users_router( | ||||
|         self, | ||||
|         user_schema: Type[schemas.U], | ||||
|         user_update_schema: Type[schemas.UU], | ||||
|         requires_verification: bool = False, | ||||
|     ) -> APIRouter: | ||||
|         """ | ||||
|         Return a router with routes to manage users. | ||||
|  | ||||
|         :param user_schema: Pydantic schema of a public user. | ||||
|         :param user_update_schema: Pydantic schema for updating a user. | ||||
|         :param requires_verification: Whether the endpoints | ||||
|         require the users to be verified or not. | ||||
|         """ | ||||
|         return get_users_router( | ||||
|             self.get_user_manager, | ||||
|             self._user_model, | ||||
|             self._user_update_model, | ||||
|             self._user_db_model, | ||||
|             user_schema, | ||||
|             user_update_schema, | ||||
|             self.authenticator, | ||||
|             requires_verification, | ||||
|         ) | ||||
|  | ||||
| @ -1,11 +1,11 @@ | ||||
| from typing import Any, Dict, Generic, Optional, Type, Union | ||||
| import uuid | ||||
| from typing import Any, Dict, Generic, Optional, Union | ||||
|  | ||||
| import jwt | ||||
| from fastapi import Request | ||||
| from fastapi.security import OAuth2PasswordRequestForm | ||||
| from pydantic import UUID4 | ||||
|  | ||||
| from fastapi_users import models | ||||
| from fastapi_users import models, schemas | ||||
| from fastapi_users.db import BaseUserDatabase | ||||
| from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt | ||||
| from fastapi_users.password import PasswordHelper, PasswordHelperProtocol | ||||
| @ -19,6 +19,10 @@ class FastAPIUsersException(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class InvalidID(FastAPIUsersException): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class UserAlreadyExists(FastAPIUsersException): | ||||
|     pass | ||||
|  | ||||
| @ -48,11 +52,10 @@ class InvalidPasswordException(FastAPIUsersException): | ||||
|         self.reason = reason | ||||
|  | ||||
|  | ||||
| class BaseUserManager(Generic[models.UC, models.UD]): | ||||
| class BaseUserManager(Generic[models.UP, models.ID]): | ||||
|     """ | ||||
|     User management logic. | ||||
|  | ||||
|     :attribute user_db_model: Pydantic model of a DB representation of a user. | ||||
|     :attribute reset_password_token_secret: Secret to encode reset password token. | ||||
|     :attribute reset_password_token_lifetime_seconds: Lifetime of reset password token. | ||||
|     :attribute reset_password_token_audience: JWT audience of reset password token. | ||||
| @ -63,7 +66,6 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|     :param user_db: Database adapter instance. | ||||
|     """ | ||||
|  | ||||
|     user_db_model: Type[models.UD] | ||||
|     reset_password_token_secret: SecretType | ||||
|     reset_password_token_lifetime_seconds: int = 3600 | ||||
|     reset_password_token_audience: str = RESET_PASSWORD_TOKEN_AUDIENCE | ||||
| @ -72,12 +74,12 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|     verification_token_lifetime_seconds: int = 3600 | ||||
|     verification_token_audience: str = VERIFY_USER_TOKEN_AUDIENCE | ||||
|  | ||||
|     user_db: BaseUserDatabase[models.UD] | ||||
|     user_db: BaseUserDatabase[models.UP, models.ID] | ||||
|     password_helper: PasswordHelperProtocol | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         user_db: BaseUserDatabase[models.UD], | ||||
|         user_db: BaseUserDatabase[models.UP, models.ID], | ||||
|         password_helper: Optional[PasswordHelperProtocol] = None, | ||||
|     ): | ||||
|         self.user_db = user_db | ||||
| @ -86,7 +88,17 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         else: | ||||
|             self.password_helper = password_helper  # pragma: no cover | ||||
|  | ||||
|     async def get(self, id: UUID4) -> models.UD: | ||||
|     def parse_id(self, value: Any) -> models.ID: | ||||
|         """ | ||||
|         Parse a value into a correct models.ID instance. | ||||
|  | ||||
|         :param value: The value to parse. | ||||
|         :raises InvalidID: The models.ID value is invalid. | ||||
|         :return: An models.ID object. | ||||
|         """ | ||||
|         raise NotImplementedError()  # pragma: no cover | ||||
|  | ||||
|     async def get(self, id: models.ID) -> models.UP: | ||||
|         """ | ||||
|         Get a user by id. | ||||
|  | ||||
| @ -101,7 +113,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|  | ||||
|         return user | ||||
|  | ||||
|     async def get_by_email(self, user_email: str) -> models.UD: | ||||
|     async def get_by_email(self, user_email: str) -> models.UP: | ||||
|         """ | ||||
|         Get a user by e-mail. | ||||
|  | ||||
| @ -116,7 +128,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|  | ||||
|         return user | ||||
|  | ||||
|     async def get_by_oauth_account(self, oauth: str, account_id: str) -> models.UD: | ||||
|     async def get_by_oauth_account(self, oauth: str, account_id: str) -> models.UP: | ||||
|         """ | ||||
|         Get a user by OAuth account. | ||||
|  | ||||
| @ -133,14 +145,17 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         return user | ||||
|  | ||||
|     async def create( | ||||
|         self, user: models.UC, safe: bool = False, request: Optional[Request] = None | ||||
|     ) -> models.UD: | ||||
|         self, | ||||
|         user_create: schemas.UC, | ||||
|         safe: bool = False, | ||||
|         request: Optional[Request] = None, | ||||
|     ) -> models.UP: | ||||
|         """ | ||||
|         Create a user in database. | ||||
|  | ||||
|         Triggers the on_after_register handler on success. | ||||
|  | ||||
|         :param user: The UserCreate model to create. | ||||
|         :param user_create: The UserCreate model to create. | ||||
|         :param safe: If True, sensitive values like is_superuser or is_verified | ||||
|         will be ignored during the creation, defaults to False. | ||||
|         :param request: Optional FastAPI request that | ||||
| @ -148,27 +163,36 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         :raises UserAlreadyExists: A user already exists with the same e-mail. | ||||
|         :return: A new user. | ||||
|         """ | ||||
|         await self.validate_password(user.password, user) | ||||
|         await self.validate_password(user_create.password, user_create) | ||||
|  | ||||
|         existing_user = await self.user_db.get_by_email(user.email) | ||||
|         existing_user = await self.user_db.get_by_email(user_create.email) | ||||
|         if existing_user is not None: | ||||
|             raise UserAlreadyExists() | ||||
|  | ||||
|         hashed_password = self.password_helper.hash(user.password) | ||||
|         user_dict = ( | ||||
|             user.create_update_dict() if safe else user.create_update_dict_superuser() | ||||
|             user_create.create_update_dict() | ||||
|             if safe | ||||
|             else user_create.create_update_dict_superuser() | ||||
|         ) | ||||
|         db_user = self.user_db_model(**user_dict, hashed_password=hashed_password) | ||||
|         password = user_dict.pop("password") | ||||
|         user_dict["hashed_password"] = self.password_helper.hash(password) | ||||
|  | ||||
|         created_user = await self.user_db.create(db_user) | ||||
|         created_user = await self.user_db.create(user_dict) | ||||
|  | ||||
|         await self.on_after_register(created_user, request) | ||||
|  | ||||
|         return created_user | ||||
|  | ||||
|     async def oauth_callback( | ||||
|         self, oauth_account: models.BaseOAuthAccount, request: Optional[Request] = None | ||||
|     ) -> models.UD: | ||||
|         self: "BaseUserManager[models.UOAP, models.ID]", | ||||
|         oauth_name: str, | ||||
|         access_token: str, | ||||
|         account_id: str, | ||||
|         account_email: str, | ||||
|         expires_at: Optional[int] = None, | ||||
|         refresh_token: Optional[str] = None, | ||||
|         request: Optional[Request] = None, | ||||
|     ) -> models.UOAP: | ||||
|         """ | ||||
|         Handle the callback after a successful OAuth authentication. | ||||
|  | ||||
| @ -180,50 +204,58 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         If the user does not exist, it is created and the on_after_register handler | ||||
|         is triggered. | ||||
|  | ||||
|         :param oauth_account: The new OAuth account to create. | ||||
|         :param oauth_name: Name of the OAuth client. | ||||
|         :param access_token: Valid access token for the service provider. | ||||
|         :param account_id: models.ID of the user on the service provider. | ||||
|         :param account_email: E-mail of the user on the service provider. | ||||
|         :param expires_at: Optional timestamp at which the access token expires. | ||||
|         :param refresh_token: Optional refresh token to get a | ||||
|         fresh access token from the service provider. | ||||
|         :param request: Optional FastAPI request that | ||||
|         triggered the operation, defaults to None | ||||
|         :return: A user. | ||||
|         """ | ||||
|         oauth_account_dict = { | ||||
|             "oauth_name": oauth_name, | ||||
|             "access_token": access_token, | ||||
|             "account_id": account_id, | ||||
|             "account_email": account_email, | ||||
|             "expires_at": expires_at, | ||||
|             "refresh_token": refresh_token, | ||||
|         } | ||||
|  | ||||
|         try: | ||||
|             user = await self.get_by_oauth_account( | ||||
|                 oauth_account.oauth_name, oauth_account.account_id | ||||
|             ) | ||||
|             user = await self.get_by_oauth_account(oauth_name, account_id) | ||||
|         except UserNotExists: | ||||
|             try: | ||||
|                 # Link account | ||||
|                 user = await self.get_by_email(oauth_account.account_email) | ||||
|                 user.oauth_accounts.append(oauth_account)  # type: ignore | ||||
|                 await self.user_db.update(user) | ||||
|                 user = await self.get_by_email(account_email) | ||||
|                 user = await self.user_db.add_oauth_account(user, oauth_account_dict) | ||||
|             except UserNotExists: | ||||
|                 # Create account | ||||
|                 password = self.password_helper.generate() | ||||
|                 user = self.user_db_model( | ||||
|                     email=oauth_account.account_email, | ||||
|                     hashed_password=self.password_helper.hash(password), | ||||
|                     oauth_accounts=[oauth_account], | ||||
|                 ) | ||||
|                 await self.user_db.create(user) | ||||
|                 user_dict = { | ||||
|                     "email": account_email, | ||||
|                     "hashed_password": self.password_helper.hash(password), | ||||
|                 } | ||||
|                 user = await self.user_db.create(user_dict) | ||||
|                 user = await self.user_db.add_oauth_account(user, oauth_account_dict) | ||||
|                 await self.on_after_register(user, request) | ||||
|         else: | ||||
|             # Update oauth | ||||
|             updated_oauth_accounts = [] | ||||
|             for existing_oauth_account in user.oauth_accounts:  # type: ignore | ||||
|             for existing_oauth_account in user.oauth_accounts: | ||||
|                 if ( | ||||
|                     existing_oauth_account.account_id == oauth_account.account_id | ||||
|                     and existing_oauth_account.oauth_name == oauth_account.oauth_name | ||||
|                     existing_oauth_account.account_id == account_id | ||||
|                     and existing_oauth_account.oauth_name == oauth_name | ||||
|                 ): | ||||
|                     oauth_account.id = existing_oauth_account.id | ||||
|                     updated_oauth_accounts.append(oauth_account) | ||||
|                 else: | ||||
|                     updated_oauth_accounts.append(existing_oauth_account) | ||||
|             user.oauth_accounts = updated_oauth_accounts  # type: ignore | ||||
|             await self.user_db.update(user) | ||||
|                     user = await self.user_db.update_oauth_account( | ||||
|                         user, existing_oauth_account, oauth_account_dict | ||||
|                     ) | ||||
|  | ||||
|         return user | ||||
|  | ||||
|     async def request_verify( | ||||
|         self, user: models.UD, request: Optional[Request] = None | ||||
|         self, user: models.UP, request: Optional[Request] = None | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Start a verification request. | ||||
| @ -253,7 +285,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         ) | ||||
|         await self.on_after_request_verify(user, token, request) | ||||
|  | ||||
|     async def verify(self, token: str, request: Optional[Request] = None) -> models.UD: | ||||
|     async def verify(self, token: str, request: Optional[Request] = None) -> models.UP: | ||||
|         """ | ||||
|         Validate a verification request. | ||||
|  | ||||
| @ -289,11 +321,11 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|             raise InvalidVerifyToken() | ||||
|  | ||||
|         try: | ||||
|             user_uuid = UUID4(user_id) | ||||
|         except ValueError: | ||||
|             parsed_id = self.parse_id(user_id) | ||||
|         except InvalidID: | ||||
|             raise InvalidVerifyToken() | ||||
|  | ||||
|         if user_uuid != user.id: | ||||
|         if parsed_id != user.id: | ||||
|             raise InvalidVerifyToken() | ||||
|  | ||||
|         if user.is_verified: | ||||
| @ -306,7 +338,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         return verified_user | ||||
|  | ||||
|     async def forgot_password( | ||||
|         self, user: models.UD, request: Optional[Request] = None | ||||
|         self, user: models.UP, request: Optional[Request] = None | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Start a forgot password request. | ||||
| @ -334,7 +366,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|  | ||||
|     async def reset_password( | ||||
|         self, token: str, password: str, request: Optional[Request] = None | ||||
|     ) -> models.UD: | ||||
|     ) -> models.UP: | ||||
|         """ | ||||
|         Reset the password of a user. | ||||
|  | ||||
| @ -364,11 +396,11 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|             raise InvalidResetPasswordToken() | ||||
|  | ||||
|         try: | ||||
|             user_uuid = UUID4(user_id) | ||||
|         except ValueError: | ||||
|             parsed_id = self.parse_id(user_id) | ||||
|         except InvalidID: | ||||
|             raise InvalidResetPasswordToken() | ||||
|  | ||||
|         user = await self.get(user_uuid) | ||||
|         user = await self.get(parsed_id) | ||||
|  | ||||
|         if not user.is_active: | ||||
|             raise UserInactive() | ||||
| @ -381,11 +413,11 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|  | ||||
|     async def update( | ||||
|         self, | ||||
|         user_update: models.UU, | ||||
|         user: models.UD, | ||||
|         user_update: schemas.UU, | ||||
|         user: models.UP, | ||||
|         safe: bool = False, | ||||
|         request: Optional[Request] = None, | ||||
|     ) -> models.UD: | ||||
|     ) -> models.UP: | ||||
|         """ | ||||
|         Update a user. | ||||
|  | ||||
| @ -408,7 +440,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         await self.on_after_update(updated_user, updated_user_data, request) | ||||
|         return updated_user | ||||
|  | ||||
|     async def delete(self, user: models.UD) -> None: | ||||
|     async def delete(self, user: models.UP) -> None: | ||||
|         """ | ||||
|         Delete a user. | ||||
|  | ||||
| @ -417,7 +449,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         await self.user_db.delete(user) | ||||
|  | ||||
|     async def validate_password( | ||||
|         self, password: str, user: Union[models.UC, models.UD] | ||||
|         self, password: str, user: Union[schemas.UC, models.UP] | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Validate a password. | ||||
| @ -432,7 +464,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         return  # pragma: no cover | ||||
|  | ||||
|     async def on_after_register( | ||||
|         self, user: models.UD, request: Optional[Request] = None | ||||
|         self, user: models.UP, request: Optional[Request] = None | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Perform logic after successful user registration. | ||||
| @ -447,7 +479,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|  | ||||
|     async def on_after_update( | ||||
|         self, | ||||
|         user: models.UD, | ||||
|         user: models.UP, | ||||
|         update_dict: Dict[str, Any], | ||||
|         request: Optional[Request] = None, | ||||
|     ) -> None: | ||||
| @ -464,7 +496,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         return  # pragma: no cover | ||||
|  | ||||
|     async def on_after_request_verify( | ||||
|         self, user: models.UD, token: str, request: Optional[Request] = None | ||||
|         self, user: models.UP, token: str, request: Optional[Request] = None | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Perform logic after successful verification request. | ||||
| @ -479,7 +511,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         return  # pragma: no cover | ||||
|  | ||||
|     async def on_after_verify( | ||||
|         self, user: models.UD, request: Optional[Request] = None | ||||
|         self, user: models.UP, request: Optional[Request] = None | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Perform logic after successful user verification. | ||||
| @ -493,7 +525,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         return  # pragma: no cover | ||||
|  | ||||
|     async def on_after_forgot_password( | ||||
|         self, user: models.UD, token: str, request: Optional[Request] = None | ||||
|         self, user: models.UP, token: str, request: Optional[Request] = None | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Perform logic after successful forgot password request. | ||||
| @ -508,7 +540,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|         return  # pragma: no cover | ||||
|  | ||||
|     async def on_after_reset_password( | ||||
|         self, user: models.UD, request: Optional[Request] = None | ||||
|         self, user: models.UP, request: Optional[Request] = None | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Perform logic after successful password reset. | ||||
| @ -523,7 +555,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|  | ||||
|     async def authenticate( | ||||
|         self, credentials: OAuth2PasswordRequestForm | ||||
|     ) -> Optional[models.UD]: | ||||
|     ) -> Optional[models.UP]: | ||||
|         """ | ||||
|         Authenticate and return a user following an email and a password. | ||||
|  | ||||
| @ -546,27 +578,48 @@ class BaseUserManager(Generic[models.UC, models.UD]): | ||||
|             return None | ||||
|         # Update password hash to a more robust one if needed | ||||
|         if updated_password_hash is not None: | ||||
|             user.hashed_password = updated_password_hash | ||||
|             await self.user_db.update(user) | ||||
|             await self.user_db.update(user, {"hashed_password": updated_password_hash}) | ||||
|  | ||||
|         return user | ||||
|  | ||||
|     async def _update(self, user: models.UD, update_dict: Dict[str, Any]) -> models.UD: | ||||
|     async def _update(self, user: models.UP, update_dict: Dict[str, Any]) -> models.UP: | ||||
|         validated_update_dict = {} | ||||
|         for field, value in update_dict.items(): | ||||
|             if field == "email" and value != user.email: | ||||
|                 try: | ||||
|                     await self.get_by_email(value) | ||||
|                     raise UserAlreadyExists() | ||||
|                 except UserNotExists: | ||||
|                     user.email = value | ||||
|                     user.is_verified = False | ||||
|                     validated_update_dict["email"] = value | ||||
|                     validated_update_dict["is_verified"] = False | ||||
|             elif field == "password": | ||||
|                 await self.validate_password(value, user) | ||||
|                 hashed_password = self.password_helper.hash(value) | ||||
|                 user.hashed_password = hashed_password | ||||
|                 validated_update_dict["hashed_password"] = self.password_helper.hash( | ||||
|                     value | ||||
|                 ) | ||||
|             else: | ||||
|                 setattr(user, field, value) | ||||
|         return await self.user_db.update(user) | ||||
|                 validated_update_dict[field] = value | ||||
|         return await self.user_db.update(user, validated_update_dict) | ||||
|  | ||||
|  | ||||
| UserManagerDependency = DependencyCallable[BaseUserManager[models.UC, models.UD]] | ||||
| class UUIDIDMixin: | ||||
|     def parse_id(self, value: Any) -> uuid.UUID: | ||||
|         if isinstance(value, uuid.UUID): | ||||
|             return value | ||||
|         try: | ||||
|             return uuid.UUID(value) | ||||
|         except ValueError as e: | ||||
|             raise InvalidID() from e | ||||
|  | ||||
|  | ||||
| class IntegerIDMixin: | ||||
|     def parse_id(self, value: Any) -> int: | ||||
|         if isinstance(value, float): | ||||
|             raise InvalidID() | ||||
|         try: | ||||
|             return int(value) | ||||
|         except ValueError as e: | ||||
|             raise InvalidID() from e | ||||
|  | ||||
|  | ||||
| UserManagerDependency = DependencyCallable[BaseUserManager[models.UP, models.ID]] | ||||
|  | ||||
| @ -1,81 +1,51 @@ | ||||
| import uuid | ||||
| from typing import List, Optional, TypeVar | ||||
| import sys | ||||
| from typing import Generic, List, Optional, TypeVar | ||||
|  | ||||
| from pydantic import UUID4, BaseModel, EmailStr, Field | ||||
| if sys.version_info < (3, 8): | ||||
|     from typing_extensions import Protocol  # pragma: no cover | ||||
| else: | ||||
|     from typing import Protocol  # pragma: no cover | ||||
|  | ||||
| ID = TypeVar("ID") | ||||
|  | ||||
|  | ||||
| class CreateUpdateDictModel(BaseModel): | ||||
|     def create_update_dict(self): | ||||
|         return self.dict( | ||||
|             exclude_unset=True, | ||||
|             exclude={ | ||||
|                 "id", | ||||
|                 "is_superuser", | ||||
|                 "is_active", | ||||
|                 "is_verified", | ||||
|                 "oauth_accounts", | ||||
|             }, | ||||
|         ) | ||||
| class UserProtocol(Protocol[ID]): | ||||
|     """User protocol that ORM model should follow.""" | ||||
|  | ||||
|     def create_update_dict_superuser(self): | ||||
|         return self.dict(exclude_unset=True, exclude={"id"}) | ||||
|  | ||||
|  | ||||
| class BaseUser(CreateUpdateDictModel): | ||||
|     """Base User model.""" | ||||
|  | ||||
|     id: UUID4 = Field(default_factory=uuid.uuid4) | ||||
|     email: EmailStr | ||||
|     is_active: bool = True | ||||
|     is_superuser: bool = False | ||||
|     is_verified: bool = False | ||||
|  | ||||
|  | ||||
| class BaseUserCreate(CreateUpdateDictModel): | ||||
|     email: EmailStr | ||||
|     password: str | ||||
|     is_active: Optional[bool] = True | ||||
|     is_superuser: Optional[bool] = False | ||||
|     is_verified: Optional[bool] = False | ||||
|  | ||||
|  | ||||
| class BaseUserUpdate(CreateUpdateDictModel): | ||||
|     password: Optional[str] | ||||
|     email: Optional[EmailStr] | ||||
|     is_active: Optional[bool] | ||||
|     is_superuser: Optional[bool] | ||||
|     is_verified: Optional[bool] | ||||
|  | ||||
|  | ||||
| class BaseUserDB(BaseUser): | ||||
|     id: ID | ||||
|     email: str | ||||
|     hashed_password: str | ||||
|     is_active: bool | ||||
|     is_superuser: bool | ||||
|     is_verified: bool | ||||
|  | ||||
|     class Config: | ||||
|         orm_mode = True | ||||
|     def __init__(self, *args, **kwargs) -> None: | ||||
|         ...  # pragma: no cover | ||||
|  | ||||
|  | ||||
| U = TypeVar("U", bound=BaseUser) | ||||
| UC = TypeVar("UC", bound=BaseUserCreate) | ||||
| UU = TypeVar("UU", bound=BaseUserUpdate) | ||||
| UD = TypeVar("UD", bound=BaseUserDB) | ||||
| class OAuthAccountProtocol(Protocol[ID]): | ||||
|     """OAuth account protocol that ORM model should follow.""" | ||||
|  | ||||
|  | ||||
| class BaseOAuthAccount(BaseModel): | ||||
|     """Base OAuth account model.""" | ||||
|  | ||||
|     id: UUID4 = Field(default_factory=uuid.uuid4) | ||||
|     id: ID | ||||
|     oauth_name: str | ||||
|     access_token: str | ||||
|     expires_at: Optional[int] = None | ||||
|     refresh_token: Optional[str] = None | ||||
|     expires_at: Optional[int] | ||||
|     refresh_token: Optional[str] | ||||
|     account_id: str | ||||
|     account_email: str | ||||
|  | ||||
|     class Config: | ||||
|         orm_mode = True | ||||
|     def __init__(self, *args, **kwargs) -> None: | ||||
|         ...  # pragma: no cover | ||||
|  | ||||
|  | ||||
| class BaseOAuthAccountMixin(BaseModel): | ||||
|     """Adds OAuth accounts list to a User model.""" | ||||
| UP = TypeVar("UP", bound=UserProtocol) | ||||
| OAP = TypeVar("OAP", bound=OAuthAccountProtocol) | ||||
|  | ||||
|     oauth_accounts: List[BaseOAuthAccount] = [] | ||||
|  | ||||
| class UserOAuthProtocol(UserProtocol[ID], Generic[ID, OAP]): | ||||
|     """User protocol including a list of OAuth accounts.""" | ||||
|  | ||||
|     oauth_accounts: List[OAP] | ||||
|  | ||||
|  | ||||
| UOAP = TypeVar("UOAP", bound=UserOAuthProtocol) | ||||
|  | ||||
| @ -12,7 +12,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel | ||||
|  | ||||
| def get_auth_router( | ||||
|     backend: AuthenticationBackend, | ||||
|     get_user_manager: UserManagerDependency[models.UC, models.UD], | ||||
|     get_user_manager: UserManagerDependency[models.UP, models.ID], | ||||
|     authenticator: Authenticator, | ||||
|     requires_verification: bool = False, | ||||
| ) -> APIRouter: | ||||
| @ -51,8 +51,8 @@ def get_auth_router( | ||||
|     async def login( | ||||
|         response: Response, | ||||
|         credentials: OAuth2PasswordRequestForm = Depends(), | ||||
|         user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), | ||||
|         strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy), | ||||
|         user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), | ||||
|         strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), | ||||
|     ): | ||||
|         user = await user_manager.authenticate(credentials) | ||||
|  | ||||
| @ -82,8 +82,8 @@ def get_auth_router( | ||||
|     ) | ||||
|     async def logout( | ||||
|         response: Response, | ||||
|         user_token: Tuple[models.UD, str] = Depends(get_current_user_token), | ||||
|         strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy), | ||||
|         user_token: Tuple[models.UP, str] = Depends(get_current_user_token), | ||||
|         strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), | ||||
|     ): | ||||
|         user, token = user_token | ||||
|         return await backend.logout(strategy, user, token, response) | ||||
|  | ||||
| @ -29,7 +29,7 @@ def generate_state_token( | ||||
| def get_oauth_router( | ||||
|     oauth_client: BaseOAuth2, | ||||
|     backend: AuthenticationBackend, | ||||
|     get_user_manager: UserManagerDependency[models.UC, models.UD], | ||||
|     get_user_manager: UserManagerDependency[models.UP, models.ID], | ||||
|     state_secret: SecretType, | ||||
|     redirect_url: str = None, | ||||
| ) -> APIRouter: | ||||
| @ -101,8 +101,8 @@ def get_oauth_router( | ||||
|         access_token_state: Tuple[OAuth2Token, str] = Depends( | ||||
|             oauth2_authorize_callback | ||||
|         ), | ||||
|         user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), | ||||
|         strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy), | ||||
|         user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), | ||||
|         strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), | ||||
|     ): | ||||
|         token, state = access_token_state | ||||
|         account_id, account_email = await oauth_client.get_id_email( | ||||
| @ -114,17 +114,16 @@ def get_oauth_router( | ||||
|         except jwt.DecodeError: | ||||
|             raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) | ||||
|  | ||||
|         new_oauth_account = models.BaseOAuthAccount( | ||||
|             oauth_name=oauth_client.name, | ||||
|             access_token=token["access_token"], | ||||
|             expires_at=token.get("expires_at"), | ||||
|             refresh_token=token.get("refresh_token"), | ||||
|             account_id=account_id, | ||||
|             account_email=account_email, | ||||
|         user = await user_manager.oauth_callback( | ||||
|             oauth_client.name, | ||||
|             token["access_token"], | ||||
|             account_id, | ||||
|             account_email, | ||||
|             token.get("expires_at"), | ||||
|             token.get("refresh_token"), | ||||
|             request, | ||||
|         ) | ||||
|  | ||||
|         user = await user_manager.oauth_callback(new_oauth_account, request) | ||||
|  | ||||
|         if not user.is_active: | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_400_BAD_REQUEST, | ||||
|  | ||||
| @ -2,7 +2,7 @@ from typing import Type | ||||
|  | ||||
| from fastapi import APIRouter, Depends, HTTPException, Request, status | ||||
|  | ||||
| from fastapi_users import models | ||||
| from fastapi_users import models, schemas | ||||
| from fastapi_users.manager import ( | ||||
|     BaseUserManager, | ||||
|     InvalidPasswordException, | ||||
| @ -13,16 +13,16 @@ from fastapi_users.router.common import ErrorCode, ErrorModel | ||||
|  | ||||
|  | ||||
| def get_register_router( | ||||
|     get_user_manager: UserManagerDependency[models.UC, models.UD], | ||||
|     user_model: Type[models.U], | ||||
|     user_create_model: Type[models.UC], | ||||
|     get_user_manager: UserManagerDependency[models.UP, models.ID], | ||||
|     user_schema: Type[schemas.U], | ||||
|     user_create_schema: Type[schemas.UC], | ||||
| ) -> APIRouter: | ||||
|     """Generate a router with the register route.""" | ||||
|     router = APIRouter() | ||||
|  | ||||
|     @router.post( | ||||
|         "/register", | ||||
|         response_model=user_model, | ||||
|         response_model=user_schema, | ||||
|         status_code=status.HTTP_201_CREATED, | ||||
|         name="register:register", | ||||
|         responses={ | ||||
| @ -55,11 +55,13 @@ def get_register_router( | ||||
|     ) | ||||
|     async def register( | ||||
|         request: Request, | ||||
|         user: user_create_model,  # type: ignore | ||||
|         user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), | ||||
|         user_create: user_create_schema,  # type: ignore | ||||
|         user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), | ||||
|     ): | ||||
|         try: | ||||
|             created_user = await user_manager.create(user, safe=True, request=request) | ||||
|             created_user = await user_manager.create( | ||||
|                 user_create, safe=True, request=request | ||||
|             ) | ||||
|         except UserAlreadyExists: | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_400_BAD_REQUEST, | ||||
|  | ||||
| @ -40,7 +40,7 @@ RESET_PASSWORD_RESPONSES: OpenAPIResponseType = { | ||||
|  | ||||
|  | ||||
| def get_reset_password_router( | ||||
|     get_user_manager: UserManagerDependency[models.UC, models.UD] | ||||
|     get_user_manager: UserManagerDependency[models.UP, models.ID], | ||||
| ) -> APIRouter: | ||||
|     """Generate a router with the reset password routes.""" | ||||
|     router = APIRouter() | ||||
| @ -53,7 +53,7 @@ def get_reset_password_router( | ||||
|     async def forgot_password( | ||||
|         request: Request, | ||||
|         email: EmailStr = Body(..., embed=True), | ||||
|         user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), | ||||
|         user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), | ||||
|     ): | ||||
|         try: | ||||
|             user = await user_manager.get_by_email(email) | ||||
| @ -76,7 +76,7 @@ def get_reset_password_router( | ||||
|         request: Request, | ||||
|         token: str = Body(...), | ||||
|         password: str = Body(...), | ||||
|         user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), | ||||
|         user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), | ||||
|     ): | ||||
|         try: | ||||
|             await user_manager.reset_password(token, password, request) | ||||
|  | ||||
| @ -1,12 +1,12 @@ | ||||
| from typing import Type | ||||
| from typing import Any, Type | ||||
|  | ||||
| from fastapi import APIRouter, Depends, HTTPException, Request, Response, status | ||||
| from pydantic import UUID4 | ||||
|  | ||||
| from fastapi_users import models | ||||
| from fastapi_users import models, schemas | ||||
| from fastapi_users.authentication import Authenticator | ||||
| from fastapi_users.manager import ( | ||||
|     BaseUserManager, | ||||
|     InvalidID, | ||||
|     InvalidPasswordException, | ||||
|     UserAlreadyExists, | ||||
|     UserManagerDependency, | ||||
| @ -16,10 +16,9 @@ from fastapi_users.router.common import ErrorCode, ErrorModel | ||||
|  | ||||
|  | ||||
| def get_users_router( | ||||
|     get_user_manager: UserManagerDependency[models.UC, models.UD], | ||||
|     user_model: Type[models.U], | ||||
|     user_update_model: Type[models.UU], | ||||
|     user_db_model: Type[models.UD], | ||||
|     get_user_manager: UserManagerDependency[models.UP, models.ID], | ||||
|     user_schema: Type[schemas.U], | ||||
|     user_update_schema: Type[schemas.UU], | ||||
|     authenticator: Authenticator, | ||||
|     requires_verification: bool = False, | ||||
| ) -> APIRouter: | ||||
| @ -34,17 +33,18 @@ def get_users_router( | ||||
|     ) | ||||
|  | ||||
|     async def get_user_or_404( | ||||
|         id: UUID4, | ||||
|         user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), | ||||
|     ) -> models.UD: | ||||
|         id: Any, | ||||
|         user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), | ||||
|     ) -> models.UP: | ||||
|         try: | ||||
|             return await user_manager.get(id) | ||||
|         except UserNotExists: | ||||
|             raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) | ||||
|             parsed_id = user_manager.parse_id(id) | ||||
|             return await user_manager.get(parsed_id) | ||||
|         except (UserNotExists, InvalidID) as e: | ||||
|             raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) from e | ||||
|  | ||||
|     @router.get( | ||||
|         "/me", | ||||
|         response_model=user_model, | ||||
|         response_model=user_schema, | ||||
|         name="users:current_user", | ||||
|         responses={ | ||||
|             status.HTTP_401_UNAUTHORIZED: { | ||||
| @ -53,13 +53,13 @@ def get_users_router( | ||||
|         }, | ||||
|     ) | ||||
|     async def me( | ||||
|         user: user_db_model = Depends(get_current_active_user),  # type: ignore | ||||
|         user: models.UP = Depends(get_current_active_user), | ||||
|     ): | ||||
|         return user | ||||
|  | ||||
|     @router.patch( | ||||
|         "/me", | ||||
|         response_model=user_model, | ||||
|         response_model=user_schema, | ||||
|         dependencies=[Depends(get_current_active_user)], | ||||
|         name="users:patch_current_user", | ||||
|         responses={ | ||||
| @ -95,9 +95,9 @@ def get_users_router( | ||||
|     ) | ||||
|     async def update_me( | ||||
|         request: Request, | ||||
|         user_update: user_update_model,  # type: ignore | ||||
|         user: user_db_model = Depends(get_current_active_user),  # type: ignore | ||||
|         user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), | ||||
|         user_update: user_update_schema,  # type: ignore | ||||
|         user: models.UP = Depends(get_current_active_user), | ||||
|         user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), | ||||
|     ): | ||||
|         try: | ||||
|             return await user_manager.update( | ||||
| @ -118,8 +118,8 @@ def get_users_router( | ||||
|             ) | ||||
|  | ||||
|     @router.get( | ||||
|         "/{id:uuid}", | ||||
|         response_model=user_model, | ||||
|         "/{id}", | ||||
|         response_model=user_schema, | ||||
|         dependencies=[Depends(get_current_superuser)], | ||||
|         name="users:user", | ||||
|         responses={ | ||||
| @ -138,8 +138,8 @@ def get_users_router( | ||||
|         return user | ||||
|  | ||||
|     @router.patch( | ||||
|         "/{id:uuid}", | ||||
|         response_model=user_model, | ||||
|         "/{id}", | ||||
|         response_model=user_schema, | ||||
|         dependencies=[Depends(get_current_superuser)], | ||||
|         name="users:patch_user", | ||||
|         responses={ | ||||
| @ -180,10 +180,10 @@ def get_users_router( | ||||
|         }, | ||||
|     ) | ||||
|     async def update_user( | ||||
|         user_update: user_update_model,  # type: ignore | ||||
|         user_update: user_update_schema,  # type: ignore | ||||
|         request: Request, | ||||
|         user=Depends(get_user_or_404), | ||||
|         user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), | ||||
|         user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), | ||||
|     ): | ||||
|         try: | ||||
|             return await user_manager.update( | ||||
| @ -204,7 +204,7 @@ def get_users_router( | ||||
|             ) | ||||
|  | ||||
|     @router.delete( | ||||
|         "/{id:uuid}", | ||||
|         "/{id}", | ||||
|         status_code=status.HTTP_204_NO_CONTENT, | ||||
|         response_class=Response, | ||||
|         dependencies=[Depends(get_current_superuser)], | ||||
| @ -223,7 +223,7 @@ def get_users_router( | ||||
|     ) | ||||
|     async def delete_user( | ||||
|         user=Depends(get_user_or_404), | ||||
|         user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), | ||||
|         user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), | ||||
|     ): | ||||
|         await user_manager.delete(user) | ||||
|         return None | ||||
|  | ||||
| @ -3,7 +3,7 @@ from typing import Type | ||||
| from fastapi import APIRouter, Body, Depends, HTTPException, Request, status | ||||
| from pydantic import EmailStr | ||||
|  | ||||
| from fastapi_users import models | ||||
| from fastapi_users import models, schemas | ||||
| from fastapi_users.manager import ( | ||||
|     BaseUserManager, | ||||
|     InvalidVerifyToken, | ||||
| @ -16,8 +16,8 @@ from fastapi_users.router.common import ErrorCode, ErrorModel | ||||
|  | ||||
|  | ||||
| def get_verify_router( | ||||
|     get_user_manager: UserManagerDependency[models.UC, models.UD], | ||||
|     user_model: Type[models.U], | ||||
|     get_user_manager: UserManagerDependency[models.UP, models.ID], | ||||
|     user_schema: Type[schemas.U], | ||||
| ): | ||||
|     router = APIRouter() | ||||
|  | ||||
| @ -29,7 +29,7 @@ def get_verify_router( | ||||
|     async def request_verify_token( | ||||
|         request: Request, | ||||
|         email: EmailStr = Body(..., embed=True), | ||||
|         user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), | ||||
|         user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), | ||||
|     ): | ||||
|         try: | ||||
|             user = await user_manager.get_by_email(email) | ||||
| @ -41,7 +41,7 @@ def get_verify_router( | ||||
|  | ||||
|     @router.post( | ||||
|         "/verify", | ||||
|         response_model=user_model, | ||||
|         response_model=user_schema, | ||||
|         name="verify:verify", | ||||
|         responses={ | ||||
|             status.HTTP_400_BAD_REQUEST: { | ||||
| @ -69,7 +69,7 @@ def get_verify_router( | ||||
|     async def verify( | ||||
|         request: Request, | ||||
|         token: str = Body(..., embed=True), | ||||
|         user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), | ||||
|         user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), | ||||
|     ): | ||||
|         try: | ||||
|             return await user_manager.verify(token, request) | ||||
|  | ||||
							
								
								
									
										74
									
								
								fastapi_users/schemas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								fastapi_users/schemas.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,74 @@ | ||||
| from typing import Generic, List, Optional, TypeVar | ||||
|  | ||||
| from pydantic import BaseModel, EmailStr | ||||
|  | ||||
| from fastapi_users import models | ||||
|  | ||||
|  | ||||
| class CreateUpdateDictModel(BaseModel): | ||||
|     def create_update_dict(self): | ||||
|         return self.dict( | ||||
|             exclude_unset=True, | ||||
|             exclude={ | ||||
|                 "id", | ||||
|                 "is_superuser", | ||||
|                 "is_active", | ||||
|                 "is_verified", | ||||
|                 "oauth_accounts", | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def create_update_dict_superuser(self): | ||||
|         return self.dict(exclude_unset=True, exclude={"id"}) | ||||
|  | ||||
|  | ||||
| class BaseUser(Generic[models.ID], CreateUpdateDictModel): | ||||
|     """Base User model.""" | ||||
|  | ||||
|     id: models.ID | ||||
|     email: EmailStr | ||||
|     is_active: bool = True | ||||
|     is_superuser: bool = False | ||||
|     is_verified: bool = False | ||||
|  | ||||
|  | ||||
| class BaseUserCreate(CreateUpdateDictModel): | ||||
|     email: EmailStr | ||||
|     password: str | ||||
|     is_active: Optional[bool] = True | ||||
|     is_superuser: Optional[bool] = False | ||||
|     is_verified: Optional[bool] = False | ||||
|  | ||||
|  | ||||
| class BaseUserUpdate(CreateUpdateDictModel): | ||||
|     password: Optional[str] | ||||
|     email: Optional[EmailStr] | ||||
|     is_active: Optional[bool] | ||||
|     is_superuser: Optional[bool] | ||||
|     is_verified: Optional[bool] | ||||
|  | ||||
|  | ||||
| U = TypeVar("U", bound=BaseUser) | ||||
| UC = TypeVar("UC", bound=BaseUserCreate) | ||||
| UU = TypeVar("UU", bound=BaseUserUpdate) | ||||
|  | ||||
|  | ||||
| class BaseOAuthAccount(Generic[models.ID], BaseModel): | ||||
|     """Base OAuth account model.""" | ||||
|  | ||||
|     id: models.ID | ||||
|     oauth_name: str | ||||
|     access_token: str | ||||
|     expires_at: Optional[int] = None | ||||
|     refresh_token: Optional[str] = None | ||||
|     account_id: str | ||||
|     account_email: str | ||||
|  | ||||
|     class Config: | ||||
|         orm_mode = True | ||||
|  | ||||
|  | ||||
| class BaseOAuthAccountMixin(BaseModel): | ||||
|     """Adds OAuth accounts list to a User model.""" | ||||
|  | ||||
|     oauth_accounts: List[BaseOAuthAccount] = [] | ||||
		Reference in New Issue
	
	Block a user
	 François Voron
					François Voron