diff --git a/backend/utils/serializers.py b/backend/utils/serializers.py index 8a5d805..63ffe61 100644 --- a/backend/utils/serializers.py +++ b/backend/utils/serializers.py @@ -3,9 +3,10 @@ from decimal import Decimal from typing import Any, Sequence, TypeVar -import msgspec - +from fastapi.encoders import decimal_encoder +from msgspec import json from sqlalchemy import Row, RowMapping +from sqlalchemy.orm import ColumnProperty, SynonymProperty, class_mapper from starlette.responses import JSONResponse RowData = Row | RowMapping | Any @@ -20,15 +21,13 @@ def select_columns_serialize(row: R) -> dict: :param row: :return: """ - obj_dict = {} + result = {} for column in row.__table__.columns.keys(): - val = getattr(row, column) - if isinstance(val, Decimal): - if val % 1 == 0: - val = int(val) - val = float(val) - obj_dict[column] = val - return obj_dict + v = getattr(row, column) + if isinstance(v, Decimal): + v = decimal_encoder(v) + result[column] = v + return result def select_list_serialize(row: Sequence[R]) -> list: @@ -38,22 +37,35 @@ def select_list_serialize(row: Sequence[R]) -> list: :param row: :return: """ - ret_list = [select_columns_serialize(_) for _ in row] - return ret_list + result = [select_columns_serialize(_) for _ in row] + return result -def select_as_dict(row: R) -> dict: +def select_as_dict(row: R, use_alias: bool = False) -> dict: """ Converting SQLAlchemy select to dict, which can contain relational data, depends on the properties of the select object itself + If set use_alias is True, the column name will be returned as alias, + If alias doesn't exist in columns, we don't recommend setting it to True + :param row: + :param use_alias: :return: """ - obj_dict = row.__dict__ - if '_sa_instance_state' in obj_dict: - del obj_dict['_sa_instance_state'] - return obj_dict + if not use_alias: + result = row.__dict__ + if '_sa_instance_state' in result: + del result['_sa_instance_state'] + return result + else: + result = {} + mapper = class_mapper(row.__class__) + for prop in mapper.iterate_properties: + if isinstance(prop, (ColumnProperty, SynonymProperty)): + key = prop.key + result[key] = getattr(row, key) + return result class MsgSpecJSONResponse(JSONResponse): @@ -62,4 +74,4 @@ class MsgSpecJSONResponse(JSONResponse): """ def render(self, content: Any) -> bytes: - return msgspec.json.encode(content) + return json.encode(content)