mirror of
				https://github.com/fastapi-users/fastapi-users.git
				synced 2025-11-04 06:37:51 +08:00 
			
		
		
		
	Implement password hash migration
This commit is contained in:
		@ -3,7 +3,7 @@ from typing import List
 | 
				
			|||||||
from fastapi.security import OAuth2PasswordRequestForm
 | 
					from fastapi.security import OAuth2PasswordRequestForm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from fastapi_users.models import UserDB
 | 
					from fastapi_users.models import UserDB
 | 
				
			||||||
from fastapi_users.password import get_password_hash, verify_password
 | 
					from fastapi_users.password import get_password_hash, verify_and_update_password
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class UserDBInterface:
 | 
					class UserDBInterface:
 | 
				
			||||||
@ -21,6 +21,9 @@ class UserDBInterface:
 | 
				
			|||||||
    async def create(self, user: UserDB) -> UserDB:
 | 
					    async def create(self, user: UserDB) -> UserDB:
 | 
				
			||||||
        raise NotImplementedError()
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def update(self, user: UserDB) -> UserDB:
 | 
				
			||||||
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB:
 | 
					    async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB:
 | 
				
			||||||
        user = await self.get_by_email(credentials.username)
 | 
					        user = await self.get_by_email(credentials.username)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -28,7 +31,15 @@ class UserDBInterface:
 | 
				
			|||||||
        # Inspired from Django: https://code.djangoproject.com/ticket/20760
 | 
					        # Inspired from Django: https://code.djangoproject.com/ticket/20760
 | 
				
			||||||
        get_password_hash(credentials.password)
 | 
					        get_password_hash(credentials.password)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if user is None or not verify_password(credentials.password, user.hashed_password):
 | 
					        if user is None:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            verified, updated_password_hash = verify_and_update_password(credentials.password, user.hashed_password)
 | 
				
			||||||
 | 
					            if not verified:
 | 
				
			||||||
 | 
					                return None
 | 
				
			||||||
 | 
					            # Update password hash to a more robust one if needed
 | 
				
			||||||
 | 
					            if updated_password_hash is not None:
 | 
				
			||||||
 | 
					                user.hashed_password = updated_password_hash
 | 
				
			||||||
 | 
					                await self.update(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return user
 | 
					        return user
 | 
				
			||||||
 | 
				
			|||||||
@ -42,3 +42,8 @@ class SQLAlchemyUserDB(UserDBInterface):
 | 
				
			|||||||
        query = users.insert().values(**user.dict())
 | 
					        query = users.insert().values(**user.dict())
 | 
				
			||||||
        await self.database.execute(query)
 | 
					        await self.database.execute(query)
 | 
				
			||||||
        return user
 | 
					        return user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def update(self, user: UserDB) -> UserDB:
 | 
				
			||||||
 | 
					        query = users.update().where(User.id == user.id).values(**user.dict())
 | 
				
			||||||
 | 
					        await self.database.execute(query)
 | 
				
			||||||
 | 
					        return user
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,13 @@
 | 
				
			|||||||
 | 
					from typing import Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from passlib.context import CryptContext
 | 
					from passlib.context import CryptContext
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
 | 
					pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def verify_password(plain_password: str, hashed_password: str):
 | 
					def verify_and_update_password(plain_password: str, hashed_password: str) -> Tuple[bool, str]:
 | 
				
			||||||
    return pwd_context.verify(plain_password, hashed_password)
 | 
					    return pwd_context.verify_and_update(plain_password, hashed_password)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_password_hash(password: str):
 | 
					def get_password_hash(password: str) -> str:
 | 
				
			||||||
    return pwd_context.hash(password)
 | 
					    return pwd_context.hash(password)
 | 
				
			||||||
 | 
				
			|||||||
@ -33,16 +33,21 @@ async def test_queries(user, sqlalchemy_user_db):
 | 
				
			|||||||
    assert user_db.is_superuser is False
 | 
					    assert user_db.is_superuser is False
 | 
				
			||||||
    assert user_db.email == user.email
 | 
					    assert user_db.email == user.email
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Update
 | 
				
			||||||
 | 
					    user_db.is_superuser = True
 | 
				
			||||||
 | 
					    await sqlalchemy_user_db.update(user_db)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Get by email
 | 
				
			||||||
 | 
					    email_user = await sqlalchemy_user_db.get_by_email(user.email)
 | 
				
			||||||
 | 
					    assert email_user.id == user_db.id
 | 
				
			||||||
 | 
					    assert email_user.is_superuser is True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # List
 | 
					    # List
 | 
				
			||||||
    users = await sqlalchemy_user_db.list()
 | 
					    users = await sqlalchemy_user_db.list()
 | 
				
			||||||
    assert len(users) == 1
 | 
					    assert len(users) == 1
 | 
				
			||||||
    first_user = users[0]
 | 
					    first_user = users[0]
 | 
				
			||||||
    assert first_user.id == user_db.id
 | 
					    assert first_user.id == user_db.id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Get by email
 | 
					 | 
				
			||||||
    email_user = await sqlalchemy_user_db.get_by_email(user.email)
 | 
					 | 
				
			||||||
    assert email_user.id == user_db.id
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Exception when inserting existing email
 | 
					    # Exception when inserting existing email
 | 
				
			||||||
    with pytest.raises(sqlite3.IntegrityError):
 | 
					    with pytest.raises(sqlite3.IntegrityError):
 | 
				
			||||||
        await sqlalchemy_user_db.create(user)
 | 
					        await sqlalchemy_user_db.create(user)
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user