Files
ManimML/examples/decision_tree/split_scene.py

691 lines
26 KiB
Python

from sklearn import datasets
from decision_tree_surface import *
from manim import *
from sklearn.tree import DecisionTreeClassifier
from scipy.stats import entropy
import math
from PIL import Image
iris = datasets.load_iris()
font = "Source Han Sans"
font_scale = 0.75
images = [
Image.open("iris_dataset/SetosaFlower.jpeg"),
Image.open("iris_dataset/VeriscolorFlower.jpeg"),
Image.open("iris_dataset/VirginicaFlower.jpeg"),
]
def entropy(class_labels, base=2):
# compute the class counts
unique, counts = np.unique(class_labels, return_counts=True)
dictionary = dict(zip(unique, counts))
total = 0.0
num_samples = len(class_labels)
for class_index in range(0, 3):
if not class_index in dictionary:
continue
prob = dictionary[class_index] / num_samples
total += prob * math.log(prob, base)
# higher set
return -total
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
# get all polygons of each color
polygon_dict = {
str(BLUE).lower(): [],
str(GREEN).lower(): [],
str(ORANGE).lower(): [],
}
for polygon in all_polygons:
print(polygon_dict)
polygon_dict[str(polygon.color).lower()].append(polygon)
return_polygons = []
for color in colors:
color = str(color).lower()
polygons = polygon_dict[color]
points = set()
for polygon in polygons:
vertices = polygon.get_vertices().tolist()
vertices = [tuple(vert) for vert in vertices]
for pt in vertices:
if pt in points: # Shared vertice, remove it.
points.remove(pt)
else:
points.add(pt)
points = list(points)
sort_x = sorted(points)
sort_y = sorted(points, key=lambda x: x[1])
edges_h = {}
edges_v = {}
i = 0
while i < len(points):
curr_y = sort_y[i][1]
while i < len(points) and sort_y[i][1] == curr_y:
edges_h[sort_y[i]] = sort_y[i + 1]
edges_h[sort_y[i + 1]] = sort_y[i]
i += 2
i = 0
while i < len(points):
curr_x = sort_x[i][0]
while i < len(points) and sort_x[i][0] == curr_x:
edges_v[sort_x[i]] = sort_x[i + 1]
edges_v[sort_x[i + 1]] = sort_x[i]
i += 2
# Get all the polygons.
while edges_h:
# We can start with any point.
polygon = [(edges_h.popitem()[0], 0)]
while True:
curr, e = polygon[-1]
if e == 0:
next_vertex = edges_v.pop(curr)
polygon.append((next_vertex, 1))
else:
next_vertex = edges_h.pop(curr)
polygon.append((next_vertex, 0))
if polygon[-1] == polygon[0]:
# Closed polygon
polygon.pop()
break
# Remove implementation-markers from the polygon.
poly = [point for point, _ in polygon]
for vertex in poly:
if vertex in edges_h:
edges_h.pop(vertex)
if vertex in edges_v:
edges_v.pop(vertex)
polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0)
return_polygons.append(polygon)
return return_polygons
class IrisDatasetPlot(VGroup):
def __init__(self):
points = iris.data[:, 0:2]
labels = iris.feature_names
targets = iris.target
# Make points
self.point_group = self._make_point_group(points, targets)
# Make axes
self.axes_group = self._make_axes_group(points, labels)
# Make legend
self.legend_group = self._make_legend(
[BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
)
# Make title
# title_text = "Iris Dataset Plot"
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
# Make all group
self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
# scale the groups
self.point_group.scale(1.6)
self.point_group.match_x(self.axes_group)
self.point_group.match_y(self.axes_group)
self.point_group.shift([0.2, 0, 0])
self.axes_group.scale(0.7)
self.all_group.shift([0, 0.2, 0])
@override_animation(Create)
def create_animation(self):
animation_group = AnimationGroup(
# Perform the animations
Create(self.point_group, run_time=2),
Wait(0.5),
Create(self.axes_group, run_time=2),
# add title
# Create(self.title),
Create(self.legend_group),
)
return animation_group
def _make_point_group(self, points, targets, class_colors=[BLUE, ORANGE, GREEN]):
point_group = VGroup()
for point_index, point in enumerate(points):
# draw the dot
current_target = targets[point_index]
color = class_colors[current_target]
dot = Dot(point=np.array([point[0], point[1], 0])).set_color(color)
dot.scale(0.5)
point_group.add(dot)
return point_group
def _make_legend(self, class_colors, feature_labels, axes):
legend_group = VGroup()
# Make Text
setosa = Text("Setosa", color=BLUE)
verisicolor = Text("Verisicolor", color=ORANGE)
virginica = Text("Virginica", color=GREEN)
labels = VGroup(setosa, verisicolor, virginica).arrange(
direction=RIGHT, aligned_edge=LEFT, buff=2.0
)
labels.scale(0.5)
legend_group.add(labels)
# surrounding rectangle
surrounding_rectangle = SurroundingRectangle(labels, color=WHITE)
surrounding_rectangle.move_to(labels)
legend_group.add(surrounding_rectangle)
# shift the legend group
legend_group.move_to(axes)
legend_group.shift([0, -3.0, 0])
legend_group.match_x(axes[0][0])
return legend_group
def _make_axes_group(self, points, labels):
axes_group = VGroup()
# make the axes
x_range = [
np.amin(points, axis=0)[0] - 0.2,
np.amax(points, axis=0)[0] - 0.2,
0.5,
]
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
axes = Axes(
x_range=x_range,
y_range=y_range,
x_length=9,
y_length=6.5,
# axis_config={"number_scale_value":0.75, "include_numbers":True},
tips=False,
).shift([0.5, 0.25, 0])
axes_group.add(axes)
# make axis labels
# x_label
x_label = (
Text(labels[0], font=font)
.match_y(axes.get_axes()[0])
.shift([0.5, -0.75, 0])
.scale(font_scale)
)
axes_group.add(x_label)
# y_label
y_label = (
Text(labels[1], font=font)
.match_x(axes.get_axes()[1])
.shift([-0.75, 0, 0])
.rotate(np.pi / 2)
.scale(font_scale)
)
axes_group.add(y_label)
return axes_group
class DecisionTreeSurface(VGroup):
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
# take the tree and construct the surface from it
self.tree_clf = tree_clf
self.data = data
self.axes = axes
self.class_colors = class_colors
self.surface_rectangles = self.generate_surface_rectangles()
def generate_surface_rectangles(self):
# compute data bounds
left = np.amin(self.data[:, 0]) - 0.2
right = np.amax(self.data[:, 0]) - 0.2
top = np.amax(self.data[:, 1])
bottom = np.amin(self.data[:, 1]) - 0.2
maxrange = [left, right, bottom, top]
rectangles = decision_areas(self.tree_clf, maxrange, x=0, y=1, n_features=2)
# turn the rectangle objects into manim rectangles
def convert_rectangle_to_polygon(rect):
# get the points for the rectangle in the plot coordinate frame
bottom_left = [rect[0], rect[3]]
bottom_right = [rect[1], rect[3]]
top_right = [rect[1], rect[2]]
top_left = [rect[0], rect[2]]
# convert those points into the entire manim coordinates
bottom_left_coord = self.axes.coords_to_point(*bottom_left)
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
top_right_coord = self.axes.coords_to_point(*top_right)
top_left_coord = self.axes.coords_to_point(*top_left)
points = [
bottom_left_coord,
bottom_right_coord,
top_right_coord,
top_left_coord,
]
# construct a polygon object from those manim coordinates
rectangle = Polygon(
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
)
return rectangle
manim_rectangles = []
for rect in rectangles:
color = self.class_colors[int(rect[4])]
rectangle = convert_rectangle_to_polygon(rect)
manim_rectangles.append(rectangle)
manim_rectangles = merge_overlapping_polygons(
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
)
return manim_rectangles
@override_animation(Create)
def create_override(self):
# play a reveal of all of the surface rectangles
animations = []
for rectangle in self.surface_rectangles:
animations.append(Create(rectangle))
animation_group = AnimationGroup(*animations)
return animation_group
@override_animation(Uncreate)
def uncreate_override(self):
# play a reveal of all of the surface rectangles
animations = []
for rectangle in self.surface_rectangles:
animations.append(Uncreate(rectangle))
animation_group = AnimationGroup(*animations)
return animation_group
class DecisionTree:
"""
Draw a single tree node
"""
def _make_node(
self,
feature,
threshold,
values,
is_leaf=False,
depth=0,
leaf_colors=[BLUE, ORANGE, GREEN],
):
if not is_leaf:
node_text = f"{feature}\n <= {threshold:.3f} cm"
# draw decision text
decision_text = Text(node_text, color=WHITE)
# draw a box
bounding_box = SurroundingRectangle(decision_text, buff=0.3, color=WHITE)
node = VGroup()
node.add(bounding_box)
node.add(decision_text)
# return the node
else:
# plot the appropriate image
class_index = np.argmax(values)
# get image
pil_image = images[class_index]
leaf_group = Group()
node = ImageMobject(pil_image)
node.scale(1.5)
rectangle = Rectangle(
width=node.width + 0.05,
height=node.height + 0.05,
color=leaf_colors[class_index],
stroke_width=6,
)
rectangle.move_to(node.get_center())
rectangle.shift([-0.02, 0.02, 0])
leaf_group.add(rectangle)
leaf_group.add(node)
node = leaf_group
return node
def _make_connection(self, top, bottom, is_leaf=False):
top_node_bottom_location = top.get_center()
top_node_bottom_location[1] -= top.height / 2
bottom_node_top_location = bottom.get_center()
bottom_node_top_location[1] += bottom.height / 2
line = Line(top_node_bottom_location, bottom_node_top_location, color=WHITE)
return line
def _make_tree(self, tree, feature_names=["Sepal Length", "Sepal Width"]):
tree_group = Group()
max_depth = tree.max_depth
# make the base node
feature_name = feature_names[tree.feature[0]]
threshold = tree.threshold[0]
values = tree.value[0]
nodes_map = {}
root_node = self._make_node(feature_name, threshold, values, depth=0)
nodes_map[0] = root_node
tree_group.add(root_node)
# save some information
node_height = root_node.height
node_width = root_node.width
scale_factor = 1.0
edge_map = {}
# tree height
tree_height = scale_factor * node_height * max_depth
tree_width = scale_factor * 2**max_depth * node_width
# traverse tree
def recurse(node, depth, direction, parent_object, parent_node):
# make sure it is a valid node
# make the node object
is_leaf = tree.children_left[node] == tree.children_right[node]
feature_name = feature_names[tree.feature[node]]
threshold = tree.threshold[node]
values = tree.value[node]
node_object = self._make_node(
feature_name, threshold, values, depth=depth, is_leaf=is_leaf
)
nodes_map[node] = node_object
node_height = node_object.height
# set the node position
direction_factor = -1 if direction == "left" else 1
shift_right_amount = (
0.8 * direction_factor * scale_factor * tree_width / (2**depth) / 2
)
if is_leaf:
shift_down_amount = -1.0 * scale_factor * node_height
else:
shift_down_amount = -1.8 * scale_factor * node_height
node_object.match_x(parent_object).match_y(parent_object).shift(
[shift_right_amount, shift_down_amount, 0]
)
tree_group.add(node_object)
# make a connection
connection = self._make_connection(
parent_object, node_object, is_leaf=is_leaf
)
edge_name = str(parent_node) + "," + str(node)
edge_map[edge_name] = connection
tree_group.add(connection)
# recurse
if not is_leaf:
recurse(tree.children_left[node], depth + 1, "left", node_object, node)
recurse(
tree.children_right[node], depth + 1, "right", node_object, node
)
recurse(tree.children_left[0], 1, "left", root_node, 0)
recurse(tree.children_right[0], 1, "right", root_node, 0)
tree_group.scale(0.35)
return tree_group, nodes_map, edge_map
def color_example_path(
self, tree_group, nodes_map, tree, edge_map, example, color=YELLOW, thickness=2
):
# get decision path
decision_path = tree.decision_path(example)[0]
path_indices = decision_path.indices
# highlight edges
for node_index in range(0, len(path_indices) - 1):
current_val = path_indices[node_index]
next_val = path_indices[node_index + 1]
edge_str = str(current_val) + "," + str(next_val)
edge = edge_map[edge_str]
animation_two = AnimationGroup(
nodes_map[current_val].animate.set_color(color)
)
self.play(animation_two, run_time=0.5)
animation_one = AnimationGroup(
edge.animate.set_color(color),
# edge.animate.set_stroke_width(4),
)
self.play(animation_one, run_time=0.5)
# surround the bottom image
last_path_index = path_indices[-1]
last_path_rectangle = nodes_map[last_path_index][0]
self.play(last_path_rectangle.animate.set_color(color))
def create_sklearn_tree(self, max_tree_depth=1):
# learn the decision tree
iris = load_iris()
tree = learn_iris_decision_tree(iris, max_depth=max_tree_depth)
feature_names = iris.feature_names[0:2]
return tree.tree_
def make_tree(self, max_tree_depth=2):
sklearn_tree = self.create_sklearn_tree()
# make the tree
tree_group, nodes_map, edge_map = self._make_tree(
sklearn_tree.tree_, feature_names
)
tree_group.shift([0, 5.5, 0])
return tree_group
# self.add(tree_group)
# self.play(SpinInFromNothing(tree_group), run_time=3)
# self.color_example_path(tree_group, nodes_map, tree, edge_map, iris.data[None, 0, 0:2])
class DecisionTreeSplitScene(Scene):
def make_decision_tree_classifier(self, max_depth=4):
decision_tree = DecisionTreeClassifier(
random_state=1, max_depth=max_depth, max_leaf_nodes=8
)
decision_tree = decision_tree.fit(iris.data[:, :2], iris.target)
# output the decisioin tree in some format
return decision_tree
def make_split_animation(self, data, classes, data_labels, main_axes):
"""
def make_entropy_animation_and_plot(dim=0, num_entropy_values=50):
# calculate the entropy values
axes_group = VGroup()
# make axes
range_vals = [np.amin(data, axis=0)[dim], np.amax(data, axis=0)[dim]]
axes = Axes(x_range=range_vals,
y_range=[0, 1.0],
x_length=9,
y_length=4,
# axis_config={"number_scale_value":0.75, "include_numbers":True},
tips=False,
)
axes_group.add(axes)
# make axis labels
# x_label
x_label = Text(data_labels[dim], font=font) \
.match_y(axes.get_axes()[0]) \
.shift([0.5, -0.75, 0]) \
.scale(font_scale*1.2)
axes_group.add(x_label)
# y_label
y_label = Text("Information Gain", font=font) \
.match_x(axes.get_axes()[1]) \
.shift([-0.75, 0, 0]) \
.rotate(np.pi / 2) \
.scale(font_scale * 1.2)
axes_group.add(y_label)
# line animation
information_gains = []
def entropy_function(split_value):
# lower entropy
lower_set = np.nonzero(data[:, dim] <= split_value)[0]
lower_set = classes[lower_set]
lower_entropy = entropy(lower_set)
# higher entropy
higher_set = np.nonzero(data[:, dim] > split_value)[0]
higher_set = classes[higher_set]
higher_entropy = entropy(higher_set)
# calculate entropies
all_entropy = entropy(classes, base=2)
lower_entropy = entropy(lower_set, base=2)
higher_entropy = entropy(higher_set, base=2)
mean_entropy = (lower_entropy + higher_entropy) / 2
# calculate information gain
lower_prob = len(lower_set) / len(data[:, dim])
higher_prob = len(higher_set) / len(data[:, dim])
info_gain = all_entropy - (lower_prob * lower_entropy + higher_prob * higher_entropy)
information_gains.append((split_value, info_gain))
return info_gain
data_range = np.amin(data[:, dim]), np.amax(data[:, dim])
entropy_graph = axes.get_graph(
entropy_function,
# color=RED,
# x_range=data_range
)
axes_group.add(entropy_graph)
axes_group.shift([4.0, 2, 0])
axes_group.scale(0.5)
dot_animation = Dot(color=WHITE)
axes_group.add(dot_animation)
# make animations
animation_group = AnimationGroup(
Create(axes_group, run_time=2),
Wait(3),
MoveAlongPath(dot_animation, entropy_graph, run_time=20, rate_func=rate_functions.ease_in_out_quad),
Wait(2)
)
return axes_group, animation_group, information_gains
"""
def make_split_line_animation(dim=0):
# make a line along one of the dims and move it up and down
origin_coord = [
np.amin(data, axis=0)[0] - 0.2,
np.amin(data, axis=0)[1] - 0.2,
]
origin_point = main_axes.coords_to_point(*origin_coord)
top_left_coord = [origin_coord[0], np.amax(data, axis=0)[1]]
bottom_right_coord = [np.amax(data, axis=0)[0] - 0.2, origin_coord[1]]
if dim == 0:
other_coord = top_left_coord
moving_line_coord = bottom_right_coord
else:
other_coord = bottom_right_coord
moving_line_coord = top_left_coord
other_point = main_axes.coords_to_point(*other_coord)
moving_line_point = main_axes.coords_to_point(*moving_line_coord)
moving_line = Line(origin_point, other_point, color=RED)
movement_line = Line(origin_point, moving_line_point)
if dim == 0:
movement_line.shift([0, moving_line.height / 2, 0])
else:
movement_line.shift([moving_line.width / 2, 0, 0])
# move the moving line along the movement line
animation = MoveAlongPath(
moving_line,
movement_line,
run_time=20,
rate_func=rate_functions.ease_in_out_quad,
)
return animation, moving_line
# plot the line in white then make it invisible
# make an animation along the line
# make a
# axes_one_group, top_animation_group, info_gains = make_entropy_animation_and_plot(dim=0)
line_movement, first_moving_line = make_split_line_animation(dim=0)
# axes_two_group, bottom_animation_group, _ = make_entropy_animation_and_plot(dim=1)
second_line_movement, second_moving_line = make_split_line_animation(dim=1)
# axes_two_group.shift([0, -3, 0])
animation_group_one = AnimationGroup(
# top_animation_group,
line_movement,
)
animation_group_two = AnimationGroup(
# bottom_animation_group,
second_line_movement,
)
"""
both_axes_group = VGroup(
axes_one_group,
axes_two_group
)
"""
return (
animation_group_one,
animation_group_two,
first_moving_line,
second_moving_line,
None,
None,
)
# both_axes_group, \
# info_gains
def construct(self):
# make the points
iris_dataset_plot = IrisDatasetPlot()
iris_dataset_plot.all_group.scale(1.0)
iris_dataset_plot.all_group.shift([-3, 0.2, 0])
# make the entropy line graph
# entropy_line_graph = self.draw_entropy_line_graph()
# arrange the plots
# do animations
self.play(Create(iris_dataset_plot))
# make the decision tree classifier
decision_tree_classifier = self.make_decision_tree_classifier()
decision_tree_surface = DecisionTreeSurface(
decision_tree_classifier, iris.data, iris_dataset_plot.axes_group[0]
)
self.play(Create(decision_tree_surface))
self.wait(3)
self.play(Uncreate(decision_tree_surface))
main_axes = iris_dataset_plot.axes_group[0]
(
split_animation_one,
split_animation_two,
first_moving_line,
second_moving_line,
both_axes_group,
info_gains,
) = self.make_split_animation(
iris.data[:, 0:2], iris.target, iris.feature_names, main_axes
)
self.play(split_animation_one)
self.wait(0.1)
self.play(Uncreate(first_moving_line))
self.wait(3)
self.play(split_animation_two)
self.wait(0.1)
self.play(Uncreate(second_moving_line))
self.wait(0.1)
# highlight the maximum on top
# sort by second key
"""
highest_info_gain = sorted(info_gains, key=lambda x: x[1])[-1]
highest_info_gain_point = both_axes_group[0][0].coords_to_point(*highest_info_gain)
highlighted_peak = Dot(highest_info_gain_point, color=YELLOW)
# get location of highest info gain point
highest_info_gain_point_in_iris_graph = iris_dataset_plot.axes_group[0].coords_to_point(*[highest_info_gain[0], 0])
first_moving_line.start[0] = highest_info_gain_point_in_iris_graph[0]
first_moving_line.end[0] = highest_info_gain_point_in_iris_graph[0]
self.play(Create(highlighted_peak))
self.play(Create(first_moving_line))
text = Text("Highest Information Gain")
text.scale(0.4)
text.move_to(highlighted_peak)
text.shift([0, 0.5, 0])
self.play(Create(text))
"""
self.wait(1)
# draw the basic tree
decision_tree_classifier = self.make_decision_tree_classifier(max_depth=1)
decision_tree_surface = DecisionTreeSurface(
decision_tree_classifier, iris.data, iris_dataset_plot.axes_group[0]
)
decision_tree_graph, _, _ = DecisionTree()._make_tree(
decision_tree_classifier.tree_
)
decision_tree_graph.match_y(iris_dataset_plot.axes_group)
decision_tree_graph.shift([4, 0, 0])
self.play(Create(decision_tree_surface))
uncreate_animation = AnimationGroup(
# Uncreate(both_axes_group),
# Uncreate(highlighted_peak),
Uncreate(second_moving_line),
# Unwrite(text)
)
self.play(uncreate_animation)
self.wait(0.5)
self.play(FadeIn(decision_tree_graph))
# self.play(FadeIn(highlighted_peak))
self.wait(5)