add parameter db session optional

This commit is contained in:
Edkar Chachati
2023-12-24 21:16:52 -04:00
parent c2ccbc3ab7
commit cd233609bf
3 changed files with 19 additions and 15 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "SQLModel-CRUD-manager" name = "SQLModel-CRUD-manager"
version = "0.1.4" version = "0.1.6"
description = "The SQLModel CRUD Manager is a Python library that facilitates common Create, Read, Update, and Delete (CRUD) operations on SQLModel entities within a FastAPI application. This library simplifies database interactions and provides an easy-to-use interface for managing SQLModel entities." description = "The SQLModel CRUD Manager is a Python library that facilitates common Create, Read, Update, and Delete (CRUD) operations on SQLModel entities within a FastAPI application. This library simplifies database interactions and provides an easy-to-use interface for managing SQLModel entities."
authors = ["Edkar Chachati <chachati28@gmail.com>"] authors = ["Edkar Chachati <chachati28@gmail.com>"]
license = "MIT" license = "MIT"

View File

@ -28,7 +28,7 @@ class CRUDManager:
self.model = model self.model = model
self.db = Session(engine) self.db = Session(engine)
def get(self, pk: int) -> ModelType: def get(self, pk: int, db: Session = None) -> ModelType:
""" """
The function retrieves a model object from the database based on its The function retrieves a model object from the database based on its
primary key and raises an exception if the object is not found. primary key and raises an exception if the object is not found.
@ -43,6 +43,7 @@ class CRUDManager:
The `get` method is returning an instance of the `ModelType` class. The `get` method is returning an instance of the `ModelType` class.
""" """
self.db = db or self.db
query = select(self.model).where(self.model.id == pk) query = select(self.model).where(self.model.id == pk)
obj = self.db.exec(query).one_or_none() obj = self.db.exec(query).one_or_none()
if obj is None: if obj is None:
@ -52,7 +53,7 @@ class CRUDManager:
) )
return obj return obj
def get_by_ids(self, ids: list[int]) -> list[ModelType]: def get_by_ids(self, ids: list[int], db: Session = None) -> list[ModelType]:
""" """
The function retrieves a list of model objects from the database based The function retrieves a list of model objects from the database based
on their primary keys. on their primary keys.
@ -68,10 +69,11 @@ class CRUDManager:
The `get_by_ids` method is returning a list of objects of type The `get_by_ids` method is returning a list of objects of type
`ModelType`. `ModelType`.
""" """
self.db = db or self.db
query = select(self.model).where(self.model.id.in_(ids)) query = select(self.model).where(self.model.id.in_(ids))
return self.db.exec(query).all() return self.db.exec(query).all()
def list(self, query: QueryLike = None) -> list[ModelType]: def list(self, query: QueryLike = None, db: Session = None) -> list[ModelType]:
""" """
The function returns a list of all the records in the database that The function returns a list of all the records in the database that
match the given query. match the given query.
@ -87,11 +89,11 @@ class CRUDManager:
a list of objects of type `ModelType`. a list of objects of type `ModelType`.
""" """
if query is None: self.db = db or self.db
query = select(self.model) query = query or select(self.model)
return self.db.exec(query).all() return self.db.exec(query).all()
def create(self, object: ModelCreateType) -> ModelType: def create(self, object: ModelCreateType, db: Session = None) -> ModelType:
""" """
The function creates a new object in the database and returns it. The function creates a new object in the database and returns it.
@ -104,14 +106,14 @@ class CRUDManager:
The `create` method is returning an object of type `ModelType`. The `create` method is returning an object of type `ModelType`.
""" """
with self.db: self.db = db or self.db
obj = self.model.model_validate(object) obj = self.model.model_validate(object)
self.db.add(obj) self.db.add(obj)
self.db.commit() self.db.commit()
self.db.refresh(obj) self.db.refresh(obj)
return obj return obj
def update(self, input_object: ModelType) -> ModelType: def update(self, input_object: ModelType, db: Session = None) -> ModelType:
""" """
The function updates a database object with the values from an input The function updates a database object with the values from an input
object and returns the updated object. object and returns the updated object.
@ -127,6 +129,7 @@ class CRUDManager:
The `update` method is returning the `db_object` after it has been The `update` method is returning the `db_object` after it has been
updated in the database. updated in the database.
""" """
self.db = db or self.db
db_object = self.get(input_object.id) db_object = self.get(input_object.id)
for field in input_object.model_fields: for field in input_object.model_fields:
setattr(db_object, field, getattr(input_object, field)) setattr(db_object, field, getattr(input_object, field))
@ -135,7 +138,7 @@ class CRUDManager:
self.db.refresh(db_object) self.db.refresh(db_object)
return db_object return db_object
def delete(self, pk: int) -> ModelType: def delete(self, pk: int, db: Session = None) -> ModelType:
""" """
The function deletes a database object with a given primary key and The function deletes a database object with a given primary key and
returns the deleted object. returns the deleted object.
@ -152,6 +155,7 @@ class CRUDManager:
The `delete` method is returning the `db_object` that was deleted from The `delete` method is returning the `db_object` that was deleted from
the database. the database.
""" """
self.db = db or self.db
db_object = self.get(pk) db_object = self.get(pk)
self.db.delete(db_object) self.db.delete(db_object)
self.db.commit() self.db.commit()

View File

@ -20,7 +20,7 @@ SECRET_NAME = "Dive Wilson"
HERO_NAME = "Deadpond" HERO_NAME = "Deadpond"
engine = create_engine("sqlite:///tests/testing.db", echo=True) engine = create_engine("sqlite:///tests/testing.db")
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)
crud = CRUDManager(Hero, engine) crud = CRUDManager(Hero, engine)