Files
ManimML/manim_ml/decision_tree/decision_tree_surface.py

362 lines
12 KiB
Python

from manim import *
import numpy as np
from collections import deque
from sklearn.tree import _tree as ctree
class AABB:
"""Axis-aligned bounding box"""
def __init__(self, n_features):
self.limits = np.array([[-np.inf, np.inf]] * n_features)
def split(self, f, v):
left = AABB(self.limits.shape[0])
right = AABB(self.limits.shape[0])
left.limits = self.limits.copy()
right.limits = self.limits.copy()
left.limits[f, 1] = v
right.limits[f, 0] = v
return left, right
def tree_bounds(tree, n_features=None):
"""Compute final decision rule for each node in tree"""
if n_features is None:
n_features = np.max(tree.feature) + 1
aabbs = [AABB(n_features) for _ in range(tree.node_count)]
queue = deque([0])
while queue:
i = queue.pop()
l = tree.children_left[i]
r = tree.children_right[i]
if l != ctree.TREE_LEAF:
aabbs[l], aabbs[r] = aabbs[i].split(tree.feature[i], tree.threshold[i])
queue.extend([l, r])
return aabbs
def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
"""Extract decision areas.
tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier
maxrange: values to insert for [left, right, top, bottom] if the interval is open (+/-inf)
x: index of the feature that goes on the x axis
y: index of the feature that goes on the y axis
n_features: override autodetection of number of features
"""
tree = tree_classifier.tree_
aabbs = tree_bounds(tree, n_features)
maxrange = np.array(maxrange)
rectangles = []
for i in range(len(aabbs)):
if tree.children_left[i] != ctree.TREE_LEAF:
continue
l = aabbs[i].limits
r = [l[x, 0], l[x, 1], l[y, 0], l[y, 1], np.argmax(tree.value[i])]
# clip out of bounds indices
"""
if r[0] < maxrange[0]:
r[0] = maxrange[0]
if r[1] > maxrange[1]:
r[1] = maxrange[1]
if r[2] < maxrange[2]:
r[2] = maxrange[2]
if r[3] > maxrange[3]:
r[3] = maxrange[3]
print(r)
"""
rectangles.append(r)
rectangles = np.array(rectangles)
rectangles[:, [0, 2]] = np.maximum(rectangles[:, [0, 2]], maxrange[0::2])
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
return rectangles
def plot_areas(rectangles):
for rect in rectangles:
color = ["b", "r"][int(rect[4])]
print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
rp = Rectangle(
[rect[0], rect[2]],
rect[1] - rect[0],
rect[3] - rect[2],
color=color,
alpha=0.3,
)
plt.gca().add_artist(rp)
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
# get all polygons of each color
polygon_dict = {
str(BLUE).lower(): [],
str(GREEN).lower(): [],
str(ORANGE).lower(): [],
}
for polygon in all_polygons:
print(polygon_dict)
polygon_dict[str(polygon.color).lower()].append(polygon)
return_polygons = []
for color in colors:
color = str(color).lower()
polygons = polygon_dict[color]
points = set()
for polygon in polygons:
vertices = polygon.get_vertices().tolist()
vertices = [tuple(vert) for vert in vertices]
for pt in vertices:
if pt in points: # Shared vertice, remove it.
points.remove(pt)
else:
points.add(pt)
points = list(points)
sort_x = sorted(points)
sort_y = sorted(points, key=lambda x: x[1])
edges_h = {}
edges_v = {}
i = 0
while i < len(points):
curr_y = sort_y[i][1]
while i < len(points) and sort_y[i][1] == curr_y:
edges_h[sort_y[i]] = sort_y[i + 1]
edges_h[sort_y[i + 1]] = sort_y[i]
i += 2
i = 0
while i < len(points):
curr_x = sort_x[i][0]
while i < len(points) and sort_x[i][0] == curr_x:
edges_v[sort_x[i]] = sort_x[i + 1]
edges_v[sort_x[i + 1]] = sort_x[i]
i += 2
# Get all the polygons.
while edges_h:
# We can start with any point.
polygon = [(edges_h.popitem()[0], 0)]
while True:
curr, e = polygon[-1]
if e == 0:
next_vertex = edges_v.pop(curr)
polygon.append((next_vertex, 1))
else:
next_vertex = edges_h.pop(curr)
polygon.append((next_vertex, 0))
if polygon[-1] == polygon[0]:
# Closed polygon
polygon.pop()
break
# Remove implementation-markers from the polygon.
poly = [point for point, _ in polygon]
for vertex in poly:
if vertex in edges_h:
edges_h.pop(vertex)
if vertex in edges_v:
edges_v.pop(vertex)
polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0)
return_polygons.append(polygon)
return return_polygons
class IrisDatasetPlot(VGroup):
def __init__(self, iris):
points = iris.data[:, 0:2]
labels = iris.feature_names
targets = iris.target
# Make points
self.point_group = self._make_point_group(points, targets)
# Make axes
self.axes_group = self._make_axes_group(points, labels)
# Make legend
self.legend_group = self._make_legend(
[BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
)
# Make title
# title_text = "Iris Dataset Plot"
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
# Make all group
self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
# scale the groups
self.point_group.scale(1.6)
self.point_group.match_x(self.axes_group)
self.point_group.match_y(self.axes_group)
self.point_group.shift([0.2, 0, 0])
self.axes_group.scale(0.7)
self.all_group.shift([0, 0.2, 0])
@override_animation(Create)
def create_animation(self):
animation_group = AnimationGroup(
# Perform the animations
Create(self.point_group, run_time=2),
Wait(0.5),
Create(self.axes_group, run_time=2),
# add title
# Create(self.title),
Create(self.legend_group),
)
return animation_group
def _make_point_group(self, points, targets, class_colors=[BLUE, ORANGE, GREEN]):
point_group = VGroup()
for point_index, point in enumerate(points):
# draw the dot
current_target = targets[point_index]
color = class_colors[current_target]
dot = Dot(point=np.array([point[0], point[1], 0])).set_color(color)
dot.scale(0.5)
point_group.add(dot)
return point_group
def _make_legend(self, class_colors, feature_labels, axes):
legend_group = VGroup()
# Make Text
setosa = Text("Setosa", color=BLUE)
verisicolor = Text("Verisicolor", color=ORANGE)
virginica = Text("Virginica", color=GREEN)
labels = VGroup(setosa, verisicolor, virginica).arrange(
direction=RIGHT, aligned_edge=LEFT, buff=2.0
)
labels.scale(0.5)
legend_group.add(labels)
# surrounding rectangle
surrounding_rectangle = SurroundingRectangle(labels, color=WHITE)
surrounding_rectangle.move_to(labels)
legend_group.add(surrounding_rectangle)
# shift the legend group
legend_group.move_to(axes)
legend_group.shift([0, -3.0, 0])
legend_group.match_x(axes[0][0])
return legend_group
def _make_axes_group(self, points, labels, font="Source Han Sans", font_scale=0.75):
axes_group = VGroup()
# make the axes
x_range = [
np.amin(points, axis=0)[0] - 0.2,
np.amax(points, axis=0)[0] - 0.2,
0.5,
]
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
axes = Axes(
x_range=x_range,
y_range=y_range,
x_length=9,
y_length=6.5,
# axis_config={"number_scale_value":0.75, "include_numbers":True},
tips=False,
).shift([0.5, 0.25, 0])
axes_group.add(axes)
# make axis labels
# x_label
x_label = (
Text(labels[0], font=font)
.match_y(axes.get_axes()[0])
.shift([0.5, -0.75, 0])
.scale(font_scale)
)
axes_group.add(x_label)
# y_label
y_label = (
Text(labels[1], font=font)
.match_x(axes.get_axes()[1])
.shift([-0.75, 0, 0])
.rotate(np.pi / 2)
.scale(font_scale)
)
axes_group.add(y_label)
return axes_group
class DecisionTreeSurface(VGroup):
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
# take the tree and construct the surface from it
self.tree_clf = tree_clf
self.data = data
self.axes = axes
self.class_colors = class_colors
self.surface_rectangles = self.generate_surface_rectangles()
def generate_surface_rectangles(self):
# compute data bounds
left = np.amin(self.data[:, 0]) - 0.2
right = np.amax(self.data[:, 0]) - 0.2
top = np.amax(self.data[:, 1])
bottom = np.amin(self.data[:, 1]) - 0.2
maxrange = [left, right, bottom, top]
rectangles = compute_decision_areas(
self.tree_clf, maxrange, x=0, y=1, n_features=2
)
# turn the rectangle objects into manim rectangles
def convert_rectangle_to_polygon(rect):
# get the points for the rectangle in the plot coordinate frame
bottom_left = [rect[0], rect[3]]
bottom_right = [rect[1], rect[3]]
top_right = [rect[1], rect[2]]
top_left = [rect[0], rect[2]]
# convert those points into the entire manim coordinates
bottom_left_coord = self.axes.coords_to_point(*bottom_left)
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
top_right_coord = self.axes.coords_to_point(*top_right)
top_left_coord = self.axes.coords_to_point(*top_left)
points = [
bottom_left_coord,
bottom_right_coord,
top_right_coord,
top_left_coord,
]
# construct a polygon object from those manim coordinates
rectangle = Polygon(
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
)
return rectangle
manim_rectangles = []
for rect in rectangles:
color = self.class_colors[int(rect[4])]
rectangle = convert_rectangle_to_polygon(rect)
manim_rectangles.append(rectangle)
manim_rectangles = merge_overlapping_polygons(
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
)
return manim_rectangles
@override_animation(Create)
def create_override(self):
# play a reveal of all of the surface rectangles
animations = []
for rectangle in self.surface_rectangles:
animations.append(Create(rectangle))
animation_group = AnimationGroup(*animations)
return animation_group
@override_animation(Uncreate)
def uncreate_override(self):
# play a reveal of all of the surface rectangles
animations = []
for rectangle in self.surface_rectangles:
animations.append(Uncreate(rectangle))
animation_group = AnimationGroup(*animations)
return animation_group
def make_split_to_animation_map(self):
"""
Returns a dictionary mapping a given split
node to an animation to be played
"""
# Create an initial decision tree surface
# Go through each split node
# 1. Make a line split animation
# 2. Create the relevant classification areas
# and transform the old ones to them