Optimize the serialization of join query results (#1058)

* Optimize the serialization of join query results

* Update some naming
This commit is contained in:
Wu Clan
2026-02-04 20:20:20 +08:00
committed by GitHub
parent 220de51aad
commit ab1f60120a

View File

@@ -9,6 +9,8 @@ from sqlalchemy import Row, RowMapping
from sqlalchemy.orm import ColumnProperty, SynonymProperty, class_mapper
from starlette.responses import JSONResponse
from backend.common.log import log
RowData = Row[Any] | RowMapping | Any
R = TypeVar('R', bound=RowData)
@@ -79,68 +81,74 @@ def select_join_serialize( # noqa: C901
return_as_dict: bool = False,
) -> dict[str, Any] | list[dict[str, Any]] | tuple[Any, ...] | list[tuple[Any, ...]] | None:
"""
将 SQLAlchemy 连接查询结果序列化为字典或支持属性访问的 namedtuple
将 SQLAlchemy 连接查询结果序列化为字典或 namedtuple
扁平序列化``relationships=None``
| 将所有查询结果平铺到同一层级,不进行嵌套处理
输出Result(name='Alice', dept=Dept(...))
扁平序列化relationships=None
所有结果平铺到同一层级,不嵌套
Result(name='Alice', dept=Dept(...))
嵌套序列化:``relationships=['User-m2o-Dept', 'User-m2m-Role:permissions', 'Role-m2m-Menu']``
| 根据指定的关系类型将数据嵌套组织,支持层级结构
| row = select(User, Dept, Role).join(...).all()
输出Result(name='Alice', dept=Dept(...), permissions=[Role(..., menus=[Menu(...)])])
嵌套序列化:
根据关系类型嵌套组织数据,支持层级结构
relationships=['User-m2o-Dept', 'User-m2m-Role:permissions']
输出Result(name='Alice', dept=Dept(...), permissions=[Role(...)])
关系格式source_model-type-target_model[:custom_name]
- type: o2m(一对多), m2o(多对一), o2o(一对一), m2m(多对多)
- o2m/m2m: 目标字段名自动加 's' 复数化
- m2o/o2o: 目标字段名保持单数
- custom_name: 自定义目标字段名
:param row: SQLAlchemy 查询结果
:param relationships: 表之间的虚拟关系
source_model_class-type-target_model_class[:custom_name], type: o2m/m2o/o2o/m2m
- o2m (一对多): 目标模型类名会自动添加's'变为复数形式 (如: dept->depts)
- m2o (多对一): 目标模型类名保持单数形式 (如: user->user)
- o2o (一对一): 目标模型类名保持单数形式 (如: profile->profile)
- m2m (多对多): 目标模型类名会自动添加's'变为复数形式 (如: role->roles)
- 自定义名称: 可以通过在关系字符串末尾添加 ':custom_name' 来指定自定义的目标字段名
例如: 'User-m2m-Role:permissions' 会将角色数据放在 'permissions' 字段而不是默认的 'roles'
:param return_as_dict: False 返回 namedtupleTrue 返回 dict
:param relationships: 关系定义列表
:param return_as_dict: True 返回字典False 返回 namedtuple
:return:
"""
list_relationship_types = {'o2m', 'm2m'}
all_relationship_types = {'o2m', 'm2o', 'o2o', 'm2m'}
def get_relation_key(model_name: str, rel_type: str, custom_field: str | None = None) -> str:
"""获取关系键名"""
return custom_field or (model_name if rel_type in ('o2o', 'm2o') else f'{model_name}s')
def get_obj_id(target_obj: Any) -> int | str:
return getattr(target_obj, 'id', None) or id(target_obj)
def extract_row_elements(row_data: Any) -> tuple:
return row_data if hasattr(row_data, '__getitem__') else (row_data,)
def get_relationship_key(model: str, relationship_type: str, custom_field: str | None) -> str:
return custom_field or (model if relationship_type not in list_relationship_types else f'{model}s')
def parse_relationships(relationship_list: list[str]) -> tuple[dict, dict, dict]:
"""解析关系定义"""
if not relationship_list:
return {}, {}, {}
parsed_relation_graph = defaultdict(dict)
parsed_reverse_relation = {}
parsed_custom_names = {}
graph = defaultdict(dict)
reverse = {}
customs = {}
for rel_str in relationship_list:
parts = rel_str.split(':', 1)
rel_part = parts[0].strip()
field_custom_name = parts[1].strip() if len(parts) > 1 else None
custom_name = parts[1].strip() if len(parts) > 1 else None
rel_info = rel_part.split('-')
if len(rel_info) != 3:
info = rel_part.split('-')
if len(info) != 3:
log.warning(f'Invalid relationship: "{rel_str}", expected "source-type-target[:custom]"')
continue
source_model, rel_type, target_model = (info.lower() for info in rel_info)
if rel_type not in ('o2m', 'm2o', 'o2o', 'm2m'):
src, parsed_type, dst = (x.lower() for x in info)
if parsed_type not in all_relationship_types:
log.warning(
f'Invalid relationship type: "{parsed_type}" in "{rel_str}", '
f'must be one of: {", ".join(all_relationship_types)}'
)
continue
parsed_relation_graph[source_model][target_model] = rel_type
parsed_reverse_relation[target_model] = source_model
if field_custom_name:
parsed_custom_names[source_model, target_model] = field_custom_name
graph[src][dst] = parsed_type
reverse[dst] = src
if custom_name:
customs[src, dst] = custom_name
return parsed_relation_graph, parsed_reverse_relation, parsed_custom_names
return graph, reverse, customs
def get_model_columns(model_obj: Any) -> list[str]:
"""获取模型列名"""
mapper = class_mapper(type(model_obj))
return [
prop.key
@@ -148,17 +156,25 @@ def select_join_serialize( # noqa: C901
if isinstance(prop, (ColumnProperty, SynonymProperty)) and hasattr(model_obj, prop.key)
]
def get_unique_objects(objs: list[Any], key_attr: str = 'id') -> list[Any]:
"""根据键属性去重对象列表"""
def dedupe_objects(obj_list: list[Any]) -> list[Any]:
seen = set()
unique = []
for item in objs:
item_id = getattr(item, key_attr, None)
for item in obj_list:
item_id = getattr(item, 'id', None)
if item_id is not None and item_id not in seen:
seen.add(item_id)
unique.append(item)
return unique
def build_namedtuple(name: str, data: dict) -> Any:
if return_as_dict or name not in namedtuple_cache:
return None
for field in namedtuple_cache[name]._fields:
if field not in data:
data[field] = None
return namedtuple_cache[name](**data)
# 输入验证
if not row:
return None
@@ -166,211 +182,207 @@ def select_join_serialize( # noqa: C901
if not rows_list:
return None
# 获取主对象信息
first_row = rows_list[0]
main_obj = first_row[0] if hasattr(first_row, '__getitem__') and first_row else first_row
if main_obj is None:
# 主对象信息
first_row = extract_row_elements(rows_list[0])
primary_obj = first_row[0]
if primary_obj is None:
return None
main_obj_name = type(main_obj).__name__.lower()
main_columns = get_model_columns(main_obj)
primary_obj_name = type(primary_obj).__name__.lower()
primary_columns = get_model_columns(primary_obj)
# 解析关系
# 关系解析
relation_graph, reverse_relation, custom_names = parse_relationships(relationships or [])
has_relationships = bool(relation_graph)
# 预处理所有模型类型和列信息
# 预处理模型信息
model_info = {}
cls_idxs = {}
cls_idx = {}
for preprocess_row in rows_list:
preprocess_row_items = preprocess_row if hasattr(preprocess_row, '__getitem__') else (preprocess_row,)
for idx, row_obj in enumerate(preprocess_row_items):
if row_obj is None:
for row_item in rows_list:
row_elements = extract_row_elements(row_item)
for idx, element in enumerate(row_elements):
if element is None:
continue
obj_class_name = type(row_obj).__name__.lower()
if obj_class_name not in model_info:
model_info[obj_class_name] = get_model_columns(row_obj)
if obj_class_name not in cls_idxs:
cls_idxs[obj_class_name] = idx
element_cls = type(element).__name__.lower()
if element_cls not in model_info:
model_info[element_cls] = get_model_columns(element)
if element_cls not in cls_idx:
cls_idx[element_cls] = idx
# 数据收集和分组
main_data = {}
grouped_data = defaultdict(lambda: defaultdict(list))
# 数据分组
main_objects = {}
children_objects = defaultdict(lambda: defaultdict(list))
for data_row in rows_list:
data_row_items = data_row if hasattr(data_row, '__getitem__') else (data_row,)
if not data_row_items or data_row_items[0] is None:
for row_item in rows_list:
row_elements = extract_row_elements(row_item)
if not row_elements or row_elements[0] is None:
continue
main_obj = data_row_items[0]
main_id = getattr(main_obj, 'id', None) or id(main_obj)
main_obj = row_elements[0]
main_id = get_obj_id(main_obj)
if main_id not in main_data:
main_data[main_id] = main_obj
if main_id not in main_objects:
main_objects[main_id] = main_obj
# 收集子对象
for child_obj in data_row_items[1:]:
for child_obj in row_elements[1:]:
if child_obj is None:
continue
child_class_name = type(child_obj).__name__.lower()
grouped_data[main_id][child_class_name].append(child_obj)
child_type = type(child_obj).__name__.lower()
children_objects[main_id][child_type].append(child_obj)
if not main_data:
if not main_objects:
return None
# 预生成 namedtuple 类型
# namedtuple 类型预生成
namedtuple_cache = {}
if not return_as_dict:
for cls_name, columns in model_info.items():
if columns:
# 为嵌套关系预计算完整字段列表
full_columns = columns.copy()
if has_relationships:
for target_class, relation_type in relation_graph.get(cls_name, {}).items():
field_name = custom_names.get((cls_name, target_class))
rel_key = get_relation_key(target_class, relation_type, field_name)
full_columns.append(rel_key)
full_columns = sorted(set(full_columns)) # 去重并排序
namedtuple_cache[cls_name] = namedtuple(cls_name.capitalize(), full_columns or columns) # noqa: PYI024
def build_flat_result(build_main_id: int, build_main_obj: Any) -> dict[str, Any]: # noqa: C901
"""构建扁平化结果"""
flat_result = {col: getattr(build_main_obj, col, None) for col in main_columns}
for class_name in sorted(grouped_data[build_main_id]):
if class_name == main_obj_name:
for model_name, model_columns in model_info.items():
if not model_columns:
continue
flat_objs = get_unique_objects(grouped_data[build_main_id][class_name])
cls_columns = model_info.get(class_name, [])
field_list = model_columns.copy()
if has_relationships:
for target, target_rtype in relation_graph.get(model_name, {}).items():
nt_key = get_relationship_key(target, target_rtype, custom_names.get((model_name, target)))
field_list.append(nt_key)
field_list = list(dict.fromkeys(field_list))
if not flat_objs:
flat_result[class_name] = []
elif len(flat_objs) == 1:
obj_data = {col: getattr(flat_objs[0], col, None) for col in cls_columns}
# 确保 namedtuple 所需的所有字段都存在
if not return_as_dict and class_name in namedtuple_cache:
nt_fields = getattr(namedtuple_cache[class_name], '_fields', [])
for field in nt_fields:
if field not in obj_data:
obj_data[field] = None
flat_result[class_name] = obj_data if return_as_dict else namedtuple_cache[class_name](**obj_data)
namedtuple_cache[model_name] = namedtuple(model_name.capitalize(), field_list) # noqa: PYI024
# 嵌套关系层级结构(一次性构建)
hierarchy = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
if has_relationships:
for row_item in rows_list:
row_elements = extract_row_elements(row_item)
if not row_elements or row_elements[0] is None:
continue
main_id = get_obj_id(row_elements[0])
m_type_name = type(row_elements[0]).__name__.lower()
for idx, rel_obj in enumerate(row_elements[1:], 1): # noqa: B007
if rel_obj is None:
continue
rel_type_name = type(rel_obj).__name__.lower()
if rel_type_name in reverse_relation:
parent_type = reverse_relation[rel_type_name]
parent_idx = cls_idx.get(parent_type)
parent = (
row_elements[parent_idx] if parent_idx is not None and parent_idx < len(row_elements) else None
)
elif rel_type_name in relation_graph.get(m_type_name, {}):
parent = row_elements[0]
else:
continue
if parent is None:
continue
parent_pk = getattr(parent, 'id', None)
if parent_pk is not None:
hierarchy[main_id][rel_type_name][parent_pk].append(rel_obj)
# 结果构建函数
def build_flat(target_id: int, target_obj: Any) -> dict[str, Any]:
result = {col: getattr(target_obj, col, None) for col in primary_columns}
for cls_type in children_objects[target_id]:
if cls_type == primary_obj_name:
continue
unique_children = dedupe_objects(children_objects[target_id][cls_type])
child_columns = model_info.get(cls_type, [])
count = len(unique_children)
field_key = cls_type if count <= 1 else f'{cls_type}s'
if count == 0:
result[field_key] = []
elif count == 1:
obj_data = {col: getattr(unique_children[0], col, None) for col in child_columns}
result[field_key] = obj_data if return_as_dict else build_namedtuple(cls_type, obj_data)
else:
if return_as_dict:
flat_result[class_name] = [
{col: getattr(flat_obj, col, None) for col in cls_columns} for flat_obj in flat_objs
]
result[field_key] = [{col: getattr(c, col, None) for col in child_columns} for c in unique_children]
else:
nested_result_list = []
for nested_obj in flat_objs:
obj_data = {col: getattr(nested_obj, col, None) for col in cls_columns}
# 确保 namedtuple 所需的所有字段都存在
if class_name in namedtuple_cache:
nt_fields = getattr(namedtuple_cache[class_name], '_fields', [])
for field in nt_fields:
if field not in obj_data:
obj_data[field] = None
nested_result_list.append(namedtuple_cache[class_name](**obj_data))
flat_result[class_name] = nested_result_list
result[field_key] = [
build_namedtuple(cls_type, {col: getattr(c, col, None) for col in child_columns})
for c in unique_children
]
return flat_result
return result
def build_nested_result(nested_main_id: int, nested_main_obj: Any) -> dict[str, Any]: # noqa: C901
"""构建嵌套化结果"""
nested_result = {col: getattr(nested_main_obj, col, None) for col in main_columns}
def build_nested(target_id: int, target_obj: Any) -> dict[str, Any]:
result = {col: getattr(target_obj, col, None) for col in primary_columns}
current_hierarchy = hierarchy.get(target_id, defaultdict(lambda: defaultdict(list)))
# 构建关系层级数据结构
hierarchy = defaultdict(lambda: defaultdict(list))
for iter_row in rows_list:
iter_row_items = iter_row if hasattr(iter_row, '__getitem__') else (iter_row,)
if not iter_row_items or iter_row_items[0] is None:
continue
iter_main_id = getattr(iter_row_items[0], 'id', None) or id(iter_row_items[0])
if iter_main_id != nested_main_id:
continue
for _i, related_obj in enumerate(iter_row_items[1:], 1):
if related_obj is None:
continue
related_class_name = type(related_obj).__name__.lower()
if related_class_name in reverse_relation:
parent_cls = reverse_relation[related_class_name]
parent_idx = cls_idxs.get(parent_cls, 0)
if parent_idx < len(iter_row_items):
parent_obj = iter_row_items[parent_idx]
if parent_obj is not None:
parent_obj_id = getattr(parent_obj, 'id', None)
if parent_obj_id is not None:
hierarchy[related_class_name][parent_obj_id].append(related_obj)
def build_recursive(current_cls_name: str, current_parent_id: int) -> list:
"""递归构建嵌套数据"""
recursive_objs = get_unique_objects(hierarchy[current_cls_name].get(current_parent_id, []))
if not recursive_objs:
def recursive_build(cls_name: str, pk: int) -> list:
nested_dict = current_hierarchy.get(cls_name)
if nested_dict is None:
return []
objs = dedupe_objects(nested_dict.get(pk, []))
if not objs:
return []
recursive_result = []
for nested_obj in recursive_objs:
# 基础数据
obj_data = {col: getattr(nested_obj, col, None) for col in model_info[current_cls_name]}
output = []
for item in objs:
item_data = {col: getattr(item, col, None) for col in model_info[cls_name]}
# 处理子关系
for child_cls, child_rel_type in relation_graph.get(current_cls_name, {}).items():
child_parent_id = getattr(nested_obj, 'id', None)
if child_parent_id is None:
for sub_type, sub_rel_type in relation_graph.get(cls_name, {}).items():
sub_pk = getattr(item, 'id', None)
if sub_pk is None:
continue
child_list = build_recursive(child_cls, child_parent_id)
child_key = get_relation_key(
child_cls, child_rel_type, custom_names.get((current_cls_name, child_cls))
)
sub_list = recursive_build(sub_type, sub_pk)
sub_key = get_relationship_key(sub_type, sub_rel_type, custom_names.get((cls_name, sub_type)))
if child_rel_type in ('m2o', 'o2o'):
obj_data[child_key] = child_list[0] if child_list else None
if sub_rel_type not in list_relationship_types:
item_data[sub_key] = sub_list[0] if sub_list else None
else:
obj_data[child_key] = child_list
item_data[sub_key] = sub_list
if not return_as_dict and current_cls_name in namedtuple_cache:
nt_fields = getattr(namedtuple_cache[current_cls_name], '_fields', [])
for field in nt_fields:
if field not in obj_data:
obj_data[field] = None
output.append(item_data if return_as_dict else build_namedtuple(cls_name, item_data))
recursive_result.append(obj_data if return_as_dict else namedtuple_cache[current_cls_name](**obj_data))
return output
return recursive_result
for top_type, top_rtype in relation_graph.get(primary_obj_name, {}).items():
instances = recursive_build(top_type, target_id)
top_key = get_relationship_key(top_type, top_rtype, custom_names.get((primary_obj_name, top_type)))
# 构建顶级关系
for top_cls_name, top_rel_type in relation_graph.get(main_obj_name, {}).items():
instances = build_recursive(top_cls_name, nested_main_id)
key = get_relation_key(top_cls_name, top_rel_type, custom_names.get((main_obj_name, top_cls_name)))
if top_rel_type in ('m2o', 'o2o'):
nested_result[key] = instances[0] if instances else None
if top_rtype not in list_relationship_types:
result[top_key] = instances[0] if instances else None
else:
nested_result[key] = instances
result[top_key] = instances
return nested_result
return result
# 构建最终结果
final_result_list = []
for current_main_id in sorted(main_data.keys()):
current_main_obj = main_data[current_main_id]
# 最终结果构建
final_results = []
processed_ids = set()
if has_relationships:
final_result_data = build_nested_result(current_main_id, current_main_obj)
else:
final_result_data = build_flat_result(current_main_id, current_main_obj)
for row_item in rows_list:
row_elements = extract_row_elements(row_item)
if not row_elements or row_elements[0] is None:
continue
main_obj = row_elements[0]
main_id = get_obj_id(main_obj)
if main_id not in main_objects or main_id in processed_ids:
continue
processed_ids.add(main_id)
result_data = build_nested(main_id, main_obj) if has_relationships else build_flat(main_id, main_obj)
if not return_as_dict:
all_fields = list(final_result_data.keys())
result_type = namedtuple('Result', all_fields) # noqa: PYI024
final_result_list.append(result_type(**final_result_data))
result_type = namedtuple('Result', result_data.keys()) # noqa: PYI024
final_results.append(result_type(**result_data))
else:
final_result_list.append(final_result_data)
final_results.append(result_data)
return final_result_list[0] if len(final_result_list) == 1 else final_result_list
return final_results[0] if len(final_results) == 1 else final_results