Reformatted the code using black, allowd for different orientation NNs, made an option for highlighting the active filter in a CNN forward pass.

This commit is contained in:
Alec Helbling
2023-01-09 15:52:37 +09:00
parent 39b0b133ce
commit ba63116b37
19 changed files with 485 additions and 283 deletions

View File

@ -1,11 +1,14 @@
def compute_node_depths(tree):
"""Computes the depths of nodes for level order traversal"""
def depth(node_index, current_node_index=0):
"""Compute the height of a node"""
if current_node_index == node_index:
return 0
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]:
elif (
tree.children_left[current_node_index]
== tree.children_right[current_node_index]
):
return -1
else:
# Compute the height of each subtree
@ -23,13 +26,18 @@ def compute_node_depths(tree):
return node_depths
def compute_level_order_traversal(tree):
"""Computes level order traversal of a sklearn tree"""
def depth(node_index, current_node_index=0):
"""Compute the height of a node"""
if current_node_index == node_index:
return 0
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]:
elif (
tree.children_left[current_node_index]
== tree.children_right[current_node_index]
):
return -1
else:
# Compute the height of each subtree
@ -47,14 +55,15 @@ def compute_level_order_traversal(tree):
node_depths = sorted(node_depths, key=lambda x: x[1])
sorted_inds = [node_depth[0] for node_depth in node_depths]
return sorted_inds
return sorted_inds
def compute_node_to_parent_mapping(tree):
"""Returns a hashmap mapping node indices to their parent indices"""
node_to_parent = {0: -1} # Root has no parent
node_to_parent = {0: -1} # Root has no parent
num_nodes = tree.node_count
for node_index in range(num_nodes):
# Explore left children
# Explore left children
left_child_node_index = tree.children_left[node_index]
if left_child_node_index != -1:
node_to_parent[left_child_node_index] = node_index
@ -62,5 +71,5 @@ def compute_node_to_parent_mapping(tree):
right_child_node_index = tree.children_right[node_index]
if right_child_node_index != -1:
node_to_parent[right_child_node_index] = node_index
return node_to_parent