From ab1f60120a2a48121a48047ffe1cebf58d8ae806 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Wed, 4 Feb 2026 20:20:20 +0800 Subject: [PATCH] Optimize the serialization of join query results (#1058) * Optimize the serialization of join query results * Update some naming --- backend/utils/serializers.py | 412 ++++++++++++++++++----------------- 1 file changed, 212 insertions(+), 200 deletions(-) diff --git a/backend/utils/serializers.py b/backend/utils/serializers.py index 20bcb656..d539debb 100644 --- a/backend/utils/serializers.py +++ b/backend/utils/serializers.py @@ -9,6 +9,8 @@ from sqlalchemy import Row, RowMapping from sqlalchemy.orm import ColumnProperty, SynonymProperty, class_mapper from starlette.responses import JSONResponse +from backend.common.log import log + RowData = Row[Any] | RowMapping | Any R = TypeVar('R', bound=RowData) @@ -79,68 +81,74 @@ def select_join_serialize( # noqa: C901 return_as_dict: bool = False, ) -> dict[str, Any] | list[dict[str, Any]] | tuple[Any, ...] | list[tuple[Any, ...]] | None: """ - 将 SQLAlchemy 连接查询结果序列化为字典或支持属性访问的 namedtuple + 将 SQLAlchemy 连接查询结果序列化为字典或 namedtuple - 扁平序列化:``relationships=None`` - | 将所有查询结果平铺到同一层级,不进行嵌套处理 - 输出:Result(name='Alice', dept=Dept(...)) + 扁平序列化(relationships=None): + 所有结果平铺到同一层级,不嵌套 + 例:Result(name='Alice', dept=Dept(...)) - 嵌套序列化:``relationships=['User-m2o-Dept', 'User-m2m-Role:permissions', 'Role-m2m-Menu']`` - | 根据指定的关系类型将数据嵌套组织,支持层级结构 - | row = select(User, Dept, Role).join(...).all() - 输出:Result(name='Alice', dept=Dept(...), permissions=[Role(..., menus=[Menu(...)])]) + 嵌套序列化: + 根据关系类型嵌套组织数据,支持层级结构 + 例:relationships=['User-m2o-Dept', 'User-m2m-Role:permissions'] + 输出:Result(name='Alice', dept=Dept(...), permissions=[Role(...)]) + + 关系格式:source_model-type-target_model[:custom_name] + - type: o2m(一对多), m2o(多对一), o2o(一对一), m2m(多对多) + - o2m/m2m: 目标字段名自动加 's' 复数化 + - m2o/o2o: 目标字段名保持单数 + - custom_name: 自定义目标字段名 :param row: SQLAlchemy 查询结果 - :param relationships: 表之间的虚拟关系 - - source_model_class-type-target_model_class[:custom_name], type: o2m/m2o/o2o/m2m - - - o2m (一对多): 目标模型类名会自动添加's'变为复数形式 (如: dept->depts) - - m2o (多对一): 目标模型类名保持单数形式 (如: user->user) - - o2o (一对一): 目标模型类名保持单数形式 (如: profile->profile) - - m2m (多对多): 目标模型类名会自动添加's'变为复数形式 (如: role->roles) - - 自定义名称: 可以通过在关系字符串末尾添加 ':custom_name' 来指定自定义的目标字段名 - 例如: 'User-m2m-Role:permissions' 会将角色数据放在 'permissions' 字段而不是默认的 'roles' - - :param return_as_dict: False 返回 namedtuple,True 返回 dict + :param relationships: 关系定义列表 + :param return_as_dict: True 返回字典,False 返回 namedtuple :return: """ + list_relationship_types = {'o2m', 'm2m'} + all_relationship_types = {'o2m', 'm2o', 'o2o', 'm2m'} - def get_relation_key(model_name: str, rel_type: str, custom_field: str | None = None) -> str: - """获取关系键名""" - return custom_field or (model_name if rel_type in ('o2o', 'm2o') else f'{model_name}s') + def get_obj_id(target_obj: Any) -> int | str: + return getattr(target_obj, 'id', None) or id(target_obj) + + def extract_row_elements(row_data: Any) -> tuple: + return row_data if hasattr(row_data, '__getitem__') else (row_data,) + + def get_relationship_key(model: str, relationship_type: str, custom_field: str | None) -> str: + return custom_field or (model if relationship_type not in list_relationship_types else f'{model}s') def parse_relationships(relationship_list: list[str]) -> tuple[dict, dict, dict]: - """解析关系定义""" if not relationship_list: return {}, {}, {} - parsed_relation_graph = defaultdict(dict) - parsed_reverse_relation = {} - parsed_custom_names = {} + graph = defaultdict(dict) + reverse = {} + customs = {} for rel_str in relationship_list: parts = rel_str.split(':', 1) rel_part = parts[0].strip() - field_custom_name = parts[1].strip() if len(parts) > 1 else None + custom_name = parts[1].strip() if len(parts) > 1 else None - rel_info = rel_part.split('-') - if len(rel_info) != 3: + info = rel_part.split('-') + if len(info) != 3: + log.warning(f'Invalid relationship: "{rel_str}", expected "source-type-target[:custom]"') continue - source_model, rel_type, target_model = (info.lower() for info in rel_info) - if rel_type not in ('o2m', 'm2o', 'o2o', 'm2m'): + src, parsed_type, dst = (x.lower() for x in info) + if parsed_type not in all_relationship_types: + log.warning( + f'Invalid relationship type: "{parsed_type}" in "{rel_str}", ' + f'must be one of: {", ".join(all_relationship_types)}' + ) continue - parsed_relation_graph[source_model][target_model] = rel_type - parsed_reverse_relation[target_model] = source_model - if field_custom_name: - parsed_custom_names[source_model, target_model] = field_custom_name + graph[src][dst] = parsed_type + reverse[dst] = src + if custom_name: + customs[src, dst] = custom_name - return parsed_relation_graph, parsed_reverse_relation, parsed_custom_names + return graph, reverse, customs def get_model_columns(model_obj: Any) -> list[str]: - """获取模型列名""" mapper = class_mapper(type(model_obj)) return [ prop.key @@ -148,17 +156,25 @@ def select_join_serialize( # noqa: C901 if isinstance(prop, (ColumnProperty, SynonymProperty)) and hasattr(model_obj, prop.key) ] - def get_unique_objects(objs: list[Any], key_attr: str = 'id') -> list[Any]: - """根据键属性去重对象列表""" + def dedupe_objects(obj_list: list[Any]) -> list[Any]: seen = set() unique = [] - for item in objs: - item_id = getattr(item, key_attr, None) + for item in obj_list: + item_id = getattr(item, 'id', None) if item_id is not None and item_id not in seen: seen.add(item_id) unique.append(item) return unique + def build_namedtuple(name: str, data: dict) -> Any: + if return_as_dict or name not in namedtuple_cache: + return None + for field in namedtuple_cache[name]._fields: + if field not in data: + data[field] = None + return namedtuple_cache[name](**data) + + # 输入验证 if not row: return None @@ -166,211 +182,207 @@ def select_join_serialize( # noqa: C901 if not rows_list: return None - # 获取主对象信息 - first_row = rows_list[0] - main_obj = first_row[0] if hasattr(first_row, '__getitem__') and first_row else first_row - if main_obj is None: + # 主对象信息 + first_row = extract_row_elements(rows_list[0]) + primary_obj = first_row[0] + if primary_obj is None: return None - main_obj_name = type(main_obj).__name__.lower() - main_columns = get_model_columns(main_obj) + primary_obj_name = type(primary_obj).__name__.lower() + primary_columns = get_model_columns(primary_obj) - # 解析关系 + # 关系解析 relation_graph, reverse_relation, custom_names = parse_relationships(relationships or []) has_relationships = bool(relation_graph) - # 预处理所有模型类型和列信息 + # 预处理模型信息 model_info = {} - cls_idxs = {} + cls_idx = {} - for preprocess_row in rows_list: - preprocess_row_items = preprocess_row if hasattr(preprocess_row, '__getitem__') else (preprocess_row,) - for idx, row_obj in enumerate(preprocess_row_items): - if row_obj is None: + for row_item in rows_list: + row_elements = extract_row_elements(row_item) + for idx, element in enumerate(row_elements): + if element is None: continue - obj_class_name = type(row_obj).__name__.lower() - if obj_class_name not in model_info: - model_info[obj_class_name] = get_model_columns(row_obj) - if obj_class_name not in cls_idxs: - cls_idxs[obj_class_name] = idx + element_cls = type(element).__name__.lower() + if element_cls not in model_info: + model_info[element_cls] = get_model_columns(element) + if element_cls not in cls_idx: + cls_idx[element_cls] = idx - # 数据收集和分组 - main_data = {} - grouped_data = defaultdict(lambda: defaultdict(list)) + # 数据分组 + main_objects = {} + children_objects = defaultdict(lambda: defaultdict(list)) - for data_row in rows_list: - data_row_items = data_row if hasattr(data_row, '__getitem__') else (data_row,) - if not data_row_items or data_row_items[0] is None: + for row_item in rows_list: + row_elements = extract_row_elements(row_item) + if not row_elements or row_elements[0] is None: continue - main_obj = data_row_items[0] - main_id = getattr(main_obj, 'id', None) or id(main_obj) + main_obj = row_elements[0] + main_id = get_obj_id(main_obj) - if main_id not in main_data: - main_data[main_id] = main_obj + if main_id not in main_objects: + main_objects[main_id] = main_obj - # 收集子对象 - for child_obj in data_row_items[1:]: + for child_obj in row_elements[1:]: if child_obj is None: continue - child_class_name = type(child_obj).__name__.lower() - grouped_data[main_id][child_class_name].append(child_obj) + child_type = type(child_obj).__name__.lower() + children_objects[main_id][child_type].append(child_obj) - if not main_data: + if not main_objects: return None - # 预生成 namedtuple 类型 + # namedtuple 类型预生成 namedtuple_cache = {} if not return_as_dict: - for cls_name, columns in model_info.items(): - if columns: - # 为嵌套关系预计算完整字段列表 - full_columns = columns.copy() - if has_relationships: - for target_class, relation_type in relation_graph.get(cls_name, {}).items(): - field_name = custom_names.get((cls_name, target_class)) - rel_key = get_relation_key(target_class, relation_type, field_name) - full_columns.append(rel_key) - full_columns = sorted(set(full_columns)) # 去重并排序 - - namedtuple_cache[cls_name] = namedtuple(cls_name.capitalize(), full_columns or columns) # noqa: PYI024 - - def build_flat_result(build_main_id: int, build_main_obj: Any) -> dict[str, Any]: # noqa: C901 - """构建扁平化结果""" - flat_result = {col: getattr(build_main_obj, col, None) for col in main_columns} - - for class_name in sorted(grouped_data[build_main_id]): - if class_name == main_obj_name: + for model_name, model_columns in model_info.items(): + if not model_columns: continue - flat_objs = get_unique_objects(grouped_data[build_main_id][class_name]) - cls_columns = model_info.get(class_name, []) + field_list = model_columns.copy() + if has_relationships: + for target, target_rtype in relation_graph.get(model_name, {}).items(): + nt_key = get_relationship_key(target, target_rtype, custom_names.get((model_name, target))) + field_list.append(nt_key) + field_list = list(dict.fromkeys(field_list)) - if not flat_objs: - flat_result[class_name] = [] - elif len(flat_objs) == 1: - obj_data = {col: getattr(flat_objs[0], col, None) for col in cls_columns} - # 确保 namedtuple 所需的所有字段都存在 - if not return_as_dict and class_name in namedtuple_cache: - nt_fields = getattr(namedtuple_cache[class_name], '_fields', []) - for field in nt_fields: - if field not in obj_data: - obj_data[field] = None - flat_result[class_name] = obj_data if return_as_dict else namedtuple_cache[class_name](**obj_data) + namedtuple_cache[model_name] = namedtuple(model_name.capitalize(), field_list) # noqa: PYI024 + + # 嵌套关系层级结构(一次性构建) + hierarchy = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + + if has_relationships: + for row_item in rows_list: + row_elements = extract_row_elements(row_item) + if not row_elements or row_elements[0] is None: + continue + + main_id = get_obj_id(row_elements[0]) + m_type_name = type(row_elements[0]).__name__.lower() + + for idx, rel_obj in enumerate(row_elements[1:], 1): # noqa: B007 + if rel_obj is None: + continue + + rel_type_name = type(rel_obj).__name__.lower() + + if rel_type_name in reverse_relation: + parent_type = reverse_relation[rel_type_name] + parent_idx = cls_idx.get(parent_type) + parent = ( + row_elements[parent_idx] if parent_idx is not None and parent_idx < len(row_elements) else None + ) + elif rel_type_name in relation_graph.get(m_type_name, {}): + parent = row_elements[0] + else: + continue + + if parent is None: + continue + + parent_pk = getattr(parent, 'id', None) + if parent_pk is not None: + hierarchy[main_id][rel_type_name][parent_pk].append(rel_obj) + + # 结果构建函数 + def build_flat(target_id: int, target_obj: Any) -> dict[str, Any]: + result = {col: getattr(target_obj, col, None) for col in primary_columns} + + for cls_type in children_objects[target_id]: + if cls_type == primary_obj_name: + continue + + unique_children = dedupe_objects(children_objects[target_id][cls_type]) + child_columns = model_info.get(cls_type, []) + + count = len(unique_children) + field_key = cls_type if count <= 1 else f'{cls_type}s' + + if count == 0: + result[field_key] = [] + elif count == 1: + obj_data = {col: getattr(unique_children[0], col, None) for col in child_columns} + result[field_key] = obj_data if return_as_dict else build_namedtuple(cls_type, obj_data) else: if return_as_dict: - flat_result[class_name] = [ - {col: getattr(flat_obj, col, None) for col in cls_columns} for flat_obj in flat_objs - ] + result[field_key] = [{col: getattr(c, col, None) for col in child_columns} for c in unique_children] else: - nested_result_list = [] - for nested_obj in flat_objs: - obj_data = {col: getattr(nested_obj, col, None) for col in cls_columns} - # 确保 namedtuple 所需的所有字段都存在 - if class_name in namedtuple_cache: - nt_fields = getattr(namedtuple_cache[class_name], '_fields', []) - for field in nt_fields: - if field not in obj_data: - obj_data[field] = None - nested_result_list.append(namedtuple_cache[class_name](**obj_data)) - flat_result[class_name] = nested_result_list + result[field_key] = [ + build_namedtuple(cls_type, {col: getattr(c, col, None) for col in child_columns}) + for c in unique_children + ] - return flat_result + return result - def build_nested_result(nested_main_id: int, nested_main_obj: Any) -> dict[str, Any]: # noqa: C901 - """构建嵌套化结果""" - nested_result = {col: getattr(nested_main_obj, col, None) for col in main_columns} + def build_nested(target_id: int, target_obj: Any) -> dict[str, Any]: + result = {col: getattr(target_obj, col, None) for col in primary_columns} + current_hierarchy = hierarchy.get(target_id, defaultdict(lambda: defaultdict(list))) - # 构建关系层级数据结构 - hierarchy = defaultdict(lambda: defaultdict(list)) - for iter_row in rows_list: - iter_row_items = iter_row if hasattr(iter_row, '__getitem__') else (iter_row,) - if not iter_row_items or iter_row_items[0] is None: - continue - - iter_main_id = getattr(iter_row_items[0], 'id', None) or id(iter_row_items[0]) - if iter_main_id != nested_main_id: - continue - - for _i, related_obj in enumerate(iter_row_items[1:], 1): - if related_obj is None: - continue - related_class_name = type(related_obj).__name__.lower() - - if related_class_name in reverse_relation: - parent_cls = reverse_relation[related_class_name] - parent_idx = cls_idxs.get(parent_cls, 0) - if parent_idx < len(iter_row_items): - parent_obj = iter_row_items[parent_idx] - if parent_obj is not None: - parent_obj_id = getattr(parent_obj, 'id', None) - if parent_obj_id is not None: - hierarchy[related_class_name][parent_obj_id].append(related_obj) - - def build_recursive(current_cls_name: str, current_parent_id: int) -> list: - """递归构建嵌套数据""" - recursive_objs = get_unique_objects(hierarchy[current_cls_name].get(current_parent_id, [])) - if not recursive_objs: + def recursive_build(cls_name: str, pk: int) -> list: + nested_dict = current_hierarchy.get(cls_name) + if nested_dict is None: + return [] + objs = dedupe_objects(nested_dict.get(pk, [])) + if not objs: return [] - recursive_result = [] - for nested_obj in recursive_objs: - # 基础数据 - obj_data = {col: getattr(nested_obj, col, None) for col in model_info[current_cls_name]} + output = [] + for item in objs: + item_data = {col: getattr(item, col, None) for col in model_info[cls_name]} - # 处理子关系 - for child_cls, child_rel_type in relation_graph.get(current_cls_name, {}).items(): - child_parent_id = getattr(nested_obj, 'id', None) - if child_parent_id is None: + for sub_type, sub_rel_type in relation_graph.get(cls_name, {}).items(): + sub_pk = getattr(item, 'id', None) + if sub_pk is None: continue - child_list = build_recursive(child_cls, child_parent_id) - child_key = get_relation_key( - child_cls, child_rel_type, custom_names.get((current_cls_name, child_cls)) - ) + sub_list = recursive_build(sub_type, sub_pk) + sub_key = get_relationship_key(sub_type, sub_rel_type, custom_names.get((cls_name, sub_type))) - if child_rel_type in ('m2o', 'o2o'): - obj_data[child_key] = child_list[0] if child_list else None + if sub_rel_type not in list_relationship_types: + item_data[sub_key] = sub_list[0] if sub_list else None else: - obj_data[child_key] = child_list + item_data[sub_key] = sub_list - if not return_as_dict and current_cls_name in namedtuple_cache: - nt_fields = getattr(namedtuple_cache[current_cls_name], '_fields', []) - for field in nt_fields: - if field not in obj_data: - obj_data[field] = None + output.append(item_data if return_as_dict else build_namedtuple(cls_name, item_data)) - recursive_result.append(obj_data if return_as_dict else namedtuple_cache[current_cls_name](**obj_data)) + return output - return recursive_result + for top_type, top_rtype in relation_graph.get(primary_obj_name, {}).items(): + instances = recursive_build(top_type, target_id) + top_key = get_relationship_key(top_type, top_rtype, custom_names.get((primary_obj_name, top_type))) - # 构建顶级关系 - for top_cls_name, top_rel_type in relation_graph.get(main_obj_name, {}).items(): - instances = build_recursive(top_cls_name, nested_main_id) - key = get_relation_key(top_cls_name, top_rel_type, custom_names.get((main_obj_name, top_cls_name))) - - if top_rel_type in ('m2o', 'o2o'): - nested_result[key] = instances[0] if instances else None + if top_rtype not in list_relationship_types: + result[top_key] = instances[0] if instances else None else: - nested_result[key] = instances + result[top_key] = instances - return nested_result + return result - # 构建最终结果 - final_result_list = [] - for current_main_id in sorted(main_data.keys()): - current_main_obj = main_data[current_main_id] + # 最终结果构建 + final_results = [] + processed_ids = set() - if has_relationships: - final_result_data = build_nested_result(current_main_id, current_main_obj) - else: - final_result_data = build_flat_result(current_main_id, current_main_obj) + for row_item in rows_list: + row_elements = extract_row_elements(row_item) + if not row_elements or row_elements[0] is None: + continue + + main_obj = row_elements[0] + main_id = get_obj_id(main_obj) + + if main_id not in main_objects or main_id in processed_ids: + continue + + processed_ids.add(main_id) + + result_data = build_nested(main_id, main_obj) if has_relationships else build_flat(main_id, main_obj) if not return_as_dict: - all_fields = list(final_result_data.keys()) - result_type = namedtuple('Result', all_fields) # noqa: PYI024 - final_result_list.append(result_type(**final_result_data)) + result_type = namedtuple('Result', result_data.keys()) # noqa: PYI024 + final_results.append(result_type(**result_data)) else: - final_result_list.append(final_result_data) + final_results.append(result_data) - return final_result_list[0] if len(final_result_list) == 1 else final_result_list + return final_results[0] if len(final_results) == 1 else final_results