mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-17 18:55:54 +08:00
Working BFS order tree expansion animation.
This commit is contained in:
@ -140,8 +140,8 @@ class DecisionTreeDiagram(Group):
|
||||
self.class_image_paths = class_images_paths
|
||||
self.class_colors = class_colors
|
||||
# Make graph container for the tree
|
||||
tree, _, _ = self._make_tree()
|
||||
self.add(tree)
|
||||
self.tree_group, self.nodes_map, self.edge_map = self._make_tree()
|
||||
self.add(self.tree_group)
|
||||
|
||||
def _make_node(
|
||||
self,
|
||||
@ -238,22 +238,63 @@ class DecisionTreeDiagram(Group):
|
||||
|
||||
def create_level_order_expansion_decision_tree(self, tree):
|
||||
"""Expands the decision tree in level order"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_bfs_expansion_decision_tree(self, tree):
|
||||
"""Expands the tree using BFS"""
|
||||
animations = []
|
||||
# Get sklearn nodes in the correct order
|
||||
node_level_order = compute_level_order_traversal(tree)
|
||||
# Add the node to the graph object
|
||||
# Position the node
|
||||
# Expand the graph in the given order.
|
||||
# Compute parent mapping
|
||||
parent_mapping = compute_node_to_parent_mapping(self.tree)
|
||||
# Create the root node
|
||||
animations.append(
|
||||
Create(self.nodes_map[0])
|
||||
)
|
||||
# Iterate through the nodes
|
||||
queue = [0]
|
||||
while len(queue) > 0:
|
||||
node_index = queue.pop(0)
|
||||
# Check if a node is a split node or not
|
||||
left_child = self.tree.children_left[node_index]
|
||||
right_child = self.tree.children_right[node_index]
|
||||
is_leaf_node = left_child == right_child
|
||||
if not is_leaf_node:
|
||||
# Split the node by creating the children and connecting them
|
||||
# to the parent
|
||||
# Get the nodes
|
||||
left_node = self.nodes_map[left_child]
|
||||
right_node = self.nodes_map[right_child]
|
||||
# Get the parent edges
|
||||
left_parent_edge = self.edge_map[f"{node_index},{left_child}"]
|
||||
right_parent_edge = self.edge_map[f"{node_index},{right_child}"]
|
||||
# Create the children
|
||||
split_animation = AnimationGroup(
|
||||
FadeIn(left_node),
|
||||
FadeIn(right_node),
|
||||
Create(left_parent_edge),
|
||||
Create(right_parent_edge),
|
||||
lag_ratio=0.0
|
||||
)
|
||||
animations.append(
|
||||
split_animation
|
||||
)
|
||||
# Add the children to the queue
|
||||
if left_child != -1:
|
||||
queue.append(left_child)
|
||||
if right_child != -1:
|
||||
queue.append(right_child)
|
||||
|
||||
return AnimationGroup(
|
||||
*animations,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
@override_animation(Create)
|
||||
def create_decision_tree(self, traversal_order="level"):
|
||||
def create_decision_tree(self, traversal_order="bfs"):
|
||||
"""Makes a create animation for the decision tree"""
|
||||
if traversal_order == "level":
|
||||
return self.create_level_order_expansion_decision_tree(self.sklearn_tree)
|
||||
return self.create_level_order_expansion_decision_tree(self.tree)
|
||||
elif traversal_order == "bfs":
|
||||
return self.create_bfs_expansion_decision_tree(self.tree)
|
||||
else:
|
||||
raise Exception(f"Uncrecognized traversal: {traversal_order}")
|
||||
|
||||
|
@ -43,7 +43,7 @@ class DecisionTreeScene(Scene):
|
||||
],
|
||||
feature_names=["Sepal Length", "Sepal Width"]
|
||||
)
|
||||
# create_decision_tree = Create(decision_tree)
|
||||
self.add(decision_tree)
|
||||
decision_tree.move_to(ORIGIN)
|
||||
create_decision_tree = Create(decision_tree, traversal_order="bfs")
|
||||
self.play(create_decision_tree)
|
||||
# self.play(create_decision_tree)
|
||||
|
Reference in New Issue
Block a user