diff --git a/fastapi_users/db/__init__.py b/fastapi_users/db/__init__.py index e7964ad1..57296bc0 100644 --- a/fastapi_users/db/__init__.py +++ b/fastapi_users/db/__init__.py @@ -3,7 +3,7 @@ from typing import List from fastapi.security import OAuth2PasswordRequestForm from fastapi_users.models import UserDB -from fastapi_users.password import get_password_hash, verify_password +from fastapi_users.password import get_password_hash, verify_and_update_password class UserDBInterface: @@ -21,6 +21,9 @@ class UserDBInterface: async def create(self, user: UserDB) -> UserDB: raise NotImplementedError() + async def update(self, user: UserDB) -> UserDB: + raise NotImplementedError() + async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB: user = await self.get_by_email(credentials.username) @@ -28,7 +31,15 @@ class UserDBInterface: # Inspired from Django: https://code.djangoproject.com/ticket/20760 get_password_hash(credentials.password) - if user is None or not verify_password(credentials.password, user.hashed_password): + if user is None: return None + else: + verified, updated_password_hash = verify_and_update_password(credentials.password, user.hashed_password) + if not verified: + return None + # Update password hash to a more robust one if needed + if updated_password_hash is not None: + user.hashed_password = updated_password_hash + await self.update(user) return user diff --git a/fastapi_users/db/sqlalchemy.py b/fastapi_users/db/sqlalchemy.py index f6b42953..dc465793 100644 --- a/fastapi_users/db/sqlalchemy.py +++ b/fastapi_users/db/sqlalchemy.py @@ -42,3 +42,8 @@ class SQLAlchemyUserDB(UserDBInterface): query = users.insert().values(**user.dict()) await self.database.execute(query) return user + + async def update(self, user: UserDB) -> UserDB: + query = users.update().where(User.id == user.id).values(**user.dict()) + await self.database.execute(query) + return user diff --git a/fastapi_users/password.py b/fastapi_users/password.py index d627fd48..b2fd200b 100644 --- a/fastapi_users/password.py +++ b/fastapi_users/password.py @@ -1,11 +1,13 @@ +from typing import Tuple + from passlib.context import CryptContext pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto') -def verify_password(plain_password: str, hashed_password: str): - return pwd_context.verify(plain_password, hashed_password) +def verify_and_update_password(plain_password: str, hashed_password: str) -> Tuple[bool, str]: + return pwd_context.verify_and_update(plain_password, hashed_password) -def get_password_hash(password: str): +def get_password_hash(password: str) -> str: return pwd_context.hash(password) diff --git a/tests/test_db_sqlalchemy.py b/tests/test_db_sqlalchemy.py index d16c128c..c92e2cb6 100644 --- a/tests/test_db_sqlalchemy.py +++ b/tests/test_db_sqlalchemy.py @@ -33,16 +33,21 @@ async def test_queries(user, sqlalchemy_user_db): assert user_db.is_superuser is False assert user_db.email == user.email + # Update + user_db.is_superuser = True + await sqlalchemy_user_db.update(user_db) + + # Get by email + email_user = await sqlalchemy_user_db.get_by_email(user.email) + assert email_user.id == user_db.id + assert email_user.is_superuser is True + # List users = await sqlalchemy_user_db.list() assert len(users) == 1 first_user = users[0] assert first_user.id == user_db.id - # Get by email - email_user = await sqlalchemy_user_db.get_by_email(user.email) - assert email_user.id == user_db.id - # Exception when inserting existing email with pytest.raises(sqlite3.IntegrityError): await sqlalchemy_user_db.create(user)