Files
fastapi-users/tests/conftest.py

513 lines
14 KiB
Python

import asyncio
from typing import Any, AsyncGenerator, Callable, Generic, Optional, Type, Union
from unittest.mock import MagicMock
import httpx
import pytest
from asgi_lifespan import LifespanManager
from fastapi import FastAPI, Response
from httpx_oauth.oauth2 import OAuth2
from pydantic import UUID4, SecretStr
from pytest_mock import MockerFixture
from fastapi_users import models
from fastapi_users.authentication import AuthenticationBackend, BearerTransport
from fastapi_users.authentication.strategy import Strategy
from fastapi_users.db import BaseUserDatabase
from fastapi_users.jwt import SecretType
from fastapi_users.manager import (
BaseUserManager,
InvalidPasswordException,
UserNotExists,
)
from fastapi_users.models import BaseOAuthAccount, BaseOAuthAccountMixin
from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users.password import PasswordHelper
password_helper = PasswordHelper()
guinevere_password_hash = password_helper.hash("guinevere")
angharad_password_hash = password_helper.hash("angharad")
viviane_password_hash = password_helper.hash("viviane")
lancelot_password_hash = password_helper.hash("lancelot")
excalibur_password_hash = password_helper.hash("excalibur")
class User(models.BaseUser):
first_name: Optional[str]
class UserCreate(models.BaseUserCreate):
first_name: Optional[str]
class UserUpdate(models.BaseUserUpdate):
first_name: Optional[str]
class UserDB(User, models.BaseUserDB):
pass
class UserOAuth(User, BaseOAuthAccountMixin):
pass
class UserDBOAuth(UserOAuth, UserDB):
pass
class BaseTestUserManager(
Generic[models.UC, models.UD], BaseUserManager[models.UC, models.UD]
):
reset_password_token_secret = "SECRET"
verification_token_secret = "SECRET"
async def validate_password(
self, password: str, user: Union[models.UC, models.UD]
) -> None:
if len(password) < 3:
raise InvalidPasswordException(
reason="Password should be at least 3 characters"
)
class UserManager(BaseTestUserManager[UserCreate, UserDB]):
user_db_model = UserDB
class UserManagerOAuth(BaseTestUserManager[UserCreate, UserDBOAuth]):
user_db_model = UserDBOAuth
class UserManagerMock(UserManager):
get_by_email: MagicMock
request_verify: MagicMock
verify: MagicMock
forgot_password: MagicMock
reset_password: MagicMock
on_after_register: MagicMock
on_after_request_verify: MagicMock
on_after_verify: MagicMock
on_after_forgot_password: MagicMock
on_after_reset_password: MagicMock
on_after_update: MagicMock
_update: MagicMock
@pytest.fixture(scope="session")
def event_loop():
"""Force the pytest-asyncio loop to be the main one."""
loop = asyncio.get_event_loop()
yield loop
AsyncMethodMocker = Callable[..., MagicMock]
@pytest.fixture
def async_method_mocker(mocker: MockerFixture) -> AsyncMethodMocker:
def _async_method_mocker(
object: Any,
method: str,
return_value: Any = None,
) -> MagicMock:
mock: MagicMock = mocker.MagicMock()
future: asyncio.Future = asyncio.Future()
future.set_result(return_value)
mock.return_value = future
mock.side_effect = None
setattr(object, method, mock)
return mock
return _async_method_mocker
@pytest.fixture(params=["SECRET", SecretStr("SECRET")])
def secret(request) -> SecretType:
return request.param
@pytest.fixture
def user() -> UserDB:
return UserDB(
email="king.arthur@camelot.bt",
hashed_password=guinevere_password_hash,
)
@pytest.fixture
def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth:
return UserDBOAuth(
email="king.arthur@camelot.bt",
hashed_password=guinevere_password_hash,
oauth_accounts=[oauth_account1, oauth_account2],
)
@pytest.fixture
def inactive_user() -> UserDB:
return UserDB(
email="percival@camelot.bt",
hashed_password=angharad_password_hash,
is_active=False,
)
@pytest.fixture
def inactive_user_oauth(oauth_account3) -> UserDBOAuth:
return UserDBOAuth(
email="percival@camelot.bt",
hashed_password=angharad_password_hash,
is_active=False,
oauth_accounts=[oauth_account3],
)
@pytest.fixture
def verified_user() -> UserDB:
return UserDB(
email="lake.lady@camelot.bt",
hashed_password=excalibur_password_hash,
is_active=True,
is_verified=True,
)
@pytest.fixture
def verified_user_oauth(oauth_account4) -> UserDBOAuth:
return UserDBOAuth(
email="lake.lady@camelot.bt",
hashed_password=excalibur_password_hash,
is_active=False,
oauth_accounts=[oauth_account4],
)
@pytest.fixture
def superuser() -> UserDB:
return UserDB(
email="merlin@camelot.bt",
hashed_password=viviane_password_hash,
is_superuser=True,
)
@pytest.fixture
def superuser_oauth() -> UserDBOAuth:
return UserDBOAuth(
email="merlin@camelot.bt",
hashed_password=viviane_password_hash,
is_superuser=True,
oauth_accounts=[],
)
@pytest.fixture
def verified_superuser() -> UserDB:
return UserDB(
email="the.real.merlin@camelot.bt",
hashed_password=viviane_password_hash,
is_superuser=True,
is_verified=True,
)
@pytest.fixture
def verified_superuser_oauth() -> UserDBOAuth:
return UserDBOAuth(
email="the.real.merlin@camelot.bt",
hashed_password=viviane_password_hash,
is_superuser=True,
is_verified=True,
oauth_accounts=[],
)
@pytest.fixture
def oauth_account1() -> BaseOAuthAccount:
return BaseOAuthAccount(
oauth_name="service1",
access_token="TOKEN",
expires_at=1579000751,
account_id="user_oauth1",
account_email="king.arthur@camelot.bt",
)
@pytest.fixture
def oauth_account2() -> BaseOAuthAccount:
return BaseOAuthAccount(
oauth_name="service2",
access_token="TOKEN",
expires_at=1579000751,
account_id="user_oauth2",
account_email="king.arthur@camelot.bt",
)
@pytest.fixture
def oauth_account3() -> BaseOAuthAccount:
return BaseOAuthAccount(
oauth_name="service3",
access_token="TOKEN",
expires_at=1579000751,
account_id="inactive_user_oauth1",
account_email="percival@camelot.bt",
)
@pytest.fixture
def oauth_account4() -> BaseOAuthAccount:
return BaseOAuthAccount(
oauth_name="service4",
access_token="TOKEN",
expires_at=1579000751,
account_id="verified_user_oauth1",
account_email="lake.lady@camelot.bt",
)
@pytest.fixture
def oauth_account5() -> BaseOAuthAccount:
return BaseOAuthAccount(
oauth_name="service5",
access_token="TOKEN",
expires_at=1579000751,
account_id="verified_superuser_oauth1",
account_email="the.real.merlin@camelot.bt",
)
@pytest.fixture
def mock_user_db(
user, verified_user, inactive_user, superuser, verified_superuser
) -> BaseUserDatabase:
class MockUserDatabase(BaseUserDatabase[UserDB]):
async def get(self, id: UUID4) -> Optional[UserDB]:
if id == user.id:
return user
if id == verified_user.id:
return verified_user
if id == inactive_user.id:
return inactive_user
if id == superuser.id:
return superuser
if id == verified_superuser.id:
return verified_superuser
return None
async def get_by_email(self, email: str) -> Optional[UserDB]:
lower_email = email.lower()
if lower_email == user.email.lower():
return user
if lower_email == verified_user.email.lower():
return verified_user
if lower_email == inactive_user.email.lower():
return inactive_user
if lower_email == superuser.email.lower():
return superuser
if lower_email == verified_superuser.email.lower():
return verified_superuser
return None
async def create(self, user: UserDB) -> UserDB:
return user
async def update(self, user: UserDB) -> UserDB:
return user
async def delete(self, user: UserDB) -> None:
pass
return MockUserDatabase(UserDB)
@pytest.fixture
def mock_user_db_oauth(
user_oauth,
verified_user_oauth,
inactive_user_oauth,
superuser_oauth,
verified_superuser_oauth,
) -> BaseUserDatabase:
class MockUserDatabase(BaseUserDatabase[UserDBOAuth]):
async def get(self, id: UUID4) -> Optional[UserDBOAuth]:
if id == user_oauth.id:
return user_oauth
if id == verified_user_oauth.id:
return verified_user_oauth
if id == inactive_user_oauth.id:
return inactive_user_oauth
if id == superuser_oauth.id:
return superuser_oauth
if id == verified_superuser_oauth.id:
return verified_superuser_oauth
return None
async def get_by_email(self, email: str) -> Optional[UserDBOAuth]:
lower_email = email.lower()
if lower_email == user_oauth.email.lower():
return user_oauth
if lower_email == verified_user_oauth.email.lower():
return verified_user_oauth
if lower_email == inactive_user_oauth.email.lower():
return inactive_user_oauth
if lower_email == superuser_oauth.email.lower():
return superuser_oauth
if lower_email == verified_superuser_oauth.email.lower():
return verified_superuser_oauth
return None
async def get_by_oauth_account(
self, oauth: str, account_id: str
) -> Optional[UserDBOAuth]:
user_oauth_account = user_oauth.oauth_accounts[0]
if (
user_oauth_account.oauth_name == oauth
and user_oauth_account.account_id == account_id
):
return user_oauth
inactive_user_oauth_account = inactive_user_oauth.oauth_accounts[0]
if (
inactive_user_oauth_account.oauth_name == oauth
and inactive_user_oauth_account.account_id == account_id
):
return inactive_user_oauth
return None
async def create(self, user: UserDBOAuth) -> UserDBOAuth:
return user_oauth
async def update(self, user: UserDBOAuth) -> UserDBOAuth:
return user_oauth
async def delete(self, user: UserDBOAuth) -> None:
pass
return MockUserDatabase(UserDBOAuth)
@pytest.fixture
def make_user_manager(mocker: MockerFixture):
def _make_user_manager(user_manager_class: Type[BaseTestUserManager], mock_user_db):
user_manager = user_manager_class(mock_user_db)
mocker.spy(user_manager, "get_by_email")
mocker.spy(user_manager, "request_verify")
mocker.spy(user_manager, "verify")
mocker.spy(user_manager, "forgot_password")
mocker.spy(user_manager, "reset_password")
mocker.spy(user_manager, "on_after_register")
mocker.spy(user_manager, "on_after_request_verify")
mocker.spy(user_manager, "on_after_verify")
mocker.spy(user_manager, "on_after_forgot_password")
mocker.spy(user_manager, "on_after_reset_password")
mocker.spy(user_manager, "on_after_update")
mocker.spy(user_manager, "_update")
return user_manager
return _make_user_manager
@pytest.fixture
def user_manager(make_user_manager, mock_user_db):
return make_user_manager(UserManager, mock_user_db)
@pytest.fixture
def user_manager_oauth(make_user_manager, mock_user_db_oauth):
return make_user_manager(UserManagerOAuth, mock_user_db_oauth)
@pytest.fixture
def get_user_manager(user_manager):
def _get_user_manager():
return user_manager
return _get_user_manager
@pytest.fixture
def get_user_manager_oauth(user_manager_oauth):
def _get_user_manager_oauth():
return user_manager_oauth
return _get_user_manager_oauth
class MockTransport(BearerTransport):
def __init__(self, tokenUrl: str):
super().__init__(tokenUrl)
async def get_logout_response(self, response: Response) -> Any:
return None
@staticmethod
def get_openapi_logout_responses_success() -> OpenAPIResponseType:
return {}
class MockStrategy(Strategy[UserCreate, UserDB]):
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[UserCreate, UserDB]
) -> Optional[UserDB]:
if token is not None:
try:
token_uuid = UUID4(token)
return await user_manager.get(token_uuid)
except ValueError:
return None
except UserNotExists:
return None
return None
async def write_token(self, user: models.UD) -> str:
return str(user.id)
async def destroy_token(self, token: str, user: models.UD) -> None:
return None
def get_mock_authentication(name: str):
return AuthenticationBackend(
name=name,
transport=MockTransport(tokenUrl="/login"),
get_strategy=lambda: MockStrategy(),
)
@pytest.fixture
def mock_authentication():
return get_mock_authentication(name="mock")
@pytest.fixture
def get_test_client():
async def _get_test_client(app: FastAPI) -> AsyncGenerator[httpx.AsyncClient, None]:
async with LifespanManager(app):
async with httpx.AsyncClient(
app=app, base_url="http://app.io"
) as test_client:
yield test_client
return _get_test_client
@pytest.fixture
def oauth_client() -> OAuth2:
CLIENT_ID = "CLIENT_ID"
CLIENT_SECRET = "CLIENT_SECRET"
AUTHORIZE_ENDPOINT = "https://www.camelot.bt/authorize"
ACCESS_TOKEN_ENDPOINT = "https://www.camelot.bt/access-token"
return OAuth2(
CLIENT_ID,
CLIENT_SECRET,
AUTHORIZE_ENDPOINT,
ACCESS_TOKEN_ENDPOINT,
name="service1",
)