mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-21 12:37:01 +08:00
99 lines
3.4 KiB
Python
99 lines
3.4 KiB
Python
def compute_node_depths(tree):
|
|
"""Computes the depths of nodes for level order traversal"""
|
|
|
|
def depth(node_index, current_node_index=0):
|
|
"""Compute the height of a node"""
|
|
if current_node_index == node_index:
|
|
return 0
|
|
elif (
|
|
tree.children_left[current_node_index]
|
|
== tree.children_right[current_node_index]
|
|
):
|
|
return -1
|
|
else:
|
|
# Compute the height of each subtree
|
|
l_depth = depth(node_index, tree.children_left[current_node_index])
|
|
r_depth = depth(node_index, tree.children_right[current_node_index])
|
|
# The index is only in one of them
|
|
if l_depth != -1:
|
|
return l_depth + 1
|
|
elif r_depth != -1:
|
|
return r_depth + 1
|
|
else:
|
|
return -1
|
|
|
|
node_depths = [depth(index) for index in range(tree.node_count)]
|
|
|
|
return node_depths
|
|
|
|
|
|
def compute_level_order_traversal(tree):
|
|
"""Computes level order traversal of a sklearn tree"""
|
|
|
|
def depth(node_index, current_node_index=0):
|
|
"""Compute the height of a node"""
|
|
if current_node_index == node_index:
|
|
return 0
|
|
elif (
|
|
tree.children_left[current_node_index]
|
|
== tree.children_right[current_node_index]
|
|
):
|
|
return -1
|
|
else:
|
|
# Compute the height of each subtree
|
|
l_depth = depth(node_index, tree.children_left[current_node_index])
|
|
r_depth = depth(node_index, tree.children_right[current_node_index])
|
|
# The index is only in one of them
|
|
if l_depth != -1:
|
|
return l_depth + 1
|
|
elif r_depth != -1:
|
|
return r_depth + 1
|
|
else:
|
|
return -1
|
|
|
|
node_depths = [(index, depth(index)) for index in range(tree.node_count)]
|
|
node_depths = sorted(node_depths, key=lambda x: x[1])
|
|
sorted_inds = [node_depth[0] for node_depth in node_depths]
|
|
|
|
return sorted_inds
|
|
|
|
|
|
def compute_bfs_traversal(tree):
|
|
"""Traverses the tree in BFS order and returns the nodes in order"""
|
|
traversal_order = []
|
|
tree_root_index = 0
|
|
queue = [tree_root_index]
|
|
while len(queue) > 0:
|
|
current_index = queue.pop(0)
|
|
traversal_order.append(current_index)
|
|
left_child_index = self.tree.children_left[node_index]
|
|
right_child_index = self.tree.children_right[node_index]
|
|
is_leaf_node = left_child_index == right_child_index
|
|
if not is_leaf_node:
|
|
queue.append(left_child_index)
|
|
queue.append(right_child_index)
|
|
|
|
return traversal_order
|
|
|
|
|
|
def compute_best_first_traversal(tree):
|
|
"""Traverses the tree according to the best split first order"""
|
|
pass
|
|
|
|
|
|
def compute_node_to_parent_mapping(tree):
|
|
"""Returns a hashmap mapping node indices to their parent indices"""
|
|
node_to_parent = {0: -1} # Root has no parent
|
|
num_nodes = tree.node_count
|
|
for node_index in range(num_nodes):
|
|
# Explore left children
|
|
left_child_node_index = tree.children_left[node_index]
|
|
if left_child_node_index != -1:
|
|
node_to_parent[left_child_node_index] = node_index
|
|
# Explore right children
|
|
right_child_node_index = tree.children_right[node_index]
|
|
if right_child_node_index != -1:
|
|
node_to_parent[right_child_node_index] = node_index
|
|
|
|
return node_to_parent
|