diff --git a/sqlmodel_crud_manager/crud.py b/sqlmodel_crud_manager/crud.py index 8cb8da7..3b7c1d0 100644 --- a/sqlmodel_crud_manager/crud.py +++ b/sqlmodel_crud_manager/crud.py @@ -4,6 +4,7 @@ from typing import List, TypeVar from fastapi import HTTPException, status from sqlalchemy.engine.base import Engine from sqlmodel import Session, SQLModel, select +from sqlmodel import update as sqlmodel_update from sqlmodel.sql.expression import Select from sqlmodel_crud_manager.decorator import for_all_methods, raise_as_http_exception @@ -387,6 +388,56 @@ class CRUDManager: else: return self.create(object, db=db) + def create_or_update_multiple_by_fields( + self, + objects: List[ModelCreateType], + fields: List[str], + db: Session = None, + ) -> List[ModelType]: + """ + The function `create_or_update_multiple_by_fields` creates or updates a list of + model objects based on specified fields. + + Arguments: + + * `objects`: The `objects` parameter is a list of objects that you want to + create or update in the database. It should be of type `List[ModelCreateType]`, + which is the type of the objects that can be created in the database. + * `fields`: The `fields` parameter is a list of strings that represents + the fields (attributes) of the object that are used to check for existing + records in the database. These fields are used to query the database and + determine if a record with the same values already exists. + * `db`: The `db` parameter is an optional argument of type `Session`. It + represents the database session that will be used for database operations. + If no session is provided, the method will use the default session stored in + the `self.db` attribute. + + Returns: + + The function `create_or_update_multiple_by_fields` returns a list of + `ModelType` objects. + """ + self.db = db or self.db + for field in fields: + self.__validate_field_exists(field) + + objects_to_create = [] + objects_to_update = [] + for object in objects: + if obj := self.get_by_fields( + {field: getattr(object, field) for field in fields}, + db=db, + ): + new_object = self.model.model_validate(object) + new_object.id = obj.id + objects_to_update.append(new_object) + else: + objects_to_create.append(object) + + objects_created = self.create_multiple(objects_to_create, db=db) + objects_updated = self.update_multiple(objects_to_update, db=db) + + return objects_created + objects_updated def update(self, input_object: ModelType, db: Session = None) -> None: """