mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-21 22:13:00 +08:00
Working initial visualization of a static decision tree.
This commit is contained in:
@ -5,27 +5,263 @@
|
|||||||
from manim import *
|
from manim import *
|
||||||
from manim_ml.one_to_one_sync import OneToOneSync
|
from manim_ml.one_to_one_sync import OneToOneSync
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
class LeafNode(VGroup):
|
def compute_node_depths(tree):
|
||||||
pass
|
"""Computes the depths of nodes for level order traversal"""
|
||||||
|
def depth(node_index, current_node_index=0):
|
||||||
|
"""Compute the height of a node"""
|
||||||
|
if current_node_index == node_index:
|
||||||
|
return 0
|
||||||
|
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]:
|
||||||
|
return -1
|
||||||
|
else:
|
||||||
|
# Compute the height of each subtree
|
||||||
|
l_depth = depth(node_index, tree.children_left[current_node_index])
|
||||||
|
r_depth = depth(node_index, tree.children_right[current_node_index])
|
||||||
|
# The index is only in one of them
|
||||||
|
if l_depth != -1:
|
||||||
|
return l_depth + 1
|
||||||
|
elif r_depth != -1:
|
||||||
|
return r_depth + 1
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
node_depths = [depth(index) for index in range(tree.node_count)]
|
||||||
|
|
||||||
class NonLeafNode(VGroup):
|
return node_depths
|
||||||
pass
|
|
||||||
|
|
||||||
|
def compute_level_order_traversal(tree):
|
||||||
|
"""Computes level order traversal of a sklearn tree"""
|
||||||
|
def depth(node_index, current_node_index=0):
|
||||||
|
"""Compute the height of a node"""
|
||||||
|
if current_node_index == node_index:
|
||||||
|
return 0
|
||||||
|
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]:
|
||||||
|
return -1
|
||||||
|
else:
|
||||||
|
# Compute the height of each subtree
|
||||||
|
l_depth = depth(node_index, tree.children_left[current_node_index])
|
||||||
|
r_depth = depth(node_index, tree.children_right[current_node_index])
|
||||||
|
# The index is only in one of them
|
||||||
|
if l_depth != -1:
|
||||||
|
return l_depth + 1
|
||||||
|
elif r_depth != -1:
|
||||||
|
return r_depth + 1
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
class DecisionTreeDiagram(Graph):
|
node_depths = [(index, depth(index)) for index in range(tree.node_count)]
|
||||||
"""Decision Tree Digram Class for Manim"""
|
node_depths = sorted(node_depths, key=lambda x: x[1])
|
||||||
|
sorted_inds = [node_depth[0] for node_depth in node_depths]
|
||||||
|
|
||||||
pass
|
return sorted_inds
|
||||||
|
|
||||||
|
def compute_node_to_parent_mapping(tree):
|
||||||
|
"""Returns a hashmap mapping node indices to their parent indices"""
|
||||||
|
node_to_parent = {0: -1} # Root has no parent
|
||||||
|
num_nodes = tree.node_count
|
||||||
|
for node_index in range(num_nodes):
|
||||||
|
# Explore left children
|
||||||
|
left_child_node_index = tree.children_left[node_index]
|
||||||
|
if left_child_node_index != -1:
|
||||||
|
node_to_parent[left_child_node_index] = node_index
|
||||||
|
# Explore right children
|
||||||
|
right_child_node_index = tree.children_right[node_index]
|
||||||
|
if right_child_node_index != -1:
|
||||||
|
node_to_parent[right_child_node_index] = node_index
|
||||||
|
|
||||||
|
return node_to_parent
|
||||||
|
|
||||||
|
class LeafNode(Group):
|
||||||
|
"""Leaf node in tree"""
|
||||||
|
|
||||||
|
def __init__(self, class_index, display_type="image", class_image_paths=[],
|
||||||
|
class_colors=[]):
|
||||||
|
super().__init__()
|
||||||
|
self.display_type = display_type
|
||||||
|
self.class_image_paths = class_image_paths
|
||||||
|
self.class_colors = class_colors
|
||||||
|
assert self.display_type in ["image", "text"]
|
||||||
|
if self.display_type == "image":
|
||||||
|
self._construct_image_node(class_index)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _construct_image_node(self, class_index):
|
||||||
|
"""Make an image node"""
|
||||||
|
# Get image
|
||||||
|
image_path = self.class_image_paths[class_index]
|
||||||
|
pil_image = Image.open(image_path)
|
||||||
|
node = ImageMobject(pil_image)
|
||||||
|
node.scale(1.5)
|
||||||
|
rectangle = Rectangle(
|
||||||
|
width=node.width + 0.05,
|
||||||
|
height=node.height + 0.05,
|
||||||
|
color=self.class_colors[class_index],
|
||||||
|
stroke_width=6
|
||||||
|
)
|
||||||
|
rectangle.move_to(node.get_center())
|
||||||
|
rectangle.shift([-0.02, 0.02, 0])
|
||||||
|
self.add(rectangle)
|
||||||
|
self.add(node)
|
||||||
|
|
||||||
|
class SplitNode(VGroup):
|
||||||
|
"""Node for splitting decision in tree"""
|
||||||
|
|
||||||
|
def __init__(self, feature, threshold):
|
||||||
|
super().__init__()
|
||||||
|
node_text = f"{feature}\n<= {threshold:.2f} cm"
|
||||||
|
# Draw decision text
|
||||||
|
decision_text = Text(
|
||||||
|
node_text,
|
||||||
|
color=WHITE
|
||||||
|
)
|
||||||
|
# Draw the surrounding box
|
||||||
|
bounding_box = SurroundingRectangle(
|
||||||
|
decision_text,
|
||||||
|
buff=0.3,
|
||||||
|
color=WHITE
|
||||||
|
)
|
||||||
|
self.add(bounding_box)
|
||||||
|
self.add(decision_text)
|
||||||
|
|
||||||
|
class DecisionTreeDiagram(Group):
|
||||||
|
"""Decision Tree Diagram Class for Manim"""
|
||||||
|
|
||||||
|
def __init__(self, sklearn_tree, feature_names=None,
|
||||||
|
class_names=None, class_images_paths=None,
|
||||||
|
class_colors=[RED, GREEN, BLUE]):
|
||||||
|
super().__init__()
|
||||||
|
self.tree = sklearn_tree
|
||||||
|
self.feature_names = feature_names
|
||||||
|
self.class_names = class_names
|
||||||
|
self.class_image_paths = class_images_paths
|
||||||
|
self.class_colors = class_colors
|
||||||
|
# Make graph container for the tree
|
||||||
|
tree, _, _ = self._make_tree()
|
||||||
|
self.add(tree)
|
||||||
|
|
||||||
|
def _make_node(
|
||||||
|
self,
|
||||||
|
node_index,
|
||||||
|
):
|
||||||
|
"""Make node"""
|
||||||
|
is_split_node = self.tree.children_left[node_index] != self.tree.children_right[node_index]
|
||||||
|
if is_split_node:
|
||||||
|
node_feature = self.tree.feature[node_index]
|
||||||
|
node_threshold = self.tree.threshold[node_index]
|
||||||
|
node = SplitNode(
|
||||||
|
self.feature_names[node_feature],
|
||||||
|
node_threshold
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Get the most abundant class for the given leaf node
|
||||||
|
# Make the leaf node object
|
||||||
|
tree_class_index = np.argmax(self.tree.value[node_index])
|
||||||
|
node = LeafNode(
|
||||||
|
class_index=tree_class_index,
|
||||||
|
class_colors=self.class_colors,
|
||||||
|
class_image_paths=self.class_image_paths
|
||||||
|
)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def _make_connection(self, top, bottom, is_leaf=False):
|
||||||
|
"""Make a connection from top to bottom"""
|
||||||
|
top_node_bottom_location = top.get_center()
|
||||||
|
top_node_bottom_location[1] -= top.height / 2
|
||||||
|
bottom_node_top_location = bottom.get_center()
|
||||||
|
bottom_node_top_location[1] += bottom.height / 2
|
||||||
|
|
||||||
|
line = Line(
|
||||||
|
top_node_bottom_location,
|
||||||
|
bottom_node_top_location,
|
||||||
|
color=WHITE
|
||||||
|
)
|
||||||
|
|
||||||
|
return line
|
||||||
|
|
||||||
|
def _make_tree(self):
|
||||||
|
"""Construct the tree diagram"""
|
||||||
|
tree_group = Group()
|
||||||
|
max_depth = self.tree.max_depth
|
||||||
|
# Make the root node
|
||||||
|
nodes_map = {}
|
||||||
|
root_node = self._make_node(
|
||||||
|
node_index=0,
|
||||||
|
)
|
||||||
|
nodes_map[0] = root_node
|
||||||
|
tree_group.add(root_node)
|
||||||
|
# Save some information
|
||||||
|
node_height = root_node.height
|
||||||
|
node_width = root_node.width
|
||||||
|
scale_factor = 1.0
|
||||||
|
edge_map = {}
|
||||||
|
# tree height
|
||||||
|
tree_height = scale_factor * node_height * max_depth
|
||||||
|
tree_width = scale_factor * 2 ** max_depth * node_width
|
||||||
|
# traverse tree
|
||||||
|
def recurse(node_index, depth, direction, parent_object, parent_node):
|
||||||
|
# make the node object
|
||||||
|
is_leaf = self.tree.children_left[node_index] == self.tree.children_right[node_index]
|
||||||
|
node_object = self._make_node(node_index=node_index)
|
||||||
|
nodes_map[node_index] = node_object
|
||||||
|
node_height = node_object.height
|
||||||
|
# set the node position
|
||||||
|
direction_factor = -1 if direction == "left" else 1
|
||||||
|
shift_right_amount = 0.9 * direction_factor * scale_factor * tree_width / (2 ** depth) / 2
|
||||||
|
if is_leaf:
|
||||||
|
shift_down_amount = -1.0 * scale_factor * node_height
|
||||||
|
else:
|
||||||
|
shift_down_amount = -1.8 * scale_factor * node_height
|
||||||
|
node_object \
|
||||||
|
.match_x(parent_object) \
|
||||||
|
.match_y(parent_object) \
|
||||||
|
.shift([shift_right_amount, shift_down_amount, 0])
|
||||||
|
tree_group.add(node_object)
|
||||||
|
# make a connection
|
||||||
|
connection = self._make_connection(parent_object, node_object, is_leaf=is_leaf)
|
||||||
|
edge_name = str(parent_node)+","+str(node_index)
|
||||||
|
edge_map[edge_name] = connection
|
||||||
|
tree_group.add(connection)
|
||||||
|
# recurse
|
||||||
|
if not is_leaf:
|
||||||
|
recurse(self.tree.children_left[node_index], depth + 1, "left", node_object, node_index)
|
||||||
|
recurse(self.tree.children_right[node_index], depth + 1, "right", node_object, node_index)
|
||||||
|
|
||||||
|
recurse(self.tree.children_left[0], 1, "left", root_node, 0)
|
||||||
|
recurse(self.tree.children_right[0], 1, "right", root_node, 0)
|
||||||
|
|
||||||
|
tree_group.scale(0.35)
|
||||||
|
return tree_group, nodes_map, edge_map
|
||||||
|
|
||||||
|
def create_level_order_expansion_decision_tree(self, tree):
|
||||||
|
"""Expands the decision tree in level order"""
|
||||||
|
animations = []
|
||||||
|
# Get sklearn nodes in the correct order
|
||||||
|
node_level_order = compute_level_order_traversal(tree)
|
||||||
|
# Add the node to the graph object
|
||||||
|
# Position the node
|
||||||
|
# Expand the graph in the given order.
|
||||||
|
|
||||||
|
return AnimationGroup(
|
||||||
|
*animations,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_animation(Create)
|
||||||
|
def create_decision_tree(self, traversal_order="level"):
|
||||||
|
"""Makes a create animation for the decision tree"""
|
||||||
|
if traversal_order == "level":
|
||||||
|
return self.create_level_order_expansion_decision_tree(self.sklearn_tree)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Uncrecognized traversal: {traversal_order}")
|
||||||
|
|
||||||
class DecisionTreeEmbedding:
|
class DecisionTreeEmbedding:
|
||||||
"""Embedding for the decision tree"""
|
"""Embedding for the decision tree"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DecisionTreeContainer(OneToOneSync):
|
class DecisionTreeContainer(OneToOneSync):
|
||||||
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
||||||
|
|
||||||
|
@ -170,6 +170,7 @@ class NeuralNetwork(Group):
|
|||||||
after_layer_args = {}
|
after_layer_args = {}
|
||||||
if layer.input_layer in layer_args:
|
if layer.input_layer in layer_args:
|
||||||
before_layer_args = layer_args[layer.input_layer]
|
before_layer_args = layer_args[layer.input_layer]
|
||||||
|
if layer in layer_args:
|
||||||
current_layer_args = layer_args[layer]
|
current_layer_args = layer_args[layer]
|
||||||
if layer.output_layer in layer_args:
|
if layer.output_layer in layer_args:
|
||||||
after_layer_args = layer_args[layer.output_layer]
|
after_layer_args = layer_args[layer.output_layer]
|
||||||
|
BIN
tests/images/iris_dataset/SetosaFlower.jpeg
Normal file
BIN
tests/images/iris_dataset/SetosaFlower.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 23 KiB |
BIN
tests/images/iris_dataset/VeriscolorFlower.jpeg
Normal file
BIN
tests/images/iris_dataset/VeriscolorFlower.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
BIN
tests/images/iris_dataset/VirginicaFlower.jpeg
Normal file
BIN
tests/images/iris_dataset/VirginicaFlower.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 30 KiB |
49
tests/test_decision_tree.py
Normal file
49
tests/test_decision_tree.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from manim import *
|
||||||
|
from manim_ml.decision_tree.decision_tree import DecisionTreeDiagram
|
||||||
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
|
from sklearn import datasets
|
||||||
|
import sklearn
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
def learn_iris_decision_tree(iris):
|
||||||
|
decision_tree = DecisionTreeClassifier(
|
||||||
|
random_state=1,
|
||||||
|
max_depth=3,
|
||||||
|
max_leaf_nodes=6
|
||||||
|
)
|
||||||
|
decision_tree = decision_tree.fit(iris.data[:, :2], iris.target)
|
||||||
|
# output the decisioin tree in some format
|
||||||
|
return decision_tree
|
||||||
|
|
||||||
|
def make_sklearn_tree(dataset, max_tree_depth=3):
|
||||||
|
tree = learn_iris_decision_tree(dataset)
|
||||||
|
feature_names = dataset.feature_names[0:2]
|
||||||
|
return tree, tree.tree_
|
||||||
|
|
||||||
|
class DecisionTreeScene(Scene):
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
"""Makes a decision tree object"""
|
||||||
|
iris_dataset = datasets.load_iris()
|
||||||
|
clf, sklearn_tree = make_sklearn_tree(iris_dataset)
|
||||||
|
# sklearn.tree.plot_tree(clf, node_ids=True)
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
|
decision_tree = DecisionTreeDiagram(
|
||||||
|
sklearn_tree,
|
||||||
|
class_images_paths=[
|
||||||
|
"images/iris_dataset/SetosaFlower.jpeg",
|
||||||
|
"images/iris_dataset/VeriscolorFlower.jpeg",
|
||||||
|
"images/iris_dataset/VirginicaFlower.jpeg",
|
||||||
|
],
|
||||||
|
class_names=[
|
||||||
|
"Setosa",
|
||||||
|
"Veriscolor",
|
||||||
|
"Virginica"
|
||||||
|
],
|
||||||
|
feature_names=["Sepal Length", "Sepal Width"]
|
||||||
|
)
|
||||||
|
# create_decision_tree = Create(decision_tree)
|
||||||
|
self.add(decision_tree)
|
||||||
|
decision_tree.move_to(ORIGIN)
|
||||||
|
# self.play(create_decision_tree)
|
Reference in New Issue
Block a user