From 7f6d038d91fa20c6a899ececf5c5dcfe96b8acd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 22 Mar 2022 14:16:49 +0100 Subject: [PATCH] Create a class helper to let configure the password hasher dynamically --- fastapi_users/manager.py | 27 ++++++++++++++++-------- fastapi_users/password.py | 44 +++++++++++++++++++++++++++++---------- pyproject.toml | 1 + tests/conftest.py | 13 ++++++------ tests/test_manager.py | 4 ++-- 5 files changed, 61 insertions(+), 28 deletions(-) diff --git a/fastapi_users/manager.py b/fastapi_users/manager.py index c49e2f9f..f464a284 100644 --- a/fastapi_users/manager.py +++ b/fastapi_users/manager.py @@ -5,10 +5,10 @@ from fastapi import Request from fastapi.security import OAuth2PasswordRequestForm from pydantic import UUID4 -from fastapi_users import models, password +from fastapi_users import models from fastapi_users.db import BaseUserDatabase from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt -from fastapi_users.password import generate_password, get_password_hash +from fastapi_users.password import PasswordHelper, PasswordHelperProtocol from fastapi_users.types import DependencyCallable RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset" @@ -73,9 +73,18 @@ class BaseUserManager(Generic[models.UC, models.UD]): verification_token_audience: str = VERIFY_USER_TOKEN_AUDIENCE user_db: BaseUserDatabase[models.UD] + password_helper: PasswordHelperProtocol - def __init__(self, user_db: BaseUserDatabase[models.UD]): + def __init__( + self, + user_db: BaseUserDatabase[models.UD], + password_helper: Optional[PasswordHelperProtocol] = None, + ): self.user_db = user_db + if password_helper is None: + self.password_helper = PasswordHelper() + else: + self.password_helper = password_helper # pragma: no cover async def get(self, id: UUID4) -> models.UD: """ @@ -145,7 +154,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): if existing_user is not None: raise UserAlreadyExists() - hashed_password = get_password_hash(user.password) + hashed_password = self.password_helper.hash(user.password) user_dict = ( user.create_update_dict() if safe else user.create_update_dict_superuser() ) @@ -188,10 +197,10 @@ class BaseUserManager(Generic[models.UC, models.UD]): await self.user_db.update(user) except UserNotExists: # Create account - password = generate_password() + password = self.password_helper.generate() user = self.user_db_model( email=oauth_account.account_email, - hashed_password=get_password_hash(password), + hashed_password=self.password_helper.hash(password), oauth_accounts=[oauth_account], ) await self.user_db.create(user) @@ -523,10 +532,10 @@ class BaseUserManager(Generic[models.UC, models.UD]): except UserNotExists: # Run the hasher to mitigate timing attack # Inspired from Django: https://code.djangoproject.com/ticket/20760 - password.get_password_hash(credentials.password) + self.password_helper.hash(credentials.password) return None - verified, updated_password_hash = password.verify_and_update_password( + verified, updated_password_hash = self.password_helper.verify_and_update( credentials.password, user.hashed_password ) if not verified: @@ -549,7 +558,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): user.is_verified = False elif field == "password": await self.validate_password(value, user) - hashed_password = get_password_hash(value) + hashed_password = self.password_helper.hash(value) user.hashed_password = hashed_password else: setattr(user, field, value) diff --git a/fastapi_users/password.py b/fastapi_users/password.py index 31cbce38..438dcda8 100644 --- a/fastapi_users/password.py +++ b/fastapi_users/password.py @@ -1,20 +1,42 @@ -from typing import Tuple +import sys +from typing import Optional, Tuple + +if sys.version_info < (3, 8): + from typing_extensions import Protocol # pragma: no cover +else: + from typing import Protocol # pragma: no cover from passlib import pwd from passlib.context import CryptContext -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +class PasswordHelperProtocol(Protocol): + def verify_and_update( + self, plain_password: str, hashed_password: str + ) -> Tuple[bool, str]: + ... # pragma: no cover + + def hash(self, password: str) -> str: + ... # pragma: no cover + + def generate(self) -> str: + ... # pragma: no cover -def verify_and_update_password( - plain_password: str, hashed_password: str -) -> Tuple[bool, str]: - return pwd_context.verify_and_update(plain_password, hashed_password) +class PasswordHelper(PasswordHelperProtocol): + def __init__(self, context: Optional[CryptContext] = None) -> None: + if context is None: + self.context = CryptContext(schemes=["bcrypt"], deprecated="auto") + else: + self.context = context # pragma: no cover + def verify_and_update( + self, plain_password: str, hashed_password: str + ) -> Tuple[bool, str]: + return self.context.verify_and_update(plain_password, hashed_password) -def get_password_hash(password: str) -> str: - return pwd_context.hash(password) + def hash(self, password: str) -> str: + return self.context.hash(password) - -def generate_password() -> str: - return pwd.genword() + def generate(self) -> str: + return pwd.genword() diff --git a/pyproject.toml b/pyproject.toml index 5d7cfd97..5ddc9fbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ dependencies = [ "pyjwt[crypto] ==2.3.0", "python-multipart ==0.0.5", "makefun >=1.11.2,<1.14", + "typing-extensions >=4.1.1; python_version < '3.8'", ] [project.optional-dependencies] diff --git a/tests/conftest.py b/tests/conftest.py index 96365c6d..d94bc94e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,13 +22,14 @@ from fastapi_users.manager import ( ) from fastapi_users.models import BaseOAuthAccount, BaseOAuthAccountMixin from fastapi_users.openapi import OpenAPIResponseType -from fastapi_users.password import get_password_hash +from fastapi_users.password import PasswordHelper -guinevere_password_hash = get_password_hash("guinevere") -angharad_password_hash = get_password_hash("angharad") -viviane_password_hash = get_password_hash("viviane") -lancelot_password_hash = get_password_hash("lancelot") -excalibur_password_hash = get_password_hash("excalibur") +password_helper = PasswordHelper() +guinevere_password_hash = password_helper.hash("guinevere") +angharad_password_hash = password_helper.hash("angharad") +viviane_password_hash = password_helper.hash("viviane") +lancelot_password_hash = password_helper.hash("lancelot") +excalibur_password_hash = password_helper.hash("excalibur") class User(models.BaseUser): diff --git a/tests/test_manager.py b/tests/test_manager.py index ed284b11..97f8cd1b 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -536,8 +536,8 @@ class TestAuthenticate: ], user_manager: UserManagerMock, ): - verify_and_update_password_patch = mocker.patch( - "fastapi_users.password.verify_and_update_password" + verify_and_update_password_patch = mocker.patch.object( + user_manager.password_helper, "verify_and_update" ) verify_and_update_password_patch.return_value = (True, "updated_hash") update_spy = mocker.spy(user_manager.user_db, "update")