mirror of
https://github.com/EChachati/SQLModel-CRUD-manager.git
synced 2025-08-14 11:00:28 +08:00
Add New functionalities
This commit is contained in:
@ -6,12 +6,15 @@ from sqlalchemy.engine.base import Engine
|
||||
from sqlmodel import Session, SQLModel, select
|
||||
from sqlmodel.sql.expression import Select
|
||||
|
||||
from sqlmodel_crud_manager.decorator import for_all_methods, raise_as_http_exception
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=SQLModel)
|
||||
ModelCreateType = TypeVar("ModelCreateType", bound=SQLModel)
|
||||
QueryLike = TypeVar("QueryLike", bound=Select)
|
||||
|
||||
|
||||
@dataclass
|
||||
@for_all_methods(raise_as_http_exception)
|
||||
class CRUDManager:
|
||||
model: ModelType
|
||||
|
||||
@ -28,6 +31,19 @@ class CRUDManager:
|
||||
self.model = model
|
||||
self.db = Session(engine)
|
||||
|
||||
def __validate_field_exists(self, field: str) -> None:
|
||||
if field not in self.model.model_fields:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"{self.model} does not have a {field} field",
|
||||
)
|
||||
|
||||
def __raise_not_found(self, detail: str) -> None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
def get(self, pk: int, db: Session = None) -> ModelType:
|
||||
"""
|
||||
The function retrieves a model object from the database based on its
|
||||
@ -45,12 +61,27 @@ class CRUDManager:
|
||||
"""
|
||||
self.db = db or self.db
|
||||
query = select(self.model).where(self.model.id == pk)
|
||||
if obj := self.db.exec(query).one_or_none():
|
||||
return self.db.exec(query).one_or_none()
|
||||
|
||||
def get_or_404(self, pk: int, db: Session = None) -> ModelType:
|
||||
"""
|
||||
The function retrieves a model object from the database based on its
|
||||
primary key and raises an exception if the object is not found.
|
||||
|
||||
Arguments:
|
||||
|
||||
* `pk`: The parameter `pk` stands for "primary key" and it is of type
|
||||
`int`. It is used to identify a specific object in the database based
|
||||
on its primary key value.
|
||||
|
||||
Returns:
|
||||
|
||||
The `get` method is returning an instance of the `ModelType` class.
|
||||
"""
|
||||
if obj := self.get(pk, db=db):
|
||||
print(obj)
|
||||
return obj
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"{self.model} with ID: {pk} Not Found",
|
||||
)
|
||||
self.__raise_not_found(f"{self.model.__name__} with id {pk} not found")
|
||||
|
||||
def get_by_ids(self, ids: list[int], db: Session = None) -> list[ModelType]:
|
||||
"""
|
||||
@ -72,7 +103,13 @@ class CRUDManager:
|
||||
query = select(self.model).where(self.model.id.in_(ids))
|
||||
return self.db.exec(query).all()
|
||||
|
||||
def get_by_field(self, field: str, value: str, db: Session = None) -> ModelType:
|
||||
def get_by_field(
|
||||
self,
|
||||
field: str,
|
||||
value: str,
|
||||
allows_multiple: bool = False,
|
||||
db: Session = None,
|
||||
) -> ModelType:
|
||||
"""
|
||||
The function retrieves a model object from the database based on a
|
||||
field and a value.
|
||||
@ -90,13 +127,83 @@ class CRUDManager:
|
||||
The `get_by_field` method is returning an object of type `ModelType`.
|
||||
"""
|
||||
self.db = db or self.db
|
||||
self.__validate_field_exists(field)
|
||||
|
||||
query = select(self.model).where(getattr(self.model, field) == value)
|
||||
if obj := self.db.exec(query).one_or_none():
|
||||
if allows_multiple:
|
||||
return self.db.exec(query).all()
|
||||
return self.db.exec(query).one_or_none()
|
||||
|
||||
def get_by_fields(
|
||||
self,
|
||||
fields: dict[str, str],
|
||||
*,
|
||||
allows_multiple: bool = False,
|
||||
db: Session = None,
|
||||
) -> list[ModelType] | ModelType:
|
||||
"""
|
||||
The function retrieves a list of model objects from the database based
|
||||
on a dictionary of fields and values.
|
||||
|
||||
Arguments:
|
||||
|
||||
* `fields`: The parameter `fields` is a dictionary of strings that
|
||||
represents the name of a field in the database table and the value of
|
||||
that field.
|
||||
|
||||
Returns:
|
||||
|
||||
The `get_by_fields` method is returning a list of objects of type
|
||||
`ModelType`.
|
||||
"""
|
||||
self.db = db or self.db
|
||||
query = select(self.model)
|
||||
for field, value in fields.items():
|
||||
self.__validate_field_exists(field)
|
||||
query = query.where(getattr(self.model, field) == value)
|
||||
if allows_multiple:
|
||||
return self.db.exec(query).all()
|
||||
return self.db.exec(query).one_or_none()
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
object: ModelCreateType,
|
||||
search_field: str = "id",
|
||||
db: Session = None,
|
||||
) -> ModelType:
|
||||
"""
|
||||
The function `get_or_create` checks if an object exists in the database based
|
||||
on a specified search field, and either returns the object if it exists
|
||||
or creates a new object if it doesn't.
|
||||
|
||||
Arguments:
|
||||
|
||||
* `object`: The `object` parameter is the object that you want to get or
|
||||
create in the database. It should be of type `ModelCreateType`, which is the
|
||||
type of the object that you want to create.
|
||||
* `search_field`: The `search_field` parameter is a string that specifies the
|
||||
field to search for when checking if an object already exists in the database.
|
||||
By default, it is set to "id", meaning it will search for an object with the
|
||||
same id as the one being passed in.
|
||||
* `db`: The `db` parameter is an optional parameter of type `Session`.
|
||||
It represents the database session that will be used for the database operations
|
||||
If no session is provided, it will use the default session.
|
||||
|
||||
Returns:
|
||||
|
||||
The function `get_or_create` returns an instance of `ModelType`.
|
||||
"""
|
||||
self.db = db or self.db
|
||||
self.__validate_field_exists(search_field)
|
||||
|
||||
if obj := self.get_by_field(
|
||||
search_field,
|
||||
getattr(object, search_field),
|
||||
db=db,
|
||||
):
|
||||
return obj
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"{self.model} with {field}: {value} Not Found",
|
||||
)
|
||||
else:
|
||||
return self.create(object, db=db)
|
||||
|
||||
def list(self, query: QueryLike = None, db: Session = None) -> list[ModelType]:
|
||||
"""
|
||||
@ -138,6 +245,47 @@ class CRUDManager:
|
||||
self.db.refresh(obj)
|
||||
return obj
|
||||
|
||||
def create_or_update(
|
||||
self,
|
||||
object: ModelCreateType,
|
||||
search_field: str = "id",
|
||||
db: Session = None,
|
||||
) -> ModelType:
|
||||
"""
|
||||
The function `create_or_update` checks if an object exists in the database based
|
||||
on a specified search field, and either updates the object if it exists
|
||||
or creates a new object if it doesn't.
|
||||
|
||||
Arguments:
|
||||
|
||||
* `object`: The `object` parameter is the object that you want to create or
|
||||
update in the database. It should be of type `ModelCreateType`, which is the
|
||||
type of the object that you want to create.
|
||||
* `search_field`: The `search_field` parameter is a string that specifies the
|
||||
field to search for when checking if an object already exists in the database.
|
||||
By default, it is set to "id", meaning it will search for an object with the
|
||||
same id as the one being passed in.
|
||||
* `db`: The `db` parameter is an optional parameter of type `Session`.
|
||||
It represents the database session that will be used for the database operations
|
||||
If no session is provided, it will use the default session.
|
||||
|
||||
Returns:
|
||||
|
||||
The function `create_or_update` returns an instance of `ModelType`.
|
||||
"""
|
||||
self.db = db or self.db
|
||||
self.__validate_field_exists(search_field)
|
||||
|
||||
if obj := self.get_by_field(
|
||||
search_field,
|
||||
getattr(object, search_field),
|
||||
db=db,
|
||||
):
|
||||
object.id = obj.id
|
||||
return self.update(object, db=db)
|
||||
else:
|
||||
return self.create(object, db=db)
|
||||
|
||||
def update(self, input_object: ModelType, db: Session = None) -> ModelType:
|
||||
"""
|
||||
The function updates a database object with the values from an input
|
||||
@ -181,7 +329,7 @@ class CRUDManager:
|
||||
the database.
|
||||
"""
|
||||
self.db = db or self.db
|
||||
db_object = self.get(pk)
|
||||
db_object = self.get_or_404(pk)
|
||||
self.db.delete(db_object)
|
||||
self.db.commit()
|
||||
return db_object
|
||||
|
35
sqlmodel_crud_manager/decorator.py
Normal file
35
sqlmodel_crud_manager/decorator.py
Normal file
@ -0,0 +1,35 @@
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def raise_as_http_exception(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if e.__class__.__name__ == "HTTPException":
|
||||
raise e
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def raise_404_if_none(func, detail="Not Found"):
|
||||
def wrapper(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail=detail)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def for_all_methods(decorator):
|
||||
def decorate(cls):
|
||||
for attr in cls.__dict__: # there's propably a better way to do this
|
||||
if callable(getattr(cls, attr)):
|
||||
setattr(cls, attr, decorator(getattr(cls, attr)))
|
||||
return cls
|
||||
|
||||
return decorate
|
@ -49,6 +49,18 @@ def test_get(last_id):
|
||||
assert hero.is_alive is True
|
||||
|
||||
|
||||
def test_get_or_404(last_id):
|
||||
hero = crud.get_or_404(last_id)
|
||||
assert hero.id == last_id
|
||||
assert hero.name == HERO_NAME
|
||||
assert hero.secret_name == SECRET_NAME
|
||||
assert hero.age is None
|
||||
assert hero.is_alive is True
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
crud.get_or_404(last_id + 1)
|
||||
|
||||
|
||||
def test_get_by_ids(last_id):
|
||||
heroes = crud.get_by_ids([last_id])
|
||||
assert len(heroes) == 1
|
||||
@ -68,10 +80,39 @@ def test_get_by_field(last_id):
|
||||
assert heroes.is_alive is True
|
||||
|
||||
|
||||
def test_get_by_fields(last_id):
|
||||
heroes = crud.get_by_fields({"name": HERO_NAME, "secret_name": SECRET_NAME})
|
||||
assert heroes.id == last_id
|
||||
assert heroes.name == HERO_NAME
|
||||
assert heroes.secret_name == SECRET_NAME
|
||||
assert heroes.age is None
|
||||
assert heroes.is_alive is True
|
||||
|
||||
|
||||
def test_get_or_create(last_id):
|
||||
hero = HeroCreate(name=HERO_NAME, secret_name=SECRET_NAME)
|
||||
hero = crud.get_or_create(hero, search_field="name")
|
||||
assert hero.id is not None
|
||||
assert hero.id == last_id
|
||||
assert hero.name == HERO_NAME
|
||||
|
||||
hero = HeroCreate(name="NotExistent", secret_name=SECRET_NAME)
|
||||
hero = crud.get_or_create(hero, search_field="name")
|
||||
assert hero.id is not None
|
||||
assert hero.id != last_id
|
||||
assert hero.name == "NotExistent"
|
||||
|
||||
hero = HeroCreate(name="NotExistentV2", secret_name=SECRET_NAME)
|
||||
with pytest.raises(HTTPException):
|
||||
crud.get_or_create(hero, search_field="id")
|
||||
with pytest.raises(HTTPException):
|
||||
crud.get_or_create(hero, search_field="NotExistent")
|
||||
|
||||
|
||||
def test_list(last_id):
|
||||
heroes = crud.list()
|
||||
assert len(heroes) == 1
|
||||
assert heroes[0].id == last_id
|
||||
assert len(heroes) == 2
|
||||
assert heroes[-1].id == last_id
|
||||
assert heroes[0].name == HERO_NAME
|
||||
assert heroes[0].secret_name == SECRET_NAME
|
||||
assert heroes[0].age is None
|
||||
@ -98,4 +139,4 @@ def test_delete(last_id):
|
||||
assert hero.age == 30
|
||||
assert hero.is_alive is True
|
||||
with pytest.raises(HTTPException):
|
||||
crud.get(last_id)
|
||||
crud.get_or_404(last_id)
|
||||
|
Reference in New Issue
Block a user