mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2026-03-13 07:49:55 +08:00
Move validate_password into UserManager
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
import asynctest
|
||||
import httpx
|
||||
import pytest
|
||||
from asgi_lifespan import LifespanManager
|
||||
@@ -15,13 +14,10 @@ from fastapi_users import models
|
||||
from fastapi_users.authentication import Authenticator, BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.jwt import SecretType
|
||||
from fastapi_users.manager import (
|
||||
InvalidPasswordException,
|
||||
UserManager,
|
||||
UserNotExists,
|
||||
ValidatePasswordProtocol,
|
||||
)
|
||||
from fastapi_users.models import BaseOAuthAccount, BaseOAuthAccountMixin, BaseUserDB
|
||||
from fastapi_users.manager import InvalidPasswordException
|
||||
from fastapi_users.manager import UserManager as BaseUserManager
|
||||
from fastapi_users.manager import UserNotExists
|
||||
from fastapi_users.models import BaseOAuthAccount, BaseOAuthAccountMixin
|
||||
from fastapi_users.password import get_password_hash
|
||||
|
||||
guinevere_password_hash = get_password_hash("guinevere")
|
||||
@@ -55,6 +51,16 @@ class UserDBOAuth(UserOAuth, UserDB):
|
||||
pass
|
||||
|
||||
|
||||
class UserManager(BaseUserManager[UserCreate, UserDB]):
|
||||
async def validate_password(
|
||||
self, password: str, user: Union[UserCreate, UserDB]
|
||||
) -> None:
|
||||
if len(password) < 3:
|
||||
raise InvalidPasswordException(
|
||||
reason="Password should be at least 3 characters"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Force the pytest-asyncio loop to be the main one."""
|
||||
@@ -218,17 +224,6 @@ def oauth_account5() -> BaseOAuthAccount:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def validate_password() -> ValidatePasswordProtocol:
|
||||
async def _validate_password(password: str, user: models.UD) -> None:
|
||||
if len(password) < 3:
|
||||
raise InvalidPasswordException(
|
||||
reason="Password should be at least 3 characters"
|
||||
)
|
||||
|
||||
return asynctest.CoroutineMock(wraps=_validate_password)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_db(
|
||||
user, verified_user, inactive_user, superuser, verified_superuser
|
||||
@@ -356,32 +351,32 @@ def get_mock_user_db_oauth(mock_user_db_oauth):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_manager(mock_user_db, validate_password):
|
||||
return UserManager(UserDB, mock_user_db, validate_password)
|
||||
def user_manager(mock_user_db):
|
||||
return UserManager(UserDB, mock_user_db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_user_manager(get_mock_user_db, validate_password):
|
||||
def get_user_manager(get_mock_user_db):
|
||||
def _get_user_manager(user_db=Depends(get_mock_user_db)):
|
||||
return UserManager(UserDB, user_db, validate_password)
|
||||
return UserManager(UserDB, user_db)
|
||||
|
||||
return _get_user_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_user_manager_oauth(get_mock_user_db_oauth, validate_password):
|
||||
def get_user_manager_oauth(get_mock_user_db_oauth):
|
||||
def _get_user_manager_oauth(user_db=Depends(get_mock_user_db_oauth)):
|
||||
return UserManager(UserDBOAuth, user_db, validate_password)
|
||||
return UserManager(UserDBOAuth, user_db)
|
||||
|
||||
return _get_user_manager_oauth
|
||||
|
||||
|
||||
class MockAuthentication(BaseAuthentication[str]):
|
||||
class MockAuthentication(BaseAuthentication[str, UserCreate, UserDB]):
|
||||
def __init__(self, name: str = "mock"):
|
||||
super().__init__(name, logout=True)
|
||||
self.scheme = OAuth2PasswordBearer("/login", auto_error=False)
|
||||
|
||||
async def __call__(self, credentials: Optional[str], user_manager: UserManager):
|
||||
async def __call__(self, credentials: Optional[str], user_manager: BaseUserManager):
|
||||
if credentials is not None:
|
||||
try:
|
||||
token_uuid = UUID4(credentials)
|
||||
@@ -392,10 +387,10 @@ class MockAuthentication(BaseAuthentication[str]):
|
||||
return None
|
||||
return None
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
async def get_login_response(self, user: UserDB, response: Response):
|
||||
return {"token": user.id}
|
||||
|
||||
async def get_logout_response(self, user: BaseUserDB, response: Response):
|
||||
async def get_logout_response(self, user: UserDB, response: Response):
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user