Add New functionalities

This commit is contained in:
Edkar Chachati
2024-01-24 20:48:49 -04:00
parent a87f364e94
commit cafa40957e
3 changed files with 239 additions and 15 deletions

View File

@ -6,12 +6,15 @@ from sqlalchemy.engine.base import Engine
from sqlmodel import Session, SQLModel, select
from sqlmodel.sql.expression import Select
from sqlmodel_crud_manager.decorator import for_all_methods, raise_as_http_exception
ModelType = TypeVar("ModelType", bound=SQLModel)
ModelCreateType = TypeVar("ModelCreateType", bound=SQLModel)
QueryLike = TypeVar("QueryLike", bound=Select)
@dataclass
@for_all_methods(raise_as_http_exception)
class CRUDManager:
model: ModelType
@ -28,6 +31,19 @@ class CRUDManager:
self.model = model
self.db = Session(engine)
def __validate_field_exists(self, field: str) -> None:
if field not in self.model.model_fields:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{self.model} does not have a {field} field",
)
def __raise_not_found(self, detail: str) -> None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def get(self, pk: int, db: Session = None) -> ModelType:
"""
The function retrieves a model object from the database based on its
@ -45,12 +61,27 @@ class CRUDManager:
"""
self.db = db or self.db
query = select(self.model).where(self.model.id == pk)
if obj := self.db.exec(query).one_or_none():
return self.db.exec(query).one_or_none()
def get_or_404(self, pk: int, db: Session = None) -> ModelType:
"""
The function retrieves a model object from the database based on its
primary key and raises an exception if the object is not found.
Arguments:
* `pk`: The parameter `pk` stands for "primary key" and it is of type
`int`. It is used to identify a specific object in the database based
on its primary key value.
Returns:
The `get` method is returning an instance of the `ModelType` class.
"""
if obj := self.get(pk, db=db):
print(obj)
return obj
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"{self.model} with ID: {pk} Not Found",
)
self.__raise_not_found(f"{self.model.__name__} with id {pk} not found")
def get_by_ids(self, ids: list[int], db: Session = None) -> list[ModelType]:
"""
@ -72,7 +103,13 @@ class CRUDManager:
query = select(self.model).where(self.model.id.in_(ids))
return self.db.exec(query).all()
def get_by_field(self, field: str, value: str, db: Session = None) -> ModelType:
def get_by_field(
self,
field: str,
value: str,
allows_multiple: bool = False,
db: Session = None,
) -> ModelType:
"""
The function retrieves a model object from the database based on a
field and a value.
@ -90,13 +127,83 @@ class CRUDManager:
The `get_by_field` method is returning an object of type `ModelType`.
"""
self.db = db or self.db
self.__validate_field_exists(field)
query = select(self.model).where(getattr(self.model, field) == value)
if obj := self.db.exec(query).one_or_none():
if allows_multiple:
return self.db.exec(query).all()
return self.db.exec(query).one_or_none()
def get_by_fields(
self,
fields: dict[str, str],
*,
allows_multiple: bool = False,
db: Session = None,
) -> list[ModelType] | ModelType:
"""
The function retrieves a list of model objects from the database based
on a dictionary of fields and values.
Arguments:
* `fields`: The parameter `fields` is a dictionary of strings that
represents the name of a field in the database table and the value of
that field.
Returns:
The `get_by_fields` method is returning a list of objects of type
`ModelType`.
"""
self.db = db or self.db
query = select(self.model)
for field, value in fields.items():
self.__validate_field_exists(field)
query = query.where(getattr(self.model, field) == value)
if allows_multiple:
return self.db.exec(query).all()
return self.db.exec(query).one_or_none()
def get_or_create(
self,
object: ModelCreateType,
search_field: str = "id",
db: Session = None,
) -> ModelType:
"""
The function `get_or_create` checks if an object exists in the database based
on a specified search field, and either returns the object if it exists
or creates a new object if it doesn't.
Arguments:
* `object`: The `object` parameter is the object that you want to get or
create in the database. It should be of type `ModelCreateType`, which is the
type of the object that you want to create.
* `search_field`: The `search_field` parameter is a string that specifies the
field to search for when checking if an object already exists in the database.
By default, it is set to "id", meaning it will search for an object with the
same id as the one being passed in.
* `db`: The `db` parameter is an optional parameter of type `Session`.
It represents the database session that will be used for the database operations
If no session is provided, it will use the default session.
Returns:
The function `get_or_create` returns an instance of `ModelType`.
"""
self.db = db or self.db
self.__validate_field_exists(search_field)
if obj := self.get_by_field(
search_field,
getattr(object, search_field),
db=db,
):
return obj
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"{self.model} with {field}: {value} Not Found",
)
else:
return self.create(object, db=db)
def list(self, query: QueryLike = None, db: Session = None) -> list[ModelType]:
"""
@ -138,6 +245,47 @@ class CRUDManager:
self.db.refresh(obj)
return obj
def create_or_update(
self,
object: ModelCreateType,
search_field: str = "id",
db: Session = None,
) -> ModelType:
"""
The function `create_or_update` checks if an object exists in the database based
on a specified search field, and either updates the object if it exists
or creates a new object if it doesn't.
Arguments:
* `object`: The `object` parameter is the object that you want to create or
update in the database. It should be of type `ModelCreateType`, which is the
type of the object that you want to create.
* `search_field`: The `search_field` parameter is a string that specifies the
field to search for when checking if an object already exists in the database.
By default, it is set to "id", meaning it will search for an object with the
same id as the one being passed in.
* `db`: The `db` parameter is an optional parameter of type `Session`.
It represents the database session that will be used for the database operations
If no session is provided, it will use the default session.
Returns:
The function `create_or_update` returns an instance of `ModelType`.
"""
self.db = db or self.db
self.__validate_field_exists(search_field)
if obj := self.get_by_field(
search_field,
getattr(object, search_field),
db=db,
):
object.id = obj.id
return self.update(object, db=db)
else:
return self.create(object, db=db)
def update(self, input_object: ModelType, db: Session = None) -> ModelType:
"""
The function updates a database object with the values from an input
@ -181,7 +329,7 @@ class CRUDManager:
the database.
"""
self.db = db or self.db
db_object = self.get(pk)
db_object = self.get_or_404(pk)
self.db.delete(db_object)
self.db.commit()
return db_object

View File

@ -0,0 +1,35 @@
from fastapi import HTTPException
def raise_as_http_exception(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if e.__class__.__name__ == "HTTPException":
raise e
else:
raise HTTPException(status_code=500, detail=str(e)) from e
return wrapper
def raise_404_if_none(func, detail="Not Found"):
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if result is None:
raise HTTPException(status_code=404, detail=detail)
return result
return wrapper
def for_all_methods(decorator):
def decorate(cls):
for attr in cls.__dict__: # there's propably a better way to do this
if callable(getattr(cls, attr)):
setattr(cls, attr, decorator(getattr(cls, attr)))
return cls
return decorate

View File

@ -49,6 +49,18 @@ def test_get(last_id):
assert hero.is_alive is True
def test_get_or_404(last_id):
hero = crud.get_or_404(last_id)
assert hero.id == last_id
assert hero.name == HERO_NAME
assert hero.secret_name == SECRET_NAME
assert hero.age is None
assert hero.is_alive is True
with pytest.raises(HTTPException):
crud.get_or_404(last_id + 1)
def test_get_by_ids(last_id):
heroes = crud.get_by_ids([last_id])
assert len(heroes) == 1
@ -68,10 +80,39 @@ def test_get_by_field(last_id):
assert heroes.is_alive is True
def test_get_by_fields(last_id):
heroes = crud.get_by_fields({"name": HERO_NAME, "secret_name": SECRET_NAME})
assert heroes.id == last_id
assert heroes.name == HERO_NAME
assert heroes.secret_name == SECRET_NAME
assert heroes.age is None
assert heroes.is_alive is True
def test_get_or_create(last_id):
hero = HeroCreate(name=HERO_NAME, secret_name=SECRET_NAME)
hero = crud.get_or_create(hero, search_field="name")
assert hero.id is not None
assert hero.id == last_id
assert hero.name == HERO_NAME
hero = HeroCreate(name="NotExistent", secret_name=SECRET_NAME)
hero = crud.get_or_create(hero, search_field="name")
assert hero.id is not None
assert hero.id != last_id
assert hero.name == "NotExistent"
hero = HeroCreate(name="NotExistentV2", secret_name=SECRET_NAME)
with pytest.raises(HTTPException):
crud.get_or_create(hero, search_field="id")
with pytest.raises(HTTPException):
crud.get_or_create(hero, search_field="NotExistent")
def test_list(last_id):
heroes = crud.list()
assert len(heroes) == 1
assert heroes[0].id == last_id
assert len(heroes) == 2
assert heroes[-1].id == last_id
assert heroes[0].name == HERO_NAME
assert heroes[0].secret_name == SECRET_NAME
assert heroes[0].age is None
@ -98,4 +139,4 @@ def test_delete(last_id):
assert hero.age == 30
assert hero.is_alive is True
with pytest.raises(HTTPException):
crud.get(last_id)
crud.get_or_404(last_id)