mirror of
https://github.com/EChachati/SQLModel-CRUD-manager.git
synced 2025-08-15 03:32:32 +08:00
Add New functionalities
This commit is contained in:
@ -6,12 +6,15 @@ from sqlalchemy.engine.base import Engine
|
|||||||
from sqlmodel import Session, SQLModel, select
|
from sqlmodel import Session, SQLModel, select
|
||||||
from sqlmodel.sql.expression import Select
|
from sqlmodel.sql.expression import Select
|
||||||
|
|
||||||
|
from sqlmodel_crud_manager.decorator import for_all_methods, raise_as_http_exception
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=SQLModel)
|
ModelType = TypeVar("ModelType", bound=SQLModel)
|
||||||
ModelCreateType = TypeVar("ModelCreateType", bound=SQLModel)
|
ModelCreateType = TypeVar("ModelCreateType", bound=SQLModel)
|
||||||
QueryLike = TypeVar("QueryLike", bound=Select)
|
QueryLike = TypeVar("QueryLike", bound=Select)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@for_all_methods(raise_as_http_exception)
|
||||||
class CRUDManager:
|
class CRUDManager:
|
||||||
model: ModelType
|
model: ModelType
|
||||||
|
|
||||||
@ -28,6 +31,19 @@ class CRUDManager:
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.db = Session(engine)
|
self.db = Session(engine)
|
||||||
|
|
||||||
|
def __validate_field_exists(self, field: str) -> None:
|
||||||
|
if field not in self.model.model_fields:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"{self.model} does not have a {field} field",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __raise_not_found(self, detail: str) -> None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
|
||||||
def get(self, pk: int, db: Session = None) -> ModelType:
|
def get(self, pk: int, db: Session = None) -> ModelType:
|
||||||
"""
|
"""
|
||||||
The function retrieves a model object from the database based on its
|
The function retrieves a model object from the database based on its
|
||||||
@ -45,12 +61,27 @@ class CRUDManager:
|
|||||||
"""
|
"""
|
||||||
self.db = db or self.db
|
self.db = db or self.db
|
||||||
query = select(self.model).where(self.model.id == pk)
|
query = select(self.model).where(self.model.id == pk)
|
||||||
if obj := self.db.exec(query).one_or_none():
|
return self.db.exec(query).one_or_none()
|
||||||
|
|
||||||
|
def get_or_404(self, pk: int, db: Session = None) -> ModelType:
|
||||||
|
"""
|
||||||
|
The function retrieves a model object from the database based on its
|
||||||
|
primary key and raises an exception if the object is not found.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* `pk`: The parameter `pk` stands for "primary key" and it is of type
|
||||||
|
`int`. It is used to identify a specific object in the database based
|
||||||
|
on its primary key value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
The `get` method is returning an instance of the `ModelType` class.
|
||||||
|
"""
|
||||||
|
if obj := self.get(pk, db=db):
|
||||||
|
print(obj)
|
||||||
return obj
|
return obj
|
||||||
raise HTTPException(
|
self.__raise_not_found(f"{self.model.__name__} with id {pk} not found")
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"{self.model} with ID: {pk} Not Found",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_by_ids(self, ids: list[int], db: Session = None) -> list[ModelType]:
|
def get_by_ids(self, ids: list[int], db: Session = None) -> list[ModelType]:
|
||||||
"""
|
"""
|
||||||
@ -72,7 +103,13 @@ class CRUDManager:
|
|||||||
query = select(self.model).where(self.model.id.in_(ids))
|
query = select(self.model).where(self.model.id.in_(ids))
|
||||||
return self.db.exec(query).all()
|
return self.db.exec(query).all()
|
||||||
|
|
||||||
def get_by_field(self, field: str, value: str, db: Session = None) -> ModelType:
|
def get_by_field(
|
||||||
|
self,
|
||||||
|
field: str,
|
||||||
|
value: str,
|
||||||
|
allows_multiple: bool = False,
|
||||||
|
db: Session = None,
|
||||||
|
) -> ModelType:
|
||||||
"""
|
"""
|
||||||
The function retrieves a model object from the database based on a
|
The function retrieves a model object from the database based on a
|
||||||
field and a value.
|
field and a value.
|
||||||
@ -90,13 +127,83 @@ class CRUDManager:
|
|||||||
The `get_by_field` method is returning an object of type `ModelType`.
|
The `get_by_field` method is returning an object of type `ModelType`.
|
||||||
"""
|
"""
|
||||||
self.db = db or self.db
|
self.db = db or self.db
|
||||||
|
self.__validate_field_exists(field)
|
||||||
|
|
||||||
query = select(self.model).where(getattr(self.model, field) == value)
|
query = select(self.model).where(getattr(self.model, field) == value)
|
||||||
if obj := self.db.exec(query).one_or_none():
|
if allows_multiple:
|
||||||
|
return self.db.exec(query).all()
|
||||||
|
return self.db.exec(query).one_or_none()
|
||||||
|
|
||||||
|
def get_by_fields(
|
||||||
|
self,
|
||||||
|
fields: dict[str, str],
|
||||||
|
*,
|
||||||
|
allows_multiple: bool = False,
|
||||||
|
db: Session = None,
|
||||||
|
) -> list[ModelType] | ModelType:
|
||||||
|
"""
|
||||||
|
The function retrieves a list of model objects from the database based
|
||||||
|
on a dictionary of fields and values.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* `fields`: The parameter `fields` is a dictionary of strings that
|
||||||
|
represents the name of a field in the database table and the value of
|
||||||
|
that field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
The `get_by_fields` method is returning a list of objects of type
|
||||||
|
`ModelType`.
|
||||||
|
"""
|
||||||
|
self.db = db or self.db
|
||||||
|
query = select(self.model)
|
||||||
|
for field, value in fields.items():
|
||||||
|
self.__validate_field_exists(field)
|
||||||
|
query = query.where(getattr(self.model, field) == value)
|
||||||
|
if allows_multiple:
|
||||||
|
return self.db.exec(query).all()
|
||||||
|
return self.db.exec(query).one_or_none()
|
||||||
|
|
||||||
|
def get_or_create(
|
||||||
|
self,
|
||||||
|
object: ModelCreateType,
|
||||||
|
search_field: str = "id",
|
||||||
|
db: Session = None,
|
||||||
|
) -> ModelType:
|
||||||
|
"""
|
||||||
|
The function `get_or_create` checks if an object exists in the database based
|
||||||
|
on a specified search field, and either returns the object if it exists
|
||||||
|
or creates a new object if it doesn't.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* `object`: The `object` parameter is the object that you want to get or
|
||||||
|
create in the database. It should be of type `ModelCreateType`, which is the
|
||||||
|
type of the object that you want to create.
|
||||||
|
* `search_field`: The `search_field` parameter is a string that specifies the
|
||||||
|
field to search for when checking if an object already exists in the database.
|
||||||
|
By default, it is set to "id", meaning it will search for an object with the
|
||||||
|
same id as the one being passed in.
|
||||||
|
* `db`: The `db` parameter is an optional parameter of type `Session`.
|
||||||
|
It represents the database session that will be used for the database operations
|
||||||
|
If no session is provided, it will use the default session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
The function `get_or_create` returns an instance of `ModelType`.
|
||||||
|
"""
|
||||||
|
self.db = db or self.db
|
||||||
|
self.__validate_field_exists(search_field)
|
||||||
|
|
||||||
|
if obj := self.get_by_field(
|
||||||
|
search_field,
|
||||||
|
getattr(object, search_field),
|
||||||
|
db=db,
|
||||||
|
):
|
||||||
return obj
|
return obj
|
||||||
raise HTTPException(
|
else:
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
return self.create(object, db=db)
|
||||||
detail=f"{self.model} with {field}: {value} Not Found",
|
|
||||||
)
|
|
||||||
|
|
||||||
def list(self, query: QueryLike = None, db: Session = None) -> list[ModelType]:
|
def list(self, query: QueryLike = None, db: Session = None) -> list[ModelType]:
|
||||||
"""
|
"""
|
||||||
@ -138,6 +245,47 @@ class CRUDManager:
|
|||||||
self.db.refresh(obj)
|
self.db.refresh(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def create_or_update(
|
||||||
|
self,
|
||||||
|
object: ModelCreateType,
|
||||||
|
search_field: str = "id",
|
||||||
|
db: Session = None,
|
||||||
|
) -> ModelType:
|
||||||
|
"""
|
||||||
|
The function `create_or_update` checks if an object exists in the database based
|
||||||
|
on a specified search field, and either updates the object if it exists
|
||||||
|
or creates a new object if it doesn't.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* `object`: The `object` parameter is the object that you want to create or
|
||||||
|
update in the database. It should be of type `ModelCreateType`, which is the
|
||||||
|
type of the object that you want to create.
|
||||||
|
* `search_field`: The `search_field` parameter is a string that specifies the
|
||||||
|
field to search for when checking if an object already exists in the database.
|
||||||
|
By default, it is set to "id", meaning it will search for an object with the
|
||||||
|
same id as the one being passed in.
|
||||||
|
* `db`: The `db` parameter is an optional parameter of type `Session`.
|
||||||
|
It represents the database session that will be used for the database operations
|
||||||
|
If no session is provided, it will use the default session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
The function `create_or_update` returns an instance of `ModelType`.
|
||||||
|
"""
|
||||||
|
self.db = db or self.db
|
||||||
|
self.__validate_field_exists(search_field)
|
||||||
|
|
||||||
|
if obj := self.get_by_field(
|
||||||
|
search_field,
|
||||||
|
getattr(object, search_field),
|
||||||
|
db=db,
|
||||||
|
):
|
||||||
|
object.id = obj.id
|
||||||
|
return self.update(object, db=db)
|
||||||
|
else:
|
||||||
|
return self.create(object, db=db)
|
||||||
|
|
||||||
def update(self, input_object: ModelType, db: Session = None) -> ModelType:
|
def update(self, input_object: ModelType, db: Session = None) -> ModelType:
|
||||||
"""
|
"""
|
||||||
The function updates a database object with the values from an input
|
The function updates a database object with the values from an input
|
||||||
@ -181,7 +329,7 @@ class CRUDManager:
|
|||||||
the database.
|
the database.
|
||||||
"""
|
"""
|
||||||
self.db = db or self.db
|
self.db = db or self.db
|
||||||
db_object = self.get(pk)
|
db_object = self.get_or_404(pk)
|
||||||
self.db.delete(db_object)
|
self.db.delete(db_object)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
return db_object
|
return db_object
|
||||||
|
35
sqlmodel_crud_manager/decorator.py
Normal file
35
sqlmodel_crud_manager/decorator.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
def raise_as_http_exception(func):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
if e.__class__.__name__ == "HTTPException":
|
||||||
|
raise e
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def raise_404_if_none(func, detail="Not Found"):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
if result is None:
|
||||||
|
raise HTTPException(status_code=404, detail=detail)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def for_all_methods(decorator):
|
||||||
|
def decorate(cls):
|
||||||
|
for attr in cls.__dict__: # there's propably a better way to do this
|
||||||
|
if callable(getattr(cls, attr)):
|
||||||
|
setattr(cls, attr, decorator(getattr(cls, attr)))
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorate
|
@ -49,6 +49,18 @@ def test_get(last_id):
|
|||||||
assert hero.is_alive is True
|
assert hero.is_alive is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_or_404(last_id):
|
||||||
|
hero = crud.get_or_404(last_id)
|
||||||
|
assert hero.id == last_id
|
||||||
|
assert hero.name == HERO_NAME
|
||||||
|
assert hero.secret_name == SECRET_NAME
|
||||||
|
assert hero.age is None
|
||||||
|
assert hero.is_alive is True
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
crud.get_or_404(last_id + 1)
|
||||||
|
|
||||||
|
|
||||||
def test_get_by_ids(last_id):
|
def test_get_by_ids(last_id):
|
||||||
heroes = crud.get_by_ids([last_id])
|
heroes = crud.get_by_ids([last_id])
|
||||||
assert len(heroes) == 1
|
assert len(heroes) == 1
|
||||||
@ -68,10 +80,39 @@ def test_get_by_field(last_id):
|
|||||||
assert heroes.is_alive is True
|
assert heroes.is_alive is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_by_fields(last_id):
|
||||||
|
heroes = crud.get_by_fields({"name": HERO_NAME, "secret_name": SECRET_NAME})
|
||||||
|
assert heroes.id == last_id
|
||||||
|
assert heroes.name == HERO_NAME
|
||||||
|
assert heroes.secret_name == SECRET_NAME
|
||||||
|
assert heroes.age is None
|
||||||
|
assert heroes.is_alive is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_or_create(last_id):
|
||||||
|
hero = HeroCreate(name=HERO_NAME, secret_name=SECRET_NAME)
|
||||||
|
hero = crud.get_or_create(hero, search_field="name")
|
||||||
|
assert hero.id is not None
|
||||||
|
assert hero.id == last_id
|
||||||
|
assert hero.name == HERO_NAME
|
||||||
|
|
||||||
|
hero = HeroCreate(name="NotExistent", secret_name=SECRET_NAME)
|
||||||
|
hero = crud.get_or_create(hero, search_field="name")
|
||||||
|
assert hero.id is not None
|
||||||
|
assert hero.id != last_id
|
||||||
|
assert hero.name == "NotExistent"
|
||||||
|
|
||||||
|
hero = HeroCreate(name="NotExistentV2", secret_name=SECRET_NAME)
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
crud.get_or_create(hero, search_field="id")
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
crud.get_or_create(hero, search_field="NotExistent")
|
||||||
|
|
||||||
|
|
||||||
def test_list(last_id):
|
def test_list(last_id):
|
||||||
heroes = crud.list()
|
heroes = crud.list()
|
||||||
assert len(heroes) == 1
|
assert len(heroes) == 2
|
||||||
assert heroes[0].id == last_id
|
assert heroes[-1].id == last_id
|
||||||
assert heroes[0].name == HERO_NAME
|
assert heroes[0].name == HERO_NAME
|
||||||
assert heroes[0].secret_name == SECRET_NAME
|
assert heroes[0].secret_name == SECRET_NAME
|
||||||
assert heroes[0].age is None
|
assert heroes[0].age is None
|
||||||
@ -98,4 +139,4 @@ def test_delete(last_id):
|
|||||||
assert hero.age == 30
|
assert hero.age == 30
|
||||||
assert hero.is_alive is True
|
assert hero.is_alive is True
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
crud.get(last_id)
|
crud.get_or_404(last_id)
|
||||||
|
Reference in New Issue
Block a user