Working BFS order tree expansion animation.

This commit is contained in:
Alec Helbling
2023-01-04 22:10:51 -05:00
parent e33f98373a
commit 3d8df61f76
2 changed files with 52 additions and 11 deletions

View File

@ -140,8 +140,8 @@ class DecisionTreeDiagram(Group):
self.class_image_paths = class_images_paths
self.class_colors = class_colors
# Make graph container for the tree
tree, _, _ = self._make_tree()
self.add(tree)
self.tree_group, self.nodes_map, self.edge_map = self._make_tree()
self.add(self.tree_group)
def _make_node(
self,
@ -238,22 +238,63 @@ class DecisionTreeDiagram(Group):
def create_level_order_expansion_decision_tree(self, tree):
"""Expands the decision tree in level order"""
raise NotImplementedError()
def create_bfs_expansion_decision_tree(self, tree):
"""Expands the tree using BFS"""
animations = []
# Get sklearn nodes in the correct order
node_level_order = compute_level_order_traversal(tree)
# Add the node to the graph object
# Position the node
# Expand the graph in the given order.
# Compute parent mapping
parent_mapping = compute_node_to_parent_mapping(self.tree)
# Create the root node
animations.append(
Create(self.nodes_map[0])
)
# Iterate through the nodes
queue = [0]
while len(queue) > 0:
node_index = queue.pop(0)
# Check if a node is a split node or not
left_child = self.tree.children_left[node_index]
right_child = self.tree.children_right[node_index]
is_leaf_node = left_child == right_child
if not is_leaf_node:
# Split the node by creating the children and connecting them
# to the parent
# Get the nodes
left_node = self.nodes_map[left_child]
right_node = self.nodes_map[right_child]
# Get the parent edges
left_parent_edge = self.edge_map[f"{node_index},{left_child}"]
right_parent_edge = self.edge_map[f"{node_index},{right_child}"]
# Create the children
split_animation = AnimationGroup(
FadeIn(left_node),
FadeIn(right_node),
Create(left_parent_edge),
Create(right_parent_edge),
lag_ratio=0.0
)
animations.append(
split_animation
)
# Add the children to the queue
if left_child != -1:
queue.append(left_child)
if right_child != -1:
queue.append(right_child)
return AnimationGroup(
*animations,
lag_ratio=1.0
)
@override_animation(Create)
def create_decision_tree(self, traversal_order="level"):
def create_decision_tree(self, traversal_order="bfs"):
"""Makes a create animation for the decision tree"""
if traversal_order == "level":
return self.create_level_order_expansion_decision_tree(self.sklearn_tree)
return self.create_level_order_expansion_decision_tree(self.tree)
elif traversal_order == "bfs":
return self.create_bfs_expansion_decision_tree(self.tree)
else:
raise Exception(f"Uncrecognized traversal: {traversal_order}")

View File

@ -43,7 +43,7 @@ class DecisionTreeScene(Scene):
],
feature_names=["Sepal Length", "Sepal Width"]
)
# create_decision_tree = Create(decision_tree)
self.add(decision_tree)
decision_tree.move_to(ORIGIN)
create_decision_tree = Create(decision_tree, traversal_order="bfs")
self.play(create_decision_tree)
# self.play(create_decision_tree)