Refactor OAuth logic into manager

This commit is contained in:
François Voron
2021-09-15 11:57:17 +02:00
parent 9673e0a5fd
commit 3bdae94869
7 changed files with 293 additions and 303 deletions

View File

@ -126,7 +126,6 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]):
oauth_client: BaseOAuth2,
state_secret: SecretType,
redirect_url: str = None,
after_register: Optional[Callable[[models.UD, Request], None]] = None,
) -> APIRouter:
"""
Return an OAuth router for a given OAuth client.
@ -135,17 +134,13 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]):
:param state_secret: Secret used to encode the state JWT.
:param redirect_url: Optional arbitrary redirect URL for the OAuth2 flow.
If not given, the URL to the callback endpoint will be generated.
:param after_register: Optional function called
after a successful registration.
"""
return get_oauth_router(
oauth_client,
self.get_user_manager,
self._user_db_model,
self.authenticator,
state_secret,
redirect_url,
after_register,
)
def get_users_router(

View File

@ -8,7 +8,7 @@ from pydantic import UUID4
from fastapi_users import models, password
from fastapi_users.db import BaseUserDatabase
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.password import get_password_hash
from fastapi_users.password import generate_password, get_password_hash
RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset"
@ -104,6 +104,42 @@ class BaseUserManager(Generic[models.UC, models.UD]):
return created_user
async def oauth_callback(
self, oauth_account: models.BaseOAuthAccount, request: Optional[Request] = None
) -> models.UD:
try:
user = await self.get_by_oauth_account(
oauth_account.oauth_name, oauth_account.account_id
)
except UserNotExists:
try:
# Link account
user = await self.get_by_email(oauth_account.account_email)
user.oauth_accounts.append(oauth_account) # type: ignore
await self.user_db.update(user)
except UserNotExists:
# Create account
password = generate_password()
user = self.user_db_model(
email=oauth_account.account_email,
hashed_password=get_password_hash(password),
oauth_accounts=[oauth_account],
)
await self.user_db.create(user)
await self.on_after_register(user, request)
else:
# Update oauth
updated_oauth_accounts = []
for existing_oauth_account in user.oauth_accounts: # type: ignore
if existing_oauth_account.account_id == oauth_account.account_id:
updated_oauth_accounts.append(oauth_account)
else:
updated_oauth_accounts.append(existing_oauth_account)
user.oauth_accounts = updated_oauth_accounts # type: ignore
await self.user_db.update(user)
return user
async def forgot_password(
self, user: models.UD, request: Optional[Request] = None
) -> None:

View File

@ -1,4 +1,4 @@
from typing import Callable, Dict, List, Optional, Type
from typing import Dict, List
import jwt
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status
@ -8,9 +8,8 @@ from httpx_oauth.oauth2 import BaseOAuth2
from fastapi_users import models
from fastapi_users.authentication import Authenticator
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.manager import BaseUserManager, UserManagerDependency, UserNotExists
from fastapi_users.password import generate_password, get_password_hash
from fastapi_users.router.common import ErrorCode, run_handler
from fastapi_users.manager import BaseUserManager, UserManagerDependency
from fastapi_users.router.common import ErrorCode
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
@ -25,11 +24,9 @@ def generate_state_token(
def get_oauth_router(
oauth_client: BaseOAuth2,
get_user_manager: UserManagerDependency[models.UC, models.UD],
user_db_model: Type[models.UD],
authenticator: Authenticator,
state_secret: SecretType,
redirect_url: str = None,
after_register: Optional[Callable[[models.UD, Request], None]] = None,
) -> APIRouter:
"""Generate a router with the OAuth routes."""
router = APIRouter()
@ -104,37 +101,7 @@ def get_oauth_router(
account_email=account_email,
)
try:
user = await user_manager.get_by_oauth_account(
oauth_client.name, account_id
)
except UserNotExists:
try:
# Link account
user = await user_manager.get_by_email(account_email)
user.oauth_accounts.append(new_oauth_account) # type: ignore
await user_manager.user_db.update(user)
except UserNotExists:
# Create account
password = generate_password()
user = user_db_model(
email=account_email,
hashed_password=get_password_hash(password),
oauth_accounts=[new_oauth_account],
)
await user_manager.user_db.create(user)
if after_register:
await run_handler(after_register, user, request)
else:
# Update oauth
updated_oauth_accounts = []
for oauth_account in user.oauth_accounts: # type: ignore
if oauth_account.account_id == account_id:
updated_oauth_accounts.append(new_oauth_account)
else:
updated_oauth_accounts.append(oauth_account)
user.oauth_accounts = updated_oauth_accounts # type: ignore
await user_manager.user_db.update(user)
user = await user_manager.oauth_callback(new_oauth_account, request)
if not user.is_active:
raise HTTPException(

View File

@ -1,5 +1,5 @@
import asyncio
from typing import AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, Callable, List, Optional, Union
from unittest.mock import MagicMock
import httpx
@ -65,16 +65,6 @@ class UserManager(BaseUserManager[UserCreate, UserDB]):
reason="Password should be at least 3 characters"
)
def mock_method(self, name: str):
mock = MagicMock()
future: asyncio.Future = asyncio.Future()
future.set_result(None)
mock.return_value = future
mock.side_effect = None
setattr(self, name, mock)
class UserManagerMock(UserManager):
get_by_email: MagicMock
@ -94,6 +84,30 @@ def event_loop():
yield loop
AsyncMethodMocker = Callable[..., MagicMock]
@pytest.fixture
def async_method_mocker(mocker: MockerFixture) -> AsyncMethodMocker:
def _async_method_mocker(
object: Any,
method: str,
return_value: Any = None,
) -> MagicMock:
mock: MagicMock = mocker.MagicMock()
future: asyncio.Future = asyncio.Future()
future.set_result(return_value)
mock.return_value = future
mock.side_effect = None
setattr(object, method, mock)
return mock
return _async_method_mocker
@pytest.fixture(params=["SECRET", SecretStr("SECRET")])
def secret(request) -> SecretType:
return request.param
@ -361,33 +375,30 @@ def mock_user_db_oauth(
@pytest.fixture
def get_mock_user_db(mock_user_db):
def _get_mock_user_db():
yield mock_user_db
def make_user_manager(mocker: MockerFixture):
def _make_user_manager(user_db_model, mock_user_db):
user_manager = UserManager(user_db_model, mock_user_db)
mocker.spy(user_manager, "get_by_email")
mocker.spy(user_manager, "forgot_password")
mocker.spy(user_manager, "reset_password")
mocker.spy(user_manager, "on_after_register")
mocker.spy(user_manager, "on_after_update")
mocker.spy(user_manager, "on_after_forgot_password")
mocker.spy(user_manager, "on_after_reset_password")
mocker.spy(user_manager, "_update")
return user_manager
return _get_mock_user_db
return _make_user_manager
@pytest.fixture
def get_mock_user_db_oauth(mock_user_db_oauth):
def _get_mock_user_db_oauth():
yield mock_user_db_oauth
return _get_mock_user_db_oauth
def user_manager(make_user_manager, mock_user_db):
return make_user_manager(UserDB, mock_user_db)
@pytest.fixture
def user_manager(mocker: MockerFixture, mock_user_db):
user_manager = UserManager(UserDB, mock_user_db)
mocker.spy(user_manager, "get_by_email")
mocker.spy(user_manager, "forgot_password")
mocker.spy(user_manager, "reset_password")
mocker.spy(user_manager, "on_after_register")
mocker.spy(user_manager, "on_after_update")
mocker.spy(user_manager, "on_after_forgot_password")
mocker.spy(user_manager, "on_after_reset_password")
mocker.spy(user_manager, "_update")
return user_manager
def user_manager_oauth(make_user_manager, mock_user_db_oauth):
return make_user_manager(UserDBOAuth, mock_user_db_oauth)
@pytest.fixture
@ -399,9 +410,9 @@ def get_user_manager(user_manager):
@pytest.fixture
def get_user_manager_oauth(get_mock_user_db_oauth):
def _get_user_manager_oauth(user_db=Depends(get_mock_user_db_oauth)):
return UserManager(UserDBOAuth, user_db)
def get_user_manager_oauth(user_manager_oauth):
def _get_user_manager_oauth():
return user_manager_oauth
return _get_user_manager_oauth

View File

@ -1,9 +1,10 @@
from typing import Callable
from typing import Callable, cast
import pytest
from fastapi.security import OAuth2PasswordRequestForm
from pytest_mock import MockerFixture
from fastapi_users import models
from fastapi_users.jwt import decode_jwt, generate_jwt
from fastapi_users.manager import (
InvalidPasswordException,
@ -13,7 +14,7 @@ from fastapi_users.manager import (
UserInactive,
UserNotExists,
)
from tests.conftest import UserCreate, UserDB, UserManagerMock, UserUpdate
from tests.conftest import UserCreate, UserDB, UserDBOAuth, UserManagerMock, UserUpdate
@pytest.fixture
@ -85,6 +86,61 @@ class TestCreateUser:
assert user_manager.on_after_register.called is True
@pytest.mark.asyncio
class TestOAuthCallback:
async def test_existing_user_with_oauth(
self, user_manager_oauth: UserManagerMock, user_oauth: UserDBOAuth
):
oauth_account = models.BaseOAuthAccount(
**user_oauth.oauth_accounts[0].dict(exclude={"id", "access_token"}),
access_token="UPDATED_TOKEN"
)
user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account))
assert user.id == user_oauth.id
assert len(user.oauth_accounts) == 2
assert user.oauth_accounts[0].oauth_name == "service1"
assert user.oauth_accounts[0].access_token == "UPDATED_TOKEN"
assert user.oauth_accounts[1].access_token == "TOKEN"
assert user.oauth_accounts[1].oauth_name == "service2"
assert user_manager_oauth.on_after_register.called is False
async def test_existing_user_without_oauth(
self, user_manager_oauth: UserManagerMock, superuser_oauth: UserDBOAuth
):
oauth_account = models.BaseOAuthAccount(
oauth_name="service1",
access_token="TOKEN",
expires_at=1579000751,
account_id="superuser_oauth1",
account_email=superuser_oauth.email,
)
user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account))
assert user.id == superuser_oauth.id
assert len(user.oauth_accounts) == 1
assert user.oauth_accounts[0].id == oauth_account.id
assert user_manager_oauth.on_after_register.called is False
async def test_new_user(self, user_manager_oauth: UserManagerMock):
oauth_account = models.BaseOAuthAccount(
oauth_name="service1",
access_token="TOKEN",
expires_at=1579000751,
account_id="new_user_oauth1",
account_email="galahad@camelot.bt",
)
user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account))
assert user.email == "galahad@camelot.bt"
assert len(user.oauth_accounts) == 1
assert user.oauth_accounts[0].id == oauth_account.id
assert user_manager_oauth.on_after_register.called is True
@pytest.mark.asyncio
class TestUpdateUser:
async def test_safe_update(self, user: UserDB, user_manager: UserManagerMock):

View File

@ -1,28 +1,18 @@
from typing import Any, AsyncGenerator, Dict, cast
from unittest.mock import MagicMock
import asynctest
import httpx
import pytest
from fastapi import FastAPI, Request, status
from fastapi import FastAPI, status
from httpx_oauth.oauth2 import BaseOAuth2
from fastapi_users.authentication import Authenticator
from fastapi_users.router.common import ErrorCode
from fastapi_users.router.oauth import generate_state_token, get_oauth_router
from tests.conftest import MockAuthentication, UserDB
def after_register_sync():
return MagicMock(return_value=None)
def after_register_async():
return asynctest.CoroutineMock(return_value=None)
@pytest.fixture(params=[after_register_sync, after_register_async])
def after_register(request):
return request.param()
from tests.conftest import (
AsyncMethodMocker,
MockAuthentication,
UserDB,
UserManagerMock,
)
@pytest.fixture
@ -31,7 +21,6 @@ def get_test_app_client(
get_user_manager_oauth,
mock_authentication,
oauth_client,
after_register,
get_test_client,
):
async def _get_test_app_client(
@ -45,11 +34,9 @@ def get_test_app_client(
oauth_router = get_oauth_router(
oauth_client,
get_user_manager_oauth,
UserDB,
authenticator,
secret,
redirect_url,
after_register,
)
app = FastAPI()
@ -80,64 +67,86 @@ async def test_app_client_redirect_url(get_test_app_client):
@pytest.mark.asyncio
class TestAuthorize:
async def test_missing_authentication_backend(
self, test_app_client: httpx.AsyncClient, oauth_client
self,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
oauth_client: BaseOAuth2,
):
with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
mock.return_value = "AUTHORIZATION_URL"
response = await test_app_client.get(
"/authorize",
params={"scopes": ["scope1", "scope2"]},
)
async_method_mocker(
oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
)
response = await test_app_client.get(
"/authorize",
params={"scopes": ["scope1", "scope2"]},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
async def test_wrong_authentication_backend(
self, test_app_client: httpx.AsyncClient, oauth_client
self,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
oauth_client: BaseOAuth2,
):
with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
mock.return_value = "AUTHORIZATION_URL"
response = await test_app_client.get(
"/authorize",
params={
"authentication_backend": "foo",
"scopes": ["scope1", "scope2"],
},
)
async_method_mocker(
oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
)
response = await test_app_client.get(
"/authorize",
params={
"authentication_backend": "foo",
"scopes": ["scope1", "scope2"],
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
async def test_success(self, test_app_client: httpx.AsyncClient, oauth_client):
with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
mock.return_value = "AUTHORIZATION_URL"
response = await test_app_client.get(
"/authorize",
params={
"authentication_backend": "mock",
"scopes": ["scope1", "scope2"],
},
)
async def test_success(
self,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
oauth_client: BaseOAuth2,
):
get_authorization_url_mock = async_method_mocker(
oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
)
response = await test_app_client.get(
"/authorize",
params={
"authentication_backend": "mock",
"scopes": ["scope1", "scope2"],
},
)
assert response.status_code == status.HTTP_200_OK
mock.assert_awaited_once()
get_authorization_url_mock.assert_called_once()
data = response.json()
assert "authorization_url" in data
async def test_with_redirect_url(
self, test_app_client_redirect_url: httpx.AsyncClient, oauth_client
self,
async_method_mocker: AsyncMethodMocker,
test_app_client_redirect_url: httpx.AsyncClient,
oauth_client: BaseOAuth2,
):
with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
mock.return_value = "AUTHORIZATION_URL"
response = await test_app_client_redirect_url.get(
"/authorize",
params={
"authentication_backend": "mock",
"scopes": ["scope1", "scope2"],
},
)
get_authorization_url_mock = async_method_mocker(
oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
)
response = await test_app_client_redirect_url.get(
"/authorize",
params={
"authentication_backend": "mock",
"scopes": ["scope1", "scope2"],
},
)
assert response.status_code == status.HTTP_200_OK
mock.assert_awaited_once()
get_authorization_url_mock.assert_called_once()
data = response.json()
assert "authorization_url" in data
@ -156,196 +165,110 @@ class TestAuthorize:
class TestCallback:
async def test_invalid_state(
self,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
access_token,
oauth_client,
user_oauth,
after_register,
oauth_client: BaseOAuth2,
user_oauth: UserDB,
access_token: str,
):
with asynctest.patch.object(
oauth_client, "get_access_token"
) as get_access_token_mock:
get_access_token_mock.return_value = access_token
with asynctest.patch.object(
oauth_client, "get_id_email"
) as get_id_email_mock:
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
response = await test_app_client.get(
"/callback",
params={"code": "CODE", "state": "STATE"},
)
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
get_id_email_mock = async_method_mocker(
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
)
get_id_email_mock.assert_awaited_once_with("TOKEN")
response = await test_app_client.get(
"/callback",
params={"code": "CODE", "state": "STATE"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert after_register.called is False
get_id_email_mock.assert_called_once_with("TOKEN")
async def test_existing_user_with_oauth(
async def test_active_user(
self,
mock_user_db_oauth,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
access_token,
oauth_client,
user_oauth,
after_register,
oauth_client: BaseOAuth2,
user_oauth: UserDB,
user_manager_oauth: UserManagerMock,
access_token: str,
):
state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
with asynctest.patch.object(
oauth_client, "get_access_token"
) as get_access_token_mock:
get_access_token_mock.return_value = access_token
with asynctest.patch.object(
oauth_client, "get_id_email"
) as get_id_email_mock:
with asynctest.patch.object(
mock_user_db_oauth, "update"
) as user_update_mock:
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
response = await test_app_client.get(
"/callback",
params={"code": "CODE", "state": state_jwt},
)
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
async_method_mocker(
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
)
async_method_mocker(
user_manager_oauth, "oauth_callback", return_value=user_oauth
)
response = await test_app_client.get(
"/callback",
params={"code": "CODE", "state": state_jwt},
)
assert response.status_code == status.HTTP_200_OK
get_id_email_mock.assert_awaited_once_with("TOKEN")
user_update_mock.assert_awaited_once()
data = cast(Dict[str, Any], response.json())
assert data["token"] == str(user_oauth.id)
assert after_register.called is False
async def test_existing_user_without_oauth(
self,
mock_user_db_oauth,
test_app_client: httpx.AsyncClient,
access_token,
oauth_client,
superuser_oauth,
after_register,
):
state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
with asynctest.patch.object(
oauth_client, "get_access_token"
) as get_access_token_mock:
get_access_token_mock.return_value = access_token
with asynctest.patch.object(
oauth_client, "get_id_email"
) as get_id_email_mock:
with asynctest.patch.object(
mock_user_db_oauth, "update"
) as user_update_mock:
get_id_email_mock.return_value = (
"superuser_oauth1",
superuser_oauth.email,
)
response = await test_app_client.get(
"/callback",
params={"code": "CODE", "state": state_jwt},
)
get_id_email_mock.assert_awaited_once_with("TOKEN")
user_update_mock.assert_awaited_once()
data = cast(Dict[str, Any], response.json())
assert data["token"] == str(superuser_oauth.id)
assert after_register.called is False
async def test_unknown_user(
self,
mock_user_db_oauth,
test_app_client: httpx.AsyncClient,
access_token,
oauth_client,
after_register,
):
state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
with asynctest.patch.object(
oauth_client, "get_access_token"
) as get_access_token_mock:
get_access_token_mock.return_value = access_token
with asynctest.patch.object(
oauth_client, "get_id_email"
) as get_id_email_mock:
with asynctest.patch.object(
mock_user_db_oauth, "create"
) as user_create_mock:
get_id_email_mock.return_value = (
"unknown_user_oauth1",
"galahad@camelot.bt",
)
response = await test_app_client.get(
"/callback",
params={"code": "CODE", "state": state_jwt},
)
get_id_email_mock.assert_awaited_once_with("TOKEN")
user_create_mock.assert_awaited_once()
data = cast(Dict[str, Any], response.json())
assert "token" in data
assert after_register.called is True
actual_user = after_register.call_args[0][0]
assert str(actual_user.id) == data["token"]
request = after_register.call_args[0][1]
assert isinstance(request, Request)
async def test_inactive_user(
self,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
access_token,
oauth_client,
inactive_user_oauth,
after_register,
oauth_client: BaseOAuth2,
inactive_user_oauth: UserDB,
user_manager_oauth: UserManagerMock,
access_token: str,
):
state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
with asynctest.patch.object(
oauth_client, "get_access_token"
) as get_access_token_mock:
get_access_token_mock.return_value = access_token
with asynctest.patch.object(
oauth_client, "get_id_email"
) as get_id_email_mock:
get_id_email_mock.return_value = (
"inactive_user_oauth1",
inactive_user_oauth.email,
)
response = await test_app_client.get(
"/callback",
params={"code": "CODE", "state": state_jwt},
)
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
async_method_mocker(
oauth_client,
"get_id_email",
return_value=("user_oauth1", inactive_user_oauth.email),
)
async_method_mocker(
user_manager_oauth, "oauth_callback", return_value=inactive_user_oauth
)
response = await test_app_client.get(
"/callback",
params={"code": "CODE", "state": state_jwt},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = cast(Dict[str, Any], response.json())
assert data["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS
assert after_register.called is False
async def test_redirect_url_router(
self,
async_method_mocker: AsyncMethodMocker,
test_app_client_redirect_url: httpx.AsyncClient,
access_token,
oauth_client,
user_oauth,
oauth_client: BaseOAuth2,
user_oauth: UserDB,
user_manager_oauth: UserManagerMock,
access_token: str,
):
state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
with asynctest.patch.object(
oauth_client, "get_access_token"
) as get_access_token_mock:
get_access_token_mock.return_value = access_token
with asynctest.patch.object(
oauth_client, "get_id_email"
) as get_id_email_mock:
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
response = await test_app_client_redirect_url.get(
"/callback",
params={"code": "CODE", "state": state_jwt},
)
get_access_token_mock = async_method_mocker(
oauth_client, "get_access_token", return_value=access_token
)
async_method_mocker(
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
)
async_method_mocker(
user_manager_oauth, "oauth_callback", return_value=user_oauth
)
get_access_token_mock.assert_awaited_once_with(
response = await test_app_client_redirect_url.get(
"/callback",
params={"code": "CODE", "state": state_jwt},
)
assert response.status_code == status.HTTP_200_OK
get_access_token_mock.assert_called_once_with(
"CODE", "http://www.tintagel.bt/callback"
)
data = cast(Dict[str, Any], response.json())
data = cast(Dict[str, Any], response.json())
assert data["token"] == str(user_oauth.id)

View File

@ -11,7 +11,7 @@ from fastapi_users.manager import (
UserNotExists,
)
from fastapi_users.router import ErrorCode, get_reset_password_router
from tests.conftest import UserManagerMock
from tests.conftest import AsyncMethodMocker, UserManagerMock
@pytest.fixture
@ -57,10 +57,11 @@ class TestForgotPassword:
async def test_existing_user(
self,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
user_manager: UserManagerMock,
):
user_manager.mock_method("forgot_password")
async_method_mocker(user_manager, "forgot_password", return_value=None)
json = {"email": "king.arthur@camelot.bt"}
response = await test_app_client.post("/forgot-password", json=json)
assert response.status_code == status.HTTP_202_ACCEPTED
@ -139,10 +140,11 @@ class TestResetPassword:
async def test_valid_user_password(
self,
async_method_mocker: AsyncMethodMocker,
test_app_client: httpx.AsyncClient,
user_manager: UserManagerMock,
):
user_manager.mock_method("reset_password")
async_method_mocker(user_manager, "reset_password", return_value=None)
json = {"token": "foo", "password": "guinevere"}
response = await test_app_client.post("/reset-password", json=json)
assert response.status_code == status.HTTP_200_OK