diff --git a/sqlmodel_crud_manager/crud.py b/sqlmodel_crud_manager/crud.py index 447cc42..e8604d1 100644 --- a/sqlmodel_crud_manager/crud.py +++ b/sqlmodel_crud_manager/crud.py @@ -6,12 +6,15 @@ from sqlalchemy.engine.base import Engine from sqlmodel import Session, SQLModel, select from sqlmodel.sql.expression import Select +from sqlmodel_crud_manager.decorator import for_all_methods, raise_as_http_exception + ModelType = TypeVar("ModelType", bound=SQLModel) ModelCreateType = TypeVar("ModelCreateType", bound=SQLModel) QueryLike = TypeVar("QueryLike", bound=Select) @dataclass +@for_all_methods(raise_as_http_exception) class CRUDManager: model: ModelType @@ -28,6 +31,19 @@ class CRUDManager: self.model = model self.db = Session(engine) + def __validate_field_exists(self, field: str) -> None: + if field not in self.model.model_fields: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"{self.model} does not have a {field} field", + ) + + def __raise_not_found(self, detail: str) -> None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) + def get(self, pk: int, db: Session = None) -> ModelType: """ The function retrieves a model object from the database based on its @@ -45,12 +61,27 @@ class CRUDManager: """ self.db = db or self.db query = select(self.model).where(self.model.id == pk) - if obj := self.db.exec(query).one_or_none(): + return self.db.exec(query).one_or_none() + + def get_or_404(self, pk: int, db: Session = None) -> ModelType: + """ + The function retrieves a model object from the database based on its + primary key and raises an exception if the object is not found. + + Arguments: + + * `pk`: The parameter `pk` stands for "primary key" and it is of type + `int`. It is used to identify a specific object in the database based + on its primary key value. + + Returns: + + The `get` method is returning an instance of the `ModelType` class. + """ + if obj := self.get(pk, db=db): + print(obj) return obj - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"{self.model} with ID: {pk} Not Found", - ) + self.__raise_not_found(f"{self.model.__name__} with id {pk} not found") def get_by_ids(self, ids: list[int], db: Session = None) -> list[ModelType]: """ @@ -72,7 +103,13 @@ class CRUDManager: query = select(self.model).where(self.model.id.in_(ids)) return self.db.exec(query).all() - def get_by_field(self, field: str, value: str, db: Session = None) -> ModelType: + def get_by_field( + self, + field: str, + value: str, + allows_multiple: bool = False, + db: Session = None, + ) -> ModelType: """ The function retrieves a model object from the database based on a field and a value. @@ -90,13 +127,83 @@ class CRUDManager: The `get_by_field` method is returning an object of type `ModelType`. """ self.db = db or self.db + self.__validate_field_exists(field) + query = select(self.model).where(getattr(self.model, field) == value) - if obj := self.db.exec(query).one_or_none(): + if allows_multiple: + return self.db.exec(query).all() + return self.db.exec(query).one_or_none() + + def get_by_fields( + self, + fields: dict[str, str], + *, + allows_multiple: bool = False, + db: Session = None, + ) -> list[ModelType] | ModelType: + """ + The function retrieves a list of model objects from the database based + on a dictionary of fields and values. + + Arguments: + + * `fields`: The parameter `fields` is a dictionary of strings that + represents the name of a field in the database table and the value of + that field. + + Returns: + + The `get_by_fields` method is returning a list of objects of type + `ModelType`. + """ + self.db = db or self.db + query = select(self.model) + for field, value in fields.items(): + self.__validate_field_exists(field) + query = query.where(getattr(self.model, field) == value) + if allows_multiple: + return self.db.exec(query).all() + return self.db.exec(query).one_or_none() + + def get_or_create( + self, + object: ModelCreateType, + search_field: str = "id", + db: Session = None, + ) -> ModelType: + """ + The function `get_or_create` checks if an object exists in the database based + on a specified search field, and either returns the object if it exists + or creates a new object if it doesn't. + + Arguments: + + * `object`: The `object` parameter is the object that you want to get or + create in the database. It should be of type `ModelCreateType`, which is the + type of the object that you want to create. + * `search_field`: The `search_field` parameter is a string that specifies the + field to search for when checking if an object already exists in the database. + By default, it is set to "id", meaning it will search for an object with the + same id as the one being passed in. + * `db`: The `db` parameter is an optional parameter of type `Session`. + It represents the database session that will be used for the database operations + If no session is provided, it will use the default session. + + Returns: + + The function `get_or_create` returns an instance of `ModelType`. + """ + self.db = db or self.db + self.__validate_field_exists(search_field) + + if obj := self.get_by_field( + search_field, + getattr(object, search_field), + db=db, + ): return obj - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"{self.model} with {field}: {value} Not Found", - ) + else: + return self.create(object, db=db) def list(self, query: QueryLike = None, db: Session = None) -> list[ModelType]: """ @@ -138,6 +245,47 @@ class CRUDManager: self.db.refresh(obj) return obj + def create_or_update( + self, + object: ModelCreateType, + search_field: str = "id", + db: Session = None, + ) -> ModelType: + """ + The function `create_or_update` checks if an object exists in the database based + on a specified search field, and either updates the object if it exists + or creates a new object if it doesn't. + + Arguments: + + * `object`: The `object` parameter is the object that you want to create or + update in the database. It should be of type `ModelCreateType`, which is the + type of the object that you want to create. + * `search_field`: The `search_field` parameter is a string that specifies the + field to search for when checking if an object already exists in the database. + By default, it is set to "id", meaning it will search for an object with the + same id as the one being passed in. + * `db`: The `db` parameter is an optional parameter of type `Session`. + It represents the database session that will be used for the database operations + If no session is provided, it will use the default session. + + Returns: + + The function `create_or_update` returns an instance of `ModelType`. + """ + self.db = db or self.db + self.__validate_field_exists(search_field) + + if obj := self.get_by_field( + search_field, + getattr(object, search_field), + db=db, + ): + object.id = obj.id + return self.update(object, db=db) + else: + return self.create(object, db=db) + def update(self, input_object: ModelType, db: Session = None) -> ModelType: """ The function updates a database object with the values from an input @@ -181,7 +329,7 @@ class CRUDManager: the database. """ self.db = db or self.db - db_object = self.get(pk) + db_object = self.get_or_404(pk) self.db.delete(db_object) self.db.commit() return db_object diff --git a/sqlmodel_crud_manager/decorator.py b/sqlmodel_crud_manager/decorator.py new file mode 100644 index 0000000..754f1b6 --- /dev/null +++ b/sqlmodel_crud_manager/decorator.py @@ -0,0 +1,35 @@ +from fastapi import HTTPException + + +def raise_as_http_exception(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + if e.__class__.__name__ == "HTTPException": + raise e + + else: + raise HTTPException(status_code=500, detail=str(e)) from e + + return wrapper + + +def raise_404_if_none(func, detail="Not Found"): + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + if result is None: + raise HTTPException(status_code=404, detail=detail) + return result + + return wrapper + + +def for_all_methods(decorator): + def decorate(cls): + for attr in cls.__dict__: # there's propably a better way to do this + if callable(getattr(cls, attr)): + setattr(cls, attr, decorator(getattr(cls, attr))) + return cls + + return decorate diff --git a/tests/test_crud.py b/tests/test_crud.py index d42ca32..7ffb283 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -49,6 +49,18 @@ def test_get(last_id): assert hero.is_alive is True +def test_get_or_404(last_id): + hero = crud.get_or_404(last_id) + assert hero.id == last_id + assert hero.name == HERO_NAME + assert hero.secret_name == SECRET_NAME + assert hero.age is None + assert hero.is_alive is True + + with pytest.raises(HTTPException): + crud.get_or_404(last_id + 1) + + def test_get_by_ids(last_id): heroes = crud.get_by_ids([last_id]) assert len(heroes) == 1 @@ -68,10 +80,39 @@ def test_get_by_field(last_id): assert heroes.is_alive is True +def test_get_by_fields(last_id): + heroes = crud.get_by_fields({"name": HERO_NAME, "secret_name": SECRET_NAME}) + assert heroes.id == last_id + assert heroes.name == HERO_NAME + assert heroes.secret_name == SECRET_NAME + assert heroes.age is None + assert heroes.is_alive is True + + +def test_get_or_create(last_id): + hero = HeroCreate(name=HERO_NAME, secret_name=SECRET_NAME) + hero = crud.get_or_create(hero, search_field="name") + assert hero.id is not None + assert hero.id == last_id + assert hero.name == HERO_NAME + + hero = HeroCreate(name="NotExistent", secret_name=SECRET_NAME) + hero = crud.get_or_create(hero, search_field="name") + assert hero.id is not None + assert hero.id != last_id + assert hero.name == "NotExistent" + + hero = HeroCreate(name="NotExistentV2", secret_name=SECRET_NAME) + with pytest.raises(HTTPException): + crud.get_or_create(hero, search_field="id") + with pytest.raises(HTTPException): + crud.get_or_create(hero, search_field="NotExistent") + + def test_list(last_id): heroes = crud.list() - assert len(heroes) == 1 - assert heroes[0].id == last_id + assert len(heroes) == 2 + assert heroes[-1].id == last_id assert heroes[0].name == HERO_NAME assert heroes[0].secret_name == SECRET_NAME assert heroes[0].age is None @@ -98,4 +139,4 @@ def test_delete(last_id): assert hero.age == 30 assert hero.is_alive is True with pytest.raises(HTTPException): - crud.get(last_id) + crud.get_or_404(last_id)