diff --git a/fastapi_users/db/__init__.py b/fastapi_users/db/__init__.py index 57296bc0..9f966e8b 100644 --- a/fastapi_users/db/__init__.py +++ b/fastapi_users/db/__init__.py @@ -15,6 +15,9 @@ class UserDBInterface: async def list(self) -> List[UserDB]: raise NotImplementedError() + async def get(self, id: str) -> UserDB: + raise NotImplementedError() + async def get_by_email(self, email: str) -> UserDB: raise NotImplementedError() diff --git a/fastapi_users/db/sqlalchemy.py b/fastapi_users/db/sqlalchemy.py index dc465793..54766f36 100644 --- a/fastapi_users/db/sqlalchemy.py +++ b/fastapi_users/db/sqlalchemy.py @@ -34,6 +34,10 @@ class SQLAlchemyUserDB(UserDBInterface): query = users.select() return await self.database.fetch_all(query) + async def get(self, id: str) -> UserDB: + query = users.select().where(User.id == id) + return await self.database.fetch_one(query) + async def get_by_email(self, email: str) -> UserDB: query = users.select().where(User.email == email) return await self.database.fetch_one(query) diff --git a/tests/test_db_sqlalchemy.py b/tests/test_db_sqlalchemy.py index c92e2cb6..e6dce1b5 100644 --- a/tests/test_db_sqlalchemy.py +++ b/tests/test_db_sqlalchemy.py @@ -37,10 +37,14 @@ async def test_queries(user, sqlalchemy_user_db): user_db.is_superuser = True await sqlalchemy_user_db.update(user_db) + # Get by id + id_user = await sqlalchemy_user_db.get(user.id) + assert id_user.id == user_db.id + assert id_user.is_superuser is True + # Get by email email_user = await sqlalchemy_user_db.get_by_email(user.email) assert email_user.id == user_db.id - assert email_user.is_superuser is True # List users = await sqlalchemy_user_db.list()