mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-06 00:48:56 +08:00
General changes, got basic visualization of an activation function working for a
convolutinoal layer.
This commit is contained in:
356
manim_ml/decision_tree/decision_tree_surface.py
Normal file
356
manim_ml/decision_tree/decision_tree_surface.py
Normal file
@ -0,0 +1,356 @@
|
||||
from manim import *
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from sklearn.tree import _tree as ctree
|
||||
|
||||
class AABB:
|
||||
"""Axis-aligned bounding box"""
|
||||
|
||||
def __init__(self, n_features):
|
||||
self.limits = np.array([[-np.inf, np.inf]] * n_features)
|
||||
|
||||
def split(self, f, v):
|
||||
left = AABB(self.limits.shape[0])
|
||||
right = AABB(self.limits.shape[0])
|
||||
left.limits = self.limits.copy()
|
||||
right.limits = self.limits.copy()
|
||||
left.limits[f, 1] = v
|
||||
right.limits[f, 0] = v
|
||||
|
||||
return left, right
|
||||
|
||||
def tree_bounds(tree, n_features=None):
|
||||
"""Compute final decision rule for each node in tree"""
|
||||
if n_features is None:
|
||||
n_features = np.max(tree.feature) + 1
|
||||
aabbs = [AABB(n_features) for _ in range(tree.node_count)]
|
||||
queue = deque([0])
|
||||
while queue:
|
||||
i = queue.pop()
|
||||
l = tree.children_left[i]
|
||||
r = tree.children_right[i]
|
||||
if l != ctree.TREE_LEAF:
|
||||
aabbs[l], aabbs[r] = aabbs[i].split(tree.feature[i], tree.threshold[i])
|
||||
queue.extend([l, r])
|
||||
return aabbs
|
||||
|
||||
def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
|
||||
"""Extract decision areas.
|
||||
|
||||
tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier
|
||||
maxrange: values to insert for [left, right, top, bottom] if the interval is open (+/-inf)
|
||||
x: index of the feature that goes on the x axis
|
||||
y: index of the feature that goes on the y axis
|
||||
n_features: override autodetection of number of features
|
||||
"""
|
||||
tree = tree_classifier.tree_
|
||||
aabbs = tree_bounds(tree, n_features)
|
||||
maxrange = np.array(maxrange)
|
||||
rectangles = []
|
||||
for i in range(len(aabbs)):
|
||||
if tree.children_left[i] != ctree.TREE_LEAF:
|
||||
continue
|
||||
l = aabbs[i].limits
|
||||
r = [l[x, 0], l[x, 1], l[y, 0], l[y, 1], np.argmax(tree.value[i])]
|
||||
# clip out of bounds indices
|
||||
"""
|
||||
if r[0] < maxrange[0]:
|
||||
r[0] = maxrange[0]
|
||||
if r[1] > maxrange[1]:
|
||||
r[1] = maxrange[1]
|
||||
if r[2] < maxrange[2]:
|
||||
r[2] = maxrange[2]
|
||||
if r[3] > maxrange[3]:
|
||||
r[3] = maxrange[3]
|
||||
print(r)
|
||||
"""
|
||||
rectangles.append(r)
|
||||
rectangles = np.array(rectangles)
|
||||
rectangles[:, [0, 2]] = np.maximum(rectangles[:, [0, 2]], maxrange[0::2])
|
||||
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
|
||||
return rectangles
|
||||
|
||||
def plot_areas(rectangles):
|
||||
for rect in rectangles:
|
||||
color = ["b", "r"][int(rect[4])]
|
||||
print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
|
||||
rp = Rectangle(
|
||||
[rect[0], rect[2]],
|
||||
rect[1] - rect[0],
|
||||
rect[3] - rect[2],
|
||||
color=color,
|
||||
alpha=0.3,
|
||||
)
|
||||
plt.gca().add_artist(rp)
|
||||
|
||||
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
||||
# get all polygons of each color
|
||||
polygon_dict = {
|
||||
str(BLUE).lower(): [],
|
||||
str(GREEN).lower(): [],
|
||||
str(ORANGE).lower(): [],
|
||||
}
|
||||
for polygon in all_polygons:
|
||||
print(polygon_dict)
|
||||
polygon_dict[str(polygon.color).lower()].append(polygon)
|
||||
|
||||
return_polygons = []
|
||||
for color in colors:
|
||||
color = str(color).lower()
|
||||
polygons = polygon_dict[color]
|
||||
points = set()
|
||||
for polygon in polygons:
|
||||
vertices = polygon.get_vertices().tolist()
|
||||
vertices = [tuple(vert) for vert in vertices]
|
||||
for pt in vertices:
|
||||
if pt in points: # Shared vertice, remove it.
|
||||
points.remove(pt)
|
||||
else:
|
||||
points.add(pt)
|
||||
points = list(points)
|
||||
sort_x = sorted(points)
|
||||
sort_y = sorted(points, key=lambda x: x[1])
|
||||
|
||||
edges_h = {}
|
||||
edges_v = {}
|
||||
|
||||
i = 0
|
||||
while i < len(points):
|
||||
curr_y = sort_y[i][1]
|
||||
while i < len(points) and sort_y[i][1] == curr_y:
|
||||
edges_h[sort_y[i]] = sort_y[i + 1]
|
||||
edges_h[sort_y[i + 1]] = sort_y[i]
|
||||
i += 2
|
||||
i = 0
|
||||
while i < len(points):
|
||||
curr_x = sort_x[i][0]
|
||||
while i < len(points) and sort_x[i][0] == curr_x:
|
||||
edges_v[sort_x[i]] = sort_x[i + 1]
|
||||
edges_v[sort_x[i + 1]] = sort_x[i]
|
||||
i += 2
|
||||
|
||||
# Get all the polygons.
|
||||
while edges_h:
|
||||
# We can start with any point.
|
||||
polygon = [(edges_h.popitem()[0], 0)]
|
||||
while True:
|
||||
curr, e = polygon[-1]
|
||||
if e == 0:
|
||||
next_vertex = edges_v.pop(curr)
|
||||
polygon.append((next_vertex, 1))
|
||||
else:
|
||||
next_vertex = edges_h.pop(curr)
|
||||
polygon.append((next_vertex, 0))
|
||||
if polygon[-1] == polygon[0]:
|
||||
# Closed polygon
|
||||
polygon.pop()
|
||||
break
|
||||
# Remove implementation-markers from the polygon.
|
||||
poly = [point for point, _ in polygon]
|
||||
for vertex in poly:
|
||||
if vertex in edges_h:
|
||||
edges_h.pop(vertex)
|
||||
if vertex in edges_v:
|
||||
edges_v.pop(vertex)
|
||||
polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0)
|
||||
return_polygons.append(polygon)
|
||||
return return_polygons
|
||||
|
||||
class IrisDatasetPlot(VGroup):
|
||||
def __init__(self, iris):
|
||||
points = iris.data[:, 0:2]
|
||||
labels = iris.feature_names
|
||||
targets = iris.target
|
||||
# Make points
|
||||
self.point_group = self._make_point_group(points, targets)
|
||||
# Make axes
|
||||
self.axes_group = self._make_axes_group(points, labels)
|
||||
# Make legend
|
||||
self.legend_group = self._make_legend(
|
||||
[BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
|
||||
)
|
||||
# Make title
|
||||
# title_text = "Iris Dataset Plot"
|
||||
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
|
||||
# Make all group
|
||||
self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
|
||||
# scale the groups
|
||||
self.point_group.scale(1.6)
|
||||
self.point_group.match_x(self.axes_group)
|
||||
self.point_group.match_y(self.axes_group)
|
||||
self.point_group.shift([0.2, 0, 0])
|
||||
self.axes_group.scale(0.7)
|
||||
self.all_group.shift([0, 0.2, 0])
|
||||
|
||||
@override_animation(Create)
|
||||
def create_animation(self):
|
||||
animation_group = AnimationGroup(
|
||||
# Perform the animations
|
||||
Create(self.point_group, run_time=2),
|
||||
Wait(0.5),
|
||||
Create(self.axes_group, run_time=2),
|
||||
# add title
|
||||
# Create(self.title),
|
||||
Create(self.legend_group),
|
||||
)
|
||||
return animation_group
|
||||
|
||||
def _make_point_group(self, points, targets, class_colors=[BLUE, ORANGE, GREEN]):
|
||||
point_group = VGroup()
|
||||
for point_index, point in enumerate(points):
|
||||
# draw the dot
|
||||
current_target = targets[point_index]
|
||||
color = class_colors[current_target]
|
||||
dot = Dot(point=np.array([point[0], point[1], 0])).set_color(color)
|
||||
dot.scale(0.5)
|
||||
point_group.add(dot)
|
||||
return point_group
|
||||
|
||||
def _make_legend(self, class_colors, feature_labels, axes):
|
||||
legend_group = VGroup()
|
||||
# Make Text
|
||||
setosa = Text("Setosa", color=BLUE)
|
||||
verisicolor = Text("Verisicolor", color=ORANGE)
|
||||
virginica = Text("Virginica", color=GREEN)
|
||||
labels = VGroup(setosa, verisicolor, virginica).arrange(
|
||||
direction=RIGHT, aligned_edge=LEFT, buff=2.0
|
||||
)
|
||||
labels.scale(0.5)
|
||||
legend_group.add(labels)
|
||||
# surrounding rectangle
|
||||
surrounding_rectangle = SurroundingRectangle(labels, color=WHITE)
|
||||
surrounding_rectangle.move_to(labels)
|
||||
legend_group.add(surrounding_rectangle)
|
||||
# shift the legend group
|
||||
legend_group.move_to(axes)
|
||||
legend_group.shift([0, -3.0, 0])
|
||||
legend_group.match_x(axes[0][0])
|
||||
|
||||
return legend_group
|
||||
|
||||
def _make_axes_group(self, points, labels, font="Source Han Sans", font_scale=0.75):
|
||||
axes_group = VGroup()
|
||||
# make the axes
|
||||
x_range = [
|
||||
np.amin(points, axis=0)[0] - 0.2,
|
||||
np.amax(points, axis=0)[0] - 0.2,
|
||||
0.5,
|
||||
]
|
||||
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
|
||||
axes = Axes(
|
||||
x_range=x_range,
|
||||
y_range=y_range,
|
||||
x_length=9,
|
||||
y_length=6.5,
|
||||
# axis_config={"number_scale_value":0.75, "include_numbers":True},
|
||||
tips=False,
|
||||
).shift([0.5, 0.25, 0])
|
||||
axes_group.add(axes)
|
||||
# make axis labels
|
||||
# x_label
|
||||
x_label = (
|
||||
Text(labels[0], font=font)
|
||||
.match_y(axes.get_axes()[0])
|
||||
.shift([0.5, -0.75, 0])
|
||||
.scale(font_scale)
|
||||
)
|
||||
axes_group.add(x_label)
|
||||
# y_label
|
||||
y_label = (
|
||||
Text(labels[1], font=font)
|
||||
.match_x(axes.get_axes()[1])
|
||||
.shift([-0.75, 0, 0])
|
||||
.rotate(np.pi / 2)
|
||||
.scale(font_scale)
|
||||
)
|
||||
axes_group.add(y_label)
|
||||
|
||||
return axes_group
|
||||
|
||||
|
||||
class DecisionTreeSurface(VGroup):
|
||||
|
||||
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
|
||||
# take the tree and construct the surface from it
|
||||
self.tree_clf = tree_clf
|
||||
self.data = data
|
||||
self.axes = axes
|
||||
self.class_colors = class_colors
|
||||
self.surface_rectangles = self.generate_surface_rectangles()
|
||||
|
||||
def generate_surface_rectangles(self):
|
||||
# compute data bounds
|
||||
left = np.amin(self.data[:, 0]) - 0.2
|
||||
right = np.amax(self.data[:, 0]) - 0.2
|
||||
top = np.amax(self.data[:, 1])
|
||||
bottom = np.amin(self.data[:, 1]) - 0.2
|
||||
maxrange = [left, right, bottom, top]
|
||||
rectangles = compute_decision_areas(
|
||||
self.tree_clf, maxrange, x=0, y=1, n_features=2
|
||||
)
|
||||
# turn the rectangle objects into manim rectangles
|
||||
def convert_rectangle_to_polygon(rect):
|
||||
# get the points for the rectangle in the plot coordinate frame
|
||||
bottom_left = [rect[0], rect[3]]
|
||||
bottom_right = [rect[1], rect[3]]
|
||||
top_right = [rect[1], rect[2]]
|
||||
top_left = [rect[0], rect[2]]
|
||||
# convert those points into the entire manim coordinates
|
||||
bottom_left_coord = self.axes.coords_to_point(*bottom_left)
|
||||
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
|
||||
top_right_coord = self.axes.coords_to_point(*top_right)
|
||||
top_left_coord = self.axes.coords_to_point(*top_left)
|
||||
points = [
|
||||
bottom_left_coord,
|
||||
bottom_right_coord,
|
||||
top_right_coord,
|
||||
top_left_coord,
|
||||
]
|
||||
# construct a polygon object from those manim coordinates
|
||||
rectangle = Polygon(
|
||||
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
|
||||
)
|
||||
return rectangle
|
||||
|
||||
manim_rectangles = []
|
||||
for rect in rectangles:
|
||||
color = self.class_colors[int(rect[4])]
|
||||
rectangle = convert_rectangle_to_polygon(rect)
|
||||
manim_rectangles.append(rectangle)
|
||||
|
||||
manim_rectangles = merge_overlapping_polygons(
|
||||
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
|
||||
)
|
||||
|
||||
return manim_rectangles
|
||||
|
||||
@override_animation(Create)
|
||||
def create_override(self):
|
||||
# play a reveal of all of the surface rectangles
|
||||
animations = []
|
||||
for rectangle in self.surface_rectangles:
|
||||
animations.append(Create(rectangle))
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
return animation_group
|
||||
|
||||
@override_animation(Uncreate)
|
||||
def uncreate_override(self):
|
||||
# play a reveal of all of the surface rectangles
|
||||
animations = []
|
||||
for rectangle in self.surface_rectangles:
|
||||
animations.append(Uncreate(rectangle))
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_split_to_animation_map(self):
|
||||
"""
|
||||
Returns a dictionary mapping a given split
|
||||
node to an animation to be played
|
||||
"""
|
||||
# Create an initial decision tree surface
|
||||
# Go through each split node
|
||||
# 1. Make a line split animation
|
||||
# 2. Create the relevant classification areas
|
||||
# and transform the old ones to them
|
Reference in New Issue
Block a user