mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-06 16:18:17 +08:00
691 lines
26 KiB
Python
691 lines
26 KiB
Python
from sklearn import datasets
|
|
from decision_tree_surface import *
|
|
from manim import *
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
from scipy.stats import entropy
|
|
import math
|
|
from PIL import Image
|
|
|
|
iris = datasets.load_iris()
|
|
font = "Source Han Sans"
|
|
font_scale = 0.75
|
|
|
|
images = [
|
|
Image.open("iris_dataset/SetosaFlower.jpeg"),
|
|
Image.open("iris_dataset/VeriscolorFlower.jpeg"),
|
|
Image.open("iris_dataset/VirginicaFlower.jpeg"),
|
|
]
|
|
|
|
|
|
def entropy(class_labels, base=2):
|
|
# compute the class counts
|
|
unique, counts = np.unique(class_labels, return_counts=True)
|
|
dictionary = dict(zip(unique, counts))
|
|
total = 0.0
|
|
num_samples = len(class_labels)
|
|
for class_index in range(0, 3):
|
|
if not class_index in dictionary:
|
|
continue
|
|
prob = dictionary[class_index] / num_samples
|
|
total += prob * math.log(prob, base)
|
|
# higher set
|
|
return -total
|
|
|
|
|
|
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
|
# get all polygons of each color
|
|
polygon_dict = {
|
|
str(BLUE).lower(): [],
|
|
str(GREEN).lower(): [],
|
|
str(ORANGE).lower(): [],
|
|
}
|
|
for polygon in all_polygons:
|
|
print(polygon_dict)
|
|
polygon_dict[str(polygon.color).lower()].append(polygon)
|
|
|
|
return_polygons = []
|
|
for color in colors:
|
|
color = str(color).lower()
|
|
polygons = polygon_dict[color]
|
|
points = set()
|
|
for polygon in polygons:
|
|
vertices = polygon.get_vertices().tolist()
|
|
vertices = [tuple(vert) for vert in vertices]
|
|
for pt in vertices:
|
|
if pt in points: # Shared vertice, remove it.
|
|
points.remove(pt)
|
|
else:
|
|
points.add(pt)
|
|
points = list(points)
|
|
sort_x = sorted(points)
|
|
sort_y = sorted(points, key=lambda x: x[1])
|
|
|
|
edges_h = {}
|
|
edges_v = {}
|
|
|
|
i = 0
|
|
while i < len(points):
|
|
curr_y = sort_y[i][1]
|
|
while i < len(points) and sort_y[i][1] == curr_y:
|
|
edges_h[sort_y[i]] = sort_y[i + 1]
|
|
edges_h[sort_y[i + 1]] = sort_y[i]
|
|
i += 2
|
|
i = 0
|
|
while i < len(points):
|
|
curr_x = sort_x[i][0]
|
|
while i < len(points) and sort_x[i][0] == curr_x:
|
|
edges_v[sort_x[i]] = sort_x[i + 1]
|
|
edges_v[sort_x[i + 1]] = sort_x[i]
|
|
i += 2
|
|
|
|
# Get all the polygons.
|
|
while edges_h:
|
|
# We can start with any point.
|
|
polygon = [(edges_h.popitem()[0], 0)]
|
|
while True:
|
|
curr, e = polygon[-1]
|
|
if e == 0:
|
|
next_vertex = edges_v.pop(curr)
|
|
polygon.append((next_vertex, 1))
|
|
else:
|
|
next_vertex = edges_h.pop(curr)
|
|
polygon.append((next_vertex, 0))
|
|
if polygon[-1] == polygon[0]:
|
|
# Closed polygon
|
|
polygon.pop()
|
|
break
|
|
# Remove implementation-markers from the polygon.
|
|
poly = [point for point, _ in polygon]
|
|
for vertex in poly:
|
|
if vertex in edges_h:
|
|
edges_h.pop(vertex)
|
|
if vertex in edges_v:
|
|
edges_v.pop(vertex)
|
|
polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0)
|
|
return_polygons.append(polygon)
|
|
return return_polygons
|
|
|
|
|
|
class IrisDatasetPlot(VGroup):
|
|
def __init__(self):
|
|
points = iris.data[:, 0:2]
|
|
labels = iris.feature_names
|
|
targets = iris.target
|
|
# Make points
|
|
self.point_group = self._make_point_group(points, targets)
|
|
# Make axes
|
|
self.axes_group = self._make_axes_group(points, labels)
|
|
# Make legend
|
|
self.legend_group = self._make_legend(
|
|
[BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
|
|
)
|
|
# Make title
|
|
# title_text = "Iris Dataset Plot"
|
|
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
|
|
# Make all group
|
|
self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
|
|
# scale the groups
|
|
self.point_group.scale(1.6)
|
|
self.point_group.match_x(self.axes_group)
|
|
self.point_group.match_y(self.axes_group)
|
|
self.point_group.shift([0.2, 0, 0])
|
|
self.axes_group.scale(0.7)
|
|
self.all_group.shift([0, 0.2, 0])
|
|
|
|
@override_animation(Create)
|
|
def create_animation(self):
|
|
animation_group = AnimationGroup(
|
|
# Perform the animations
|
|
Create(self.point_group, run_time=2),
|
|
Wait(0.5),
|
|
Create(self.axes_group, run_time=2),
|
|
# add title
|
|
# Create(self.title),
|
|
Create(self.legend_group),
|
|
)
|
|
return animation_group
|
|
|
|
def _make_point_group(self, points, targets, class_colors=[BLUE, ORANGE, GREEN]):
|
|
point_group = VGroup()
|
|
for point_index, point in enumerate(points):
|
|
# draw the dot
|
|
current_target = targets[point_index]
|
|
color = class_colors[current_target]
|
|
dot = Dot(point=np.array([point[0], point[1], 0])).set_color(color)
|
|
dot.scale(0.5)
|
|
point_group.add(dot)
|
|
return point_group
|
|
|
|
def _make_legend(self, class_colors, feature_labels, axes):
|
|
legend_group = VGroup()
|
|
# Make Text
|
|
setosa = Text("Setosa", color=BLUE)
|
|
verisicolor = Text("Verisicolor", color=ORANGE)
|
|
virginica = Text("Virginica", color=GREEN)
|
|
labels = VGroup(setosa, verisicolor, virginica).arrange(
|
|
direction=RIGHT, aligned_edge=LEFT, buff=2.0
|
|
)
|
|
labels.scale(0.5)
|
|
legend_group.add(labels)
|
|
# surrounding rectangle
|
|
surrounding_rectangle = SurroundingRectangle(labels, color=WHITE)
|
|
surrounding_rectangle.move_to(labels)
|
|
legend_group.add(surrounding_rectangle)
|
|
# shift the legend group
|
|
legend_group.move_to(axes)
|
|
legend_group.shift([0, -3.0, 0])
|
|
legend_group.match_x(axes[0][0])
|
|
|
|
return legend_group
|
|
|
|
def _make_axes_group(self, points, labels):
|
|
axes_group = VGroup()
|
|
# make the axes
|
|
x_range = [
|
|
np.amin(points, axis=0)[0] - 0.2,
|
|
np.amax(points, axis=0)[0] - 0.2,
|
|
0.5,
|
|
]
|
|
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
|
|
axes = Axes(
|
|
x_range=x_range,
|
|
y_range=y_range,
|
|
x_length=9,
|
|
y_length=6.5,
|
|
# axis_config={"number_scale_value":0.75, "include_numbers":True},
|
|
tips=False,
|
|
).shift([0.5, 0.25, 0])
|
|
axes_group.add(axes)
|
|
# make axis labels
|
|
# x_label
|
|
x_label = (
|
|
Text(labels[0], font=font)
|
|
.match_y(axes.get_axes()[0])
|
|
.shift([0.5, -0.75, 0])
|
|
.scale(font_scale)
|
|
)
|
|
axes_group.add(x_label)
|
|
# y_label
|
|
y_label = (
|
|
Text(labels[1], font=font)
|
|
.match_x(axes.get_axes()[1])
|
|
.shift([-0.75, 0, 0])
|
|
.rotate(np.pi / 2)
|
|
.scale(font_scale)
|
|
)
|
|
axes_group.add(y_label)
|
|
|
|
return axes_group
|
|
|
|
|
|
class DecisionTreeSurface(VGroup):
|
|
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
|
|
# take the tree and construct the surface from it
|
|
self.tree_clf = tree_clf
|
|
self.data = data
|
|
self.axes = axes
|
|
self.class_colors = class_colors
|
|
self.surface_rectangles = self.generate_surface_rectangles()
|
|
|
|
def generate_surface_rectangles(self):
|
|
# compute data bounds
|
|
left = np.amin(self.data[:, 0]) - 0.2
|
|
right = np.amax(self.data[:, 0]) - 0.2
|
|
top = np.amax(self.data[:, 1])
|
|
bottom = np.amin(self.data[:, 1]) - 0.2
|
|
maxrange = [left, right, bottom, top]
|
|
rectangles = decision_areas(self.tree_clf, maxrange, x=0, y=1, n_features=2)
|
|
# turn the rectangle objects into manim rectangles
|
|
def convert_rectangle_to_polygon(rect):
|
|
# get the points for the rectangle in the plot coordinate frame
|
|
bottom_left = [rect[0], rect[3]]
|
|
bottom_right = [rect[1], rect[3]]
|
|
top_right = [rect[1], rect[2]]
|
|
top_left = [rect[0], rect[2]]
|
|
# convert those points into the entire manim coordinates
|
|
bottom_left_coord = self.axes.coords_to_point(*bottom_left)
|
|
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
|
|
top_right_coord = self.axes.coords_to_point(*top_right)
|
|
top_left_coord = self.axes.coords_to_point(*top_left)
|
|
points = [
|
|
bottom_left_coord,
|
|
bottom_right_coord,
|
|
top_right_coord,
|
|
top_left_coord,
|
|
]
|
|
# construct a polygon object from those manim coordinates
|
|
rectangle = Polygon(
|
|
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
|
|
)
|
|
return rectangle
|
|
|
|
manim_rectangles = []
|
|
for rect in rectangles:
|
|
color = self.class_colors[int(rect[4])]
|
|
rectangle = convert_rectangle_to_polygon(rect)
|
|
manim_rectangles.append(rectangle)
|
|
|
|
manim_rectangles = merge_overlapping_polygons(
|
|
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
|
|
)
|
|
|
|
return manim_rectangles
|
|
|
|
@override_animation(Create)
|
|
def create_override(self):
|
|
# play a reveal of all of the surface rectangles
|
|
animations = []
|
|
for rectangle in self.surface_rectangles:
|
|
animations.append(Create(rectangle))
|
|
animation_group = AnimationGroup(*animations)
|
|
|
|
return animation_group
|
|
|
|
@override_animation(Uncreate)
|
|
def uncreate_override(self):
|
|
# play a reveal of all of the surface rectangles
|
|
animations = []
|
|
for rectangle in self.surface_rectangles:
|
|
animations.append(Uncreate(rectangle))
|
|
animation_group = AnimationGroup(*animations)
|
|
|
|
return animation_group
|
|
|
|
|
|
class DecisionTree:
|
|
"""
|
|
Draw a single tree node
|
|
"""
|
|
|
|
def _make_node(
|
|
self,
|
|
feature,
|
|
threshold,
|
|
values,
|
|
is_leaf=False,
|
|
depth=0,
|
|
leaf_colors=[BLUE, ORANGE, GREEN],
|
|
):
|
|
if not is_leaf:
|
|
node_text = f"{feature}\n <= {threshold:.3f} cm"
|
|
# draw decision text
|
|
decision_text = Text(node_text, color=WHITE)
|
|
# draw a box
|
|
bounding_box = SurroundingRectangle(decision_text, buff=0.3, color=WHITE)
|
|
node = VGroup()
|
|
node.add(bounding_box)
|
|
node.add(decision_text)
|
|
# return the node
|
|
else:
|
|
# plot the appropriate image
|
|
class_index = np.argmax(values)
|
|
# get image
|
|
pil_image = images[class_index]
|
|
leaf_group = Group()
|
|
node = ImageMobject(pil_image)
|
|
node.scale(1.5)
|
|
rectangle = Rectangle(
|
|
width=node.width + 0.05,
|
|
height=node.height + 0.05,
|
|
color=leaf_colors[class_index],
|
|
stroke_width=6,
|
|
)
|
|
rectangle.move_to(node.get_center())
|
|
rectangle.shift([-0.02, 0.02, 0])
|
|
leaf_group.add(rectangle)
|
|
leaf_group.add(node)
|
|
node = leaf_group
|
|
|
|
return node
|
|
|
|
def _make_connection(self, top, bottom, is_leaf=False):
|
|
top_node_bottom_location = top.get_center()
|
|
top_node_bottom_location[1] -= top.height / 2
|
|
bottom_node_top_location = bottom.get_center()
|
|
bottom_node_top_location[1] += bottom.height / 2
|
|
line = Line(top_node_bottom_location, bottom_node_top_location, color=WHITE)
|
|
return line
|
|
|
|
def _make_tree(self, tree, feature_names=["Sepal Length", "Sepal Width"]):
|
|
tree_group = Group()
|
|
max_depth = tree.max_depth
|
|
# make the base node
|
|
feature_name = feature_names[tree.feature[0]]
|
|
threshold = tree.threshold[0]
|
|
values = tree.value[0]
|
|
nodes_map = {}
|
|
root_node = self._make_node(feature_name, threshold, values, depth=0)
|
|
nodes_map[0] = root_node
|
|
tree_group.add(root_node)
|
|
# save some information
|
|
node_height = root_node.height
|
|
node_width = root_node.width
|
|
scale_factor = 1.0
|
|
edge_map = {}
|
|
# tree height
|
|
tree_height = scale_factor * node_height * max_depth
|
|
tree_width = scale_factor * 2**max_depth * node_width
|
|
# traverse tree
|
|
def recurse(node, depth, direction, parent_object, parent_node):
|
|
# make sure it is a valid node
|
|
# make the node object
|
|
is_leaf = tree.children_left[node] == tree.children_right[node]
|
|
feature_name = feature_names[tree.feature[node]]
|
|
threshold = tree.threshold[node]
|
|
values = tree.value[node]
|
|
node_object = self._make_node(
|
|
feature_name, threshold, values, depth=depth, is_leaf=is_leaf
|
|
)
|
|
nodes_map[node] = node_object
|
|
node_height = node_object.height
|
|
# set the node position
|
|
direction_factor = -1 if direction == "left" else 1
|
|
shift_right_amount = (
|
|
0.8 * direction_factor * scale_factor * tree_width / (2**depth) / 2
|
|
)
|
|
if is_leaf:
|
|
shift_down_amount = -1.0 * scale_factor * node_height
|
|
else:
|
|
shift_down_amount = -1.8 * scale_factor * node_height
|
|
node_object.match_x(parent_object).match_y(parent_object).shift(
|
|
[shift_right_amount, shift_down_amount, 0]
|
|
)
|
|
tree_group.add(node_object)
|
|
# make a connection
|
|
connection = self._make_connection(
|
|
parent_object, node_object, is_leaf=is_leaf
|
|
)
|
|
edge_name = str(parent_node) + "," + str(node)
|
|
edge_map[edge_name] = connection
|
|
tree_group.add(connection)
|
|
# recurse
|
|
if not is_leaf:
|
|
recurse(tree.children_left[node], depth + 1, "left", node_object, node)
|
|
recurse(
|
|
tree.children_right[node], depth + 1, "right", node_object, node
|
|
)
|
|
|
|
recurse(tree.children_left[0], 1, "left", root_node, 0)
|
|
recurse(tree.children_right[0], 1, "right", root_node, 0)
|
|
|
|
tree_group.scale(0.35)
|
|
return tree_group, nodes_map, edge_map
|
|
|
|
def color_example_path(
|
|
self, tree_group, nodes_map, tree, edge_map, example, color=YELLOW, thickness=2
|
|
):
|
|
# get decision path
|
|
decision_path = tree.decision_path(example)[0]
|
|
path_indices = decision_path.indices
|
|
# highlight edges
|
|
for node_index in range(0, len(path_indices) - 1):
|
|
current_val = path_indices[node_index]
|
|
next_val = path_indices[node_index + 1]
|
|
edge_str = str(current_val) + "," + str(next_val)
|
|
edge = edge_map[edge_str]
|
|
animation_two = AnimationGroup(
|
|
nodes_map[current_val].animate.set_color(color)
|
|
)
|
|
self.play(animation_two, run_time=0.5)
|
|
animation_one = AnimationGroup(
|
|
edge.animate.set_color(color),
|
|
# edge.animate.set_stroke_width(4),
|
|
)
|
|
self.play(animation_one, run_time=0.5)
|
|
# surround the bottom image
|
|
last_path_index = path_indices[-1]
|
|
last_path_rectangle = nodes_map[last_path_index][0]
|
|
self.play(last_path_rectangle.animate.set_color(color))
|
|
|
|
def create_sklearn_tree(self, max_tree_depth=1):
|
|
# learn the decision tree
|
|
iris = load_iris()
|
|
tree = learn_iris_decision_tree(iris, max_depth=max_tree_depth)
|
|
feature_names = iris.feature_names[0:2]
|
|
return tree.tree_
|
|
|
|
def make_tree(self, max_tree_depth=2):
|
|
sklearn_tree = self.create_sklearn_tree()
|
|
# make the tree
|
|
tree_group, nodes_map, edge_map = self._make_tree(
|
|
sklearn_tree.tree_, feature_names
|
|
)
|
|
tree_group.shift([0, 5.5, 0])
|
|
return tree_group
|
|
# self.add(tree_group)
|
|
# self.play(SpinInFromNothing(tree_group), run_time=3)
|
|
# self.color_example_path(tree_group, nodes_map, tree, edge_map, iris.data[None, 0, 0:2])
|
|
|
|
|
|
class DecisionTreeSplitScene(Scene):
|
|
def make_decision_tree_classifier(self, max_depth=4):
|
|
decision_tree = DecisionTreeClassifier(
|
|
random_state=1, max_depth=max_depth, max_leaf_nodes=8
|
|
)
|
|
decision_tree = decision_tree.fit(iris.data[:, :2], iris.target)
|
|
# output the decisioin tree in some format
|
|
return decision_tree
|
|
|
|
def make_split_animation(self, data, classes, data_labels, main_axes):
|
|
|
|
"""
|
|
def make_entropy_animation_and_plot(dim=0, num_entropy_values=50):
|
|
# calculate the entropy values
|
|
axes_group = VGroup()
|
|
# make axes
|
|
range_vals = [np.amin(data, axis=0)[dim], np.amax(data, axis=0)[dim]]
|
|
axes = Axes(x_range=range_vals,
|
|
y_range=[0, 1.0],
|
|
x_length=9,
|
|
y_length=4,
|
|
# axis_config={"number_scale_value":0.75, "include_numbers":True},
|
|
tips=False,
|
|
)
|
|
axes_group.add(axes)
|
|
# make axis labels
|
|
# x_label
|
|
x_label = Text(data_labels[dim], font=font) \
|
|
.match_y(axes.get_axes()[0]) \
|
|
.shift([0.5, -0.75, 0]) \
|
|
.scale(font_scale*1.2)
|
|
axes_group.add(x_label)
|
|
# y_label
|
|
y_label = Text("Information Gain", font=font) \
|
|
.match_x(axes.get_axes()[1]) \
|
|
.shift([-0.75, 0, 0]) \
|
|
.rotate(np.pi / 2) \
|
|
.scale(font_scale * 1.2)
|
|
|
|
axes_group.add(y_label)
|
|
# line animation
|
|
information_gains = []
|
|
def entropy_function(split_value):
|
|
# lower entropy
|
|
lower_set = np.nonzero(data[:, dim] <= split_value)[0]
|
|
lower_set = classes[lower_set]
|
|
lower_entropy = entropy(lower_set)
|
|
# higher entropy
|
|
higher_set = np.nonzero(data[:, dim] > split_value)[0]
|
|
higher_set = classes[higher_set]
|
|
higher_entropy = entropy(higher_set)
|
|
# calculate entropies
|
|
all_entropy = entropy(classes, base=2)
|
|
lower_entropy = entropy(lower_set, base=2)
|
|
higher_entropy = entropy(higher_set, base=2)
|
|
mean_entropy = (lower_entropy + higher_entropy) / 2
|
|
# calculate information gain
|
|
lower_prob = len(lower_set) / len(data[:, dim])
|
|
higher_prob = len(higher_set) / len(data[:, dim])
|
|
info_gain = all_entropy - (lower_prob * lower_entropy + higher_prob * higher_entropy)
|
|
information_gains.append((split_value, info_gain))
|
|
return info_gain
|
|
|
|
data_range = np.amin(data[:, dim]), np.amax(data[:, dim])
|
|
|
|
entropy_graph = axes.get_graph(
|
|
entropy_function,
|
|
# color=RED,
|
|
# x_range=data_range
|
|
)
|
|
axes_group.add(entropy_graph)
|
|
axes_group.shift([4.0, 2, 0])
|
|
axes_group.scale(0.5)
|
|
dot_animation = Dot(color=WHITE)
|
|
axes_group.add(dot_animation)
|
|
# make animations
|
|
animation_group = AnimationGroup(
|
|
Create(axes_group, run_time=2),
|
|
Wait(3),
|
|
MoveAlongPath(dot_animation, entropy_graph, run_time=20, rate_func=rate_functions.ease_in_out_quad),
|
|
Wait(2)
|
|
)
|
|
|
|
return axes_group, animation_group, information_gains
|
|
"""
|
|
|
|
def make_split_line_animation(dim=0):
|
|
# make a line along one of the dims and move it up and down
|
|
origin_coord = [
|
|
np.amin(data, axis=0)[0] - 0.2,
|
|
np.amin(data, axis=0)[1] - 0.2,
|
|
]
|
|
origin_point = main_axes.coords_to_point(*origin_coord)
|
|
top_left_coord = [origin_coord[0], np.amax(data, axis=0)[1]]
|
|
bottom_right_coord = [np.amax(data, axis=0)[0] - 0.2, origin_coord[1]]
|
|
if dim == 0:
|
|
other_coord = top_left_coord
|
|
moving_line_coord = bottom_right_coord
|
|
else:
|
|
other_coord = bottom_right_coord
|
|
moving_line_coord = top_left_coord
|
|
other_point = main_axes.coords_to_point(*other_coord)
|
|
moving_line_point = main_axes.coords_to_point(*moving_line_coord)
|
|
moving_line = Line(origin_point, other_point, color=RED)
|
|
movement_line = Line(origin_point, moving_line_point)
|
|
if dim == 0:
|
|
movement_line.shift([0, moving_line.height / 2, 0])
|
|
else:
|
|
movement_line.shift([moving_line.width / 2, 0, 0])
|
|
# move the moving line along the movement line
|
|
animation = MoveAlongPath(
|
|
moving_line,
|
|
movement_line,
|
|
run_time=20,
|
|
rate_func=rate_functions.ease_in_out_quad,
|
|
)
|
|
return animation, moving_line
|
|
|
|
# plot the line in white then make it invisible
|
|
# make an animation along the line
|
|
# make a
|
|
# axes_one_group, top_animation_group, info_gains = make_entropy_animation_and_plot(dim=0)
|
|
line_movement, first_moving_line = make_split_line_animation(dim=0)
|
|
# axes_two_group, bottom_animation_group, _ = make_entropy_animation_and_plot(dim=1)
|
|
second_line_movement, second_moving_line = make_split_line_animation(dim=1)
|
|
# axes_two_group.shift([0, -3, 0])
|
|
animation_group_one = AnimationGroup(
|
|
# top_animation_group,
|
|
line_movement,
|
|
)
|
|
|
|
animation_group_two = AnimationGroup(
|
|
# bottom_animation_group,
|
|
second_line_movement,
|
|
)
|
|
"""
|
|
both_axes_group = VGroup(
|
|
axes_one_group,
|
|
axes_two_group
|
|
)
|
|
"""
|
|
|
|
return (
|
|
animation_group_one,
|
|
animation_group_two,
|
|
first_moving_line,
|
|
second_moving_line,
|
|
None,
|
|
None,
|
|
)
|
|
# both_axes_group, \
|
|
# info_gains
|
|
|
|
def construct(self):
|
|
# make the points
|
|
iris_dataset_plot = IrisDatasetPlot()
|
|
iris_dataset_plot.all_group.scale(1.0)
|
|
iris_dataset_plot.all_group.shift([-3, 0.2, 0])
|
|
# make the entropy line graph
|
|
# entropy_line_graph = self.draw_entropy_line_graph()
|
|
# arrange the plots
|
|
# do animations
|
|
self.play(Create(iris_dataset_plot))
|
|
# make the decision tree classifier
|
|
decision_tree_classifier = self.make_decision_tree_classifier()
|
|
decision_tree_surface = DecisionTreeSurface(
|
|
decision_tree_classifier, iris.data, iris_dataset_plot.axes_group[0]
|
|
)
|
|
self.play(Create(decision_tree_surface))
|
|
self.wait(3)
|
|
self.play(Uncreate(decision_tree_surface))
|
|
main_axes = iris_dataset_plot.axes_group[0]
|
|
(
|
|
split_animation_one,
|
|
split_animation_two,
|
|
first_moving_line,
|
|
second_moving_line,
|
|
both_axes_group,
|
|
info_gains,
|
|
) = self.make_split_animation(
|
|
iris.data[:, 0:2], iris.target, iris.feature_names, main_axes
|
|
)
|
|
self.play(split_animation_one)
|
|
self.wait(0.1)
|
|
self.play(Uncreate(first_moving_line))
|
|
self.wait(3)
|
|
self.play(split_animation_two)
|
|
self.wait(0.1)
|
|
self.play(Uncreate(second_moving_line))
|
|
self.wait(0.1)
|
|
# highlight the maximum on top
|
|
# sort by second key
|
|
"""
|
|
highest_info_gain = sorted(info_gains, key=lambda x: x[1])[-1]
|
|
highest_info_gain_point = both_axes_group[0][0].coords_to_point(*highest_info_gain)
|
|
highlighted_peak = Dot(highest_info_gain_point, color=YELLOW)
|
|
# get location of highest info gain point
|
|
highest_info_gain_point_in_iris_graph = iris_dataset_plot.axes_group[0].coords_to_point(*[highest_info_gain[0], 0])
|
|
first_moving_line.start[0] = highest_info_gain_point_in_iris_graph[0]
|
|
first_moving_line.end[0] = highest_info_gain_point_in_iris_graph[0]
|
|
self.play(Create(highlighted_peak))
|
|
self.play(Create(first_moving_line))
|
|
text = Text("Highest Information Gain")
|
|
text.scale(0.4)
|
|
text.move_to(highlighted_peak)
|
|
text.shift([0, 0.5, 0])
|
|
self.play(Create(text))
|
|
"""
|
|
self.wait(1)
|
|
# draw the basic tree
|
|
decision_tree_classifier = self.make_decision_tree_classifier(max_depth=1)
|
|
decision_tree_surface = DecisionTreeSurface(
|
|
decision_tree_classifier, iris.data, iris_dataset_plot.axes_group[0]
|
|
)
|
|
decision_tree_graph, _, _ = DecisionTree()._make_tree(
|
|
decision_tree_classifier.tree_
|
|
)
|
|
decision_tree_graph.match_y(iris_dataset_plot.axes_group)
|
|
decision_tree_graph.shift([4, 0, 0])
|
|
self.play(Create(decision_tree_surface))
|
|
uncreate_animation = AnimationGroup(
|
|
# Uncreate(both_axes_group),
|
|
# Uncreate(highlighted_peak),
|
|
Uncreate(second_moving_line),
|
|
# Unwrite(text)
|
|
)
|
|
self.play(uncreate_animation)
|
|
self.wait(0.5)
|
|
self.play(FadeIn(decision_tree_graph))
|
|
# self.play(FadeIn(highlighted_peak))
|
|
self.wait(5)
|