diff --git a/pyproject.toml b/pyproject.toml index f6d7883..0740606 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "SQLModel-CRUD-manager" -version = "0.1.6" +version = "0.1.7" description = "The SQLModel CRUD Manager is a Python library that facilitates common Create, Read, Update, and Delete (CRUD) operations on SQLModel entities within a FastAPI application. This library simplifies database interactions and provides an easy-to-use interface for managing SQLModel entities." authors = ["Edkar Chachati "] license = "MIT" diff --git a/sqlmodel_crud_manager/crud.py b/sqlmodel_crud_manager/crud.py index 974133c..447cc42 100644 --- a/sqlmodel_crud_manager/crud.py +++ b/sqlmodel_crud_manager/crud.py @@ -45,13 +45,12 @@ class CRUDManager: """ self.db = db or self.db query = select(self.model).where(self.model.id == pk) - obj = self.db.exec(query).one_or_none() - if obj is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"{self.model} with ID: {pk} Not Found", - ) - return obj + if obj := self.db.exec(query).one_or_none(): + return obj + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"{self.model} with ID: {pk} Not Found", + ) def get_by_ids(self, ids: list[int], db: Session = None) -> list[ModelType]: """ @@ -73,6 +72,32 @@ class CRUDManager: query = select(self.model).where(self.model.id.in_(ids)) return self.db.exec(query).all() + def get_by_field(self, field: str, value: str, db: Session = None) -> ModelType: + """ + The function retrieves a model object from the database based on a + field and a value. + + Arguments: + + * `field`: The parameter `field` is a string that represents the name + of a field in the database table. + + * `value`: The parameter `value` is a string that represents the value + of a field in the database table. + + Returns: + + The `get_by_field` method is returning an object of type `ModelType`. + """ + self.db = db or self.db + query = select(self.model).where(getattr(self.model, field) == value) + if obj := self.db.exec(query).one_or_none(): + return obj + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"{self.model} with {field}: {value} Not Found", + ) + def list(self, query: QueryLike = None, db: Session = None) -> list[ModelType]: """ The function returns a list of all the records in the database that diff --git a/tests/test_crud.py b/tests/test_crud.py index c3fa5c7..d42ca32 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -59,6 +59,15 @@ def test_get_by_ids(last_id): assert heroes[0].is_alive is True +def test_get_by_field(last_id): + heroes = crud.get_by_field("name", HERO_NAME) + assert heroes.id == last_id + assert heroes.name == HERO_NAME + assert heroes.secret_name == SECRET_NAME + assert heroes.age is None + assert heroes.is_alive is True + + def test_list(last_id): heroes = crud.list() assert len(heroes) == 1