diff --git a/sqlmodel_crud_manager/crud.py b/sqlmodel_crud_manager/crud.py index 22fda47..5ae59f7 100644 --- a/sqlmodel_crud_manager/crud.py +++ b/sqlmodel_crud_manager/crud.py @@ -52,6 +52,25 @@ class CRUDManager: ) return obj + def get_by_ids(self, ids: list[int]) -> list[ModelType]: + """ + The function retrieves a list of model objects from the database based + on their primary keys. + + Arguments: + + * `ids`: The parameter `ids` is a list of integers. It is used to + identify a list of objects in the database based on their primary key + values. + + Returns: + + The `get_by_ids` method is returning a list of objects of type + `ModelType`. + """ + query = select(self.model).where(self.model.id.in_(ids)) + return self.db.exec(query).all() + def list(self, query: QueryLike = None) -> list[ModelType]: """ The function returns a list of all the records in the database that diff --git a/tests/test_crud.py b/tests/test_crud.py index 9337616..65a9745 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -49,6 +49,16 @@ def test_get(last_id): assert hero.is_alive is True +def test_get_by_ids(last_id): + heroes = crud.get_by_ids([last_id]) + assert len(heroes) == 1 + assert heroes[0].id == last_id + assert heroes[0].name == HERO_NAME + assert heroes[0].secret_name == SECRET_NAME + assert heroes[0].age is None + assert heroes[0].is_alive is True + + def test_list(last_id): heroes = crud.list() assert len(heroes) == 1