diff --git a/docs/usage/helpers.md b/docs/usage/helpers.md new file mode 100644 index 00000000..623dc9e9 --- /dev/null +++ b/docs/usage/helpers.md @@ -0,0 +1,22 @@ +# Helpers + +## Create user + +**FastAPI Users** provides a helper function to easily create a user programmatically. They are available from your `FastAPIUsers` instance. + +```py +regular_user = await fastapi_users.create_user( + UserCreate( + email="king.arthur@camelot.bt", + password="guinevere", + ) +) + +superuser = await fastapi_users.create_user( + UserCreate( + email="king.arthur@camelot.bt", + password="guinevere", + is_superuser=True, + ) +) +``` diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index e606ca05..51d93c4d 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -11,6 +11,7 @@ from fastapi_users.router import ( get_reset_password_router, get_users_router, ) +from fastapi_users.user import CreateUserProtocol, get_create_user try: from httpx_oauth.oauth2 import BaseOAuth2 @@ -31,6 +32,7 @@ class FastAPIUsers: :param user_update_model: Pydantic model for updating a user. :param user_db_model: Pydantic model of a DB representation of a user. + :attribute create_user: Helper function to create a user programmatically. :attribute get_current_user: Dependency callable to inject authenticated user. :attribute get_current_active_user: Dependency callable to inject active user. :attribute get_current_superuser: Dependency callable to inject superuser. @@ -38,6 +40,7 @@ class FastAPIUsers: db: BaseUserDatabase authenticator: Authenticator + create_user: CreateUserProtocol _user_model: Type[models.BaseUser] _user_create_model: Type[models.BaseUserCreate] _user_update_model: Type[models.BaseUserUpdate] @@ -61,6 +64,8 @@ class FastAPIUsers: self._user_update_model = user_update_model self._user_db_model = user_db_model + self.create_user = get_create_user(db, user_db_model) + self.get_current_user = self.authenticator.get_current_user self.get_current_active_user = self.authenticator.get_current_active_user self.get_current_superuser = self.authenticator.get_current_superuser @@ -83,10 +88,9 @@ class FastAPIUsers: after a successful registration. """ return get_register_router( - self.db, + self.create_user, self._user_model, self._user_create_model, - self._user_db_model, after_register, ) diff --git a/fastapi_users/models.py b/fastapi_users/models.py index 833a891f..6f843c10 100644 --- a/fastapi_users/models.py +++ b/fastapi_users/models.py @@ -31,6 +31,8 @@ class BaseUser(CreateUpdateDictModel): class BaseUserCreate(CreateUpdateDictModel): email: EmailStr password: str + is_active: Optional[bool] = True + is_superuser: Optional[bool] = False class BaseUserUpdate(BaseUser): diff --git a/fastapi_users/router/register.py b/fastapi_users/router/register.py index 53beed01..6ce58a8c 100644 --- a/fastapi_users/router/register.py +++ b/fastapi_users/router/register.py @@ -1,18 +1,16 @@ -from typing import Callable, Optional, Type, cast +from typing import Callable, Optional, Type from fastapi import APIRouter, HTTPException, Request, status from fastapi_users import models -from fastapi_users.db import BaseUserDatabase -from fastapi_users.password import get_password_hash from fastapi_users.router.common import ErrorCode, run_handler +from fastapi_users.user import CreateUserProtocol, UserAlreadyExists def get_register_router( - user_db: BaseUserDatabase[models.BaseUserDB], + create_user: CreateUserProtocol, user_model: Type[models.BaseUser], user_create_model: Type[models.BaseUserCreate], - user_db_model: Type[models.BaseUserDB], after_register: Optional[Callable[[models.UD, Request], None]] = None, ) -> APIRouter: """Generate a router with the register route.""" @@ -22,21 +20,14 @@ def get_register_router( "/register", response_model=user_model, status_code=status.HTTP_201_CREATED ) async def register(request: Request, user: user_create_model): # type: ignore - user = cast(models.BaseUserCreate, user) # Prevent mypy complain - existing_user = await user_db.get_by_email(user.email) - - if existing_user is not None: + try: + created_user = await create_user(user, safe=True) + except UserAlreadyExists: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ErrorCode.REGISTER_USER_ALREADY_EXISTS, ) - hashed_password = get_password_hash(user.password) - db_user = user_db_model( - **user.create_update_dict(), hashed_password=hashed_password - ) - created_user = await user_db.create(db_user) - if after_register: await run_handler(after_register, created_user, request) diff --git a/fastapi_users/user.py b/fastapi_users/user.py new file mode 100644 index 00000000..86c96b4d --- /dev/null +++ b/fastapi_users/user.py @@ -0,0 +1,40 @@ +from typing import Awaitable, Type + +from typing_extensions import Protocol + +from fastapi_users import models +from fastapi_users.db import BaseUserDatabase +from fastapi_users.password import get_password_hash + + +class UserAlreadyExists(Exception): + pass + + +class CreateUserProtocol(Protocol): # pragma: no cover + def __call__( + self, user: models.BaseUserCreate, safe: bool = False + ) -> Awaitable[models.BaseUserDB]: + pass + + +def get_create_user( + user_db: BaseUserDatabase[models.BaseUserDB], + user_db_model: Type[models.BaseUserDB], +) -> CreateUserProtocol: + async def create_user( + user: models.BaseUserCreate, safe: bool = False + ) -> models.BaseUserDB: + existing_user = await user_db.get_by_email(user.email) + + if existing_user is not None: + raise UserAlreadyExists() + + hashed_password = get_password_hash(user.password) + user_dict = ( + user.create_update_dict() if safe else user.create_update_dict_superuser() + ) + db_user = user_db_model(**user_dict, hashed_password=hashed_password) + return await user_db.create(db_user) + + return create_user diff --git a/mkdocs.yml b/mkdocs.yml index 5fc97ad3..3876f1bc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -50,6 +50,7 @@ nav: - usage/flow.md - usage/routes.md - usage/dependency-callables.md + - usage/helpers.md - Migration: - migration/08_to_1x.md - migration/1x_to_2x.md diff --git a/tests/test_router_register.py b/tests/test_router_register.py index efaddbba..65fd7a5e 100644 --- a/tests/test_router_register.py +++ b/tests/test_router_register.py @@ -7,6 +7,7 @@ import pytest from fastapi import FastAPI, Request, status from fastapi_users.router import ErrorCode, get_register_router +from fastapi_users.user import get_create_user from tests.conftest import User, UserCreate, UserDB SECRET = "SECRET" @@ -29,13 +30,13 @@ def after_register(request): @pytest.fixture @pytest.mark.asyncio async def test_app_client( - mock_user_db, mock_authentication, after_register, get_test_client + mock_user_db, after_register, get_test_client ) -> AsyncGenerator[httpx.AsyncClient, None]: + create_user = get_create_user(mock_user_db, UserDB) register_router = get_register_router( - mock_user_db, + create_user, User, UserCreate, - UserDB, after_register, ) diff --git a/tests/test_user.py b/tests/test_user.py new file mode 100644 index 00000000..b21e14f3 --- /dev/null +++ b/tests/test_user.py @@ -0,0 +1,47 @@ +import pytest + +from fastapi_users.user import CreateUserProtocol, UserAlreadyExists, get_create_user +from tests.conftest import UserCreate, UserDB + + +@pytest.fixture +def create_user( + mock_user_db, +) -> CreateUserProtocol: + return get_create_user(mock_user_db, UserDB) + + +@pytest.mark.router +@pytest.mark.asyncio +class TestCreateUser: + @pytest.mark.parametrize( + "email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"] + ) + async def test_existing_user(self, email, create_user): + user = UserCreate(email=email, password="guinevere") + with pytest.raises(UserAlreadyExists): + await create_user(user) + + @pytest.mark.parametrize("email", ["lancelot@camelot.bt", "Lancelot@camelot.bt"]) + async def test_regular_user(self, email, create_user): + user = UserCreate(email=email, password="guinevere") + created_user = await create_user(user) + assert type(created_user) == UserDB + + @pytest.mark.parametrize("safe,result", [(True, False), (False, True)]) + async def test_superuser(self, create_user, safe, result): + user = UserCreate( + email="lancelot@camelot.b", password="guinevere", is_superuser=True + ) + created_user = await create_user(user, safe) + assert type(created_user) == UserDB + assert created_user.is_superuser is result + + @pytest.mark.parametrize("safe,result", [(True, True), (False, False)]) + async def test_is_active(self, create_user, safe, result): + user = UserCreate( + email="lancelot@camelot.b", password="guinevere", is_active=False + ) + created_user = await create_user(user, safe) + assert type(created_user) == UserDB + assert created_user.is_active is result