Files
fastapi-users/tests/conftest.py
2024-11-03 12:51:32 +00:00

570 lines
17 KiB
Python

import asyncio
import dataclasses
import uuid
from collections.abc import AsyncGenerator
from typing import Any, Callable, Generic, Optional, Union
from unittest.mock import MagicMock
import httpx
import pytest
from asgi_lifespan import LifespanManager
from fastapi import FastAPI, Response
from httpx_oauth.oauth2 import OAuth2
from pydantic import UUID4, SecretStr
from pytest_mock import MockerFixture
from fastapi_users import exceptions, models, schemas
from fastapi_users.authentication import AuthenticationBackend, BearerTransport
from fastapi_users.authentication.strategy import Strategy
from fastapi_users.db import BaseUserDatabase
from fastapi_users.jwt import SecretType
from fastapi_users.manager import BaseUserManager, UUIDIDMixin
from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users.password import PasswordHelper
password_helper = PasswordHelper()
guinevere_password_hash = password_helper.hash("guinevere")
angharad_password_hash = password_helper.hash("angharad")
viviane_password_hash = password_helper.hash("viviane")
lancelot_password_hash = password_helper.hash("lancelot")
excalibur_password_hash = password_helper.hash("excalibur")
IDType = UUID4
@dataclasses.dataclass
class UserModel(models.UserProtocol[IDType]):
email: str
hashed_password: str
id: IDType = dataclasses.field(default_factory=uuid.uuid4)
is_active: bool = True
is_superuser: bool = False
is_verified: bool = False
first_name: Optional[str] = None
@dataclasses.dataclass
class OAuthAccountModel(models.OAuthAccountProtocol[IDType]):
oauth_name: str
access_token: str
account_id: str
account_email: str
id: IDType = dataclasses.field(default_factory=uuid.uuid4)
expires_at: Optional[int] = None
refresh_token: Optional[str] = None
@dataclasses.dataclass
class UserOAuthModel(UserModel):
oauth_accounts: list[OAuthAccountModel] = dataclasses.field(default_factory=list)
class User(schemas.BaseUser[IDType]):
first_name: Optional[str] = None
class UserCreate(schemas.BaseUserCreate):
first_name: Optional[str] = None
class UserUpdate(schemas.BaseUserUpdate):
first_name: Optional[str] = None
class UserOAuth(User, schemas.BaseOAuthAccountMixin):
pass
class BaseTestUserManager(
Generic[models.UP], UUIDIDMixin, BaseUserManager[models.UP, IDType]
):
reset_password_token_secret = "SECRET"
verification_token_secret = "SECRET"
async def validate_password(
self, password: str, user: Union[schemas.UC, models.UP]
) -> None:
if len(password) < 3:
raise exceptions.InvalidPasswordException(
reason="Password should be at least 3 characters"
)
class UserManager(BaseTestUserManager[UserModel]):
pass
class UserManagerOAuth(BaseTestUserManager[UserOAuthModel]):
pass
class UserManagerMock(BaseTestUserManager[models.UP]):
get_by_email: MagicMock
request_verify: MagicMock
verify: MagicMock
forgot_password: MagicMock
reset_password: MagicMock
on_after_register: MagicMock
on_after_request_verify: MagicMock
on_after_verify: MagicMock
on_after_forgot_password: MagicMock
on_after_reset_password: MagicMock
on_after_update: MagicMock
on_before_delete: MagicMock
on_after_delete: MagicMock
on_after_login: MagicMock
_update: MagicMock
AsyncMethodMocker = Callable[..., MagicMock]
@pytest.fixture
def async_method_mocker(mocker: MockerFixture) -> AsyncMethodMocker:
def _async_method_mocker(
object: Any,
method: str,
return_value: Any = None,
) -> MagicMock:
mock: MagicMock = mocker.MagicMock()
future: asyncio.Future = asyncio.Future()
future.set_result(return_value)
mock.return_value = future
mock.side_effect = None
setattr(object, method, mock)
return mock
return _async_method_mocker
@pytest.fixture(params=["SECRET", SecretStr("SECRET")])
def secret(request) -> SecretType:
return request.param
@pytest.fixture
def user() -> UserModel:
return UserModel(
email="king.arthur@camelot.bt",
hashed_password=guinevere_password_hash,
)
@pytest.fixture
def user_oauth(
oauth_account1: OAuthAccountModel, oauth_account2: OAuthAccountModel
) -> UserOAuthModel:
return UserOAuthModel(
email="king.arthur@camelot.bt",
hashed_password=guinevere_password_hash,
oauth_accounts=[oauth_account1, oauth_account2],
)
@pytest.fixture
def inactive_user() -> UserModel:
return UserModel(
email="percival@camelot.bt",
hashed_password=angharad_password_hash,
is_active=False,
)
@pytest.fixture
def inactive_user_oauth(oauth_account3: OAuthAccountModel) -> UserOAuthModel:
return UserOAuthModel(
email="percival@camelot.bt",
hashed_password=angharad_password_hash,
is_active=False,
oauth_accounts=[oauth_account3],
)
@pytest.fixture
def verified_user() -> UserModel:
return UserModel(
email="lake.lady@camelot.bt",
hashed_password=excalibur_password_hash,
is_active=True,
is_verified=True,
)
@pytest.fixture
def verified_user_oauth(oauth_account4: OAuthAccountModel) -> UserOAuthModel:
return UserOAuthModel(
email="lake.lady@camelot.bt",
hashed_password=excalibur_password_hash,
is_active=False,
oauth_accounts=[oauth_account4],
)
@pytest.fixture
def superuser() -> UserModel:
return UserModel(
email="merlin@camelot.bt",
hashed_password=viviane_password_hash,
is_superuser=True,
)
@pytest.fixture
def superuser_oauth() -> UserOAuthModel:
return UserOAuthModel(
email="merlin@camelot.bt",
hashed_password=viviane_password_hash,
is_superuser=True,
oauth_accounts=[],
)
@pytest.fixture
def verified_superuser() -> UserModel:
return UserModel(
email="the.real.merlin@camelot.bt",
hashed_password=viviane_password_hash,
is_superuser=True,
is_verified=True,
)
@pytest.fixture
def verified_superuser_oauth() -> UserOAuthModel:
return UserOAuthModel(
email="the.real.merlin@camelot.bt",
hashed_password=viviane_password_hash,
is_superuser=True,
is_verified=True,
oauth_accounts=[],
)
@pytest.fixture
def oauth_account1() -> OAuthAccountModel:
return OAuthAccountModel(
oauth_name="service1",
access_token="TOKEN",
expires_at=1579000751,
account_id="user_oauth1",
account_email="king.arthur@camelot.bt",
)
@pytest.fixture
def oauth_account2() -> OAuthAccountModel:
return OAuthAccountModel(
oauth_name="service2",
access_token="TOKEN",
expires_at=1579000751,
account_id="user_oauth2",
account_email="king.arthur@camelot.bt",
)
@pytest.fixture
def oauth_account3() -> OAuthAccountModel:
return OAuthAccountModel(
oauth_name="service3",
access_token="TOKEN",
expires_at=1579000751,
account_id="inactive_user_oauth1",
account_email="percival@camelot.bt",
)
@pytest.fixture
def oauth_account4() -> OAuthAccountModel:
return OAuthAccountModel(
oauth_name="service4",
access_token="TOKEN",
expires_at=1579000751,
account_id="verified_user_oauth1",
account_email="lake.lady@camelot.bt",
)
@pytest.fixture
def oauth_account5() -> OAuthAccountModel:
return OAuthAccountModel(
oauth_name="service5",
access_token="TOKEN",
expires_at=1579000751,
account_id="verified_superuser_oauth1",
account_email="the.real.merlin@camelot.bt",
)
@pytest.fixture
def mock_user_db(
user: UserModel,
verified_user: UserModel,
inactive_user: UserModel,
superuser: UserModel,
verified_superuser: UserModel,
) -> BaseUserDatabase[UserModel, IDType]:
class MockUserDatabase(BaseUserDatabase[UserModel, IDType]):
async def get(self, id: UUID4) -> Optional[UserModel]:
if id == user.id:
return user
if id == verified_user.id:
return verified_user
if id == inactive_user.id:
return inactive_user
if id == superuser.id:
return superuser
if id == verified_superuser.id:
return verified_superuser
return None
async def get_by_email(self, email: str) -> Optional[UserModel]:
lower_email = email.lower()
if lower_email == user.email.lower():
return user
if lower_email == verified_user.email.lower():
return verified_user
if lower_email == inactive_user.email.lower():
return inactive_user
if lower_email == superuser.email.lower():
return superuser
if lower_email == verified_superuser.email.lower():
return verified_superuser
return None
async def create(self, create_dict: dict[str, Any]) -> UserModel:
return UserModel(**create_dict)
async def update(
self, user: UserModel, update_dict: dict[str, Any]
) -> UserModel:
for field, value in update_dict.items():
setattr(user, field, value)
return user
async def delete(self, user: UserModel) -> None:
pass
return MockUserDatabase()
@pytest.fixture
def mock_user_db_oauth(
user_oauth: UserOAuthModel,
verified_user_oauth: UserOAuthModel,
inactive_user_oauth: UserOAuthModel,
superuser_oauth: UserOAuthModel,
verified_superuser_oauth: UserOAuthModel,
) -> BaseUserDatabase[UserOAuthModel, IDType]:
class MockUserDatabase(BaseUserDatabase[UserOAuthModel, IDType]):
async def get(self, id: UUID4) -> Optional[UserOAuthModel]:
if id == user_oauth.id:
return user_oauth
if id == verified_user_oauth.id:
return verified_user_oauth
if id == inactive_user_oauth.id:
return inactive_user_oauth
if id == superuser_oauth.id:
return superuser_oauth
if id == verified_superuser_oauth.id:
return verified_superuser_oauth
return None
async def get_by_email(self, email: str) -> Optional[UserOAuthModel]:
lower_email = email.lower()
if lower_email == user_oauth.email.lower():
return user_oauth
if lower_email == verified_user_oauth.email.lower():
return verified_user_oauth
if lower_email == inactive_user_oauth.email.lower():
return inactive_user_oauth
if lower_email == superuser_oauth.email.lower():
return superuser_oauth
if lower_email == verified_superuser_oauth.email.lower():
return verified_superuser_oauth
return None
async def get_by_oauth_account(
self, oauth: str, account_id: str
) -> Optional[UserOAuthModel]:
user_oauth_account = user_oauth.oauth_accounts[0]
if (
user_oauth_account.oauth_name == oauth
and user_oauth_account.account_id == account_id
):
return user_oauth
inactive_user_oauth_account = inactive_user_oauth.oauth_accounts[0]
if (
inactive_user_oauth_account.oauth_name == oauth
and inactive_user_oauth_account.account_id == account_id
):
return inactive_user_oauth
return None
async def create(self, create_dict: dict[str, Any]) -> UserOAuthModel:
return UserOAuthModel(**create_dict)
async def update(
self, user: UserOAuthModel, update_dict: dict[str, Any]
) -> UserOAuthModel:
for field, value in update_dict.items():
setattr(user, field, value)
return user
async def delete(self, user: UserOAuthModel) -> None:
pass
async def add_oauth_account(
self, user: UserOAuthModel, create_dict: dict[str, Any]
) -> UserOAuthModel:
oauth_account = OAuthAccountModel(**create_dict)
user.oauth_accounts.append(oauth_account)
return user
async def update_oauth_account( # type: ignore
self,
user: UserOAuthModel,
oauth_account: OAuthAccountModel,
update_dict: dict[str, Any],
) -> UserOAuthModel:
for field, value in update_dict.items():
setattr(oauth_account, field, value)
updated_oauth_accounts = []
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == oauth_account.account_id
and existing_oauth_account.oauth_name == oauth_account.oauth_name
):
updated_oauth_accounts.append(oauth_account)
else:
updated_oauth_accounts.append(existing_oauth_account)
return user
return MockUserDatabase()
@pytest.fixture
def make_user_manager(mocker: MockerFixture):
def _make_user_manager(user_manager_class: type[BaseTestUserManager], mock_user_db):
user_manager = user_manager_class(mock_user_db)
mocker.spy(user_manager, "get_by_email")
mocker.spy(user_manager, "request_verify")
mocker.spy(user_manager, "verify")
mocker.spy(user_manager, "forgot_password")
mocker.spy(user_manager, "reset_password")
mocker.spy(user_manager, "on_after_register")
mocker.spy(user_manager, "on_after_request_verify")
mocker.spy(user_manager, "on_after_verify")
mocker.spy(user_manager, "on_after_forgot_password")
mocker.spy(user_manager, "on_after_reset_password")
mocker.spy(user_manager, "on_after_update")
mocker.spy(user_manager, "on_before_delete")
mocker.spy(user_manager, "on_after_delete")
mocker.spy(user_manager, "on_after_login")
mocker.spy(user_manager, "_update")
return user_manager
return _make_user_manager
@pytest.fixture
def user_manager(make_user_manager, mock_user_db):
return make_user_manager(UserManager, mock_user_db)
@pytest.fixture
def user_manager_oauth(make_user_manager, mock_user_db_oauth):
return make_user_manager(UserManagerOAuth, mock_user_db_oauth)
@pytest.fixture
def get_user_manager(user_manager):
def _get_user_manager():
return user_manager
return _get_user_manager
@pytest.fixture
def get_user_manager_oauth(user_manager_oauth):
def _get_user_manager_oauth():
return user_manager_oauth
return _get_user_manager_oauth
class MockTransport(BearerTransport):
def __init__(self, tokenUrl: str):
super().__init__(tokenUrl)
async def get_logout_response(self) -> Any:
return Response()
@staticmethod
def get_openapi_logout_responses_success() -> OpenAPIResponseType:
return {}
class MockStrategy(Strategy[UserModel, IDType]):
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[UserModel, IDType]
) -> Optional[UserModel]:
if token is not None:
try:
parsed_id = user_manager.parse_id(token)
return await user_manager.get(parsed_id)
except (exceptions.InvalidID, exceptions.UserNotExists):
return None
return None
async def write_token(self, user: UserModel) -> str:
return str(user.id)
async def destroy_token(self, token: str, user: UserModel) -> None:
return None
def get_mock_authentication(name: str):
return AuthenticationBackend(
name=name,
transport=MockTransport(tokenUrl="/login"),
get_strategy=lambda: MockStrategy(),
)
@pytest.fixture
def mock_authentication():
return get_mock_authentication(name="mock")
@pytest.fixture
def get_test_client():
async def _get_test_client(app: FastAPI) -> AsyncGenerator[httpx.AsyncClient, None]:
async with LifespanManager(app):
async with httpx.AsyncClient(
app=app, base_url="http://app.io"
) as test_client:
yield test_client
return _get_test_client
@pytest.fixture
def oauth_client() -> OAuth2:
CLIENT_ID = "CLIENT_ID"
CLIENT_SECRET = "CLIENT_SECRET"
AUTHORIZE_ENDPOINT = "https://www.camelot.bt/authorize"
ACCESS_TOKEN_ENDPOINT = "https://www.camelot.bt/access-token"
return OAuth2(
CLIENT_ID,
CLIENT_SECRET,
AUTHORIZE_ENDPOINT,
ACCESS_TOKEN_ENDPOINT,
name="service1",
)