mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-14 18:58:10 +08:00
570 lines
17 KiB
Python
570 lines
17 KiB
Python
import asyncio
|
|
import dataclasses
|
|
import uuid
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any, Callable, Generic, Optional, Union
|
|
from unittest.mock import MagicMock
|
|
|
|
import httpx
|
|
import pytest
|
|
from asgi_lifespan import LifespanManager
|
|
from fastapi import FastAPI, Response
|
|
from httpx_oauth.oauth2 import OAuth2
|
|
from pydantic import UUID4, SecretStr
|
|
from pytest_mock import MockerFixture
|
|
|
|
from fastapi_users import exceptions, models, schemas
|
|
from fastapi_users.authentication import AuthenticationBackend, BearerTransport
|
|
from fastapi_users.authentication.strategy import Strategy
|
|
from fastapi_users.db import BaseUserDatabase
|
|
from fastapi_users.jwt import SecretType
|
|
from fastapi_users.manager import BaseUserManager, UUIDIDMixin
|
|
from fastapi_users.openapi import OpenAPIResponseType
|
|
from fastapi_users.password import PasswordHelper
|
|
|
|
password_helper = PasswordHelper()
|
|
guinevere_password_hash = password_helper.hash("guinevere")
|
|
angharad_password_hash = password_helper.hash("angharad")
|
|
viviane_password_hash = password_helper.hash("viviane")
|
|
lancelot_password_hash = password_helper.hash("lancelot")
|
|
excalibur_password_hash = password_helper.hash("excalibur")
|
|
|
|
|
|
IDType = UUID4
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class UserModel(models.UserProtocol[IDType]):
|
|
email: str
|
|
hashed_password: str
|
|
id: IDType = dataclasses.field(default_factory=uuid.uuid4)
|
|
is_active: bool = True
|
|
is_superuser: bool = False
|
|
is_verified: bool = False
|
|
first_name: Optional[str] = None
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class OAuthAccountModel(models.OAuthAccountProtocol[IDType]):
|
|
oauth_name: str
|
|
access_token: str
|
|
account_id: str
|
|
account_email: str
|
|
id: IDType = dataclasses.field(default_factory=uuid.uuid4)
|
|
expires_at: Optional[int] = None
|
|
refresh_token: Optional[str] = None
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class UserOAuthModel(UserModel):
|
|
oauth_accounts: list[OAuthAccountModel] = dataclasses.field(default_factory=list)
|
|
|
|
|
|
class User(schemas.BaseUser[IDType]):
|
|
first_name: Optional[str] = None
|
|
|
|
|
|
class UserCreate(schemas.BaseUserCreate):
|
|
first_name: Optional[str] = None
|
|
|
|
|
|
class UserUpdate(schemas.BaseUserUpdate):
|
|
first_name: Optional[str] = None
|
|
|
|
|
|
class UserOAuth(User, schemas.BaseOAuthAccountMixin):
|
|
pass
|
|
|
|
|
|
class BaseTestUserManager(
|
|
Generic[models.UP], UUIDIDMixin, BaseUserManager[models.UP, IDType]
|
|
):
|
|
reset_password_token_secret = "SECRET"
|
|
verification_token_secret = "SECRET"
|
|
|
|
async def validate_password(
|
|
self, password: str, user: Union[schemas.UC, models.UP]
|
|
) -> None:
|
|
if len(password) < 3:
|
|
raise exceptions.InvalidPasswordException(
|
|
reason="Password should be at least 3 characters"
|
|
)
|
|
|
|
|
|
class UserManager(BaseTestUserManager[UserModel]):
|
|
pass
|
|
|
|
|
|
class UserManagerOAuth(BaseTestUserManager[UserOAuthModel]):
|
|
pass
|
|
|
|
|
|
class UserManagerMock(BaseTestUserManager[models.UP]):
|
|
get_by_email: MagicMock
|
|
request_verify: MagicMock
|
|
verify: MagicMock
|
|
forgot_password: MagicMock
|
|
reset_password: MagicMock
|
|
on_after_register: MagicMock
|
|
on_after_request_verify: MagicMock
|
|
on_after_verify: MagicMock
|
|
on_after_forgot_password: MagicMock
|
|
on_after_reset_password: MagicMock
|
|
on_after_update: MagicMock
|
|
on_before_delete: MagicMock
|
|
on_after_delete: MagicMock
|
|
on_after_login: MagicMock
|
|
_update: MagicMock
|
|
|
|
|
|
AsyncMethodMocker = Callable[..., MagicMock]
|
|
|
|
|
|
@pytest.fixture
|
|
def async_method_mocker(mocker: MockerFixture) -> AsyncMethodMocker:
|
|
def _async_method_mocker(
|
|
object: Any,
|
|
method: str,
|
|
return_value: Any = None,
|
|
) -> MagicMock:
|
|
mock: MagicMock = mocker.MagicMock()
|
|
|
|
future: asyncio.Future = asyncio.Future()
|
|
future.set_result(return_value)
|
|
mock.return_value = future
|
|
mock.side_effect = None
|
|
|
|
setattr(object, method, mock)
|
|
|
|
return mock
|
|
|
|
return _async_method_mocker
|
|
|
|
|
|
@pytest.fixture(params=["SECRET", SecretStr("SECRET")])
|
|
def secret(request) -> SecretType:
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture
|
|
def user() -> UserModel:
|
|
return UserModel(
|
|
email="king.arthur@camelot.bt",
|
|
hashed_password=guinevere_password_hash,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def user_oauth(
|
|
oauth_account1: OAuthAccountModel, oauth_account2: OAuthAccountModel
|
|
) -> UserOAuthModel:
|
|
return UserOAuthModel(
|
|
email="king.arthur@camelot.bt",
|
|
hashed_password=guinevere_password_hash,
|
|
oauth_accounts=[oauth_account1, oauth_account2],
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def inactive_user() -> UserModel:
|
|
return UserModel(
|
|
email="percival@camelot.bt",
|
|
hashed_password=angharad_password_hash,
|
|
is_active=False,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def inactive_user_oauth(oauth_account3: OAuthAccountModel) -> UserOAuthModel:
|
|
return UserOAuthModel(
|
|
email="percival@camelot.bt",
|
|
hashed_password=angharad_password_hash,
|
|
is_active=False,
|
|
oauth_accounts=[oauth_account3],
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def verified_user() -> UserModel:
|
|
return UserModel(
|
|
email="lake.lady@camelot.bt",
|
|
hashed_password=excalibur_password_hash,
|
|
is_active=True,
|
|
is_verified=True,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def verified_user_oauth(oauth_account4: OAuthAccountModel) -> UserOAuthModel:
|
|
return UserOAuthModel(
|
|
email="lake.lady@camelot.bt",
|
|
hashed_password=excalibur_password_hash,
|
|
is_active=False,
|
|
oauth_accounts=[oauth_account4],
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def superuser() -> UserModel:
|
|
return UserModel(
|
|
email="merlin@camelot.bt",
|
|
hashed_password=viviane_password_hash,
|
|
is_superuser=True,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def superuser_oauth() -> UserOAuthModel:
|
|
return UserOAuthModel(
|
|
email="merlin@camelot.bt",
|
|
hashed_password=viviane_password_hash,
|
|
is_superuser=True,
|
|
oauth_accounts=[],
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def verified_superuser() -> UserModel:
|
|
return UserModel(
|
|
email="the.real.merlin@camelot.bt",
|
|
hashed_password=viviane_password_hash,
|
|
is_superuser=True,
|
|
is_verified=True,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def verified_superuser_oauth() -> UserOAuthModel:
|
|
return UserOAuthModel(
|
|
email="the.real.merlin@camelot.bt",
|
|
hashed_password=viviane_password_hash,
|
|
is_superuser=True,
|
|
is_verified=True,
|
|
oauth_accounts=[],
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def oauth_account1() -> OAuthAccountModel:
|
|
return OAuthAccountModel(
|
|
oauth_name="service1",
|
|
access_token="TOKEN",
|
|
expires_at=1579000751,
|
|
account_id="user_oauth1",
|
|
account_email="king.arthur@camelot.bt",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def oauth_account2() -> OAuthAccountModel:
|
|
return OAuthAccountModel(
|
|
oauth_name="service2",
|
|
access_token="TOKEN",
|
|
expires_at=1579000751,
|
|
account_id="user_oauth2",
|
|
account_email="king.arthur@camelot.bt",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def oauth_account3() -> OAuthAccountModel:
|
|
return OAuthAccountModel(
|
|
oauth_name="service3",
|
|
access_token="TOKEN",
|
|
expires_at=1579000751,
|
|
account_id="inactive_user_oauth1",
|
|
account_email="percival@camelot.bt",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def oauth_account4() -> OAuthAccountModel:
|
|
return OAuthAccountModel(
|
|
oauth_name="service4",
|
|
access_token="TOKEN",
|
|
expires_at=1579000751,
|
|
account_id="verified_user_oauth1",
|
|
account_email="lake.lady@camelot.bt",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def oauth_account5() -> OAuthAccountModel:
|
|
return OAuthAccountModel(
|
|
oauth_name="service5",
|
|
access_token="TOKEN",
|
|
expires_at=1579000751,
|
|
account_id="verified_superuser_oauth1",
|
|
account_email="the.real.merlin@camelot.bt",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_user_db(
|
|
user: UserModel,
|
|
verified_user: UserModel,
|
|
inactive_user: UserModel,
|
|
superuser: UserModel,
|
|
verified_superuser: UserModel,
|
|
) -> BaseUserDatabase[UserModel, IDType]:
|
|
class MockUserDatabase(BaseUserDatabase[UserModel, IDType]):
|
|
async def get(self, id: UUID4) -> Optional[UserModel]:
|
|
if id == user.id:
|
|
return user
|
|
if id == verified_user.id:
|
|
return verified_user
|
|
if id == inactive_user.id:
|
|
return inactive_user
|
|
if id == superuser.id:
|
|
return superuser
|
|
if id == verified_superuser.id:
|
|
return verified_superuser
|
|
return None
|
|
|
|
async def get_by_email(self, email: str) -> Optional[UserModel]:
|
|
lower_email = email.lower()
|
|
if lower_email == user.email.lower():
|
|
return user
|
|
if lower_email == verified_user.email.lower():
|
|
return verified_user
|
|
if lower_email == inactive_user.email.lower():
|
|
return inactive_user
|
|
if lower_email == superuser.email.lower():
|
|
return superuser
|
|
if lower_email == verified_superuser.email.lower():
|
|
return verified_superuser
|
|
return None
|
|
|
|
async def create(self, create_dict: dict[str, Any]) -> UserModel:
|
|
return UserModel(**create_dict)
|
|
|
|
async def update(
|
|
self, user: UserModel, update_dict: dict[str, Any]
|
|
) -> UserModel:
|
|
for field, value in update_dict.items():
|
|
setattr(user, field, value)
|
|
return user
|
|
|
|
async def delete(self, user: UserModel) -> None:
|
|
pass
|
|
|
|
return MockUserDatabase()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_user_db_oauth(
|
|
user_oauth: UserOAuthModel,
|
|
verified_user_oauth: UserOAuthModel,
|
|
inactive_user_oauth: UserOAuthModel,
|
|
superuser_oauth: UserOAuthModel,
|
|
verified_superuser_oauth: UserOAuthModel,
|
|
) -> BaseUserDatabase[UserOAuthModel, IDType]:
|
|
class MockUserDatabase(BaseUserDatabase[UserOAuthModel, IDType]):
|
|
async def get(self, id: UUID4) -> Optional[UserOAuthModel]:
|
|
if id == user_oauth.id:
|
|
return user_oauth
|
|
if id == verified_user_oauth.id:
|
|
return verified_user_oauth
|
|
if id == inactive_user_oauth.id:
|
|
return inactive_user_oauth
|
|
if id == superuser_oauth.id:
|
|
return superuser_oauth
|
|
if id == verified_superuser_oauth.id:
|
|
return verified_superuser_oauth
|
|
return None
|
|
|
|
async def get_by_email(self, email: str) -> Optional[UserOAuthModel]:
|
|
lower_email = email.lower()
|
|
if lower_email == user_oauth.email.lower():
|
|
return user_oauth
|
|
if lower_email == verified_user_oauth.email.lower():
|
|
return verified_user_oauth
|
|
if lower_email == inactive_user_oauth.email.lower():
|
|
return inactive_user_oauth
|
|
if lower_email == superuser_oauth.email.lower():
|
|
return superuser_oauth
|
|
if lower_email == verified_superuser_oauth.email.lower():
|
|
return verified_superuser_oauth
|
|
return None
|
|
|
|
async def get_by_oauth_account(
|
|
self, oauth: str, account_id: str
|
|
) -> Optional[UserOAuthModel]:
|
|
user_oauth_account = user_oauth.oauth_accounts[0]
|
|
if (
|
|
user_oauth_account.oauth_name == oauth
|
|
and user_oauth_account.account_id == account_id
|
|
):
|
|
return user_oauth
|
|
|
|
inactive_user_oauth_account = inactive_user_oauth.oauth_accounts[0]
|
|
if (
|
|
inactive_user_oauth_account.oauth_name == oauth
|
|
and inactive_user_oauth_account.account_id == account_id
|
|
):
|
|
return inactive_user_oauth
|
|
return None
|
|
|
|
async def create(self, create_dict: dict[str, Any]) -> UserOAuthModel:
|
|
return UserOAuthModel(**create_dict)
|
|
|
|
async def update(
|
|
self, user: UserOAuthModel, update_dict: dict[str, Any]
|
|
) -> UserOAuthModel:
|
|
for field, value in update_dict.items():
|
|
setattr(user, field, value)
|
|
return user
|
|
|
|
async def delete(self, user: UserOAuthModel) -> None:
|
|
pass
|
|
|
|
async def add_oauth_account(
|
|
self, user: UserOAuthModel, create_dict: dict[str, Any]
|
|
) -> UserOAuthModel:
|
|
oauth_account = OAuthAccountModel(**create_dict)
|
|
user.oauth_accounts.append(oauth_account)
|
|
return user
|
|
|
|
async def update_oauth_account( # type: ignore
|
|
self,
|
|
user: UserOAuthModel,
|
|
oauth_account: OAuthAccountModel,
|
|
update_dict: dict[str, Any],
|
|
) -> UserOAuthModel:
|
|
for field, value in update_dict.items():
|
|
setattr(oauth_account, field, value)
|
|
updated_oauth_accounts = []
|
|
for existing_oauth_account in user.oauth_accounts:
|
|
if (
|
|
existing_oauth_account.account_id == oauth_account.account_id
|
|
and existing_oauth_account.oauth_name == oauth_account.oauth_name
|
|
):
|
|
updated_oauth_accounts.append(oauth_account)
|
|
else:
|
|
updated_oauth_accounts.append(existing_oauth_account)
|
|
return user
|
|
|
|
return MockUserDatabase()
|
|
|
|
|
|
@pytest.fixture
|
|
def make_user_manager(mocker: MockerFixture):
|
|
def _make_user_manager(user_manager_class: type[BaseTestUserManager], mock_user_db):
|
|
user_manager = user_manager_class(mock_user_db)
|
|
mocker.spy(user_manager, "get_by_email")
|
|
mocker.spy(user_manager, "request_verify")
|
|
mocker.spy(user_manager, "verify")
|
|
mocker.spy(user_manager, "forgot_password")
|
|
mocker.spy(user_manager, "reset_password")
|
|
mocker.spy(user_manager, "on_after_register")
|
|
mocker.spy(user_manager, "on_after_request_verify")
|
|
mocker.spy(user_manager, "on_after_verify")
|
|
mocker.spy(user_manager, "on_after_forgot_password")
|
|
mocker.spy(user_manager, "on_after_reset_password")
|
|
mocker.spy(user_manager, "on_after_update")
|
|
mocker.spy(user_manager, "on_before_delete")
|
|
mocker.spy(user_manager, "on_after_delete")
|
|
mocker.spy(user_manager, "on_after_login")
|
|
mocker.spy(user_manager, "_update")
|
|
return user_manager
|
|
|
|
return _make_user_manager
|
|
|
|
|
|
@pytest.fixture
|
|
def user_manager(make_user_manager, mock_user_db):
|
|
return make_user_manager(UserManager, mock_user_db)
|
|
|
|
|
|
@pytest.fixture
|
|
def user_manager_oauth(make_user_manager, mock_user_db_oauth):
|
|
return make_user_manager(UserManagerOAuth, mock_user_db_oauth)
|
|
|
|
|
|
@pytest.fixture
|
|
def get_user_manager(user_manager):
|
|
def _get_user_manager():
|
|
return user_manager
|
|
|
|
return _get_user_manager
|
|
|
|
|
|
@pytest.fixture
|
|
def get_user_manager_oauth(user_manager_oauth):
|
|
def _get_user_manager_oauth():
|
|
return user_manager_oauth
|
|
|
|
return _get_user_manager_oauth
|
|
|
|
|
|
class MockTransport(BearerTransport):
|
|
def __init__(self, tokenUrl: str):
|
|
super().__init__(tokenUrl)
|
|
|
|
async def get_logout_response(self) -> Any:
|
|
return Response()
|
|
|
|
@staticmethod
|
|
def get_openapi_logout_responses_success() -> OpenAPIResponseType:
|
|
return {}
|
|
|
|
|
|
class MockStrategy(Strategy[UserModel, IDType]):
|
|
async def read_token(
|
|
self, token: Optional[str], user_manager: BaseUserManager[UserModel, IDType]
|
|
) -> Optional[UserModel]:
|
|
if token is not None:
|
|
try:
|
|
parsed_id = user_manager.parse_id(token)
|
|
return await user_manager.get(parsed_id)
|
|
except (exceptions.InvalidID, exceptions.UserNotExists):
|
|
return None
|
|
return None
|
|
|
|
async def write_token(self, user: UserModel) -> str:
|
|
return str(user.id)
|
|
|
|
async def destroy_token(self, token: str, user: UserModel) -> None:
|
|
return None
|
|
|
|
|
|
def get_mock_authentication(name: str):
|
|
return AuthenticationBackend(
|
|
name=name,
|
|
transport=MockTransport(tokenUrl="/login"),
|
|
get_strategy=lambda: MockStrategy(),
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_authentication():
|
|
return get_mock_authentication(name="mock")
|
|
|
|
|
|
@pytest.fixture
|
|
def get_test_client():
|
|
async def _get_test_client(app: FastAPI) -> AsyncGenerator[httpx.AsyncClient, None]:
|
|
async with LifespanManager(app):
|
|
async with httpx.AsyncClient(
|
|
app=app, base_url="http://app.io"
|
|
) as test_client:
|
|
yield test_client
|
|
|
|
return _get_test_client
|
|
|
|
|
|
@pytest.fixture
|
|
def oauth_client() -> OAuth2:
|
|
CLIENT_ID = "CLIENT_ID"
|
|
CLIENT_SECRET = "CLIENT_SECRET"
|
|
AUTHORIZE_ENDPOINT = "https://www.camelot.bt/authorize"
|
|
ACCESS_TOKEN_ENDPOINT = "https://www.camelot.bt/access-token"
|
|
|
|
return OAuth2(
|
|
CLIENT_ID,
|
|
CLIENT_SECRET,
|
|
AUTHORIZE_ENDPOINT,
|
|
ACCESS_TOKEN_ENDPOINT,
|
|
name="service1",
|
|
)
|