mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-11 15:28:18 +08:00
Working initial visualization of a static decision tree.
This commit is contained in:
manim_ml
tests
@ -5,27 +5,263 @@
|
||||
from manim import *
|
||||
from manim_ml.one_to_one_sync import OneToOneSync
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
class LeafNode(VGroup):
|
||||
pass
|
||||
def compute_node_depths(tree):
|
||||
"""Computes the depths of nodes for level order traversal"""
|
||||
def depth(node_index, current_node_index=0):
|
||||
"""Compute the height of a node"""
|
||||
if current_node_index == node_index:
|
||||
return 0
|
||||
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]:
|
||||
return -1
|
||||
else:
|
||||
# Compute the height of each subtree
|
||||
l_depth = depth(node_index, tree.children_left[current_node_index])
|
||||
r_depth = depth(node_index, tree.children_right[current_node_index])
|
||||
# The index is only in one of them
|
||||
if l_depth != -1:
|
||||
return l_depth + 1
|
||||
elif r_depth != -1:
|
||||
return r_depth + 1
|
||||
else:
|
||||
return -1
|
||||
|
||||
node_depths = [depth(index) for index in range(tree.node_count)]
|
||||
|
||||
class NonLeafNode(VGroup):
|
||||
pass
|
||||
return node_depths
|
||||
|
||||
def compute_level_order_traversal(tree):
|
||||
"""Computes level order traversal of a sklearn tree"""
|
||||
def depth(node_index, current_node_index=0):
|
||||
"""Compute the height of a node"""
|
||||
if current_node_index == node_index:
|
||||
return 0
|
||||
elif tree.children_left[current_node_index] == tree.children_right[current_node_index]:
|
||||
return -1
|
||||
else:
|
||||
# Compute the height of each subtree
|
||||
l_depth = depth(node_index, tree.children_left[current_node_index])
|
||||
r_depth = depth(node_index, tree.children_right[current_node_index])
|
||||
# The index is only in one of them
|
||||
if l_depth != -1:
|
||||
return l_depth + 1
|
||||
elif r_depth != -1:
|
||||
return r_depth + 1
|
||||
else:
|
||||
return -1
|
||||
|
||||
class DecisionTreeDiagram(Graph):
|
||||
"""Decision Tree Digram Class for Manim"""
|
||||
node_depths = [(index, depth(index)) for index in range(tree.node_count)]
|
||||
node_depths = sorted(node_depths, key=lambda x: x[1])
|
||||
sorted_inds = [node_depth[0] for node_depth in node_depths]
|
||||
|
||||
pass
|
||||
return sorted_inds
|
||||
|
||||
def compute_node_to_parent_mapping(tree):
|
||||
"""Returns a hashmap mapping node indices to their parent indices"""
|
||||
node_to_parent = {0: -1} # Root has no parent
|
||||
num_nodes = tree.node_count
|
||||
for node_index in range(num_nodes):
|
||||
# Explore left children
|
||||
left_child_node_index = tree.children_left[node_index]
|
||||
if left_child_node_index != -1:
|
||||
node_to_parent[left_child_node_index] = node_index
|
||||
# Explore right children
|
||||
right_child_node_index = tree.children_right[node_index]
|
||||
if right_child_node_index != -1:
|
||||
node_to_parent[right_child_node_index] = node_index
|
||||
|
||||
return node_to_parent
|
||||
|
||||
class LeafNode(Group):
|
||||
"""Leaf node in tree"""
|
||||
|
||||
def __init__(self, class_index, display_type="image", class_image_paths=[],
|
||||
class_colors=[]):
|
||||
super().__init__()
|
||||
self.display_type = display_type
|
||||
self.class_image_paths = class_image_paths
|
||||
self.class_colors = class_colors
|
||||
assert self.display_type in ["image", "text"]
|
||||
if self.display_type == "image":
|
||||
self._construct_image_node(class_index)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _construct_image_node(self, class_index):
|
||||
"""Make an image node"""
|
||||
# Get image
|
||||
image_path = self.class_image_paths[class_index]
|
||||
pil_image = Image.open(image_path)
|
||||
node = ImageMobject(pil_image)
|
||||
node.scale(1.5)
|
||||
rectangle = Rectangle(
|
||||
width=node.width + 0.05,
|
||||
height=node.height + 0.05,
|
||||
color=self.class_colors[class_index],
|
||||
stroke_width=6
|
||||
)
|
||||
rectangle.move_to(node.get_center())
|
||||
rectangle.shift([-0.02, 0.02, 0])
|
||||
self.add(rectangle)
|
||||
self.add(node)
|
||||
|
||||
class SplitNode(VGroup):
|
||||
"""Node for splitting decision in tree"""
|
||||
|
||||
def __init__(self, feature, threshold):
|
||||
super().__init__()
|
||||
node_text = f"{feature}\n<= {threshold:.2f} cm"
|
||||
# Draw decision text
|
||||
decision_text = Text(
|
||||
node_text,
|
||||
color=WHITE
|
||||
)
|
||||
# Draw the surrounding box
|
||||
bounding_box = SurroundingRectangle(
|
||||
decision_text,
|
||||
buff=0.3,
|
||||
color=WHITE
|
||||
)
|
||||
self.add(bounding_box)
|
||||
self.add(decision_text)
|
||||
|
||||
class DecisionTreeDiagram(Group):
|
||||
"""Decision Tree Diagram Class for Manim"""
|
||||
|
||||
def __init__(self, sklearn_tree, feature_names=None,
|
||||
class_names=None, class_images_paths=None,
|
||||
class_colors=[RED, GREEN, BLUE]):
|
||||
super().__init__()
|
||||
self.tree = sklearn_tree
|
||||
self.feature_names = feature_names
|
||||
self.class_names = class_names
|
||||
self.class_image_paths = class_images_paths
|
||||
self.class_colors = class_colors
|
||||
# Make graph container for the tree
|
||||
tree, _, _ = self._make_tree()
|
||||
self.add(tree)
|
||||
|
||||
def _make_node(
|
||||
self,
|
||||
node_index,
|
||||
):
|
||||
"""Make node"""
|
||||
is_split_node = self.tree.children_left[node_index] != self.tree.children_right[node_index]
|
||||
if is_split_node:
|
||||
node_feature = self.tree.feature[node_index]
|
||||
node_threshold = self.tree.threshold[node_index]
|
||||
node = SplitNode(
|
||||
self.feature_names[node_feature],
|
||||
node_threshold
|
||||
)
|
||||
else:
|
||||
# Get the most abundant class for the given leaf node
|
||||
# Make the leaf node object
|
||||
tree_class_index = np.argmax(self.tree.value[node_index])
|
||||
node = LeafNode(
|
||||
class_index=tree_class_index,
|
||||
class_colors=self.class_colors,
|
||||
class_image_paths=self.class_image_paths
|
||||
)
|
||||
return node
|
||||
|
||||
def _make_connection(self, top, bottom, is_leaf=False):
|
||||
"""Make a connection from top to bottom"""
|
||||
top_node_bottom_location = top.get_center()
|
||||
top_node_bottom_location[1] -= top.height / 2
|
||||
bottom_node_top_location = bottom.get_center()
|
||||
bottom_node_top_location[1] += bottom.height / 2
|
||||
|
||||
line = Line(
|
||||
top_node_bottom_location,
|
||||
bottom_node_top_location,
|
||||
color=WHITE
|
||||
)
|
||||
|
||||
return line
|
||||
|
||||
def _make_tree(self):
|
||||
"""Construct the tree diagram"""
|
||||
tree_group = Group()
|
||||
max_depth = self.tree.max_depth
|
||||
# Make the root node
|
||||
nodes_map = {}
|
||||
root_node = self._make_node(
|
||||
node_index=0,
|
||||
)
|
||||
nodes_map[0] = root_node
|
||||
tree_group.add(root_node)
|
||||
# Save some information
|
||||
node_height = root_node.height
|
||||
node_width = root_node.width
|
||||
scale_factor = 1.0
|
||||
edge_map = {}
|
||||
# tree height
|
||||
tree_height = scale_factor * node_height * max_depth
|
||||
tree_width = scale_factor * 2 ** max_depth * node_width
|
||||
# traverse tree
|
||||
def recurse(node_index, depth, direction, parent_object, parent_node):
|
||||
# make the node object
|
||||
is_leaf = self.tree.children_left[node_index] == self.tree.children_right[node_index]
|
||||
node_object = self._make_node(node_index=node_index)
|
||||
nodes_map[node_index] = node_object
|
||||
node_height = node_object.height
|
||||
# set the node position
|
||||
direction_factor = -1 if direction == "left" else 1
|
||||
shift_right_amount = 0.9 * direction_factor * scale_factor * tree_width / (2 ** depth) / 2
|
||||
if is_leaf:
|
||||
shift_down_amount = -1.0 * scale_factor * node_height
|
||||
else:
|
||||
shift_down_amount = -1.8 * scale_factor * node_height
|
||||
node_object \
|
||||
.match_x(parent_object) \
|
||||
.match_y(parent_object) \
|
||||
.shift([shift_right_amount, shift_down_amount, 0])
|
||||
tree_group.add(node_object)
|
||||
# make a connection
|
||||
connection = self._make_connection(parent_object, node_object, is_leaf=is_leaf)
|
||||
edge_name = str(parent_node)+","+str(node_index)
|
||||
edge_map[edge_name] = connection
|
||||
tree_group.add(connection)
|
||||
# recurse
|
||||
if not is_leaf:
|
||||
recurse(self.tree.children_left[node_index], depth + 1, "left", node_object, node_index)
|
||||
recurse(self.tree.children_right[node_index], depth + 1, "right", node_object, node_index)
|
||||
|
||||
recurse(self.tree.children_left[0], 1, "left", root_node, 0)
|
||||
recurse(self.tree.children_right[0], 1, "right", root_node, 0)
|
||||
|
||||
tree_group.scale(0.35)
|
||||
return tree_group, nodes_map, edge_map
|
||||
|
||||
def create_level_order_expansion_decision_tree(self, tree):
|
||||
"""Expands the decision tree in level order"""
|
||||
animations = []
|
||||
# Get sklearn nodes in the correct order
|
||||
node_level_order = compute_level_order_traversal(tree)
|
||||
# Add the node to the graph object
|
||||
# Position the node
|
||||
# Expand the graph in the given order.
|
||||
|
||||
return AnimationGroup(
|
||||
*animations,
|
||||
)
|
||||
|
||||
@override_animation(Create)
|
||||
def create_decision_tree(self, traversal_order="level"):
|
||||
"""Makes a create animation for the decision tree"""
|
||||
if traversal_order == "level":
|
||||
return self.create_level_order_expansion_decision_tree(self.sklearn_tree)
|
||||
else:
|
||||
raise Exception(f"Uncrecognized traversal: {traversal_order}")
|
||||
|
||||
class DecisionTreeEmbedding:
|
||||
"""Embedding for the decision tree"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DecisionTreeContainer(OneToOneSync):
|
||||
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
||||
|
||||
|
@ -170,7 +170,8 @@ class NeuralNetwork(Group):
|
||||
after_layer_args = {}
|
||||
if layer.input_layer in layer_args:
|
||||
before_layer_args = layer_args[layer.input_layer]
|
||||
current_layer_args = layer_args[layer]
|
||||
if layer in layer_args:
|
||||
current_layer_args = layer_args[layer]
|
||||
if layer.output_layer in layer_args:
|
||||
after_layer_args = layer_args[layer.output_layer]
|
||||
# Merge the two dicts
|
||||
|
BIN
tests/images/iris_dataset/SetosaFlower.jpeg
Normal file
BIN
tests/images/iris_dataset/SetosaFlower.jpeg
Normal file
Binary file not shown.
After ![]() (image error) Size: 23 KiB |
BIN
tests/images/iris_dataset/VeriscolorFlower.jpeg
Normal file
BIN
tests/images/iris_dataset/VeriscolorFlower.jpeg
Normal file
Binary file not shown.
After ![]() (image error) Size: 31 KiB |
BIN
tests/images/iris_dataset/VirginicaFlower.jpeg
Normal file
BIN
tests/images/iris_dataset/VirginicaFlower.jpeg
Normal file
Binary file not shown.
After ![]() (image error) Size: 30 KiB |
49
tests/test_decision_tree.py
Normal file
49
tests/test_decision_tree.py
Normal file
@ -0,0 +1,49 @@
|
||||
from manim import *
|
||||
from manim_ml.decision_tree.decision_tree import DecisionTreeDiagram
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn import datasets
|
||||
import sklearn
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def learn_iris_decision_tree(iris):
|
||||
decision_tree = DecisionTreeClassifier(
|
||||
random_state=1,
|
||||
max_depth=3,
|
||||
max_leaf_nodes=6
|
||||
)
|
||||
decision_tree = decision_tree.fit(iris.data[:, :2], iris.target)
|
||||
# output the decisioin tree in some format
|
||||
return decision_tree
|
||||
|
||||
def make_sklearn_tree(dataset, max_tree_depth=3):
|
||||
tree = learn_iris_decision_tree(dataset)
|
||||
feature_names = dataset.feature_names[0:2]
|
||||
return tree, tree.tree_
|
||||
|
||||
class DecisionTreeScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
"""Makes a decision tree object"""
|
||||
iris_dataset = datasets.load_iris()
|
||||
clf, sklearn_tree = make_sklearn_tree(iris_dataset)
|
||||
# sklearn.tree.plot_tree(clf, node_ids=True)
|
||||
# plt.show()
|
||||
|
||||
decision_tree = DecisionTreeDiagram(
|
||||
sklearn_tree,
|
||||
class_images_paths=[
|
||||
"images/iris_dataset/SetosaFlower.jpeg",
|
||||
"images/iris_dataset/VeriscolorFlower.jpeg",
|
||||
"images/iris_dataset/VirginicaFlower.jpeg",
|
||||
],
|
||||
class_names=[
|
||||
"Setosa",
|
||||
"Veriscolor",
|
||||
"Virginica"
|
||||
],
|
||||
feature_names=["Sepal Length", "Sepal Width"]
|
||||
)
|
||||
# create_decision_tree = Create(decision_tree)
|
||||
self.add(decision_tree)
|
||||
decision_tree.move_to(ORIGIN)
|
||||
# self.play(create_decision_tree)
|
Reference in New Issue
Block a user