mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-16 03:40:23 +08:00
Make ID a generic instead of forcing UUIDs
This commit is contained in:
@ -51,7 +51,7 @@ class Authenticator:
|
||||
def __init__(
|
||||
self,
|
||||
backends: Sequence[AuthenticationBackend],
|
||||
get_user_manager: UserManagerDependency[models.UP],
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
):
|
||||
self.backends = backends
|
||||
self.get_user_manager = get_user_manager
|
||||
@ -148,7 +148,7 @@ class Authenticator:
|
||||
async def _authenticate(
|
||||
self,
|
||||
*args,
|
||||
user_manager: BaseUserManager[models.UP],
|
||||
user_manager: BaseUserManager[models.UP, models.ID],
|
||||
optional: bool = False,
|
||||
active: bool = False,
|
||||
verified: bool = False,
|
||||
@ -163,7 +163,7 @@ class Authenticator:
|
||||
for backend in self.backends:
|
||||
if backend in enabled_backends:
|
||||
token = kwargs[name_to_variable_name(backend.name)]
|
||||
strategy: Strategy[models.UP] = kwargs[
|
||||
strategy: Strategy[models.UP, models.ID] = kwargs[
|
||||
name_to_strategy_variable_name(backend.name)
|
||||
]
|
||||
if token is not None:
|
||||
|
@ -33,7 +33,7 @@ class AuthenticationBackend(Generic[models.UP]):
|
||||
self,
|
||||
name: str,
|
||||
transport: Transport,
|
||||
get_strategy: DependencyCallable[Strategy[models.UP]],
|
||||
get_strategy: DependencyCallable[Strategy[models.UP, models.ID]],
|
||||
):
|
||||
self.name = name
|
||||
self.transport = transport
|
||||
@ -41,7 +41,7 @@ class AuthenticationBackend(Generic[models.UP]):
|
||||
|
||||
async def login(
|
||||
self,
|
||||
strategy: Strategy[models.UP],
|
||||
strategy: Strategy[models.UP, models.ID],
|
||||
user: models.UP,
|
||||
response: Response,
|
||||
) -> Any:
|
||||
@ -50,7 +50,7 @@ class AuthenticationBackend(Generic[models.UP]):
|
||||
|
||||
async def logout(
|
||||
self,
|
||||
strategy: Strategy[models.UP],
|
||||
strategy: Strategy[models.UP, models.ID],
|
||||
user: models.UP,
|
||||
token: str,
|
||||
response: Response,
|
||||
|
@ -14,9 +14,9 @@ class StrategyDestroyNotSupportedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Strategy(Protocol, Generic[models.UP]):
|
||||
class Strategy(Protocol, Generic[models.UP, models.ID]):
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
|
||||
) -> Optional[models.UP]:
|
||||
... # pragma: no cover
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TypeVar
|
||||
|
||||
@ -8,12 +7,14 @@ if sys.version_info < (3, 8):
|
||||
else:
|
||||
from typing import Protocol # pragma: no cover
|
||||
|
||||
from fastapi_users import models
|
||||
|
||||
class AccessTokenProtocol(Protocol):
|
||||
|
||||
class AccessTokenProtocol(Protocol[models.ID]):
|
||||
"""Access token protocol that ORM model should follow."""
|
||||
|
||||
token: str
|
||||
user_id: uuid.UUID
|
||||
user_id: models.ID
|
||||
created_at: datetime
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
|
@ -6,7 +6,7 @@ from fastapi_users import models
|
||||
from fastapi_users.authentication.strategy.base import Strategy
|
||||
from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db.models import AP
|
||||
from fastapi_users.manager import BaseUserManager, UserNotExists
|
||||
from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists
|
||||
|
||||
|
||||
class DatabaseStrategy(Strategy, Generic[models.UP, AP]):
|
||||
@ -17,7 +17,7 @@ class DatabaseStrategy(Strategy, Generic[models.UP, AP]):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
|
||||
) -> Optional[models.UP]:
|
||||
if token is None:
|
||||
return None
|
||||
@ -33,9 +33,9 @@ class DatabaseStrategy(Strategy, Generic[models.UP, AP]):
|
||||
return None
|
||||
|
||||
try:
|
||||
user_id = access_token.user_id
|
||||
return await user_manager.get(user_id)
|
||||
except UserNotExists:
|
||||
parsed_id = user_manager.parse_id(access_token.user_id)
|
||||
return await user_manager.get(parsed_id)
|
||||
except (UserNotExists, InvalidID):
|
||||
return None
|
||||
|
||||
async def write_token(self, user: models.UP) -> str:
|
||||
|
@ -1,7 +1,6 @@
|
||||
from typing import Generic, List, Optional
|
||||
|
||||
import jwt
|
||||
from pydantic import UUID4
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication.strategy.base import (
|
||||
@ -9,7 +8,7 @@ from fastapi_users.authentication.strategy.base import (
|
||||
StrategyDestroyNotSupportedError,
|
||||
)
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.manager import BaseUserManager, UserNotExists
|
||||
from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists
|
||||
|
||||
|
||||
class JWTStrategy(Strategy, Generic[models.UP]):
|
||||
@ -36,7 +35,7 @@ class JWTStrategy(Strategy, Generic[models.UP]):
|
||||
return self.public_key or self.secret
|
||||
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
|
||||
) -> Optional[models.UP]:
|
||||
if token is None:
|
||||
return None
|
||||
@ -52,11 +51,9 @@ class JWTStrategy(Strategy, Generic[models.UP]):
|
||||
return None
|
||||
|
||||
try:
|
||||
user_uiid = UUID4(user_id)
|
||||
return await user_manager.get(user_uiid)
|
||||
except ValueError:
|
||||
return None
|
||||
except UserNotExists:
|
||||
parsed_id = user_manager.parse_id(user_id)
|
||||
return await user_manager.get(parsed_id)
|
||||
except (UserNotExists, InvalidID):
|
||||
return None
|
||||
|
||||
async def write_token(self, user: models.UP) -> str:
|
||||
|
@ -2,11 +2,10 @@ import secrets
|
||||
from typing import Generic, Optional
|
||||
|
||||
import aioredis
|
||||
from pydantic import UUID4
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication.strategy.base import Strategy
|
||||
from fastapi_users.manager import BaseUserManager, UserNotExists
|
||||
from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists
|
||||
|
||||
|
||||
class RedisStrategy(Strategy, Generic[models.UP]):
|
||||
@ -15,7 +14,7 @@ class RedisStrategy(Strategy, Generic[models.UP]):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
|
||||
) -> Optional[models.UP]:
|
||||
if token is None:
|
||||
return None
|
||||
@ -25,11 +24,9 @@ class RedisStrategy(Strategy, Generic[models.UP]):
|
||||
return None
|
||||
|
||||
try:
|
||||
user_uiid = UUID4(user_id)
|
||||
return await user_manager.get(user_uiid)
|
||||
except ValueError:
|
||||
return None
|
||||
except UserNotExists:
|
||||
parsed_id = user_manager.parse_id(user_id)
|
||||
return await user_manager.get(parsed_id)
|
||||
except (UserNotExists, InvalidID):
|
||||
return None
|
||||
|
||||
async def write_token(self, user: models.UP) -> str:
|
||||
|
@ -1,9 +1,6 @@
|
||||
from fastapi_users.db.base import BaseUserDatabase, UserDatabaseDependency
|
||||
|
||||
__all__ = [
|
||||
"BaseUserDatabase",
|
||||
"UserDatabaseDependency",
|
||||
]
|
||||
__all__ = ["BaseUserDatabase", "UserDatabaseDependency"]
|
||||
|
||||
try: # pragma: no cover
|
||||
from fastapi_users_db_mongodb import MongoDBUserDatabase # noqa: F401
|
||||
|
@ -1,15 +1,13 @@
|
||||
from typing import Any, Dict, Generic, Optional
|
||||
|
||||
from pydantic import UUID4
|
||||
|
||||
from fastapi_users.models import OAP, UP
|
||||
from fastapi_users.models import ID, OAP, UP
|
||||
from fastapi_users.types import DependencyCallable
|
||||
|
||||
|
||||
class BaseUserDatabase(Generic[UP]):
|
||||
class BaseUserDatabase(Generic[UP, ID]):
|
||||
"""Base adapter for retrieving, creating and updating users from a database."""
|
||||
|
||||
async def get(self, id: UUID4) -> Optional[UP]:
|
||||
async def get(self, id: ID) -> Optional[UP]:
|
||||
"""Get a single user by id."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -44,4 +42,4 @@ class BaseUserDatabase(Generic[UP]):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UP]]
|
||||
UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UP, ID]]
|
||||
|
@ -44,7 +44,7 @@ class FastAPIUsers(Generic[models.UP, schemas.U, schemas.UC, schemas.UU]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
get_user_manager: UserManagerDependency[models.UP],
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
auth_backends: Sequence[AuthenticationBackend],
|
||||
user_schema: Type[schemas.U],
|
||||
user_create_schema: Type[schemas.UC],
|
||||
|
@ -1,9 +1,9 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, Generic, Optional, Union
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import UUID4
|
||||
|
||||
from fastapi_users import models, schemas
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
@ -19,6 +19,10 @@ class FastAPIUsersException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidID(FastAPIUsersException):
|
||||
pass
|
||||
|
||||
|
||||
class UserAlreadyExists(FastAPIUsersException):
|
||||
pass
|
||||
|
||||
@ -48,7 +52,7 @@ class InvalidPasswordException(FastAPIUsersException):
|
||||
self.reason = reason
|
||||
|
||||
|
||||
class BaseUserManager(Generic[models.UP]):
|
||||
class BaseUserManager(Generic[models.UP, models.ID]):
|
||||
"""
|
||||
User management logic.
|
||||
|
||||
@ -70,12 +74,12 @@ class BaseUserManager(Generic[models.UP]):
|
||||
verification_token_lifetime_seconds: int = 3600
|
||||
verification_token_audience: str = VERIFY_USER_TOKEN_AUDIENCE
|
||||
|
||||
user_db: BaseUserDatabase[models.UP]
|
||||
user_db: BaseUserDatabase[models.UP, models.ID]
|
||||
password_helper: PasswordHelperProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_db: BaseUserDatabase[models.UP],
|
||||
user_db: BaseUserDatabase[models.UP, models.ID],
|
||||
password_helper: Optional[PasswordHelperProtocol] = None,
|
||||
):
|
||||
self.user_db = user_db
|
||||
@ -84,7 +88,17 @@ class BaseUserManager(Generic[models.UP]):
|
||||
else:
|
||||
self.password_helper = password_helper # pragma: no cover
|
||||
|
||||
async def get(self, id: UUID4) -> models.UP:
|
||||
def parse_id(self, value: Any) -> models.ID:
|
||||
"""
|
||||
Parse a value into a correct models.ID instance.
|
||||
|
||||
:param value: The value to parse.
|
||||
:raises InvalidID: The models.ID value is invalid.
|
||||
:return: An models.ID object.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def get(self, id: models.ID) -> models.UP:
|
||||
"""
|
||||
Get a user by id.
|
||||
|
||||
@ -170,7 +184,7 @@ class BaseUserManager(Generic[models.UP]):
|
||||
return created_user
|
||||
|
||||
async def oauth_callback(
|
||||
self: "BaseUserManager[models.UOAP]",
|
||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||
oauth_name: str,
|
||||
access_token: str,
|
||||
account_id: str,
|
||||
@ -192,7 +206,7 @@ class BaseUserManager(Generic[models.UP]):
|
||||
|
||||
:param oauth_name: Name of the OAuth client.
|
||||
:param access_token: Valid access token for the service provider.
|
||||
:param account_id: ID of the user on the service provider.
|
||||
:param account_id: models.ID of the user on the service provider.
|
||||
:param account_email: E-mail of the user on the service provider.
|
||||
:param expires_at: Optional timestamp at which the access token expires.
|
||||
:param refresh_token: Optional refresh token to get a
|
||||
@ -307,11 +321,11 @@ class BaseUserManager(Generic[models.UP]):
|
||||
raise InvalidVerifyToken()
|
||||
|
||||
try:
|
||||
user_uuid = UUID4(user_id)
|
||||
except ValueError:
|
||||
parsed_id = self.parse_id(user_id)
|
||||
except InvalidID:
|
||||
raise InvalidVerifyToken()
|
||||
|
||||
if user_uuid != user.id:
|
||||
if parsed_id != user.id:
|
||||
raise InvalidVerifyToken()
|
||||
|
||||
if user.is_verified:
|
||||
@ -382,11 +396,11 @@ class BaseUserManager(Generic[models.UP]):
|
||||
raise InvalidResetPasswordToken()
|
||||
|
||||
try:
|
||||
user_uuid = UUID4(user_id)
|
||||
except ValueError:
|
||||
parsed_id = self.parse_id(user_id)
|
||||
except InvalidID:
|
||||
raise InvalidResetPasswordToken()
|
||||
|
||||
user = await self.get(user_uuid)
|
||||
user = await self.get(parsed_id)
|
||||
|
||||
if not user.is_active:
|
||||
raise UserInactive()
|
||||
@ -588,4 +602,14 @@ class BaseUserManager(Generic[models.UP]):
|
||||
return await self.user_db.update(user, validated_update_dict)
|
||||
|
||||
|
||||
UserManagerDependency = DependencyCallable[BaseUserManager[models.UP]]
|
||||
class UUIDIDMixin:
|
||||
def parse_id(self, value: Any) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
try:
|
||||
return uuid.UUID(value)
|
||||
except ValueError as e:
|
||||
raise InvalidID() from e
|
||||
|
||||
|
||||
UserManagerDependency = DependencyCallable[BaseUserManager[models.UP, models.ID]]
|
||||
|
@ -1,5 +1,4 @@
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Generic, List, Optional, TypeVar
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
@ -7,11 +6,13 @@ if sys.version_info < (3, 8):
|
||||
else:
|
||||
from typing import Protocol # pragma: no cover
|
||||
|
||||
ID = TypeVar("ID")
|
||||
|
||||
class UserProtocol(Protocol):
|
||||
|
||||
class UserProtocol(Protocol[ID]):
|
||||
"""User protocol that ORM model should follow."""
|
||||
|
||||
id: uuid.UUID
|
||||
id: ID
|
||||
email: str
|
||||
hashed_password: str
|
||||
is_active: bool
|
||||
@ -22,10 +23,10 @@ class UserProtocol(Protocol):
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
class OAuthAccountProtocol(Protocol):
|
||||
class OAuthAccountProtocol(Protocol[ID]):
|
||||
"""OAuth account protocol that ORM model should follow."""
|
||||
|
||||
id: uuid.UUID
|
||||
id: ID
|
||||
oauth_name: str
|
||||
access_token: str
|
||||
expires_at: Optional[int]
|
||||
@ -41,7 +42,7 @@ UP = TypeVar("UP", bound=UserProtocol)
|
||||
OAP = TypeVar("OAP", bound=OAuthAccountProtocol)
|
||||
|
||||
|
||||
class UserOAuthProtocol(UserProtocol, Generic[OAP]):
|
||||
class UserOAuthProtocol(UserProtocol[ID], Generic[ID, OAP]):
|
||||
"""User protocol including a list of OAuth accounts."""
|
||||
|
||||
oauth_accounts: List[OAP]
|
||||
|
@ -12,7 +12,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
|
||||
|
||||
def get_auth_router(
|
||||
backend: AuthenticationBackend,
|
||||
get_user_manager: UserManagerDependency[models.UP],
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
authenticator: Authenticator,
|
||||
requires_verification: bool = False,
|
||||
) -> APIRouter:
|
||||
@ -51,8 +51,8 @@ def get_auth_router(
|
||||
async def login(
|
||||
response: Response,
|
||||
credentials: OAuth2PasswordRequestForm = Depends(),
|
||||
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
|
||||
strategy: Strategy[models.UP] = Depends(backend.get_strategy),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
):
|
||||
user = await user_manager.authenticate(credentials)
|
||||
|
||||
@ -83,7 +83,7 @@ def get_auth_router(
|
||||
async def logout(
|
||||
response: Response,
|
||||
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
|
||||
strategy: Strategy[models.UP] = Depends(backend.get_strategy),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
):
|
||||
user, token = user_token
|
||||
return await backend.logout(strategy, user, token, response)
|
||||
|
@ -29,7 +29,7 @@ def generate_state_token(
|
||||
def get_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
get_user_manager: UserManagerDependency[models.UP],
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
state_secret: SecretType,
|
||||
redirect_url: str = None,
|
||||
) -> APIRouter:
|
||||
@ -101,8 +101,8 @@ def get_oauth_router(
|
||||
access_token_state: Tuple[OAuth2Token, str] = Depends(
|
||||
oauth2_authorize_callback
|
||||
),
|
||||
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
|
||||
strategy: Strategy[models.UP] = Depends(backend.get_strategy),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
):
|
||||
token, state = access_token_state
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
|
@ -13,7 +13,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
|
||||
|
||||
|
||||
def get_register_router(
|
||||
get_user_manager: UserManagerDependency[models.UP],
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
user_schema: Type[schemas.U],
|
||||
user_create_schema: Type[schemas.UC],
|
||||
) -> APIRouter:
|
||||
@ -56,7 +56,7 @@ def get_register_router(
|
||||
async def register(
|
||||
request: Request,
|
||||
user_create: user_create_schema, # type: ignore
|
||||
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
):
|
||||
try:
|
||||
created_user = await user_manager.create(
|
||||
|
@ -40,7 +40,7 @@ RESET_PASSWORD_RESPONSES: OpenAPIResponseType = {
|
||||
|
||||
|
||||
def get_reset_password_router(
|
||||
get_user_manager: UserManagerDependency[models.UP],
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the reset password routes."""
|
||||
router = APIRouter()
|
||||
@ -53,7 +53,7 @@ def get_reset_password_router(
|
||||
async def forgot_password(
|
||||
request: Request,
|
||||
email: EmailStr = Body(..., embed=True),
|
||||
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
):
|
||||
try:
|
||||
user = await user_manager.get_by_email(email)
|
||||
@ -76,7 +76,7 @@ def get_reset_password_router(
|
||||
request: Request,
|
||||
token: str = Body(...),
|
||||
password: str = Body(...),
|
||||
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
):
|
||||
try:
|
||||
await user_manager.reset_password(token, password, request)
|
||||
|
@ -1,12 +1,12 @@
|
||||
from typing import Type
|
||||
from typing import Any, Type
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from pydantic import UUID4
|
||||
|
||||
from fastapi_users import models, schemas
|
||||
from fastapi_users.authentication import Authenticator
|
||||
from fastapi_users.manager import (
|
||||
BaseUserManager,
|
||||
InvalidID,
|
||||
InvalidPasswordException,
|
||||
UserAlreadyExists,
|
||||
UserManagerDependency,
|
||||
@ -16,7 +16,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
|
||||
|
||||
|
||||
def get_users_router(
|
||||
get_user_manager: UserManagerDependency[models.UP],
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
user_schema: Type[schemas.U],
|
||||
user_update_schema: Type[schemas.UU],
|
||||
authenticator: Authenticator,
|
||||
@ -33,13 +33,14 @@ def get_users_router(
|
||||
)
|
||||
|
||||
async def get_user_or_404(
|
||||
id: UUID4,
|
||||
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
|
||||
id: Any,
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
) -> models.UP:
|
||||
try:
|
||||
return await user_manager.get(id)
|
||||
except UserNotExists:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
parsed_id = user_manager.parse_id(id)
|
||||
return await user_manager.get(parsed_id)
|
||||
except (UserNotExists, InvalidID) as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) from e
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
@ -96,7 +97,7 @@ def get_users_router(
|
||||
request: Request,
|
||||
user_update: user_update_schema, # type: ignore
|
||||
user: models.UP = Depends(get_current_active_user),
|
||||
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
):
|
||||
try:
|
||||
return await user_manager.update(
|
||||
@ -117,7 +118,7 @@ def get_users_router(
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/{id:uuid}",
|
||||
"/{id}",
|
||||
response_model=user_schema,
|
||||
dependencies=[Depends(get_current_superuser)],
|
||||
name="users:user",
|
||||
@ -137,7 +138,7 @@ def get_users_router(
|
||||
return user
|
||||
|
||||
@router.patch(
|
||||
"/{id:uuid}",
|
||||
"/{id}",
|
||||
response_model=user_schema,
|
||||
dependencies=[Depends(get_current_superuser)],
|
||||
name="users:patch_user",
|
||||
@ -182,7 +183,7 @@ def get_users_router(
|
||||
user_update: user_update_schema, # type: ignore
|
||||
request: Request,
|
||||
user=Depends(get_user_or_404),
|
||||
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
):
|
||||
try:
|
||||
return await user_manager.update(
|
||||
@ -203,7 +204,7 @@ def get_users_router(
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{id:uuid}",
|
||||
"/{id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
response_class=Response,
|
||||
dependencies=[Depends(get_current_superuser)],
|
||||
@ -222,7 +223,7 @@ def get_users_router(
|
||||
)
|
||||
async def delete_user(
|
||||
user=Depends(get_user_or_404),
|
||||
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
):
|
||||
await user_manager.delete(user)
|
||||
return None
|
||||
|
@ -16,7 +16,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel
|
||||
|
||||
|
||||
def get_verify_router(
|
||||
get_user_manager: UserManagerDependency[models.UP],
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
user_schema: Type[schemas.U],
|
||||
):
|
||||
router = APIRouter()
|
||||
@ -29,7 +29,7 @@ def get_verify_router(
|
||||
async def request_verify_token(
|
||||
request: Request,
|
||||
email: EmailStr = Body(..., embed=True),
|
||||
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
):
|
||||
try:
|
||||
user = await user_manager.get_by_email(email)
|
||||
@ -69,7 +69,7 @@ def get_verify_router(
|
||||
async def verify(
|
||||
request: Request,
|
||||
token: str = Body(..., embed=True),
|
||||
user_manager: BaseUserManager[models.UP] = Depends(get_user_manager),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
):
|
||||
try:
|
||||
return await user_manager.verify(token, request)
|
||||
|
@ -1,7 +1,8 @@
|
||||
import uuid
|
||||
from typing import List, Optional, TypeVar
|
||||
from typing import Generic, List, Optional, TypeVar
|
||||
|
||||
from pydantic import UUID4, BaseModel, EmailStr, Field
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from fastapi_users import models
|
||||
|
||||
|
||||
class CreateUpdateDictModel(BaseModel):
|
||||
@ -21,10 +22,10 @@ class CreateUpdateDictModel(BaseModel):
|
||||
return self.dict(exclude_unset=True, exclude={"id"})
|
||||
|
||||
|
||||
class BaseUser(CreateUpdateDictModel):
|
||||
class BaseUser(Generic[models.ID], CreateUpdateDictModel):
|
||||
"""Base User model."""
|
||||
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4)
|
||||
id: models.ID
|
||||
email: EmailStr
|
||||
is_active: bool = True
|
||||
is_superuser: bool = False
|
||||
@ -52,10 +53,10 @@ UC = TypeVar("UC", bound=BaseUserCreate)
|
||||
UU = TypeVar("UU", bound=BaseUserUpdate)
|
||||
|
||||
|
||||
class BaseOAuthAccount(BaseModel):
|
||||
class BaseOAuthAccount(Generic[models.ID], BaseModel):
|
||||
"""Base OAuth account model."""
|
||||
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4)
|
||||
id: models.ID
|
||||
oauth_name: str
|
||||
access_token: str
|
||||
expires_at: Optional[int] = None
|
||||
|
@ -29,8 +29,10 @@ from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.jwt import SecretType
|
||||
from fastapi_users.manager import (
|
||||
BaseUserManager,
|
||||
InvalidID,
|
||||
InvalidPasswordException,
|
||||
UserNotExists,
|
||||
UUIDIDMixin,
|
||||
)
|
||||
from fastapi_users.openapi import OpenAPIResponseType
|
||||
from fastapi_users.password import PasswordHelper
|
||||
@ -43,8 +45,11 @@ lancelot_password_hash = password_helper.hash("lancelot")
|
||||
excalibur_password_hash = password_helper.hash("excalibur")
|
||||
|
||||
|
||||
IDType = uuid.UUID
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UserModel(models.UserProtocol):
|
||||
class UserModel(models.UserProtocol[IDType]):
|
||||
email: str
|
||||
hashed_password: str
|
||||
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
|
||||
@ -55,7 +60,7 @@ class UserModel(models.UserProtocol):
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class OAuthAccountModel(models.OAuthAccountProtocol):
|
||||
class OAuthAccountModel(models.OAuthAccountProtocol[IDType]):
|
||||
oauth_name: str
|
||||
access_token: str
|
||||
account_id: str
|
||||
@ -86,7 +91,9 @@ class UserOAuth(User, schemas.BaseOAuthAccountMixin):
|
||||
pass
|
||||
|
||||
|
||||
class BaseTestUserManager(Generic[models.UP], BaseUserManager[models.UP]):
|
||||
class BaseTestUserManager(
|
||||
Generic[models.UP], UUIDIDMixin, BaseUserManager[models.UP, IDType]
|
||||
):
|
||||
reset_password_token_secret = "SECRET"
|
||||
verification_token_secret = "SECRET"
|
||||
|
||||
@ -318,8 +325,8 @@ def mock_user_db(
|
||||
inactive_user: UserModel,
|
||||
superuser: UserModel,
|
||||
verified_superuser: UserModel,
|
||||
) -> BaseUserDatabase[UserModel]:
|
||||
class MockUserDatabase(BaseUserDatabase[UserModel]):
|
||||
) -> BaseUserDatabase[UserModel, IDType]:
|
||||
class MockUserDatabase(BaseUserDatabase[UserModel, IDType]):
|
||||
async def get(self, id: UUID4) -> Optional[UserModel]:
|
||||
if id == user.id:
|
||||
return user
|
||||
@ -370,8 +377,8 @@ def mock_user_db_oauth(
|
||||
inactive_user_oauth: UserOAuthModel,
|
||||
superuser_oauth: UserOAuthModel,
|
||||
verified_superuser_oauth: UserOAuthModel,
|
||||
) -> BaseUserDatabase[UserOAuthModel]:
|
||||
class MockUserDatabase(BaseUserDatabase[UserOAuthModel]):
|
||||
) -> BaseUserDatabase[UserOAuthModel, IDType]:
|
||||
class MockUserDatabase(BaseUserDatabase[UserOAuthModel, IDType]):
|
||||
async def get(self, id: UUID4) -> Optional[UserOAuthModel]:
|
||||
if id == user_oauth.id:
|
||||
return user_oauth
|
||||
@ -518,17 +525,15 @@ class MockTransport(BearerTransport):
|
||||
return {}
|
||||
|
||||
|
||||
class MockStrategy(Strategy[UserModel]):
|
||||
class MockStrategy(Strategy[UserModel, IDType]):
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[UserModel]
|
||||
self, token: Optional[str], user_manager: BaseUserManager[UserModel, IDType]
|
||||
) -> Optional[UserModel]:
|
||||
if token is not None:
|
||||
try:
|
||||
token_uuid = UUID4(token)
|
||||
return await user_manager.get(token_uuid)
|
||||
except ValueError:
|
||||
return None
|
||||
except UserNotExists:
|
||||
parsed_id = user_manager.parse_id(token)
|
||||
return await user_manager.get(parsed_id)
|
||||
except (InvalidID, UserNotExists):
|
||||
return None
|
||||
return None
|
||||
|
||||
|
@ -29,7 +29,7 @@ class MockTransport(Transport):
|
||||
|
||||
class NoneStrategy(Strategy):
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
|
||||
) -> Optional[models.UP]:
|
||||
return None
|
||||
|
||||
@ -39,7 +39,7 @@ class UserStrategy(Strategy, Generic[models.UP]):
|
||||
self.user = user
|
||||
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
|
||||
) -> Optional[models.UP]:
|
||||
return self.user
|
||||
|
||||
|
@ -21,7 +21,7 @@ class MockTransportLogoutNotSupported(BearerTransport):
|
||||
|
||||
class MockStrategyDestroyNotSupported(Strategy, Generic[models.UP]):
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP]
|
||||
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
|
||||
) -> Optional[models.UP]:
|
||||
return None
|
||||
|
||||
|
@ -3,7 +3,7 @@ import uuid
|
||||
import pytest
|
||||
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from tests.conftest import OAuthAccountModel, UserModel
|
||||
from tests.conftest import IDType, OAuthAccountModel, UserModel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -11,7 +11,7 @@ from tests.conftest import OAuthAccountModel, UserModel
|
||||
async def test_not_implemented_methods(
|
||||
user: UserModel, oauth_account1: OAuthAccountModel
|
||||
):
|
||||
base_user_db = BaseUserDatabase[UserModel]()
|
||||
base_user_db = BaseUserDatabase[UserModel, IDType]()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await base_user_db.get(uuid.uuid4())
|
||||
|
@ -32,13 +32,14 @@ async def test_app_client(
|
||||
app.include_router(
|
||||
fastapi_users.get_oauth_router(oauth_client, mock_authentication, secret)
|
||||
)
|
||||
app.include_router(fastapi_users.get_users_router(), prefix="/users")
|
||||
app.include_router(fastapi_users.get_verify_router())
|
||||
|
||||
@app.delete("/users/me")
|
||||
def custom_users_route():
|
||||
return None
|
||||
|
||||
app.include_router(fastapi_users.get_users_router(), prefix="/users")
|
||||
app.include_router(fastapi_users.get_verify_router())
|
||||
|
||||
@app.get("/current-user", response_model=User)
|
||||
def current_user(user: UserModel = Depends(fastapi_users.current_user())):
|
||||
return user
|
||||
|
@ -1,5 +1,3 @@
|
||||
import copy
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
@ -18,7 +16,6 @@ from fastapi_users.manager import (
|
||||
UserNotExists,
|
||||
)
|
||||
from tests.conftest import (
|
||||
OAuthAccountModel,
|
||||
UserCreate,
|
||||
UserManagerMock,
|
||||
UserModel,
|
||||
|
Reference in New Issue
Block a user