mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-26 12:31:25 +08:00
Fix #391: put user creation logic in a importable function
This commit is contained in:
22
docs/usage/helpers.md
Normal file
22
docs/usage/helpers.md
Normal file
@ -0,0 +1,22 @@
|
||||
# Helpers
|
||||
|
||||
## Create user
|
||||
|
||||
**FastAPI Users** provides a helper function to easily create a user programmatically. They are available from your `FastAPIUsers` instance.
|
||||
|
||||
```py
|
||||
regular_user = await fastapi_users.create_user(
|
||||
UserCreate(
|
||||
email="king.arthur@camelot.bt",
|
||||
password="guinevere",
|
||||
)
|
||||
)
|
||||
|
||||
superuser = await fastapi_users.create_user(
|
||||
UserCreate(
|
||||
email="king.arthur@camelot.bt",
|
||||
password="guinevere",
|
||||
is_superuser=True,
|
||||
)
|
||||
)
|
||||
```
|
@ -11,6 +11,7 @@ from fastapi_users.router import (
|
||||
get_reset_password_router,
|
||||
get_users_router,
|
||||
)
|
||||
from fastapi_users.user import CreateUserProtocol, get_create_user
|
||||
|
||||
try:
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
@ -31,6 +32,7 @@ class FastAPIUsers:
|
||||
:param user_update_model: Pydantic model for updating a user.
|
||||
:param user_db_model: Pydantic model of a DB representation of a user.
|
||||
|
||||
:attribute create_user: Helper function to create a user programmatically.
|
||||
:attribute get_current_user: Dependency callable to inject authenticated user.
|
||||
:attribute get_current_active_user: Dependency callable to inject active user.
|
||||
:attribute get_current_superuser: Dependency callable to inject superuser.
|
||||
@ -38,6 +40,7 @@ class FastAPIUsers:
|
||||
|
||||
db: BaseUserDatabase
|
||||
authenticator: Authenticator
|
||||
create_user: CreateUserProtocol
|
||||
_user_model: Type[models.BaseUser]
|
||||
_user_create_model: Type[models.BaseUserCreate]
|
||||
_user_update_model: Type[models.BaseUserUpdate]
|
||||
@ -61,6 +64,8 @@ class FastAPIUsers:
|
||||
self._user_update_model = user_update_model
|
||||
self._user_db_model = user_db_model
|
||||
|
||||
self.create_user = get_create_user(db, user_db_model)
|
||||
|
||||
self.get_current_user = self.authenticator.get_current_user
|
||||
self.get_current_active_user = self.authenticator.get_current_active_user
|
||||
self.get_current_superuser = self.authenticator.get_current_superuser
|
||||
@ -83,10 +88,9 @@ class FastAPIUsers:
|
||||
after a successful registration.
|
||||
"""
|
||||
return get_register_router(
|
||||
self.db,
|
||||
self.create_user,
|
||||
self._user_model,
|
||||
self._user_create_model,
|
||||
self._user_db_model,
|
||||
after_register,
|
||||
)
|
||||
|
||||
|
@ -31,6 +31,8 @@ class BaseUser(CreateUpdateDictModel):
|
||||
class BaseUserCreate(CreateUpdateDictModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
is_active: Optional[bool] = True
|
||||
is_superuser: Optional[bool] = False
|
||||
|
||||
|
||||
class BaseUserUpdate(BaseUser):
|
||||
|
@ -1,18 +1,16 @@
|
||||
from typing import Callable, Optional, Type, cast
|
||||
from typing import Callable, Optional, Type
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, status
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.password import get_password_hash
|
||||
from fastapi_users.router.common import ErrorCode, run_handler
|
||||
from fastapi_users.user import CreateUserProtocol, UserAlreadyExists
|
||||
|
||||
|
||||
def get_register_router(
|
||||
user_db: BaseUserDatabase[models.BaseUserDB],
|
||||
create_user: CreateUserProtocol,
|
||||
user_model: Type[models.BaseUser],
|
||||
user_create_model: Type[models.BaseUserCreate],
|
||||
user_db_model: Type[models.BaseUserDB],
|
||||
after_register: Optional[Callable[[models.UD, Request], None]] = None,
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the register route."""
|
||||
@ -22,21 +20,14 @@ def get_register_router(
|
||||
"/register", response_model=user_model, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def register(request: Request, user: user_create_model): # type: ignore
|
||||
user = cast(models.BaseUserCreate, user) # Prevent mypy complain
|
||||
existing_user = await user_db.get_by_email(user.email)
|
||||
|
||||
if existing_user is not None:
|
||||
try:
|
||||
created_user = await create_user(user, safe=True)
|
||||
except UserAlreadyExists:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.REGISTER_USER_ALREADY_EXISTS,
|
||||
)
|
||||
|
||||
hashed_password = get_password_hash(user.password)
|
||||
db_user = user_db_model(
|
||||
**user.create_update_dict(), hashed_password=hashed_password
|
||||
)
|
||||
created_user = await user_db.create(db_user)
|
||||
|
||||
if after_register:
|
||||
await run_handler(after_register, created_user, request)
|
||||
|
||||
|
40
fastapi_users/user.py
Normal file
40
fastapi_users/user.py
Normal file
@ -0,0 +1,40 @@
|
||||
from typing import Awaitable, Type
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.password import get_password_hash
|
||||
|
||||
|
||||
class UserAlreadyExists(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CreateUserProtocol(Protocol): # pragma: no cover
|
||||
def __call__(
|
||||
self, user: models.BaseUserCreate, safe: bool = False
|
||||
) -> Awaitable[models.BaseUserDB]:
|
||||
pass
|
||||
|
||||
|
||||
def get_create_user(
|
||||
user_db: BaseUserDatabase[models.BaseUserDB],
|
||||
user_db_model: Type[models.BaseUserDB],
|
||||
) -> CreateUserProtocol:
|
||||
async def create_user(
|
||||
user: models.BaseUserCreate, safe: bool = False
|
||||
) -> models.BaseUserDB:
|
||||
existing_user = await user_db.get_by_email(user.email)
|
||||
|
||||
if existing_user is not None:
|
||||
raise UserAlreadyExists()
|
||||
|
||||
hashed_password = get_password_hash(user.password)
|
||||
user_dict = (
|
||||
user.create_update_dict() if safe else user.create_update_dict_superuser()
|
||||
)
|
||||
db_user = user_db_model(**user_dict, hashed_password=hashed_password)
|
||||
return await user_db.create(db_user)
|
||||
|
||||
return create_user
|
@ -50,6 +50,7 @@ nav:
|
||||
- usage/flow.md
|
||||
- usage/routes.md
|
||||
- usage/dependency-callables.md
|
||||
- usage/helpers.md
|
||||
- Migration:
|
||||
- migration/08_to_1x.md
|
||||
- migration/1x_to_2x.md
|
||||
|
@ -7,6 +7,7 @@ import pytest
|
||||
from fastapi import FastAPI, Request, status
|
||||
|
||||
from fastapi_users.router import ErrorCode, get_register_router
|
||||
from fastapi_users.user import get_create_user
|
||||
from tests.conftest import User, UserCreate, UserDB
|
||||
|
||||
SECRET = "SECRET"
|
||||
@ -29,13 +30,13 @@ def after_register(request):
|
||||
@pytest.fixture
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_client(
|
||||
mock_user_db, mock_authentication, after_register, get_test_client
|
||||
mock_user_db, after_register, get_test_client
|
||||
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
create_user = get_create_user(mock_user_db, UserDB)
|
||||
register_router = get_register_router(
|
||||
mock_user_db,
|
||||
create_user,
|
||||
User,
|
||||
UserCreate,
|
||||
UserDB,
|
||||
after_register,
|
||||
)
|
||||
|
||||
|
47
tests/test_user.py
Normal file
47
tests/test_user.py
Normal file
@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
|
||||
from fastapi_users.user import CreateUserProtocol, UserAlreadyExists, get_create_user
|
||||
from tests.conftest import UserCreate, UserDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_user(
|
||||
mock_user_db,
|
||||
) -> CreateUserProtocol:
|
||||
return get_create_user(mock_user_db, UserDB)
|
||||
|
||||
|
||||
@pytest.mark.router
|
||||
@pytest.mark.asyncio
|
||||
class TestCreateUser:
|
||||
@pytest.mark.parametrize(
|
||||
"email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"]
|
||||
)
|
||||
async def test_existing_user(self, email, create_user):
|
||||
user = UserCreate(email=email, password="guinevere")
|
||||
with pytest.raises(UserAlreadyExists):
|
||||
await create_user(user)
|
||||
|
||||
@pytest.mark.parametrize("email", ["lancelot@camelot.bt", "Lancelot@camelot.bt"])
|
||||
async def test_regular_user(self, email, create_user):
|
||||
user = UserCreate(email=email, password="guinevere")
|
||||
created_user = await create_user(user)
|
||||
assert type(created_user) == UserDB
|
||||
|
||||
@pytest.mark.parametrize("safe,result", [(True, False), (False, True)])
|
||||
async def test_superuser(self, create_user, safe, result):
|
||||
user = UserCreate(
|
||||
email="lancelot@camelot.b", password="guinevere", is_superuser=True
|
||||
)
|
||||
created_user = await create_user(user, safe)
|
||||
assert type(created_user) == UserDB
|
||||
assert created_user.is_superuser is result
|
||||
|
||||
@pytest.mark.parametrize("safe,result", [(True, True), (False, False)])
|
||||
async def test_is_active(self, create_user, safe, result):
|
||||
user = UserCreate(
|
||||
email="lancelot@camelot.b", password="guinevere", is_active=False
|
||||
)
|
||||
created_user = await create_user(user, safe)
|
||||
assert type(created_user) == UserDB
|
||||
assert created_user.is_active is result
|
Reference in New Issue
Block a user