Create a class helper to let configure the password hasher dynamically

This commit is contained in:
François Voron
2022-03-22 14:16:49 +01:00
parent 4f5676b979
commit 7f6d038d91
5 changed files with 61 additions and 28 deletions

View File

@ -5,10 +5,10 @@ from fastapi import Request
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from pydantic import UUID4 from pydantic import UUID4
from fastapi_users import models, password from fastapi_users import models
from fastapi_users.db import BaseUserDatabase from fastapi_users.db import BaseUserDatabase
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
from fastapi_users.password import generate_password, get_password_hash from fastapi_users.password import PasswordHelper, PasswordHelperProtocol
from fastapi_users.types import DependencyCallable from fastapi_users.types import DependencyCallable
RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset" RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset"
@ -73,9 +73,18 @@ class BaseUserManager(Generic[models.UC, models.UD]):
verification_token_audience: str = VERIFY_USER_TOKEN_AUDIENCE verification_token_audience: str = VERIFY_USER_TOKEN_AUDIENCE
user_db: BaseUserDatabase[models.UD] user_db: BaseUserDatabase[models.UD]
password_helper: PasswordHelperProtocol
def __init__(self, user_db: BaseUserDatabase[models.UD]): def __init__(
self,
user_db: BaseUserDatabase[models.UD],
password_helper: Optional[PasswordHelperProtocol] = None,
):
self.user_db = user_db self.user_db = user_db
if password_helper is None:
self.password_helper = PasswordHelper()
else:
self.password_helper = password_helper # pragma: no cover
async def get(self, id: UUID4) -> models.UD: async def get(self, id: UUID4) -> models.UD:
""" """
@ -145,7 +154,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
if existing_user is not None: if existing_user is not None:
raise UserAlreadyExists() raise UserAlreadyExists()
hashed_password = get_password_hash(user.password) hashed_password = self.password_helper.hash(user.password)
user_dict = ( user_dict = (
user.create_update_dict() if safe else user.create_update_dict_superuser() user.create_update_dict() if safe else user.create_update_dict_superuser()
) )
@ -188,10 +197,10 @@ class BaseUserManager(Generic[models.UC, models.UD]):
await self.user_db.update(user) await self.user_db.update(user)
except UserNotExists: except UserNotExists:
# Create account # Create account
password = generate_password() password = self.password_helper.generate()
user = self.user_db_model( user = self.user_db_model(
email=oauth_account.account_email, email=oauth_account.account_email,
hashed_password=get_password_hash(password), hashed_password=self.password_helper.hash(password),
oauth_accounts=[oauth_account], oauth_accounts=[oauth_account],
) )
await self.user_db.create(user) await self.user_db.create(user)
@ -523,10 +532,10 @@ class BaseUserManager(Generic[models.UC, models.UD]):
except UserNotExists: except UserNotExists:
# Run the hasher to mitigate timing attack # Run the hasher to mitigate timing attack
# Inspired from Django: https://code.djangoproject.com/ticket/20760 # Inspired from Django: https://code.djangoproject.com/ticket/20760
password.get_password_hash(credentials.password) self.password_helper.hash(credentials.password)
return None return None
verified, updated_password_hash = password.verify_and_update_password( verified, updated_password_hash = self.password_helper.verify_and_update(
credentials.password, user.hashed_password credentials.password, user.hashed_password
) )
if not verified: if not verified:
@ -549,7 +558,7 @@ class BaseUserManager(Generic[models.UC, models.UD]):
user.is_verified = False user.is_verified = False
elif field == "password": elif field == "password":
await self.validate_password(value, user) await self.validate_password(value, user)
hashed_password = get_password_hash(value) hashed_password = self.password_helper.hash(value)
user.hashed_password = hashed_password user.hashed_password = hashed_password
else: else:
setattr(user, field, value) setattr(user, field, value)

View File

@ -1,20 +1,42 @@
from typing import Tuple import sys
from typing import Optional, Tuple
if sys.version_info < (3, 8):
from typing_extensions import Protocol # pragma: no cover
else:
from typing import Protocol # pragma: no cover
from passlib import pwd from passlib import pwd
from passlib.context import CryptContext from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class PasswordHelperProtocol(Protocol):
def verify_and_update(
self, plain_password: str, hashed_password: str
) -> Tuple[bool, str]:
... # pragma: no cover
def hash(self, password: str) -> str:
... # pragma: no cover
def generate(self) -> str:
... # pragma: no cover
def verify_and_update_password( class PasswordHelper(PasswordHelperProtocol):
plain_password: str, hashed_password: str def __init__(self, context: Optional[CryptContext] = None) -> None:
) -> Tuple[bool, str]: if context is None:
return pwd_context.verify_and_update(plain_password, hashed_password) self.context = CryptContext(schemes=["bcrypt"], deprecated="auto")
else:
self.context = context # pragma: no cover
def verify_and_update(
self, plain_password: str, hashed_password: str
) -> Tuple[bool, str]:
return self.context.verify_and_update(plain_password, hashed_password)
def get_password_hash(password: str) -> str: def hash(self, password: str) -> str:
return pwd_context.hash(password) return self.context.hash(password)
def generate(self) -> str:
def generate_password() -> str: return pwd.genword()
return pwd.genword()

View File

@ -75,6 +75,7 @@ dependencies = [
"pyjwt[crypto] ==2.3.0", "pyjwt[crypto] ==2.3.0",
"python-multipart ==0.0.5", "python-multipart ==0.0.5",
"makefun >=1.11.2,<1.14", "makefun >=1.11.2,<1.14",
"typing-extensions >=4.1.1; python_version < '3.8'",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@ -22,13 +22,14 @@ from fastapi_users.manager import (
) )
from fastapi_users.models import BaseOAuthAccount, BaseOAuthAccountMixin from fastapi_users.models import BaseOAuthAccount, BaseOAuthAccountMixin
from fastapi_users.openapi import OpenAPIResponseType from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users.password import get_password_hash from fastapi_users.password import PasswordHelper
guinevere_password_hash = get_password_hash("guinevere") password_helper = PasswordHelper()
angharad_password_hash = get_password_hash("angharad") guinevere_password_hash = password_helper.hash("guinevere")
viviane_password_hash = get_password_hash("viviane") angharad_password_hash = password_helper.hash("angharad")
lancelot_password_hash = get_password_hash("lancelot") viviane_password_hash = password_helper.hash("viviane")
excalibur_password_hash = get_password_hash("excalibur") lancelot_password_hash = password_helper.hash("lancelot")
excalibur_password_hash = password_helper.hash("excalibur")
class User(models.BaseUser): class User(models.BaseUser):

View File

@ -536,8 +536,8 @@ class TestAuthenticate:
], ],
user_manager: UserManagerMock, user_manager: UserManagerMock,
): ):
verify_and_update_password_patch = mocker.patch( verify_and_update_password_patch = mocker.patch.object(
"fastapi_users.password.verify_and_update_password" user_manager.password_helper, "verify_and_update"
) )
verify_and_update_password_patch.return_value = (True, "updated_hash") verify_and_update_password_patch.return_value = (True, "updated_hash")
update_spy = mocker.spy(user_manager.user_db, "update") update_spy = mocker.spy(user_manager.user_db, "update")