mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-14 18:58:10 +08:00
Move forgot/reset password logic to manager
This commit is contained in:
@ -100,32 +100,9 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]):
|
||||
after_verification,
|
||||
)
|
||||
|
||||
def get_reset_password_router(
|
||||
self,
|
||||
reset_password_token_secret: SecretType,
|
||||
reset_password_token_lifetime_seconds: int = 3600,
|
||||
after_forgot_password: Optional[
|
||||
Callable[[models.UD, str, Request], None]
|
||||
] = None,
|
||||
after_reset_password: Optional[Callable[[models.UD, Request], None]] = None,
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Return a reset password process router.
|
||||
|
||||
:param reset_password_token_secret: Secret to encode reset password token.
|
||||
:param reset_password_token_lifetime_seconds: Lifetime of reset password token.
|
||||
:param after_forgot_password: Optional function called after a successful
|
||||
forgot password request.
|
||||
:param after_reset_password: Optional function called after a successful
|
||||
password reset.
|
||||
"""
|
||||
return get_reset_password_router(
|
||||
self.get_user_manager,
|
||||
reset_password_token_secret,
|
||||
reset_password_token_lifetime_seconds,
|
||||
after_forgot_password,
|
||||
after_reset_password,
|
||||
)
|
||||
def get_reset_password_router(self) -> APIRouter:
|
||||
"""Return a reset password process router."""
|
||||
return get_reset_password_router(self.get_user_manager)
|
||||
|
||||
def get_auth_router(
|
||||
self, backend: BaseAuthentication, requires_verification: bool = False
|
||||
|
@ -1,13 +1,17 @@
|
||||
from typing import Any, Callable, Dict, Generic, Optional, Type, Union
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic.types import UUID4
|
||||
from pydantic import UUID4
|
||||
|
||||
from fastapi_users import models, password
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.password import get_password_hash
|
||||
|
||||
RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset"
|
||||
|
||||
|
||||
class FastAPIUsersException(Exception):
|
||||
pass
|
||||
@ -21,10 +25,18 @@ class UserNotExists(FastAPIUsersException):
|
||||
pass
|
||||
|
||||
|
||||
class UserInactive(FastAPIUsersException):
|
||||
pass
|
||||
|
||||
|
||||
class UserAlreadyVerified(FastAPIUsersException):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidResetPasswordToken(FastAPIUsersException):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidPasswordException(FastAPIUsersException):
|
||||
def __init__(self, reason: Any) -> None:
|
||||
self.reason = reason
|
||||
@ -35,6 +47,10 @@ class BaseUserManager(Generic[models.UC, models.UD]):
|
||||
user_db_model: Type[models.UD]
|
||||
user_db: BaseUserDatabase[models.UD]
|
||||
|
||||
reset_password_token_secret: SecretType
|
||||
reset_password_token_lifetime_seconds: int = 3600
|
||||
reset_password_token_audience: str = RESET_PASSWORD_TOKEN_AUDIENCE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_db_model: Type[models.UD],
|
||||
@ -88,6 +104,53 @@ class BaseUserManager(Generic[models.UC, models.UD]):
|
||||
|
||||
return created_user
|
||||
|
||||
async def forgot_password(
|
||||
self, user: models.UD, request: Optional[Request] = None
|
||||
) -> None:
|
||||
if not user.is_active:
|
||||
raise UserInactive()
|
||||
|
||||
token_data = {"user_id": str(user.id), "aud": RESET_PASSWORD_TOKEN_AUDIENCE}
|
||||
token = generate_jwt(
|
||||
token_data,
|
||||
self.reset_password_token_secret,
|
||||
self.reset_password_token_lifetime_seconds,
|
||||
)
|
||||
await self.on_after_forgot_password(user, token, request)
|
||||
|
||||
async def reset_password(
|
||||
self, token: str, password: str, request: Optional[Request] = None
|
||||
) -> models.UD:
|
||||
try:
|
||||
data = decode_jwt(
|
||||
token,
|
||||
self.reset_password_token_secret,
|
||||
[self.reset_password_token_audience],
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
raise InvalidResetPasswordToken()
|
||||
|
||||
try:
|
||||
user_id = data["user_id"]
|
||||
except KeyError:
|
||||
raise InvalidResetPasswordToken()
|
||||
|
||||
try:
|
||||
user_uuid = UUID4(user_id)
|
||||
except ValueError:
|
||||
raise InvalidResetPasswordToken()
|
||||
|
||||
user = await self.get(user_uuid)
|
||||
|
||||
if not user.is_active:
|
||||
raise UserInactive()
|
||||
|
||||
updated_user = await self._update(user, {"password": password})
|
||||
|
||||
await self.on_after_reset_password(user, request)
|
||||
|
||||
return updated_user
|
||||
|
||||
async def verify(self, user: models.UD) -> models.UD:
|
||||
if user.is_verified:
|
||||
raise UserAlreadyVerified()
|
||||
@ -117,6 +180,16 @@ class BaseUserManager(Generic[models.UC, models.UD]):
|
||||
) -> None:
|
||||
return # pragma: no cover
|
||||
|
||||
async def on_after_forgot_password(
|
||||
self, user: models.UD, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
return # pragma: no cover
|
||||
|
||||
async def on_after_reset_password(
|
||||
self, user: models.UD, request: Optional[Request] = None
|
||||
) -> None:
|
||||
return # pragma: no cover
|
||||
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[models.UD]:
|
||||
|
@ -1,29 +1,20 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Request, status
|
||||
from pydantic import UUID4, EmailStr
|
||||
from pydantic import EmailStr
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.manager import (
|
||||
BaseUserManager,
|
||||
InvalidPasswordException,
|
||||
InvalidResetPasswordToken,
|
||||
UserInactive,
|
||||
UserManagerDependency,
|
||||
UserNotExists,
|
||||
)
|
||||
from fastapi_users.password import get_password_hash
|
||||
from fastapi_users.router.common import ErrorCode, run_handler
|
||||
|
||||
RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset"
|
||||
from fastapi_users.router.common import ErrorCode
|
||||
|
||||
|
||||
def get_reset_password_router(
|
||||
get_user_manager: UserManagerDependency[models.UC, models.UD],
|
||||
reset_password_token_secret: SecretType,
|
||||
reset_password_token_lifetime_seconds: int = 3600,
|
||||
after_forgot_password: Optional[Callable[[models.UD, str, Request], None]] = None,
|
||||
after_reset_password: Optional[Callable[[models.UD, Request], None]] = None,
|
||||
get_user_manager: UserManagerDependency[models.UC, models.UD]
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the reset password routes."""
|
||||
router = APIRouter()
|
||||
@ -39,15 +30,10 @@ def get_reset_password_router(
|
||||
except UserNotExists:
|
||||
return None
|
||||
|
||||
if user.is_active:
|
||||
token_data = {"user_id": str(user.id), "aud": RESET_PASSWORD_TOKEN_AUDIENCE}
|
||||
token = generate_jwt(
|
||||
token_data,
|
||||
reset_password_token_secret,
|
||||
reset_password_token_lifetime_seconds,
|
||||
)
|
||||
if after_forgot_password:
|
||||
await run_handler(after_forgot_password, user, token, request)
|
||||
try:
|
||||
await user_manager.forgot_password(user, request)
|
||||
except UserInactive:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
@ -59,57 +45,19 @@ def get_reset_password_router(
|
||||
user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager),
|
||||
):
|
||||
try:
|
||||
data = decode_jwt(
|
||||
token, reset_password_token_secret, [RESET_PASSWORD_TOKEN_AUDIENCE]
|
||||
)
|
||||
user_id = data.get("user_id")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.RESET_PASSWORD_BAD_TOKEN,
|
||||
)
|
||||
|
||||
try:
|
||||
user_uiid = UUID4(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.RESET_PASSWORD_BAD_TOKEN,
|
||||
)
|
||||
|
||||
try:
|
||||
user = await user_manager.get(user_uiid)
|
||||
except UserNotExists:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.RESET_PASSWORD_BAD_TOKEN,
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.RESET_PASSWORD_BAD_TOKEN,
|
||||
)
|
||||
|
||||
try:
|
||||
await user_manager.validate_password(password, user)
|
||||
except InvalidPasswordException as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": ErrorCode.RESET_PASSWORD_INVALID_PASSWORD,
|
||||
"reason": e.reason,
|
||||
},
|
||||
)
|
||||
|
||||
user.hashed_password = get_password_hash(password)
|
||||
await user_manager.user_db.update(user)
|
||||
if after_reset_password:
|
||||
await run_handler(after_reset_password, user, request)
|
||||
except jwt.PyJWTError:
|
||||
await user_manager.reset_password(token, password, request)
|
||||
except (InvalidResetPasswordToken, UserNotExists, UserInactive):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.RESET_PASSWORD_BAD_TOKEN,
|
||||
)
|
||||
except InvalidPasswordException as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": ErrorCode.RESET_PASSWORD_INVALID_PASSWORD,
|
||||
"reason": e.reason,
|
||||
},
|
||||
)
|
||||
|
||||
return router
|
||||
|
@ -5,12 +5,11 @@ from unittest.mock import MagicMock
|
||||
import httpx
|
||||
import pytest
|
||||
from asgi_lifespan import LifespanManager
|
||||
from fastapi import Depends, FastAPI, Request, Response
|
||||
from fastapi import Depends, FastAPI, Response
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from httpx_oauth.oauth2 import OAuth2
|
||||
from pydantic import UUID4, SecretStr
|
||||
from pytest_mock import MockerFixture
|
||||
from starlette.applications import ASGIApp
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication import Authenticator, BaseAuthentication
|
||||
@ -56,6 +55,8 @@ class UserDBOAuth(UserOAuth, UserDB):
|
||||
|
||||
|
||||
class UserManager(BaseUserManager[UserCreate, UserDB]):
|
||||
reset_password_token_secret = "SECRET"
|
||||
|
||||
async def validate_password(
|
||||
self, password: str, user: Union[UserCreate, UserDB]
|
||||
) -> None:
|
||||
@ -64,14 +65,25 @@ class UserManager(BaseUserManager[UserCreate, UserDB]):
|
||||
reason="Password should be at least 3 characters"
|
||||
)
|
||||
|
||||
async def on_after_register(
|
||||
self, user: UserDB, request: Optional[Request] = None
|
||||
) -> None:
|
||||
return
|
||||
def mock_method(self, name: str):
|
||||
mock = MagicMock()
|
||||
|
||||
future: asyncio.Future = asyncio.Future()
|
||||
future.set_result(None)
|
||||
mock.return_value = future
|
||||
mock.side_effect = None
|
||||
|
||||
setattr(self, name, mock)
|
||||
|
||||
|
||||
class UserManagerMock(UserManager):
|
||||
get_by_email: MagicMock
|
||||
forgot_password: MagicMock
|
||||
reset_password: MagicMock
|
||||
on_after_register: MagicMock
|
||||
on_after_forgot_password: MagicMock
|
||||
on_after_reset_password: MagicMock
|
||||
_update: MagicMock
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -366,14 +378,20 @@ def get_mock_user_db_oauth(mock_user_db_oauth):
|
||||
@pytest.fixture
|
||||
def user_manager(mocker: MockerFixture, mock_user_db):
|
||||
user_manager = UserManager(UserDB, mock_user_db)
|
||||
mocker.spy(user_manager, "get_by_email")
|
||||
mocker.spy(user_manager, "forgot_password")
|
||||
mocker.spy(user_manager, "reset_password")
|
||||
mocker.spy(user_manager, "on_after_register")
|
||||
mocker.spy(user_manager, "on_after_forgot_password")
|
||||
mocker.spy(user_manager, "on_after_reset_password")
|
||||
mocker.spy(user_manager, "_update")
|
||||
return user_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_user_manager(get_mock_user_db):
|
||||
def _get_user_manager(user_db=Depends(get_mock_user_db)):
|
||||
return UserManager(UserDB, user_db)
|
||||
def get_user_manager(user_manager):
|
||||
def _get_user_manager():
|
||||
return user_manager
|
||||
|
||||
return _get_user_manager
|
||||
|
||||
@ -416,7 +434,7 @@ def mock_authentication():
|
||||
|
||||
@pytest.fixture
|
||||
def get_test_client():
|
||||
async def _get_test_client(app: ASGIApp) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
async def _get_test_client(app: FastAPI) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
async with LifespanManager(app):
|
||||
async with httpx.AsyncClient(
|
||||
app=app, base_url="http://app.io"
|
||||
|
@ -28,7 +28,7 @@ async def test_app_client(
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(fastapi_users.get_register_router())
|
||||
app.include_router(fastapi_users.get_reset_password_router(secret))
|
||||
app.include_router(fastapi_users.get_reset_password_router())
|
||||
app.include_router(fastapi_users.get_auth_router(mock_authentication))
|
||||
app.include_router(fastapi_users.get_oauth_router(oauth_client, secret))
|
||||
app.include_router(fastapi_users.get_users_router(), prefix="/users")
|
||||
|
@ -4,10 +4,31 @@ import pytest
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from fastapi_users.manager import UserAlreadyExists, UserAlreadyVerified
|
||||
from fastapi_users.jwt import decode_jwt, generate_jwt
|
||||
from fastapi_users.manager import (
|
||||
InvalidPasswordException,
|
||||
InvalidResetPasswordToken,
|
||||
UserAlreadyExists,
|
||||
UserAlreadyVerified,
|
||||
UserInactive,
|
||||
UserNotExists,
|
||||
)
|
||||
from tests.conftest import UserCreate, UserDB, UserManagerMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def forgot_password_token(user_manager: UserManagerMock):
|
||||
def _forgot_password_token(
|
||||
user_id=None, lifetime=user_manager.reset_password_token_lifetime_seconds
|
||||
):
|
||||
data = {"aud": "fastapi-users:reset"}
|
||||
if user_id is not None:
|
||||
data["user_id"] = str(user_id)
|
||||
return generate_jwt(data, user_manager.reset_password_token_secret, lifetime)
|
||||
|
||||
return _forgot_password_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_oauth2_password_request_form() -> Callable[
|
||||
[str, str], OAuth2PasswordRequestForm
|
||||
@ -77,9 +98,101 @@ class TestVerifyUser:
|
||||
assert user.is_verified
|
||||
|
||||
|
||||
@pytest.mark.db
|
||||
@pytest.mark.asyncio
|
||||
class TestForgotPassword:
|
||||
async def test_user_inactive(
|
||||
self, user_manager: UserManagerMock, inactive_user: UserDB
|
||||
):
|
||||
with pytest.raises(UserInactive):
|
||||
await user_manager.forgot_password(inactive_user)
|
||||
assert user_manager.on_after_forgot_password.called is False
|
||||
|
||||
async def test_user_active(self, user_manager: UserManagerMock, user: UserDB):
|
||||
await user_manager.forgot_password(user)
|
||||
assert user_manager.on_after_forgot_password.called is True
|
||||
|
||||
actual_user = user_manager.on_after_forgot_password.call_args[0][0]
|
||||
actual_token = user_manager.on_after_forgot_password.call_args[0][1]
|
||||
|
||||
assert actual_user.id == user.id
|
||||
decoded_token = decode_jwt(
|
||||
actual_token,
|
||||
user_manager.reset_password_token_secret,
|
||||
audience=[user_manager.reset_password_token_audience],
|
||||
)
|
||||
assert decoded_token["user_id"] == str(user.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestResetPassword:
|
||||
async def test_invalid_token(self, user_manager: UserManagerMock):
|
||||
with pytest.raises(InvalidResetPasswordToken):
|
||||
await user_manager.reset_password("foo", "guinevere")
|
||||
assert user_manager._update.called is False
|
||||
assert user_manager.on_after_reset_password.called is False
|
||||
|
||||
@pytest.mark.parametrize("user_id", [None, "foo"])
|
||||
async def test_valid_token_bad_payload(
|
||||
self, user_id: str, user_manager: UserManagerMock, forgot_password_token
|
||||
):
|
||||
with pytest.raises(InvalidResetPasswordToken):
|
||||
await user_manager.reset_password(
|
||||
forgot_password_token(user_id), "guinevere"
|
||||
)
|
||||
assert user_manager._update.called is False
|
||||
assert user_manager.on_after_reset_password.called is False
|
||||
|
||||
async def test_not_existing_user(
|
||||
self, user_manager: UserManagerMock, forgot_password_token
|
||||
):
|
||||
with pytest.raises(UserNotExists):
|
||||
await user_manager.reset_password(
|
||||
forgot_password_token("d35d213e-f3d8-4f08-954a-7e0d1bea286f"),
|
||||
"guinevere",
|
||||
)
|
||||
assert user_manager._update.called is False
|
||||
assert user_manager.on_after_reset_password.called is False
|
||||
|
||||
async def test_inactive_user(
|
||||
self,
|
||||
inactive_user: UserDB,
|
||||
user_manager: UserManagerMock,
|
||||
forgot_password_token,
|
||||
):
|
||||
with pytest.raises(UserInactive):
|
||||
await user_manager.reset_password(
|
||||
forgot_password_token(inactive_user.id),
|
||||
"guinevere",
|
||||
)
|
||||
assert user_manager._update.called is False
|
||||
assert user_manager.on_after_reset_password.called is False
|
||||
|
||||
async def test_invalid_password(
|
||||
self, user: UserDB, user_manager: UserManagerMock, forgot_password_token
|
||||
):
|
||||
with pytest.raises(InvalidPasswordException):
|
||||
await user_manager.reset_password(
|
||||
forgot_password_token(user.id),
|
||||
"h",
|
||||
)
|
||||
assert user_manager.on_after_reset_password.called is False
|
||||
|
||||
async def test_valid_user_password(
|
||||
self, user: UserDB, user_manager: UserManagerMock, forgot_password_token
|
||||
):
|
||||
await user_manager.reset_password(forgot_password_token(user.id), "holygrail")
|
||||
|
||||
assert user_manager._update.called is True
|
||||
update_dict = user_manager._update.call_args[0][1]
|
||||
assert update_dict == {"password": "holygrail"}
|
||||
|
||||
assert user_manager.on_after_reset_password.called is True
|
||||
actual_user = user_manager.on_after_reset_password.call_args[0][0]
|
||||
assert actual_user.id == user.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthenticate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_user(
|
||||
self,
|
||||
create_oauth2_password_request_form: Callable[
|
||||
@ -91,7 +204,6 @@ class TestAuthenticate:
|
||||
user = await user_manager.authenticate(form)
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_password(
|
||||
self,
|
||||
create_oauth2_password_request_form: Callable[
|
||||
@ -103,7 +215,6 @@ class TestAuthenticate:
|
||||
user = await user_manager.authenticate(form)
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_credentials(
|
||||
self,
|
||||
create_oauth2_password_request_form: Callable[
|
||||
@ -118,7 +229,6 @@ class TestAuthenticate:
|
||||
assert user is not None
|
||||
assert user.email == "king.arthur@camelot.bt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgrade_password_hash(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
|
@ -1,71 +1,25 @@
|
||||
from typing import Any, AsyncGenerator, Dict, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import asynctest
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi import FastAPI, status
|
||||
|
||||
from fastapi_users.jwt import decode_jwt, generate_jwt
|
||||
from fastapi_users.manager import (
|
||||
InvalidPasswordException,
|
||||
InvalidResetPasswordToken,
|
||||
UserInactive,
|
||||
UserNotExists,
|
||||
)
|
||||
from fastapi_users.router import ErrorCode, get_reset_password_router
|
||||
from tests.conftest import UserDB
|
||||
|
||||
LIFETIME = 3600
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def forgot_password_token(secret):
|
||||
def _forgot_password_token(user_id=None, lifetime=LIFETIME):
|
||||
data = {"aud": "fastapi-users:reset"}
|
||||
if user_id is not None:
|
||||
data["user_id"] = str(user_id)
|
||||
return generate_jwt(data, secret, lifetime)
|
||||
|
||||
return _forgot_password_token
|
||||
|
||||
|
||||
def after_forgot_password_sync():
|
||||
return MagicMock(return_value=None)
|
||||
|
||||
|
||||
def after_forgot_password_async():
|
||||
return asynctest.CoroutineMock(return_value=None)
|
||||
|
||||
|
||||
@pytest.fixture(params=[after_forgot_password_sync, after_forgot_password_async])
|
||||
def after_forgot_password(request):
|
||||
return request.param()
|
||||
|
||||
|
||||
def after_reset_password_sync():
|
||||
return MagicMock(return_value=None)
|
||||
|
||||
|
||||
def after_reset_password_async():
|
||||
return asynctest.CoroutineMock(return_value=None)
|
||||
|
||||
|
||||
@pytest.fixture(params=[after_reset_password_sync, after_reset_password_async])
|
||||
def after_reset_password(request):
|
||||
return request.param()
|
||||
from tests.conftest import UserManagerMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_client(
|
||||
secret,
|
||||
get_user_manager,
|
||||
after_forgot_password,
|
||||
after_reset_password,
|
||||
get_test_client,
|
||||
get_user_manager, get_test_client
|
||||
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
reset_router = get_reset_password_router(
|
||||
get_user_manager,
|
||||
secret,
|
||||
LIFETIME,
|
||||
after_forgot_password,
|
||||
after_reset_password,
|
||||
)
|
||||
reset_router = get_reset_password_router(get_user_manager)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(reset_router)
|
||||
@ -78,220 +32,117 @@ async def test_app_client(
|
||||
@pytest.mark.asyncio
|
||||
class TestForgotPassword:
|
||||
async def test_empty_body(
|
||||
self, test_app_client: httpx.AsyncClient, after_forgot_password
|
||||
self, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock
|
||||
):
|
||||
response = await test_app_client.post("/forgot-password", json={})
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
assert after_forgot_password.called is False
|
||||
assert user_manager.forgot_password.called is False
|
||||
|
||||
async def test_not_existing_user(
|
||||
self, test_app_client: httpx.AsyncClient, after_forgot_password
|
||||
self, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock
|
||||
):
|
||||
user_manager.get_by_email.side_effect = UserNotExists()
|
||||
json = {"email": "lancelot@camelot.bt"}
|
||||
response = await test_app_client.post("/forgot-password", json=json)
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
assert after_forgot_password.called is False
|
||||
assert user_manager.forgot_password.called is False
|
||||
|
||||
async def test_inactive_user(
|
||||
self, test_app_client: httpx.AsyncClient, after_forgot_password
|
||||
self, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock
|
||||
):
|
||||
user_manager.forgot_password.side_effect = UserInactive()
|
||||
json = {"email": "percival@camelot.bt"}
|
||||
response = await test_app_client.post("/forgot-password", json=json)
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
assert after_forgot_password.called is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"]
|
||||
)
|
||||
async def test_existing_user(
|
||||
self,
|
||||
secret,
|
||||
email,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
after_forgot_password,
|
||||
user,
|
||||
user_manager: UserManagerMock,
|
||||
):
|
||||
json = {"email": email}
|
||||
user_manager.mock_method("forgot_password")
|
||||
json = {"email": "king.arthur@camelot.bt"}
|
||||
response = await test_app_client.post("/forgot-password", json=json)
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
assert after_forgot_password.called is True
|
||||
|
||||
actual_user = after_forgot_password.call_args[0][0]
|
||||
assert actual_user.id == user.id
|
||||
actual_token = after_forgot_password.call_args[0][1]
|
||||
|
||||
decoded_token = decode_jwt(
|
||||
actual_token,
|
||||
secret,
|
||||
audience=["fastapi-users:reset"],
|
||||
)
|
||||
assert decoded_token["user_id"] == str(user.id)
|
||||
request = after_forgot_password.call_args[0][2]
|
||||
assert isinstance(request, Request)
|
||||
|
||||
|
||||
@pytest.mark.router
|
||||
@pytest.mark.asyncio
|
||||
class TestResetPassword:
|
||||
async def test_empty_body(
|
||||
self, test_app_client: httpx.AsyncClient, after_reset_password
|
||||
self,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
user_manager: UserManagerMock,
|
||||
):
|
||||
response = await test_app_client.post("/reset-password", json={})
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
assert after_reset_password.called is False
|
||||
assert user_manager.reset_password.called is False
|
||||
|
||||
async def test_missing_token(
|
||||
self, test_app_client: httpx.AsyncClient, after_reset_password
|
||||
self, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock
|
||||
):
|
||||
json = {"password": "guinevere"}
|
||||
response = await test_app_client.post("/reset-password", json=json)
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
assert after_reset_password.called is False
|
||||
assert user_manager.reset_password.called is False
|
||||
|
||||
async def test_missing_password(
|
||||
self, test_app_client: httpx.AsyncClient, after_reset_password
|
||||
self,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
user_manager: UserManagerMock,
|
||||
):
|
||||
json = {"token": "foo"}
|
||||
response = await test_app_client.post("/reset-password", json=json)
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
assert after_reset_password.called is False
|
||||
assert user_manager.reset_password.called is False
|
||||
|
||||
async def test_invalid_token(
|
||||
self, test_app_client: httpx.AsyncClient, after_reset_password
|
||||
self,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
user_manager: UserManagerMock,
|
||||
):
|
||||
user_manager.reset_password.side_effect = InvalidResetPasswordToken()
|
||||
json = {"token": "foo", "password": "guinevere"}
|
||||
response = await test_app_client.post("/reset-password", json=json)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN
|
||||
assert after_reset_password.called is False
|
||||
|
||||
async def test_valid_token_missing_user_id_payload(
|
||||
self,
|
||||
mocker,
|
||||
mock_user_db,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
forgot_password_token,
|
||||
after_reset_password,
|
||||
):
|
||||
mocker.spy(mock_user_db, "update")
|
||||
|
||||
json = {"token": forgot_password_token(), "password": "holygrail"}
|
||||
response = await test_app_client.post("/reset-password", json=json)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN
|
||||
assert mock_user_db.update.called is False
|
||||
assert after_reset_password.called is False
|
||||
|
||||
async def test_valid_token_invalid_uuid(
|
||||
self,
|
||||
mocker,
|
||||
mock_user_db,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
forgot_password_token,
|
||||
after_reset_password,
|
||||
):
|
||||
mocker.spy(mock_user_db, "update")
|
||||
|
||||
json = {"token": forgot_password_token("foo"), "password": "holygrail"}
|
||||
response = await test_app_client.post("/reset-password", json=json)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN
|
||||
assert mock_user_db.update.called is False
|
||||
assert after_reset_password.called is False
|
||||
|
||||
async def test_valid_token_not_existing_user(
|
||||
self,
|
||||
mocker,
|
||||
mock_user_db,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
forgot_password_token,
|
||||
after_reset_password,
|
||||
):
|
||||
mocker.spy(mock_user_db, "update")
|
||||
|
||||
json = {
|
||||
"token": forgot_password_token("d35d213e-f3d8-4f08-954a-7e0d1bea286f"),
|
||||
"password": "holygrail",
|
||||
}
|
||||
response = await test_app_client.post("/reset-password", json=json)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN
|
||||
assert mock_user_db.update.called is False
|
||||
assert after_reset_password.called is False
|
||||
|
||||
async def test_inactive_user(
|
||||
self,
|
||||
mocker,
|
||||
mock_user_db,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
forgot_password_token,
|
||||
inactive_user: UserDB,
|
||||
after_reset_password,
|
||||
user_manager: UserManagerMock,
|
||||
):
|
||||
mocker.spy(mock_user_db, "update")
|
||||
|
||||
json = {
|
||||
"token": forgot_password_token(inactive_user.id),
|
||||
"password": "holygrail",
|
||||
}
|
||||
user_manager.reset_password.side_effect = UserInactive()
|
||||
json = {"token": "foo", "password": "guinevere"}
|
||||
response = await test_app_client.post("/reset-password", json=json)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN
|
||||
assert mock_user_db.update.called is False
|
||||
assert after_reset_password.called is False
|
||||
|
||||
async def test_invalid_password(
|
||||
self,
|
||||
mocker,
|
||||
mock_user_db,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
forgot_password_token,
|
||||
user: UserDB,
|
||||
after_reset_password,
|
||||
user_manager: UserManagerMock,
|
||||
):
|
||||
mocker.spy(mock_user_db, "update")
|
||||
|
||||
json = {
|
||||
"token": forgot_password_token(user.id),
|
||||
"password": "h",
|
||||
}
|
||||
user_manager.reset_password.side_effect = InvalidPasswordException(
|
||||
reason="Invalid"
|
||||
)
|
||||
json = {"token": "foo", "password": "guinevere"}
|
||||
response = await test_app_client.post("/reset-password", json=json)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
assert data["detail"] == {
|
||||
"code": ErrorCode.RESET_PASSWORD_INVALID_PASSWORD,
|
||||
"reason": "Password should be at least 3 characters",
|
||||
"reason": "Invalid",
|
||||
}
|
||||
assert mock_user_db.update.called is False
|
||||
assert after_reset_password.called is False
|
||||
|
||||
async def test_existing_user(
|
||||
async def test_valid_user_password(
|
||||
self,
|
||||
mocker,
|
||||
mock_user_db,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
forgot_password_token,
|
||||
user: UserDB,
|
||||
after_reset_password,
|
||||
user_manager: UserManagerMock,
|
||||
):
|
||||
mocker.spy(mock_user_db, "update")
|
||||
current_hashed_password = user.hashed_password
|
||||
|
||||
json = {"token": forgot_password_token(user.id), "password": "holygrail"}
|
||||
user_manager.mock_method("reset_password")
|
||||
json = {"token": "foo", "password": "guinevere"}
|
||||
response = await test_app_client.post("/reset-password", json=json)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert mock_user_db.update.called is True
|
||||
|
||||
updated_user = mock_user_db.update.call_args[0][0]
|
||||
assert updated_user.hashed_password != current_hashed_password
|
||||
|
||||
assert after_reset_password.called is True
|
||||
actual_user = after_reset_password.call_args[0][0]
|
||||
assert actual_user.id == updated_user.id
|
||||
request = after_reset_password.call_args[0][1]
|
||||
assert isinstance(request, Request)
|
||||
|
Reference in New Issue
Block a user