mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-28 19:51:06 +08:00
Reformatted the code using black, allowd for different orientation NNs, made an option for highlighting the active filter in a CNN forward pass.
This commit is contained in:
@ -1,11 +1,14 @@
|
||||
|
||||
def compute_node_depths(tree):
|
||||
"""Computes the depths of nodes for level order traversal"""
|
||||
|
||||
def depth(node_index, current_node_index=0):
|
||||
"""Compute the height of a node"""
|
||||
if current_node_index == node_index:
|
||||
return 0
|
||||
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]:
|
||||
elif (
|
||||
tree.children_left[current_node_index]
|
||||
== tree.children_right[current_node_index]
|
||||
):
|
||||
return -1
|
||||
else:
|
||||
# Compute the height of each subtree
|
||||
@ -23,13 +26,18 @@ def compute_node_depths(tree):
|
||||
|
||||
return node_depths
|
||||
|
||||
|
||||
def compute_level_order_traversal(tree):
|
||||
"""Computes level order traversal of a sklearn tree"""
|
||||
|
||||
def depth(node_index, current_node_index=0):
|
||||
"""Compute the height of a node"""
|
||||
if current_node_index == node_index:
|
||||
return 0
|
||||
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]:
|
||||
elif (
|
||||
tree.children_left[current_node_index]
|
||||
== tree.children_right[current_node_index]
|
||||
):
|
||||
return -1
|
||||
else:
|
||||
# Compute the height of each subtree
|
||||
@ -47,14 +55,15 @@ def compute_level_order_traversal(tree):
|
||||
node_depths = sorted(node_depths, key=lambda x: x[1])
|
||||
sorted_inds = [node_depth[0] for node_depth in node_depths]
|
||||
|
||||
return sorted_inds
|
||||
return sorted_inds
|
||||
|
||||
|
||||
def compute_node_to_parent_mapping(tree):
|
||||
"""Returns a hashmap mapping node indices to their parent indices"""
|
||||
node_to_parent = {0: -1} # Root has no parent
|
||||
node_to_parent = {0: -1} # Root has no parent
|
||||
num_nodes = tree.node_count
|
||||
for node_index in range(num_nodes):
|
||||
# Explore left children
|
||||
# Explore left children
|
||||
left_child_node_index = tree.children_left[node_index]
|
||||
if left_child_node_index != -1:
|
||||
node_to_parent[left_child_node_index] = node_index
|
||||
@ -62,5 +71,5 @@ def compute_node_to_parent_mapping(tree):
|
||||
right_child_node_index = tree.children_right[node_index]
|
||||
if right_child_node_index != -1:
|
||||
node_to_parent[right_child_node_index] = node_index
|
||||
|
||||
|
||||
return node_to_parent
|
||||
|
Reference in New Issue
Block a user