add parameter db session optional

This commit is contained in:
Edkar Chachati
2023-12-24 21:16:52 -04:00
parent c2ccbc3ab7
commit cd233609bf
3 changed files with 19 additions and 15 deletions

View File

@ -28,7 +28,7 @@ class CRUDManager:
self.model = model
self.db = Session(engine)
def get(self, pk: int) -> ModelType:
def get(self, pk: int, db: Session = None) -> ModelType:
"""
The function retrieves a model object from the database based on its
primary key and raises an exception if the object is not found.
@ -43,6 +43,7 @@ class CRUDManager:
The `get` method is returning an instance of the `ModelType` class.
"""
self.db = db or self.db
query = select(self.model).where(self.model.id == pk)
obj = self.db.exec(query).one_or_none()
if obj is None:
@ -52,7 +53,7 @@ class CRUDManager:
)
return obj
def get_by_ids(self, ids: list[int]) -> list[ModelType]:
def get_by_ids(self, ids: list[int], db: Session = None) -> list[ModelType]:
"""
The function retrieves a list of model objects from the database based
on their primary keys.
@ -68,10 +69,11 @@ class CRUDManager:
The `get_by_ids` method is returning a list of objects of type
`ModelType`.
"""
self.db = db or self.db
query = select(self.model).where(self.model.id.in_(ids))
return self.db.exec(query).all()
def list(self, query: QueryLike = None) -> list[ModelType]:
def list(self, query: QueryLike = None, db: Session = None) -> list[ModelType]:
"""
The function returns a list of all the records in the database that
match the given query.
@ -87,11 +89,11 @@ class CRUDManager:
a list of objects of type `ModelType`.
"""
if query is None:
query = select(self.model)
self.db = db or self.db
query = query or select(self.model)
return self.db.exec(query).all()
def create(self, object: ModelCreateType) -> ModelType:
def create(self, object: ModelCreateType, db: Session = None) -> ModelType:
"""
The function creates a new object in the database and returns it.
@ -104,14 +106,14 @@ class CRUDManager:
The `create` method is returning an object of type `ModelType`.
"""
with self.db:
obj = self.model.model_validate(object)
self.db.add(obj)
self.db.commit()
self.db.refresh(obj)
self.db = db or self.db
obj = self.model.model_validate(object)
self.db.add(obj)
self.db.commit()
self.db.refresh(obj)
return obj
def update(self, input_object: ModelType) -> ModelType:
def update(self, input_object: ModelType, db: Session = None) -> ModelType:
"""
The function updates a database object with the values from an input
object and returns the updated object.
@ -127,6 +129,7 @@ class CRUDManager:
The `update` method is returning the `db_object` after it has been
updated in the database.
"""
self.db = db or self.db
db_object = self.get(input_object.id)
for field in input_object.model_fields:
setattr(db_object, field, getattr(input_object, field))
@ -135,7 +138,7 @@ class CRUDManager:
self.db.refresh(db_object)
return db_object
def delete(self, pk: int) -> ModelType:
def delete(self, pk: int, db: Session = None) -> ModelType:
"""
The function deletes a database object with a given primary key and
returns the deleted object.
@ -152,6 +155,7 @@ class CRUDManager:
The `delete` method is returning the `db_object` that was deleted from
the database.
"""
self.db = db or self.db
db_object = self.get(pk)
self.db.delete(db_object)
self.db.commit()