mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-22 13:06:46 +08:00
457 lines
16 KiB
Python
457 lines
16 KiB
Python
"""
|
|
Module for visualizing decision trees in Manim.
|
|
It parses a decision tree classifier from sklearn.
|
|
|
|
TODO return a map from nodes to split animation for BFS tree expansion
|
|
TODO reimplement the decision 2D decision tree surface drawing.
|
|
"""
|
|
from manim import *
|
|
from manim_ml.decision_tree.classification_areas import (
|
|
compute_decision_areas,
|
|
merge_overlapping_polygons,
|
|
)
|
|
import manim_ml.decision_tree.helpers as helpers
|
|
from manim_ml.one_to_one_sync import OneToOneSync
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
|
|
class LeafNode(Group):
|
|
"""Leaf node in tree"""
|
|
|
|
def __init__(
|
|
self, class_index, display_type="image", class_image_paths=[], class_colors=[]
|
|
):
|
|
super().__init__()
|
|
self.display_type = display_type
|
|
self.class_image_paths = class_image_paths
|
|
self.class_colors = class_colors
|
|
assert self.display_type in ["image", "text"]
|
|
if self.display_type == "image":
|
|
self._construct_image_node(class_index)
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
def _construct_image_node(self, class_index):
|
|
"""Make an image node"""
|
|
# Get image
|
|
image_path = self.class_image_paths[class_index]
|
|
pil_image = Image.open(image_path)
|
|
node = ImageMobject(pil_image)
|
|
node.scale(1.5)
|
|
rectangle = Rectangle(
|
|
width=node.width + 0.05,
|
|
height=node.height + 0.05,
|
|
color=self.class_colors[class_index],
|
|
stroke_width=6,
|
|
)
|
|
rectangle.move_to(node.get_center())
|
|
rectangle.shift([-0.02, 0.02, 0])
|
|
self.add(rectangle)
|
|
self.add(node)
|
|
|
|
|
|
class SplitNode(VGroup):
|
|
"""Node for splitting decision in tree"""
|
|
|
|
def __init__(self, feature, threshold):
|
|
super().__init__()
|
|
node_text = f"{feature}\n<= {threshold:.2f} cm"
|
|
# Draw decision text
|
|
decision_text = Text(node_text, color=WHITE)
|
|
# Draw the surrounding box
|
|
bounding_box = SurroundingRectangle(decision_text, buff=0.3, color=WHITE)
|
|
self.add(bounding_box)
|
|
self.add(decision_text)
|
|
|
|
|
|
class DecisionTreeDiagram(Group):
|
|
"""Decision Tree Diagram Class for Manim"""
|
|
|
|
def __init__(
|
|
self,
|
|
sklearn_tree,
|
|
feature_names=None,
|
|
class_names=None,
|
|
class_images_paths=None,
|
|
class_colors=[RED, GREEN, BLUE],
|
|
):
|
|
super().__init__()
|
|
self.tree = sklearn_tree
|
|
self.feature_names = feature_names
|
|
self.class_names = class_names
|
|
self.class_image_paths = class_images_paths
|
|
self.class_colors = class_colors
|
|
# Make graph container for the tree
|
|
self.tree_group, self.nodes_map, self.edge_map = self._make_tree()
|
|
self.add(self.tree_group)
|
|
|
|
def _make_node(
|
|
self,
|
|
node_index,
|
|
):
|
|
"""Make node"""
|
|
is_split_node = (
|
|
self.tree.children_left[node_index] != self.tree.children_right[node_index]
|
|
)
|
|
if is_split_node:
|
|
node_feature = self.tree.feature[node_index]
|
|
node_threshold = self.tree.threshold[node_index]
|
|
node = SplitNode(self.feature_names[node_feature], node_threshold)
|
|
else:
|
|
# Get the most abundant class for the given leaf node
|
|
# Make the leaf node object
|
|
tree_class_index = np.argmax(self.tree.value[node_index])
|
|
node = LeafNode(
|
|
class_index=tree_class_index,
|
|
class_colors=self.class_colors,
|
|
class_image_paths=self.class_image_paths,
|
|
)
|
|
return node
|
|
|
|
def _make_connection(self, top, bottom, is_leaf=False):
|
|
"""Make a connection from top to bottom"""
|
|
top_node_bottom_location = top.get_center()
|
|
top_node_bottom_location[1] -= top.height / 2
|
|
bottom_node_top_location = bottom.get_center()
|
|
bottom_node_top_location[1] += bottom.height / 2
|
|
|
|
line = Line(top_node_bottom_location, bottom_node_top_location, color=WHITE)
|
|
|
|
return line
|
|
|
|
def _make_tree(self):
|
|
"""Construct the tree diagram"""
|
|
tree_group = Group()
|
|
max_depth = self.tree.max_depth
|
|
# Make the root node
|
|
nodes_map = {}
|
|
root_node = self._make_node(
|
|
node_index=0,
|
|
)
|
|
nodes_map[0] = root_node
|
|
tree_group.add(root_node)
|
|
# Save some information
|
|
node_height = root_node.height
|
|
node_width = root_node.width
|
|
scale_factor = 1.0
|
|
edge_map = {}
|
|
# tree height
|
|
tree_height = scale_factor * node_height * max_depth
|
|
tree_width = scale_factor * 2**max_depth * node_width
|
|
# traverse tree
|
|
def recurse(node_index, depth, direction, parent_object, parent_node):
|
|
# make the node object
|
|
is_leaf = (
|
|
self.tree.children_left[node_index]
|
|
== self.tree.children_right[node_index]
|
|
)
|
|
node_object = self._make_node(node_index=node_index)
|
|
nodes_map[node_index] = node_object
|
|
node_height = node_object.height
|
|
# set the node position
|
|
direction_factor = -1 if direction == "left" else 1
|
|
shift_right_amount = (
|
|
0.9 * direction_factor * scale_factor * tree_width / (2**depth) / 2
|
|
)
|
|
if is_leaf:
|
|
shift_down_amount = -1.0 * scale_factor * node_height
|
|
else:
|
|
shift_down_amount = -1.8 * scale_factor * node_height
|
|
node_object.match_x(parent_object).match_y(parent_object).shift(
|
|
[shift_right_amount, shift_down_amount, 0]
|
|
)
|
|
tree_group.add(node_object)
|
|
# make a connection
|
|
connection = self._make_connection(
|
|
parent_object, node_object, is_leaf=is_leaf
|
|
)
|
|
edge_name = str(parent_node) + "," + str(node_index)
|
|
edge_map[edge_name] = connection
|
|
tree_group.add(connection)
|
|
# recurse
|
|
if not is_leaf:
|
|
recurse(
|
|
self.tree.children_left[node_index],
|
|
depth + 1,
|
|
"left",
|
|
node_object,
|
|
node_index,
|
|
)
|
|
recurse(
|
|
self.tree.children_right[node_index],
|
|
depth + 1,
|
|
"right",
|
|
node_object,
|
|
node_index,
|
|
)
|
|
|
|
recurse(self.tree.children_left[0], 1, "left", root_node, 0)
|
|
recurse(self.tree.children_right[0], 1, "right", root_node, 0)
|
|
|
|
tree_group.scale(0.35)
|
|
return tree_group, nodes_map, edge_map
|
|
|
|
def create_level_order_expansion_decision_tree(self, tree):
|
|
"""Expands the decision tree in level order"""
|
|
raise NotImplementedError()
|
|
|
|
def create_bfs_expansion_decision_tree(self, tree):
|
|
"""Expands the tree using BFS"""
|
|
animations = []
|
|
# Compute parent mapping
|
|
parent_mapping = helpers.compute_node_to_parent_mapping(self.tree)
|
|
# Create the root node
|
|
animations.append(Create(self.nodes_map[0]))
|
|
# Iterate through the nodes
|
|
queue = [0]
|
|
while len(queue) > 0:
|
|
node_index = queue.pop(0)
|
|
# Check if a node is a split node or not
|
|
left_child = self.tree.children_left[node_index]
|
|
right_child = self.tree.children_right[node_index]
|
|
is_leaf_node = left_child == right_child
|
|
if not is_leaf_node:
|
|
# Split the node by creating the children and connecting them
|
|
# to the parent
|
|
# Get the nodes
|
|
left_node = self.nodes_map[left_child]
|
|
right_node = self.nodes_map[right_child]
|
|
# Get the parent edges
|
|
left_parent_edge = self.edge_map[f"{node_index},{left_child}"]
|
|
right_parent_edge = self.edge_map[f"{node_index},{right_child}"]
|
|
# Create the children
|
|
split_animation = AnimationGroup(
|
|
FadeIn(left_node),
|
|
FadeIn(right_node),
|
|
Create(left_parent_edge),
|
|
Create(right_parent_edge),
|
|
lag_ratio=0.0,
|
|
)
|
|
animations.append(split_animation)
|
|
# Add the children to the queue
|
|
if left_child != -1:
|
|
queue.append(left_child)
|
|
if right_child != -1:
|
|
queue.append(right_child)
|
|
|
|
return AnimationGroup(*animations, lag_ratio=1.0)
|
|
|
|
@override_animation(Create)
|
|
def create_decision_tree(self, traversal_order="bfs"):
|
|
"""Makes a create animation for the decision tree"""
|
|
if traversal_order == "level":
|
|
return self.create_level_order_expansion_decision_tree(self.tree)
|
|
elif traversal_order == "bfs":
|
|
return self.create_bfs_expansion_decision_tree(self.tree)
|
|
else:
|
|
raise Exception(f"Uncrecognized traversal: {traversal_order}")
|
|
|
|
|
|
class IrisDatasetPlot(VGroup):
|
|
def __init__(self, iris):
|
|
points = iris.data[:, 0:2]
|
|
labels = iris.feature_names
|
|
targets = iris.target
|
|
# Make points
|
|
self.point_group = self._make_point_group(points, targets)
|
|
# Make axes
|
|
self.axes_group = self._make_axes_group(points, labels)
|
|
# Make legend
|
|
self.legend_group = self._make_legend(
|
|
[BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
|
|
)
|
|
# Make title
|
|
# title_text = "Iris Dataset Plot"
|
|
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
|
|
# Make all group
|
|
self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
|
|
# scale the groups
|
|
self.point_group.scale(1.6)
|
|
self.point_group.match_x(self.axes_group)
|
|
self.point_group.match_y(self.axes_group)
|
|
self.point_group.shift([0.2, 0, 0])
|
|
self.axes_group.scale(0.7)
|
|
self.all_group.shift([0, 0.2, 0])
|
|
|
|
@override_animation(Create)
|
|
def create_animation(self):
|
|
animation_group = AnimationGroup(
|
|
# Perform the animations
|
|
Create(self.point_group, run_time=2),
|
|
Wait(0.5),
|
|
Create(self.axes_group, run_time=2),
|
|
# add title
|
|
# Create(self.title),
|
|
Create(self.legend_group),
|
|
)
|
|
return animation_group
|
|
|
|
def _make_point_group(self, points, targets, class_colors=[BLUE, ORANGE, GREEN]):
|
|
point_group = VGroup()
|
|
for point_index, point in enumerate(points):
|
|
# draw the dot
|
|
current_target = targets[point_index]
|
|
color = class_colors[current_target]
|
|
dot = Dot(point=np.array([point[0], point[1], 0])).set_color(color)
|
|
dot.scale(0.5)
|
|
point_group.add(dot)
|
|
return point_group
|
|
|
|
def _make_legend(self, class_colors, feature_labels, axes):
|
|
legend_group = VGroup()
|
|
# Make Text
|
|
setosa = Text("Setosa", color=BLUE)
|
|
verisicolor = Text("Verisicolor", color=ORANGE)
|
|
virginica = Text("Virginica", color=GREEN)
|
|
labels = VGroup(setosa, verisicolor, virginica).arrange(
|
|
direction=RIGHT, aligned_edge=LEFT, buff=2.0
|
|
)
|
|
labels.scale(0.5)
|
|
legend_group.add(labels)
|
|
# surrounding rectangle
|
|
surrounding_rectangle = SurroundingRectangle(labels, color=WHITE)
|
|
surrounding_rectangle.move_to(labels)
|
|
legend_group.add(surrounding_rectangle)
|
|
# shift the legend group
|
|
legend_group.move_to(axes)
|
|
legend_group.shift([0, -3.0, 0])
|
|
legend_group.match_x(axes[0][0])
|
|
|
|
return legend_group
|
|
|
|
def _make_axes_group(self, points, labels, font="Source Han Sans", font_scale=0.75):
|
|
axes_group = VGroup()
|
|
# make the axes
|
|
x_range = [
|
|
np.amin(points, axis=0)[0] - 0.2,
|
|
np.amax(points, axis=0)[0] - 0.2,
|
|
0.5,
|
|
]
|
|
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
|
|
axes = Axes(
|
|
x_range=x_range,
|
|
y_range=y_range,
|
|
x_length=9,
|
|
y_length=6.5,
|
|
# axis_config={"number_scale_value":0.75, "include_numbers":True},
|
|
tips=False,
|
|
).shift([0.5, 0.25, 0])
|
|
axes_group.add(axes)
|
|
# make axis labels
|
|
# x_label
|
|
x_label = (
|
|
Text(labels[0], font=font)
|
|
.match_y(axes.get_axes()[0])
|
|
.shift([0.5, -0.75, 0])
|
|
.scale(font_scale)
|
|
)
|
|
axes_group.add(x_label)
|
|
# y_label
|
|
y_label = (
|
|
Text(labels[1], font=font)
|
|
.match_x(axes.get_axes()[1])
|
|
.shift([-0.75, 0, 0])
|
|
.rotate(np.pi / 2)
|
|
.scale(font_scale)
|
|
)
|
|
axes_group.add(y_label)
|
|
|
|
return axes_group
|
|
|
|
|
|
class DecisionTreeSurface(VGroup):
|
|
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
|
|
# take the tree and construct the surface from it
|
|
self.tree_clf = tree_clf
|
|
self.data = data
|
|
self.axes = axes
|
|
self.class_colors = class_colors
|
|
self.surface_rectangles = self.generate_surface_rectangles()
|
|
|
|
def generate_surface_rectangles(self):
|
|
# compute data bounds
|
|
left = np.amin(self.data[:, 0]) - 0.2
|
|
right = np.amax(self.data[:, 0]) - 0.2
|
|
top = np.amax(self.data[:, 1])
|
|
bottom = np.amin(self.data[:, 1]) - 0.2
|
|
maxrange = [left, right, bottom, top]
|
|
rectangles = compute_decision_areas(
|
|
self.tree_clf, maxrange, x=0, y=1, n_features=2
|
|
)
|
|
# turn the rectangle objects into manim rectangles
|
|
def convert_rectangle_to_polygon(rect):
|
|
# get the points for the rectangle in the plot coordinate frame
|
|
bottom_left = [rect[0], rect[3]]
|
|
bottom_right = [rect[1], rect[3]]
|
|
top_right = [rect[1], rect[2]]
|
|
top_left = [rect[0], rect[2]]
|
|
# convert those points into the entire manim coordinates
|
|
bottom_left_coord = self.axes.coords_to_point(*bottom_left)
|
|
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
|
|
top_right_coord = self.axes.coords_to_point(*top_right)
|
|
top_left_coord = self.axes.coords_to_point(*top_left)
|
|
points = [
|
|
bottom_left_coord,
|
|
bottom_right_coord,
|
|
top_right_coord,
|
|
top_left_coord,
|
|
]
|
|
# construct a polygon object from those manim coordinates
|
|
rectangle = Polygon(
|
|
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
|
|
)
|
|
return rectangle
|
|
|
|
manim_rectangles = []
|
|
for rect in rectangles:
|
|
color = self.class_colors[int(rect[4])]
|
|
rectangle = convert_rectangle_to_polygon(rect)
|
|
manim_rectangles.append(rectangle)
|
|
|
|
manim_rectangles = merge_overlapping_polygons(
|
|
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
|
|
)
|
|
|
|
return manim_rectangles
|
|
|
|
@override_animation(Create)
|
|
def create_override(self):
|
|
# play a reveal of all of the surface rectangles
|
|
animations = []
|
|
for rectangle in self.surface_rectangles:
|
|
animations.append(Create(rectangle))
|
|
animation_group = AnimationGroup(*animations)
|
|
|
|
return animation_group
|
|
|
|
@override_animation(Uncreate)
|
|
def uncreate_override(self):
|
|
# play a reveal of all of the surface rectangles
|
|
animations = []
|
|
for rectangle in self.surface_rectangles:
|
|
animations.append(Uncreate(rectangle))
|
|
animation_group = AnimationGroup(*animations)
|
|
|
|
return animation_group
|
|
|
|
|
|
class DecisionTreeContainer(OneToOneSync):
|
|
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
|
|
|
def __init__(self, sklearn_tree, points, classes):
|
|
self.sklearn_tree = sklearn_tree
|
|
self.points = points
|
|
self.classes = classes
|
|
|
|
def make_unfold_tree_animation(self):
|
|
"""Unfolds the tree through an in order traversal
|
|
|
|
This animations unfolds the tree diagram as well as showing the splitting
|
|
of a shaded region in the Decision Tree embedding.
|
|
"""
|
|
# Draw points in the embedding
|
|
# Start the tree splitting animation
|
|
pass
|