mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-06 08:59:43 +08:00
General changes, got basic visualization of an activation function working for a
convolutinoal layer.
This commit is contained in:
@ -16,7 +16,6 @@ from manim_ml.one_to_one_sync import OneToOneSync
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class LeafNode(Group):
|
||||
"""Leaf node in tree"""
|
||||
|
||||
@ -51,7 +50,6 @@ class LeafNode(Group):
|
||||
self.add(rectangle)
|
||||
self.add(node)
|
||||
|
||||
|
||||
class SplitNode(VGroup):
|
||||
"""Node for splitting decision in tree"""
|
||||
|
||||
@ -65,7 +63,6 @@ class SplitNode(VGroup):
|
||||
self.add(bounding_box)
|
||||
self.add(decision_text)
|
||||
|
||||
|
||||
class DecisionTreeDiagram(Group):
|
||||
"""Decision Tree Diagram Class for Manim"""
|
||||
|
||||
@ -196,246 +193,157 @@ class DecisionTreeDiagram(Group):
|
||||
def create_level_order_expansion_decision_tree(self, tree):
|
||||
"""Expands the decision tree in level order"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def create_bfs_expansion_decision_tree(self, tree):
|
||||
"""Expands the tree using BFS"""
|
||||
animations = []
|
||||
split_node_animations = {} # Dictionary mapping split node to animation
|
||||
# Compute parent mapping
|
||||
parent_mapping = helpers.compute_node_to_parent_mapping(self.tree)
|
||||
# Create the root node
|
||||
animations.append(Create(self.nodes_map[0]))
|
||||
# Create the root node as most common class
|
||||
placeholder_class_nodes = {}
|
||||
root_node_class_index = np.argmax(
|
||||
self.tree.value[0]
|
||||
)
|
||||
root_placeholder_node = LeafNode(
|
||||
class_index=root_node_class_index,
|
||||
class_colors=self.class_colors,
|
||||
class_image_paths=self.class_image_paths,
|
||||
)
|
||||
root_placeholder_node.move_to(self.nodes_map[0])
|
||||
placeholder_class_nodes[0] = root_placeholder_node
|
||||
root_create_animation = AnimationGroup(
|
||||
FadeIn(root_placeholder_node),
|
||||
lag_ratio=0.0
|
||||
)
|
||||
animations.append(root_create_animation)
|
||||
# Iterate through the nodes
|
||||
queue = [0]
|
||||
while len(queue) > 0:
|
||||
node_index = queue.pop(0)
|
||||
# Check if a node is a split node or not
|
||||
left_child = self.tree.children_left[node_index]
|
||||
right_child = self.tree.children_right[node_index]
|
||||
is_leaf_node = left_child == right_child
|
||||
left_child_index = self.tree.children_left[node_index]
|
||||
right_child_index = self.tree.children_right[node_index]
|
||||
is_leaf_node = left_child_index == right_child_index
|
||||
if not is_leaf_node:
|
||||
# Remove the currently placeholder class node
|
||||
fade_out_animation = FadeOut(
|
||||
placeholder_class_nodes[node_index]
|
||||
)
|
||||
animations.append(fade_out_animation)
|
||||
# Fade in the split node
|
||||
fade_in_animation = FadeIn(
|
||||
self.nodes_map[node_index]
|
||||
)
|
||||
animations.append(fade_in_animation)
|
||||
# Split the node by creating the children and connecting them
|
||||
# to the parent
|
||||
# Get the nodes
|
||||
left_node = self.nodes_map[left_child]
|
||||
right_node = self.nodes_map[right_child]
|
||||
# Get the parent edges
|
||||
left_parent_edge = self.edge_map[f"{node_index},{left_child}"]
|
||||
right_parent_edge = self.edge_map[f"{node_index},{right_child}"]
|
||||
# Create the children
|
||||
# Handle left child
|
||||
assert left_child_index in self.nodes_map.keys()
|
||||
left_node = self.nodes_map[left_child_index]
|
||||
left_parent_edge = self.edge_map[f"{node_index},{left_child_index}"]
|
||||
# Get the children of the left node
|
||||
left_node_left_index = self.tree.children_left[left_child_index]
|
||||
left_node_right_index = self.tree.children_right[left_child_index]
|
||||
left_is_leaf = left_node_left_index == left_node_right_index
|
||||
if left_is_leaf:
|
||||
# If a child is a leaf then just create it
|
||||
left_animation = FadeIn(left_node)
|
||||
else:
|
||||
# If the child is a split node find the dominant class and make a temp
|
||||
left_node_class_index = np.argmax(
|
||||
self.tree.value[left_child_index]
|
||||
)
|
||||
new_leaf_node = LeafNode(
|
||||
class_index=left_node_class_index,
|
||||
class_colors=self.class_colors,
|
||||
class_image_paths=self.class_image_paths,
|
||||
)
|
||||
new_leaf_node.move_to(self.nodes_map[leaf_child_index])
|
||||
placeholder_class_nodes[left_child_index] = new_leaf_node
|
||||
left_animation = AnimationGroup(
|
||||
FadeIn(new_leaf_node),
|
||||
Create(left_parent_edge),
|
||||
lag_ratio=0.0
|
||||
)
|
||||
# Handle right child
|
||||
assert right_child_index in self.nodes_map.keys()
|
||||
right_node = self.nodes_map[right_child_index]
|
||||
right_parent_edge = self.edge_map[f"{node_index},{right_child_index}"]
|
||||
# Get the children of the left node
|
||||
right_node_left_index = self.tree.children_left[right_child_index]
|
||||
right_node_right_index = self.tree.children_right[right_child_index]
|
||||
right_is_leaf = right_node_left_index == right_node_right_index
|
||||
if right_is_leaf:
|
||||
# If a child is a leaf then just create it
|
||||
right_animation = FadeIn(right_node)
|
||||
else:
|
||||
# If the child is a split node find the dominant class and make a temp
|
||||
right_node_class_index = np.argmax(
|
||||
self.tree.value[right_child_index]
|
||||
)
|
||||
new_leaf_node = LeafNode(
|
||||
class_index=right_node_class_index,
|
||||
class_colors=self.class_colors,
|
||||
class_image_paths=self.class_image_paths,
|
||||
)
|
||||
placeholder_class_nodes[right_child_index] = new_leaf_node
|
||||
right_animation = AnimationGroup(
|
||||
FadeIn(new_leaf_node),
|
||||
Create(right_parent_edge),
|
||||
lag_ratio=0.0
|
||||
)
|
||||
# Combine the animations
|
||||
split_animation = AnimationGroup(
|
||||
FadeIn(left_node),
|
||||
FadeIn(right_node),
|
||||
Create(left_parent_edge),
|
||||
Create(right_parent_edge),
|
||||
left_animation,
|
||||
right_animation,
|
||||
lag_ratio=0.0,
|
||||
)
|
||||
animations.append(split_animation)
|
||||
# Add the split animation to the split node dict
|
||||
split_node_animations[node_index] = split_animation
|
||||
# Add the children to the queue
|
||||
if left_child != -1:
|
||||
queue.append(left_child)
|
||||
if right_child != -1:
|
||||
queue.append(right_child)
|
||||
if left_child_index != -1:
|
||||
queue.append(left_child_index)
|
||||
if right_child_index != -1:
|
||||
queue.append(right_child_index)
|
||||
|
||||
return AnimationGroup(*animations, lag_ratio=1.0)
|
||||
return Succession(
|
||||
*animations,
|
||||
lag_ratio=1.0
|
||||
), split_node_animations
|
||||
|
||||
def make_expand_tree_animation(self, node_expand_order):
|
||||
"""
|
||||
Make an animation for expanding the decision tree
|
||||
|
||||
Shows each split node as a leaf node initially, and
|
||||
then when it comes up shows it as a split node. The
|
||||
reason for this is for purposes of animating each of the
|
||||
splits in a decision surface.
|
||||
"""
|
||||
# Show the root node as a leaf node
|
||||
# Iterate through the nodes in the traversal order
|
||||
for node_index in node_expand_order[1:]:
|
||||
# Figure out if it is a leaf or not
|
||||
# If it is not a leaf then remove the placeholder leaf node
|
||||
# then show the split node
|
||||
# If it is a leaf then just show the leaf node
|
||||
pass
|
||||
|
||||
@override_animation(Create)
|
||||
def create_decision_tree(self, traversal_order="bfs"):
|
||||
"""Makes a create animation for the decision tree"""
|
||||
# Comptue the node expand order
|
||||
if traversal_order == "level":
|
||||
return self.create_level_order_expansion_decision_tree(self.tree)
|
||||
node_expand_order = helpers.compute_level_order_traversal(self.tree)
|
||||
elif traversal_order == "bfs":
|
||||
return self.create_bfs_expansion_decision_tree(self.tree)
|
||||
node_expand_order = helpers.compute_bfs_traversal(self.tree)
|
||||
else:
|
||||
raise Exception(f"Uncrecognized traversal: {traversal_order}")
|
||||
|
||||
|
||||
class IrisDatasetPlot(VGroup):
|
||||
def __init__(self, iris):
|
||||
points = iris.data[:, 0:2]
|
||||
labels = iris.feature_names
|
||||
targets = iris.target
|
||||
# Make points
|
||||
self.point_group = self._make_point_group(points, targets)
|
||||
# Make axes
|
||||
self.axes_group = self._make_axes_group(points, labels)
|
||||
# Make legend
|
||||
self.legend_group = self._make_legend(
|
||||
[BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
|
||||
)
|
||||
# Make title
|
||||
# title_text = "Iris Dataset Plot"
|
||||
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
|
||||
# Make all group
|
||||
self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
|
||||
# scale the groups
|
||||
self.point_group.scale(1.6)
|
||||
self.point_group.match_x(self.axes_group)
|
||||
self.point_group.match_y(self.axes_group)
|
||||
self.point_group.shift([0.2, 0, 0])
|
||||
self.axes_group.scale(0.7)
|
||||
self.all_group.shift([0, 0.2, 0])
|
||||
|
||||
@override_animation(Create)
|
||||
def create_animation(self):
|
||||
animation_group = AnimationGroup(
|
||||
# Perform the animations
|
||||
Create(self.point_group, run_time=2),
|
||||
Wait(0.5),
|
||||
Create(self.axes_group, run_time=2),
|
||||
# add title
|
||||
# Create(self.title),
|
||||
Create(self.legend_group),
|
||||
)
|
||||
return animation_group
|
||||
|
||||
def _make_point_group(self, points, targets, class_colors=[BLUE, ORANGE, GREEN]):
|
||||
point_group = VGroup()
|
||||
for point_index, point in enumerate(points):
|
||||
# draw the dot
|
||||
current_target = targets[point_index]
|
||||
color = class_colors[current_target]
|
||||
dot = Dot(point=np.array([point[0], point[1], 0])).set_color(color)
|
||||
dot.scale(0.5)
|
||||
point_group.add(dot)
|
||||
return point_group
|
||||
|
||||
def _make_legend(self, class_colors, feature_labels, axes):
|
||||
legend_group = VGroup()
|
||||
# Make Text
|
||||
setosa = Text("Setosa", color=BLUE)
|
||||
verisicolor = Text("Verisicolor", color=ORANGE)
|
||||
virginica = Text("Virginica", color=GREEN)
|
||||
labels = VGroup(setosa, verisicolor, virginica).arrange(
|
||||
direction=RIGHT, aligned_edge=LEFT, buff=2.0
|
||||
)
|
||||
labels.scale(0.5)
|
||||
legend_group.add(labels)
|
||||
# surrounding rectangle
|
||||
surrounding_rectangle = SurroundingRectangle(labels, color=WHITE)
|
||||
surrounding_rectangle.move_to(labels)
|
||||
legend_group.add(surrounding_rectangle)
|
||||
# shift the legend group
|
||||
legend_group.move_to(axes)
|
||||
legend_group.shift([0, -3.0, 0])
|
||||
legend_group.match_x(axes[0][0])
|
||||
|
||||
return legend_group
|
||||
|
||||
def _make_axes_group(self, points, labels, font="Source Han Sans", font_scale=0.75):
|
||||
axes_group = VGroup()
|
||||
# make the axes
|
||||
x_range = [
|
||||
np.amin(points, axis=0)[0] - 0.2,
|
||||
np.amax(points, axis=0)[0] - 0.2,
|
||||
0.5,
|
||||
]
|
||||
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
|
||||
axes = Axes(
|
||||
x_range=x_range,
|
||||
y_range=y_range,
|
||||
x_length=9,
|
||||
y_length=6.5,
|
||||
# axis_config={"number_scale_value":0.75, "include_numbers":True},
|
||||
tips=False,
|
||||
).shift([0.5, 0.25, 0])
|
||||
axes_group.add(axes)
|
||||
# make axis labels
|
||||
# x_label
|
||||
x_label = (
|
||||
Text(labels[0], font=font)
|
||||
.match_y(axes.get_axes()[0])
|
||||
.shift([0.5, -0.75, 0])
|
||||
.scale(font_scale)
|
||||
)
|
||||
axes_group.add(x_label)
|
||||
# y_label
|
||||
y_label = (
|
||||
Text(labels[1], font=font)
|
||||
.match_x(axes.get_axes()[1])
|
||||
.shift([-0.75, 0, 0])
|
||||
.rotate(np.pi / 2)
|
||||
.scale(font_scale)
|
||||
)
|
||||
axes_group.add(y_label)
|
||||
|
||||
return axes_group
|
||||
|
||||
|
||||
class DecisionTreeSurface(VGroup):
|
||||
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
|
||||
# take the tree and construct the surface from it
|
||||
self.tree_clf = tree_clf
|
||||
self.data = data
|
||||
self.axes = axes
|
||||
self.class_colors = class_colors
|
||||
self.surface_rectangles = self.generate_surface_rectangles()
|
||||
|
||||
def generate_surface_rectangles(self):
|
||||
# compute data bounds
|
||||
left = np.amin(self.data[:, 0]) - 0.2
|
||||
right = np.amax(self.data[:, 0]) - 0.2
|
||||
top = np.amax(self.data[:, 1])
|
||||
bottom = np.amin(self.data[:, 1]) - 0.2
|
||||
maxrange = [left, right, bottom, top]
|
||||
rectangles = compute_decision_areas(
|
||||
self.tree_clf, maxrange, x=0, y=1, n_features=2
|
||||
)
|
||||
# turn the rectangle objects into manim rectangles
|
||||
def convert_rectangle_to_polygon(rect):
|
||||
# get the points for the rectangle in the plot coordinate frame
|
||||
bottom_left = [rect[0], rect[3]]
|
||||
bottom_right = [rect[1], rect[3]]
|
||||
top_right = [rect[1], rect[2]]
|
||||
top_left = [rect[0], rect[2]]
|
||||
# convert those points into the entire manim coordinates
|
||||
bottom_left_coord = self.axes.coords_to_point(*bottom_left)
|
||||
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
|
||||
top_right_coord = self.axes.coords_to_point(*top_right)
|
||||
top_left_coord = self.axes.coords_to_point(*top_left)
|
||||
points = [
|
||||
bottom_left_coord,
|
||||
bottom_right_coord,
|
||||
top_right_coord,
|
||||
top_left_coord,
|
||||
]
|
||||
# construct a polygon object from those manim coordinates
|
||||
rectangle = Polygon(
|
||||
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
|
||||
)
|
||||
return rectangle
|
||||
|
||||
manim_rectangles = []
|
||||
for rect in rectangles:
|
||||
color = self.class_colors[int(rect[4])]
|
||||
rectangle = convert_rectangle_to_polygon(rect)
|
||||
manim_rectangles.append(rectangle)
|
||||
|
||||
manim_rectangles = merge_overlapping_polygons(
|
||||
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
|
||||
)
|
||||
|
||||
return manim_rectangles
|
||||
|
||||
@override_animation(Create)
|
||||
def create_override(self):
|
||||
# play a reveal of all of the surface rectangles
|
||||
animations = []
|
||||
for rectangle in self.surface_rectangles:
|
||||
animations.append(Create(rectangle))
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
return animation_group
|
||||
|
||||
@override_animation(Uncreate)
|
||||
def uncreate_override(self):
|
||||
# play a reveal of all of the surface rectangles
|
||||
animations = []
|
||||
for rectangle in self.surface_rectangles:
|
||||
animations.append(Uncreate(rectangle))
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
return animation_group
|
||||
|
||||
# Make the animation
|
||||
expand_tree_animation = self.make_expand_tree_animation(node_expand_order)
|
||||
return expand_tree_animation
|
||||
|
||||
class DecisionTreeContainer(OneToOneSync):
|
||||
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
||||
|
Reference in New Issue
Block a user