mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-04 07:15:49 +08:00
General changes, got basic visualization of an activation function working for a
convolutinoal layer.
This commit is contained in:
@ -56,7 +56,7 @@ class CombinedScene(ThreeDScene):
|
||||
ImageLayer(numpy_image, height=1.5),
|
||||
Convolutional2DLayer(1, 7, 3, filter_spacing=0.32),
|
||||
Convolutional2DLayer(3, 5, 3, filter_spacing=0.32),
|
||||
Convolutional2DLayer(5, 3, 1, filter_spacing=0.18),
|
||||
Convolutional2DLayer(5, 3, 3, filter_spacing=0.18),
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(3),
|
||||
],
|
||||
@ -66,12 +66,13 @@ class CombinedScene(ThreeDScene):
|
||||
nn.move_to(ORIGIN)
|
||||
self.add(nn)
|
||||
# Make code snippet
|
||||
code = make_code_snippet()
|
||||
code.next_to(nn, DOWN)
|
||||
self.add(code)
|
||||
# code = make_code_snippet()
|
||||
# code.next_to(nn, DOWN)
|
||||
# self.add(code)
|
||||
# Group it all
|
||||
group = Group(nn, code)
|
||||
group.move_to(ORIGIN)
|
||||
# group = Group(nn, code)
|
||||
# group.move_to(ORIGIN)
|
||||
nn.move_to(ORIGIN)
|
||||
# Play animation
|
||||
forward_pass = nn.make_forward_pass_animation(
|
||||
corner_pulses=False, all_filters_at_once=False
|
||||
|
@ -15,7 +15,6 @@ config.frame_height = 7.0
|
||||
config.frame_width = 7.0
|
||||
ROOT_DIR = Path(__file__).parents[2]
|
||||
|
||||
|
||||
def make_code_snippet():
|
||||
code_str = """
|
||||
# Make nn
|
||||
|
@ -1,162 +0,0 @@
|
||||
from manim import *
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from sklearn.tree import _tree as ctree
|
||||
|
||||
|
||||
class AABB:
|
||||
"""Axis-aligned bounding box"""
|
||||
|
||||
def __init__(self, n_features):
|
||||
self.limits = np.array([[-np.inf, np.inf]] * n_features)
|
||||
|
||||
def split(self, f, v):
|
||||
left = AABB(self.limits.shape[0])
|
||||
right = AABB(self.limits.shape[0])
|
||||
left.limits = self.limits.copy()
|
||||
right.limits = self.limits.copy()
|
||||
left.limits[f, 1] = v
|
||||
right.limits[f, 0] = v
|
||||
|
||||
return left, right
|
||||
|
||||
|
||||
def tree_bounds(tree, n_features=None):
|
||||
"""Compute final decision rule for each node in tree"""
|
||||
if n_features is None:
|
||||
n_features = np.max(tree.feature) + 1
|
||||
aabbs = [AABB(n_features) for _ in range(tree.node_count)]
|
||||
queue = deque([0])
|
||||
while queue:
|
||||
i = queue.pop()
|
||||
l = tree.children_left[i]
|
||||
r = tree.children_right[i]
|
||||
if l != ctree.TREE_LEAF:
|
||||
aabbs[l], aabbs[r] = aabbs[i].split(tree.feature[i], tree.threshold[i])
|
||||
queue.extend([l, r])
|
||||
return aabbs
|
||||
|
||||
|
||||
def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
|
||||
"""Extract decision areas.
|
||||
|
||||
tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier
|
||||
maxrange: values to insert for [left, right, top, bottom] if the interval is open (+/-inf)
|
||||
x: index of the feature that goes on the x axis
|
||||
y: index of the feature that goes on the y axis
|
||||
n_features: override autodetection of number of features
|
||||
"""
|
||||
tree = tree_classifier.tree_
|
||||
aabbs = tree_bounds(tree, n_features)
|
||||
maxrange = np.array(maxrange)
|
||||
rectangles = []
|
||||
for i in range(len(aabbs)):
|
||||
if tree.children_left[i] != ctree.TREE_LEAF:
|
||||
continue
|
||||
l = aabbs[i].limits
|
||||
r = [l[x, 0], l[x, 1], l[y, 0], l[y, 1], np.argmax(tree.value[i])]
|
||||
# clip out of bounds indices
|
||||
"""
|
||||
if r[0] < maxrange[0]:
|
||||
r[0] = maxrange[0]
|
||||
if r[1] > maxrange[1]:
|
||||
r[1] = maxrange[1]
|
||||
if r[2] < maxrange[2]:
|
||||
r[2] = maxrange[2]
|
||||
if r[3] > maxrange[3]:
|
||||
r[3] = maxrange[3]
|
||||
print(r)
|
||||
"""
|
||||
rectangles.append(r)
|
||||
rectangles = np.array(rectangles)
|
||||
rectangles[:, [0, 2]] = np.maximum(rectangles[:, [0, 2]], maxrange[0::2])
|
||||
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
|
||||
return rectangles
|
||||
|
||||
|
||||
def plot_areas(rectangles):
|
||||
for rect in rectangles:
|
||||
color = ["b", "r"][int(rect[4])]
|
||||
print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
|
||||
rp = Rectangle(
|
||||
[rect[0], rect[2]],
|
||||
rect[1] - rect[0],
|
||||
rect[3] - rect[2],
|
||||
color=color,
|
||||
alpha=0.3,
|
||||
)
|
||||
plt.gca().add_artist(rp)
|
||||
|
||||
|
||||
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
||||
# get all polygons of each color
|
||||
polygon_dict = {
|
||||
str(BLUE).lower(): [],
|
||||
str(GREEN).lower(): [],
|
||||
str(ORANGE).lower(): [],
|
||||
}
|
||||
for polygon in all_polygons:
|
||||
print(polygon_dict)
|
||||
polygon_dict[str(polygon.color).lower()].append(polygon)
|
||||
|
||||
return_polygons = []
|
||||
for color in colors:
|
||||
color = str(color).lower()
|
||||
polygons = polygon_dict[color]
|
||||
points = set()
|
||||
for polygon in polygons:
|
||||
vertices = polygon.get_vertices().tolist()
|
||||
vertices = [tuple(vert) for vert in vertices]
|
||||
for pt in vertices:
|
||||
if pt in points: # Shared vertice, remove it.
|
||||
points.remove(pt)
|
||||
else:
|
||||
points.add(pt)
|
||||
points = list(points)
|
||||
sort_x = sorted(points)
|
||||
sort_y = sorted(points, key=lambda x: x[1])
|
||||
|
||||
edges_h = {}
|
||||
edges_v = {}
|
||||
|
||||
i = 0
|
||||
while i < len(points):
|
||||
curr_y = sort_y[i][1]
|
||||
while i < len(points) and sort_y[i][1] == curr_y:
|
||||
edges_h[sort_y[i]] = sort_y[i + 1]
|
||||
edges_h[sort_y[i + 1]] = sort_y[i]
|
||||
i += 2
|
||||
i = 0
|
||||
while i < len(points):
|
||||
curr_x = sort_x[i][0]
|
||||
while i < len(points) and sort_x[i][0] == curr_x:
|
||||
edges_v[sort_x[i]] = sort_x[i + 1]
|
||||
edges_v[sort_x[i + 1]] = sort_x[i]
|
||||
i += 2
|
||||
|
||||
# Get all the polygons.
|
||||
while edges_h:
|
||||
# We can start with any point.
|
||||
polygon = [(edges_h.popitem()[0], 0)]
|
||||
while True:
|
||||
curr, e = polygon[-1]
|
||||
if e == 0:
|
||||
next_vertex = edges_v.pop(curr)
|
||||
polygon.append((next_vertex, 1))
|
||||
else:
|
||||
next_vertex = edges_h.pop(curr)
|
||||
polygon.append((next_vertex, 0))
|
||||
if polygon[-1] == polygon[0]:
|
||||
# Closed polygon
|
||||
polygon.pop()
|
||||
break
|
||||
# Remove implementation-markers from the polygon.
|
||||
poly = [point for point, _ in polygon]
|
||||
for vertex in poly:
|
||||
if vertex in edges_h:
|
||||
edges_h.pop(vertex)
|
||||
if vertex in edges_v:
|
||||
edges_v.pop(vertex)
|
||||
polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0)
|
||||
return_polygons.append(polygon)
|
||||
return return_polygons
|
@ -16,7 +16,6 @@ from manim_ml.one_to_one_sync import OneToOneSync
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class LeafNode(Group):
|
||||
"""Leaf node in tree"""
|
||||
|
||||
@ -51,7 +50,6 @@ class LeafNode(Group):
|
||||
self.add(rectangle)
|
||||
self.add(node)
|
||||
|
||||
|
||||
class SplitNode(VGroup):
|
||||
"""Node for splitting decision in tree"""
|
||||
|
||||
@ -65,7 +63,6 @@ class SplitNode(VGroup):
|
||||
self.add(bounding_box)
|
||||
self.add(decision_text)
|
||||
|
||||
|
||||
class DecisionTreeDiagram(Group):
|
||||
"""Decision Tree Diagram Class for Manim"""
|
||||
|
||||
@ -196,246 +193,157 @@ class DecisionTreeDiagram(Group):
|
||||
def create_level_order_expansion_decision_tree(self, tree):
|
||||
"""Expands the decision tree in level order"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def create_bfs_expansion_decision_tree(self, tree):
|
||||
"""Expands the tree using BFS"""
|
||||
animations = []
|
||||
split_node_animations = {} # Dictionary mapping split node to animation
|
||||
# Compute parent mapping
|
||||
parent_mapping = helpers.compute_node_to_parent_mapping(self.tree)
|
||||
# Create the root node
|
||||
animations.append(Create(self.nodes_map[0]))
|
||||
# Create the root node as most common class
|
||||
placeholder_class_nodes = {}
|
||||
root_node_class_index = np.argmax(
|
||||
self.tree.value[0]
|
||||
)
|
||||
root_placeholder_node = LeafNode(
|
||||
class_index=root_node_class_index,
|
||||
class_colors=self.class_colors,
|
||||
class_image_paths=self.class_image_paths,
|
||||
)
|
||||
root_placeholder_node.move_to(self.nodes_map[0])
|
||||
placeholder_class_nodes[0] = root_placeholder_node
|
||||
root_create_animation = AnimationGroup(
|
||||
FadeIn(root_placeholder_node),
|
||||
lag_ratio=0.0
|
||||
)
|
||||
animations.append(root_create_animation)
|
||||
# Iterate through the nodes
|
||||
queue = [0]
|
||||
while len(queue) > 0:
|
||||
node_index = queue.pop(0)
|
||||
# Check if a node is a split node or not
|
||||
left_child = self.tree.children_left[node_index]
|
||||
right_child = self.tree.children_right[node_index]
|
||||
is_leaf_node = left_child == right_child
|
||||
left_child_index = self.tree.children_left[node_index]
|
||||
right_child_index = self.tree.children_right[node_index]
|
||||
is_leaf_node = left_child_index == right_child_index
|
||||
if not is_leaf_node:
|
||||
# Remove the currently placeholder class node
|
||||
fade_out_animation = FadeOut(
|
||||
placeholder_class_nodes[node_index]
|
||||
)
|
||||
animations.append(fade_out_animation)
|
||||
# Fade in the split node
|
||||
fade_in_animation = FadeIn(
|
||||
self.nodes_map[node_index]
|
||||
)
|
||||
animations.append(fade_in_animation)
|
||||
# Split the node by creating the children and connecting them
|
||||
# to the parent
|
||||
# Get the nodes
|
||||
left_node = self.nodes_map[left_child]
|
||||
right_node = self.nodes_map[right_child]
|
||||
# Get the parent edges
|
||||
left_parent_edge = self.edge_map[f"{node_index},{left_child}"]
|
||||
right_parent_edge = self.edge_map[f"{node_index},{right_child}"]
|
||||
# Create the children
|
||||
# Handle left child
|
||||
assert left_child_index in self.nodes_map.keys()
|
||||
left_node = self.nodes_map[left_child_index]
|
||||
left_parent_edge = self.edge_map[f"{node_index},{left_child_index}"]
|
||||
# Get the children of the left node
|
||||
left_node_left_index = self.tree.children_left[left_child_index]
|
||||
left_node_right_index = self.tree.children_right[left_child_index]
|
||||
left_is_leaf = left_node_left_index == left_node_right_index
|
||||
if left_is_leaf:
|
||||
# If a child is a leaf then just create it
|
||||
left_animation = FadeIn(left_node)
|
||||
else:
|
||||
# If the child is a split node find the dominant class and make a temp
|
||||
left_node_class_index = np.argmax(
|
||||
self.tree.value[left_child_index]
|
||||
)
|
||||
new_leaf_node = LeafNode(
|
||||
class_index=left_node_class_index,
|
||||
class_colors=self.class_colors,
|
||||
class_image_paths=self.class_image_paths,
|
||||
)
|
||||
new_leaf_node.move_to(self.nodes_map[leaf_child_index])
|
||||
placeholder_class_nodes[left_child_index] = new_leaf_node
|
||||
left_animation = AnimationGroup(
|
||||
FadeIn(new_leaf_node),
|
||||
Create(left_parent_edge),
|
||||
lag_ratio=0.0
|
||||
)
|
||||
# Handle right child
|
||||
assert right_child_index in self.nodes_map.keys()
|
||||
right_node = self.nodes_map[right_child_index]
|
||||
right_parent_edge = self.edge_map[f"{node_index},{right_child_index}"]
|
||||
# Get the children of the left node
|
||||
right_node_left_index = self.tree.children_left[right_child_index]
|
||||
right_node_right_index = self.tree.children_right[right_child_index]
|
||||
right_is_leaf = right_node_left_index == right_node_right_index
|
||||
if right_is_leaf:
|
||||
# If a child is a leaf then just create it
|
||||
right_animation = FadeIn(right_node)
|
||||
else:
|
||||
# If the child is a split node find the dominant class and make a temp
|
||||
right_node_class_index = np.argmax(
|
||||
self.tree.value[right_child_index]
|
||||
)
|
||||
new_leaf_node = LeafNode(
|
||||
class_index=right_node_class_index,
|
||||
class_colors=self.class_colors,
|
||||
class_image_paths=self.class_image_paths,
|
||||
)
|
||||
placeholder_class_nodes[right_child_index] = new_leaf_node
|
||||
right_animation = AnimationGroup(
|
||||
FadeIn(new_leaf_node),
|
||||
Create(right_parent_edge),
|
||||
lag_ratio=0.0
|
||||
)
|
||||
# Combine the animations
|
||||
split_animation = AnimationGroup(
|
||||
FadeIn(left_node),
|
||||
FadeIn(right_node),
|
||||
Create(left_parent_edge),
|
||||
Create(right_parent_edge),
|
||||
left_animation,
|
||||
right_animation,
|
||||
lag_ratio=0.0,
|
||||
)
|
||||
animations.append(split_animation)
|
||||
# Add the split animation to the split node dict
|
||||
split_node_animations[node_index] = split_animation
|
||||
# Add the children to the queue
|
||||
if left_child != -1:
|
||||
queue.append(left_child)
|
||||
if right_child != -1:
|
||||
queue.append(right_child)
|
||||
if left_child_index != -1:
|
||||
queue.append(left_child_index)
|
||||
if right_child_index != -1:
|
||||
queue.append(right_child_index)
|
||||
|
||||
return AnimationGroup(*animations, lag_ratio=1.0)
|
||||
return Succession(
|
||||
*animations,
|
||||
lag_ratio=1.0
|
||||
), split_node_animations
|
||||
|
||||
def make_expand_tree_animation(self, node_expand_order):
|
||||
"""
|
||||
Make an animation for expanding the decision tree
|
||||
|
||||
Shows each split node as a leaf node initially, and
|
||||
then when it comes up shows it as a split node. The
|
||||
reason for this is for purposes of animating each of the
|
||||
splits in a decision surface.
|
||||
"""
|
||||
# Show the root node as a leaf node
|
||||
# Iterate through the nodes in the traversal order
|
||||
for node_index in node_expand_order[1:]:
|
||||
# Figure out if it is a leaf or not
|
||||
# If it is not a leaf then remove the placeholder leaf node
|
||||
# then show the split node
|
||||
# If it is a leaf then just show the leaf node
|
||||
pass
|
||||
|
||||
@override_animation(Create)
|
||||
def create_decision_tree(self, traversal_order="bfs"):
|
||||
"""Makes a create animation for the decision tree"""
|
||||
# Comptue the node expand order
|
||||
if traversal_order == "level":
|
||||
return self.create_level_order_expansion_decision_tree(self.tree)
|
||||
node_expand_order = helpers.compute_level_order_traversal(self.tree)
|
||||
elif traversal_order == "bfs":
|
||||
return self.create_bfs_expansion_decision_tree(self.tree)
|
||||
node_expand_order = helpers.compute_bfs_traversal(self.tree)
|
||||
else:
|
||||
raise Exception(f"Uncrecognized traversal: {traversal_order}")
|
||||
|
||||
|
||||
class IrisDatasetPlot(VGroup):
|
||||
def __init__(self, iris):
|
||||
points = iris.data[:, 0:2]
|
||||
labels = iris.feature_names
|
||||
targets = iris.target
|
||||
# Make points
|
||||
self.point_group = self._make_point_group(points, targets)
|
||||
# Make axes
|
||||
self.axes_group = self._make_axes_group(points, labels)
|
||||
# Make legend
|
||||
self.legend_group = self._make_legend(
|
||||
[BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
|
||||
)
|
||||
# Make title
|
||||
# title_text = "Iris Dataset Plot"
|
||||
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
|
||||
# Make all group
|
||||
self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
|
||||
# scale the groups
|
||||
self.point_group.scale(1.6)
|
||||
self.point_group.match_x(self.axes_group)
|
||||
self.point_group.match_y(self.axes_group)
|
||||
self.point_group.shift([0.2, 0, 0])
|
||||
self.axes_group.scale(0.7)
|
||||
self.all_group.shift([0, 0.2, 0])
|
||||
|
||||
@override_animation(Create)
|
||||
def create_animation(self):
|
||||
animation_group = AnimationGroup(
|
||||
# Perform the animations
|
||||
Create(self.point_group, run_time=2),
|
||||
Wait(0.5),
|
||||
Create(self.axes_group, run_time=2),
|
||||
# add title
|
||||
# Create(self.title),
|
||||
Create(self.legend_group),
|
||||
)
|
||||
return animation_group
|
||||
|
||||
def _make_point_group(self, points, targets, class_colors=[BLUE, ORANGE, GREEN]):
|
||||
point_group = VGroup()
|
||||
for point_index, point in enumerate(points):
|
||||
# draw the dot
|
||||
current_target = targets[point_index]
|
||||
color = class_colors[current_target]
|
||||
dot = Dot(point=np.array([point[0], point[1], 0])).set_color(color)
|
||||
dot.scale(0.5)
|
||||
point_group.add(dot)
|
||||
return point_group
|
||||
|
||||
def _make_legend(self, class_colors, feature_labels, axes):
|
||||
legend_group = VGroup()
|
||||
# Make Text
|
||||
setosa = Text("Setosa", color=BLUE)
|
||||
verisicolor = Text("Verisicolor", color=ORANGE)
|
||||
virginica = Text("Virginica", color=GREEN)
|
||||
labels = VGroup(setosa, verisicolor, virginica).arrange(
|
||||
direction=RIGHT, aligned_edge=LEFT, buff=2.0
|
||||
)
|
||||
labels.scale(0.5)
|
||||
legend_group.add(labels)
|
||||
# surrounding rectangle
|
||||
surrounding_rectangle = SurroundingRectangle(labels, color=WHITE)
|
||||
surrounding_rectangle.move_to(labels)
|
||||
legend_group.add(surrounding_rectangle)
|
||||
# shift the legend group
|
||||
legend_group.move_to(axes)
|
||||
legend_group.shift([0, -3.0, 0])
|
||||
legend_group.match_x(axes[0][0])
|
||||
|
||||
return legend_group
|
||||
|
||||
def _make_axes_group(self, points, labels, font="Source Han Sans", font_scale=0.75):
|
||||
axes_group = VGroup()
|
||||
# make the axes
|
||||
x_range = [
|
||||
np.amin(points, axis=0)[0] - 0.2,
|
||||
np.amax(points, axis=0)[0] - 0.2,
|
||||
0.5,
|
||||
]
|
||||
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
|
||||
axes = Axes(
|
||||
x_range=x_range,
|
||||
y_range=y_range,
|
||||
x_length=9,
|
||||
y_length=6.5,
|
||||
# axis_config={"number_scale_value":0.75, "include_numbers":True},
|
||||
tips=False,
|
||||
).shift([0.5, 0.25, 0])
|
||||
axes_group.add(axes)
|
||||
# make axis labels
|
||||
# x_label
|
||||
x_label = (
|
||||
Text(labels[0], font=font)
|
||||
.match_y(axes.get_axes()[0])
|
||||
.shift([0.5, -0.75, 0])
|
||||
.scale(font_scale)
|
||||
)
|
||||
axes_group.add(x_label)
|
||||
# y_label
|
||||
y_label = (
|
||||
Text(labels[1], font=font)
|
||||
.match_x(axes.get_axes()[1])
|
||||
.shift([-0.75, 0, 0])
|
||||
.rotate(np.pi / 2)
|
||||
.scale(font_scale)
|
||||
)
|
||||
axes_group.add(y_label)
|
||||
|
||||
return axes_group
|
||||
|
||||
|
||||
class DecisionTreeSurface(VGroup):
|
||||
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
|
||||
# take the tree and construct the surface from it
|
||||
self.tree_clf = tree_clf
|
||||
self.data = data
|
||||
self.axes = axes
|
||||
self.class_colors = class_colors
|
||||
self.surface_rectangles = self.generate_surface_rectangles()
|
||||
|
||||
def generate_surface_rectangles(self):
|
||||
# compute data bounds
|
||||
left = np.amin(self.data[:, 0]) - 0.2
|
||||
right = np.amax(self.data[:, 0]) - 0.2
|
||||
top = np.amax(self.data[:, 1])
|
||||
bottom = np.amin(self.data[:, 1]) - 0.2
|
||||
maxrange = [left, right, bottom, top]
|
||||
rectangles = compute_decision_areas(
|
||||
self.tree_clf, maxrange, x=0, y=1, n_features=2
|
||||
)
|
||||
# turn the rectangle objects into manim rectangles
|
||||
def convert_rectangle_to_polygon(rect):
|
||||
# get the points for the rectangle in the plot coordinate frame
|
||||
bottom_left = [rect[0], rect[3]]
|
||||
bottom_right = [rect[1], rect[3]]
|
||||
top_right = [rect[1], rect[2]]
|
||||
top_left = [rect[0], rect[2]]
|
||||
# convert those points into the entire manim coordinates
|
||||
bottom_left_coord = self.axes.coords_to_point(*bottom_left)
|
||||
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
|
||||
top_right_coord = self.axes.coords_to_point(*top_right)
|
||||
top_left_coord = self.axes.coords_to_point(*top_left)
|
||||
points = [
|
||||
bottom_left_coord,
|
||||
bottom_right_coord,
|
||||
top_right_coord,
|
||||
top_left_coord,
|
||||
]
|
||||
# construct a polygon object from those manim coordinates
|
||||
rectangle = Polygon(
|
||||
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
|
||||
)
|
||||
return rectangle
|
||||
|
||||
manim_rectangles = []
|
||||
for rect in rectangles:
|
||||
color = self.class_colors[int(rect[4])]
|
||||
rectangle = convert_rectangle_to_polygon(rect)
|
||||
manim_rectangles.append(rectangle)
|
||||
|
||||
manim_rectangles = merge_overlapping_polygons(
|
||||
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
|
||||
)
|
||||
|
||||
return manim_rectangles
|
||||
|
||||
@override_animation(Create)
|
||||
def create_override(self):
|
||||
# play a reveal of all of the surface rectangles
|
||||
animations = []
|
||||
for rectangle in self.surface_rectangles:
|
||||
animations.append(Create(rectangle))
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
return animation_group
|
||||
|
||||
@override_animation(Uncreate)
|
||||
def uncreate_override(self):
|
||||
# play a reveal of all of the surface rectangles
|
||||
animations = []
|
||||
for rectangle in self.surface_rectangles:
|
||||
animations.append(Uncreate(rectangle))
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
return animation_group
|
||||
|
||||
# Make the animation
|
||||
expand_tree_animation = self.make_expand_tree_animation(node_expand_order)
|
||||
return expand_tree_animation
|
||||
|
||||
class DecisionTreeContainer(OneToOneSync):
|
||||
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
||||
|
356
manim_ml/decision_tree/decision_tree_surface.py
Normal file
356
manim_ml/decision_tree/decision_tree_surface.py
Normal file
@ -0,0 +1,356 @@
|
||||
from manim import *
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from sklearn.tree import _tree as ctree
|
||||
|
||||
class AABB:
|
||||
"""Axis-aligned bounding box"""
|
||||
|
||||
def __init__(self, n_features):
|
||||
self.limits = np.array([[-np.inf, np.inf]] * n_features)
|
||||
|
||||
def split(self, f, v):
|
||||
left = AABB(self.limits.shape[0])
|
||||
right = AABB(self.limits.shape[0])
|
||||
left.limits = self.limits.copy()
|
||||
right.limits = self.limits.copy()
|
||||
left.limits[f, 1] = v
|
||||
right.limits[f, 0] = v
|
||||
|
||||
return left, right
|
||||
|
||||
def tree_bounds(tree, n_features=None):
|
||||
"""Compute final decision rule for each node in tree"""
|
||||
if n_features is None:
|
||||
n_features = np.max(tree.feature) + 1
|
||||
aabbs = [AABB(n_features) for _ in range(tree.node_count)]
|
||||
queue = deque([0])
|
||||
while queue:
|
||||
i = queue.pop()
|
||||
l = tree.children_left[i]
|
||||
r = tree.children_right[i]
|
||||
if l != ctree.TREE_LEAF:
|
||||
aabbs[l], aabbs[r] = aabbs[i].split(tree.feature[i], tree.threshold[i])
|
||||
queue.extend([l, r])
|
||||
return aabbs
|
||||
|
||||
def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
|
||||
"""Extract decision areas.
|
||||
|
||||
tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier
|
||||
maxrange: values to insert for [left, right, top, bottom] if the interval is open (+/-inf)
|
||||
x: index of the feature that goes on the x axis
|
||||
y: index of the feature that goes on the y axis
|
||||
n_features: override autodetection of number of features
|
||||
"""
|
||||
tree = tree_classifier.tree_
|
||||
aabbs = tree_bounds(tree, n_features)
|
||||
maxrange = np.array(maxrange)
|
||||
rectangles = []
|
||||
for i in range(len(aabbs)):
|
||||
if tree.children_left[i] != ctree.TREE_LEAF:
|
||||
continue
|
||||
l = aabbs[i].limits
|
||||
r = [l[x, 0], l[x, 1], l[y, 0], l[y, 1], np.argmax(tree.value[i])]
|
||||
# clip out of bounds indices
|
||||
"""
|
||||
if r[0] < maxrange[0]:
|
||||
r[0] = maxrange[0]
|
||||
if r[1] > maxrange[1]:
|
||||
r[1] = maxrange[1]
|
||||
if r[2] < maxrange[2]:
|
||||
r[2] = maxrange[2]
|
||||
if r[3] > maxrange[3]:
|
||||
r[3] = maxrange[3]
|
||||
print(r)
|
||||
"""
|
||||
rectangles.append(r)
|
||||
rectangles = np.array(rectangles)
|
||||
rectangles[:, [0, 2]] = np.maximum(rectangles[:, [0, 2]], maxrange[0::2])
|
||||
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
|
||||
return rectangles
|
||||
|
||||
def plot_areas(rectangles):
|
||||
for rect in rectangles:
|
||||
color = ["b", "r"][int(rect[4])]
|
||||
print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
|
||||
rp = Rectangle(
|
||||
[rect[0], rect[2]],
|
||||
rect[1] - rect[0],
|
||||
rect[3] - rect[2],
|
||||
color=color,
|
||||
alpha=0.3,
|
||||
)
|
||||
plt.gca().add_artist(rp)
|
||||
|
||||
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
||||
# get all polygons of each color
|
||||
polygon_dict = {
|
||||
str(BLUE).lower(): [],
|
||||
str(GREEN).lower(): [],
|
||||
str(ORANGE).lower(): [],
|
||||
}
|
||||
for polygon in all_polygons:
|
||||
print(polygon_dict)
|
||||
polygon_dict[str(polygon.color).lower()].append(polygon)
|
||||
|
||||
return_polygons = []
|
||||
for color in colors:
|
||||
color = str(color).lower()
|
||||
polygons = polygon_dict[color]
|
||||
points = set()
|
||||
for polygon in polygons:
|
||||
vertices = polygon.get_vertices().tolist()
|
||||
vertices = [tuple(vert) for vert in vertices]
|
||||
for pt in vertices:
|
||||
if pt in points: # Shared vertice, remove it.
|
||||
points.remove(pt)
|
||||
else:
|
||||
points.add(pt)
|
||||
points = list(points)
|
||||
sort_x = sorted(points)
|
||||
sort_y = sorted(points, key=lambda x: x[1])
|
||||
|
||||
edges_h = {}
|
||||
edges_v = {}
|
||||
|
||||
i = 0
|
||||
while i < len(points):
|
||||
curr_y = sort_y[i][1]
|
||||
while i < len(points) and sort_y[i][1] == curr_y:
|
||||
edges_h[sort_y[i]] = sort_y[i + 1]
|
||||
edges_h[sort_y[i + 1]] = sort_y[i]
|
||||
i += 2
|
||||
i = 0
|
||||
while i < len(points):
|
||||
curr_x = sort_x[i][0]
|
||||
while i < len(points) and sort_x[i][0] == curr_x:
|
||||
edges_v[sort_x[i]] = sort_x[i + 1]
|
||||
edges_v[sort_x[i + 1]] = sort_x[i]
|
||||
i += 2
|
||||
|
||||
# Get all the polygons.
|
||||
while edges_h:
|
||||
# We can start with any point.
|
||||
polygon = [(edges_h.popitem()[0], 0)]
|
||||
while True:
|
||||
curr, e = polygon[-1]
|
||||
if e == 0:
|
||||
next_vertex = edges_v.pop(curr)
|
||||
polygon.append((next_vertex, 1))
|
||||
else:
|
||||
next_vertex = edges_h.pop(curr)
|
||||
polygon.append((next_vertex, 0))
|
||||
if polygon[-1] == polygon[0]:
|
||||
# Closed polygon
|
||||
polygon.pop()
|
||||
break
|
||||
# Remove implementation-markers from the polygon.
|
||||
poly = [point for point, _ in polygon]
|
||||
for vertex in poly:
|
||||
if vertex in edges_h:
|
||||
edges_h.pop(vertex)
|
||||
if vertex in edges_v:
|
||||
edges_v.pop(vertex)
|
||||
polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0)
|
||||
return_polygons.append(polygon)
|
||||
return return_polygons
|
||||
|
||||
class IrisDatasetPlot(VGroup):
|
||||
def __init__(self, iris):
|
||||
points = iris.data[:, 0:2]
|
||||
labels = iris.feature_names
|
||||
targets = iris.target
|
||||
# Make points
|
||||
self.point_group = self._make_point_group(points, targets)
|
||||
# Make axes
|
||||
self.axes_group = self._make_axes_group(points, labels)
|
||||
# Make legend
|
||||
self.legend_group = self._make_legend(
|
||||
[BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
|
||||
)
|
||||
# Make title
|
||||
# title_text = "Iris Dataset Plot"
|
||||
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
|
||||
# Make all group
|
||||
self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
|
||||
# scale the groups
|
||||
self.point_group.scale(1.6)
|
||||
self.point_group.match_x(self.axes_group)
|
||||
self.point_group.match_y(self.axes_group)
|
||||
self.point_group.shift([0.2, 0, 0])
|
||||
self.axes_group.scale(0.7)
|
||||
self.all_group.shift([0, 0.2, 0])
|
||||
|
||||
@override_animation(Create)
|
||||
def create_animation(self):
|
||||
animation_group = AnimationGroup(
|
||||
# Perform the animations
|
||||
Create(self.point_group, run_time=2),
|
||||
Wait(0.5),
|
||||
Create(self.axes_group, run_time=2),
|
||||
# add title
|
||||
# Create(self.title),
|
||||
Create(self.legend_group),
|
||||
)
|
||||
return animation_group
|
||||
|
||||
def _make_point_group(self, points, targets, class_colors=[BLUE, ORANGE, GREEN]):
|
||||
point_group = VGroup()
|
||||
for point_index, point in enumerate(points):
|
||||
# draw the dot
|
||||
current_target = targets[point_index]
|
||||
color = class_colors[current_target]
|
||||
dot = Dot(point=np.array([point[0], point[1], 0])).set_color(color)
|
||||
dot.scale(0.5)
|
||||
point_group.add(dot)
|
||||
return point_group
|
||||
|
||||
def _make_legend(self, class_colors, feature_labels, axes):
|
||||
legend_group = VGroup()
|
||||
# Make Text
|
||||
setosa = Text("Setosa", color=BLUE)
|
||||
verisicolor = Text("Verisicolor", color=ORANGE)
|
||||
virginica = Text("Virginica", color=GREEN)
|
||||
labels = VGroup(setosa, verisicolor, virginica).arrange(
|
||||
direction=RIGHT, aligned_edge=LEFT, buff=2.0
|
||||
)
|
||||
labels.scale(0.5)
|
||||
legend_group.add(labels)
|
||||
# surrounding rectangle
|
||||
surrounding_rectangle = SurroundingRectangle(labels, color=WHITE)
|
||||
surrounding_rectangle.move_to(labels)
|
||||
legend_group.add(surrounding_rectangle)
|
||||
# shift the legend group
|
||||
legend_group.move_to(axes)
|
||||
legend_group.shift([0, -3.0, 0])
|
||||
legend_group.match_x(axes[0][0])
|
||||
|
||||
return legend_group
|
||||
|
||||
def _make_axes_group(self, points, labels, font="Source Han Sans", font_scale=0.75):
|
||||
axes_group = VGroup()
|
||||
# make the axes
|
||||
x_range = [
|
||||
np.amin(points, axis=0)[0] - 0.2,
|
||||
np.amax(points, axis=0)[0] - 0.2,
|
||||
0.5,
|
||||
]
|
||||
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
|
||||
axes = Axes(
|
||||
x_range=x_range,
|
||||
y_range=y_range,
|
||||
x_length=9,
|
||||
y_length=6.5,
|
||||
# axis_config={"number_scale_value":0.75, "include_numbers":True},
|
||||
tips=False,
|
||||
).shift([0.5, 0.25, 0])
|
||||
axes_group.add(axes)
|
||||
# make axis labels
|
||||
# x_label
|
||||
x_label = (
|
||||
Text(labels[0], font=font)
|
||||
.match_y(axes.get_axes()[0])
|
||||
.shift([0.5, -0.75, 0])
|
||||
.scale(font_scale)
|
||||
)
|
||||
axes_group.add(x_label)
|
||||
# y_label
|
||||
y_label = (
|
||||
Text(labels[1], font=font)
|
||||
.match_x(axes.get_axes()[1])
|
||||
.shift([-0.75, 0, 0])
|
||||
.rotate(np.pi / 2)
|
||||
.scale(font_scale)
|
||||
)
|
||||
axes_group.add(y_label)
|
||||
|
||||
return axes_group
|
||||
|
||||
|
||||
class DecisionTreeSurface(VGroup):
|
||||
|
||||
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
|
||||
# take the tree and construct the surface from it
|
||||
self.tree_clf = tree_clf
|
||||
self.data = data
|
||||
self.axes = axes
|
||||
self.class_colors = class_colors
|
||||
self.surface_rectangles = self.generate_surface_rectangles()
|
||||
|
||||
def generate_surface_rectangles(self):
|
||||
# compute data bounds
|
||||
left = np.amin(self.data[:, 0]) - 0.2
|
||||
right = np.amax(self.data[:, 0]) - 0.2
|
||||
top = np.amax(self.data[:, 1])
|
||||
bottom = np.amin(self.data[:, 1]) - 0.2
|
||||
maxrange = [left, right, bottom, top]
|
||||
rectangles = compute_decision_areas(
|
||||
self.tree_clf, maxrange, x=0, y=1, n_features=2
|
||||
)
|
||||
# turn the rectangle objects into manim rectangles
|
||||
def convert_rectangle_to_polygon(rect):
|
||||
# get the points for the rectangle in the plot coordinate frame
|
||||
bottom_left = [rect[0], rect[3]]
|
||||
bottom_right = [rect[1], rect[3]]
|
||||
top_right = [rect[1], rect[2]]
|
||||
top_left = [rect[0], rect[2]]
|
||||
# convert those points into the entire manim coordinates
|
||||
bottom_left_coord = self.axes.coords_to_point(*bottom_left)
|
||||
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
|
||||
top_right_coord = self.axes.coords_to_point(*top_right)
|
||||
top_left_coord = self.axes.coords_to_point(*top_left)
|
||||
points = [
|
||||
bottom_left_coord,
|
||||
bottom_right_coord,
|
||||
top_right_coord,
|
||||
top_left_coord,
|
||||
]
|
||||
# construct a polygon object from those manim coordinates
|
||||
rectangle = Polygon(
|
||||
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
|
||||
)
|
||||
return rectangle
|
||||
|
||||
manim_rectangles = []
|
||||
for rect in rectangles:
|
||||
color = self.class_colors[int(rect[4])]
|
||||
rectangle = convert_rectangle_to_polygon(rect)
|
||||
manim_rectangles.append(rectangle)
|
||||
|
||||
manim_rectangles = merge_overlapping_polygons(
|
||||
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
|
||||
)
|
||||
|
||||
return manim_rectangles
|
||||
|
||||
@override_animation(Create)
|
||||
def create_override(self):
|
||||
# play a reveal of all of the surface rectangles
|
||||
animations = []
|
||||
for rectangle in self.surface_rectangles:
|
||||
animations.append(Create(rectangle))
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
return animation_group
|
||||
|
||||
@override_animation(Uncreate)
|
||||
def uncreate_override(self):
|
||||
# play a reveal of all of the surface rectangles
|
||||
animations = []
|
||||
for rectangle in self.surface_rectangles:
|
||||
animations.append(Uncreate(rectangle))
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_split_to_animation_map(self):
|
||||
"""
|
||||
Returns a dictionary mapping a given split
|
||||
node to an animation to be played
|
||||
"""
|
||||
# Create an initial decision tree surface
|
||||
# Go through each split node
|
||||
# 1. Make a line split animation
|
||||
# 2. Create the relevant classification areas
|
||||
# and transform the old ones to them
|
@ -26,7 +26,6 @@ def compute_node_depths(tree):
|
||||
|
||||
return node_depths
|
||||
|
||||
|
||||
def compute_level_order_traversal(tree):
|
||||
"""Computes level order traversal of a sklearn tree"""
|
||||
|
||||
@ -57,6 +56,26 @@ def compute_level_order_traversal(tree):
|
||||
|
||||
return sorted_inds
|
||||
|
||||
def compute_bfs_traversal(tree):
|
||||
"""Traverses the tree in BFS order and returns the nodes in order"""
|
||||
traversal_order = []
|
||||
tree_root_index = 0
|
||||
queue = [tree_root_index]
|
||||
while len(queue) > 0:
|
||||
current_index = queue.pop(0)
|
||||
traversal_order.append(current_index)
|
||||
left_child_index = self.tree.children_left[node_index]
|
||||
right_child_index = self.tree.children_right[node_index]
|
||||
is_leaf_node = left_child_index == right_child_index
|
||||
if not is_leaf_node:
|
||||
queue.append(left_child_index)
|
||||
queue.append(right_child_index)
|
||||
|
||||
return traversal_order
|
||||
|
||||
def compute_best_first_traversal(tree):
|
||||
"""Traverses the tree according to the best split first order"""
|
||||
pass
|
||||
|
||||
def compute_node_to_parent_mapping(tree):
|
||||
"""Returns a hashmap mapping node indices to their parent indices"""
|
||||
|
0
manim_ml/decision_tree/synced_surfa
Normal file
0
manim_ml/decision_tree/synced_surfa
Normal file
369
manim_ml/diffusion/mcmc.py
Normal file
369
manim_ml/diffusion/mcmc.py
Normal file
@ -0,0 +1,369 @@
|
||||
"""
|
||||
Tool for animating Markov Chain Monte Carlo simulations in 2D.
|
||||
"""
|
||||
from manim import *
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.stats
|
||||
from tqdm import tqdm
|
||||
|
||||
from manim_ml.probability import GaussianDistribution
|
||||
|
||||
def gaussian_proposal(x, sigma=0.2):
|
||||
"""
|
||||
Gaussian proposal distribution.
|
||||
|
||||
Draw new parameters from Gaussian distribution with
|
||||
mean at current position and standard deviation sigma.
|
||||
|
||||
Since the mean is the current position and the standard
|
||||
deviation is fixed. This proposal is symmetric so the ratio
|
||||
of proposal densities is 1.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : np.ndarray or list
|
||||
point to center proposal around
|
||||
sigma : float, optional
|
||||
standard deviation of gaussian for proposal, by default 0.1
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
propossed point
|
||||
"""
|
||||
# Draw x_star
|
||||
x_star = x + np.random.randn(len(x)) * sigma
|
||||
# proposal ratio factor is 1 since jump is symmetric
|
||||
qxx = 1
|
||||
|
||||
return (x_star, qxx)
|
||||
|
||||
class MultidimensionalGaussianPosterior():
|
||||
"""
|
||||
N-Dimensional Gaussian distribution with
|
||||
|
||||
mu ~ Normal(0, 10)
|
||||
var ~ LogNormal(0, 1.5)
|
||||
|
||||
Prior on mean is U(-500, 500)
|
||||
"""
|
||||
|
||||
def __init__(self, ndim=2, seed=12345, scale=3,
|
||||
mu=None, var=None):
|
||||
"""_summary_
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ndim : int, optional
|
||||
_description_, by default 2
|
||||
seed : int, optional
|
||||
_description_, by default 12345
|
||||
scale : int, optional
|
||||
_description_, by default 10
|
||||
"""
|
||||
np.random.seed(seed)
|
||||
self.scale = scale
|
||||
|
||||
if var is None:
|
||||
self.var = 10 ** (np.random.randn(ndim) * 1.5)
|
||||
else:
|
||||
self.var = var
|
||||
|
||||
if mu is None:
|
||||
self.mu = scipy.stats.norm(
|
||||
loc=0,
|
||||
scale=self.scale
|
||||
).rvs(ndim)
|
||||
else:
|
||||
self.mu = mu
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Call multivariate normal posterior.
|
||||
"""
|
||||
|
||||
if np.all(x < 500) and np.all(x > -500):
|
||||
return scipy.stats.multivariate_normal(
|
||||
mean=self.mu,
|
||||
cov=self.var
|
||||
).logpdf(x)
|
||||
else:
|
||||
return -1e6
|
||||
|
||||
def metropolis_hastings_sampler(
|
||||
log_prob_fn=MultidimensionalGaussianPosterior(),
|
||||
prop_fn=gaussian_proposal,
|
||||
initial_location : np.ndarray = np.array([0, 0]),
|
||||
iterations=25,
|
||||
warm_up=0,
|
||||
ndim=2
|
||||
):
|
||||
"""Samples using a Metropolis-Hastings sampler.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
log_prob_fn : function, optional
|
||||
Function to compute log-posterior, by default MultidimensionalGaussianPosterior
|
||||
prop_fn : function, optional
|
||||
Function to compute proposal location, by default gaussian_proposal
|
||||
initial_location : np.ndarray, optional
|
||||
initial location for the chain
|
||||
iterations : int, optional
|
||||
number of iterations of the markov chain, by default 100
|
||||
warm_up : int, optional,
|
||||
number of warm up iterations
|
||||
|
||||
Returns
|
||||
-------
|
||||
samples : np.ndarray
|
||||
numpy array of 2D samples of length `iterations`
|
||||
warm_up_samples : np.ndarray
|
||||
numpy array of 2D warm up samples of length `warm_up`
|
||||
candidate_samples: np.ndarray
|
||||
numpy array of the candidate samples for each time step
|
||||
"""
|
||||
assert warm_up == 0, "Warmup not implemented yet"
|
||||
# initialize chain, acceptance rate and lnprob
|
||||
chain = np.zeros((iterations, ndim))
|
||||
proposals = np.zeros((iterations, ndim))
|
||||
lnprob = np.zeros(iterations)
|
||||
accept_rate = np.zeros(iterations)
|
||||
# first samples
|
||||
chain[0] = initial_location
|
||||
proposals[0] = initial_location
|
||||
lnprob0 = log_prob_fn(initial_location)
|
||||
lnprob[0] = lnprob0
|
||||
# start loop
|
||||
x0 = initial_location
|
||||
naccept = 0
|
||||
for ii in range(1, iterations):
|
||||
# propose
|
||||
x_star, factor = prop_fn(x0)
|
||||
# draw random uniform number
|
||||
u = np.random.uniform(0, 1)
|
||||
# compute hastings ratio
|
||||
lnprob_star = log_prob_fn(x_star)
|
||||
H = np.exp(lnprob_star - lnprob0) * factor
|
||||
# accept/reject step (update acceptance counter)
|
||||
if u < H:
|
||||
x0 = x_star
|
||||
lnprob0 = lnprob_star
|
||||
naccept += 1
|
||||
# update chain
|
||||
chain[ii] = x0
|
||||
proposals[ii] = x_star
|
||||
lnprob[ii] = lnprob0
|
||||
accept_rate[ii] = naccept / ii
|
||||
|
||||
return chain, np.array([]), proposals
|
||||
|
||||
class MCMCAxes(Group):
|
||||
"""Container object for visualizing MCMC on a 2D axis"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dot_color=BLUE,
|
||||
dot_radius=0.05,
|
||||
accept_line_color=GREEN,
|
||||
reject_line_color=RED,
|
||||
line_color=WHITE,
|
||||
line_stroke_width=1
|
||||
):
|
||||
super().__init__()
|
||||
self.dot_color = dot_color
|
||||
self.dot_radius = dot_radius
|
||||
self.accept_line_color = accept_line_color
|
||||
self.reject_line_color = reject_line_color
|
||||
self.line_color = line_color
|
||||
self.line_stroke_width=line_stroke_width
|
||||
# Make the axes
|
||||
self.axes = Axes(
|
||||
x_range=[-3, 3],
|
||||
y_range=[-3, 3],
|
||||
x_length=12,
|
||||
y_length=12,
|
||||
x_axis_config={"stroke_opacity": 0.0},
|
||||
y_axis_config={"stroke_opacity": 0.0},
|
||||
tips=False
|
||||
)
|
||||
self.add(self.axes)
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
"""Overrides Create animation"""
|
||||
return AnimationGroup(
|
||||
Create(self.axes)
|
||||
)
|
||||
|
||||
def visualize_gaussian_proposal_about_point(
|
||||
self,
|
||||
mean,
|
||||
cov=None
|
||||
) -> AnimationGroup:
|
||||
"""Creates a Gaussian distribution about a certain point
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : np.ndarray
|
||||
mean of proposal distribution
|
||||
cov : np.ndarray
|
||||
covariance matrix of proposal distribution
|
||||
|
||||
Returns
|
||||
-------
|
||||
AnimationGroup
|
||||
animation of creating the proposal Gaussian distribution
|
||||
"""
|
||||
gaussian = GaussianDistribution(
|
||||
axes=self.axes,
|
||||
mean=mean,
|
||||
cov=cov,
|
||||
dist_theme="gaussian"
|
||||
)
|
||||
|
||||
create_guassian = Create(gaussian)
|
||||
return create_guassian
|
||||
|
||||
def make_transition_animation(
|
||||
self,
|
||||
start_point,
|
||||
end_point,
|
||||
candidate_point,
|
||||
run_time=0.1
|
||||
) -> AnimationGroup:
|
||||
"""Makes an transition animation for a single point on a Markov Chain
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_point: Dot
|
||||
Start point of the transition
|
||||
end_point : Dot
|
||||
End point of the transition
|
||||
|
||||
Returns
|
||||
-------
|
||||
AnimationGroup
|
||||
Animation of the transition from start to end
|
||||
"""
|
||||
start_location = self.axes.point_to_coords(start_point.get_center())
|
||||
end_location = self.axes.point_to_coords(end_point.get_center())
|
||||
candidate_location = self.axes.point_to_coords(candidate_point.get_center())
|
||||
# Figure out if a point is accepted or rejected
|
||||
# point_is_rejected = not candidate_location == end_location
|
||||
point_is_rejected = False
|
||||
if point_is_rejected:
|
||||
return AnimationGroup()
|
||||
else:
|
||||
create_end_point = Create(
|
||||
end_point
|
||||
)
|
||||
create_line = Create(
|
||||
Line(
|
||||
start_point,
|
||||
end_point,
|
||||
color=self.line_color,
|
||||
stroke_width=self.line_stroke_width
|
||||
)
|
||||
)
|
||||
return AnimationGroup(
|
||||
create_end_point,
|
||||
create_line,
|
||||
lag_ratio=1.0,
|
||||
run_time=run_time
|
||||
)
|
||||
|
||||
def show_ground_truth_gaussian(self, distribution):
|
||||
"""
|
||||
"""
|
||||
mean = distribution.mu
|
||||
var = np.eye(2) * distribution.var
|
||||
distribution_drawing = GaussianDistribution(
|
||||
self.axes,
|
||||
mean,
|
||||
var,
|
||||
dist_theme="gaussian"
|
||||
).set_opacity(0.2)
|
||||
return AnimationGroup(
|
||||
Create(distribution_drawing)
|
||||
)
|
||||
|
||||
def visualize_metropolis_hastings_chain_sampling(
|
||||
self,
|
||||
log_prob_fn=MultidimensionalGaussianPosterior(),
|
||||
prop_fn=gaussian_proposal,
|
||||
sampling_kwargs={},
|
||||
):
|
||||
"""
|
||||
Makes an animation for visualizing a 2D markov chain using
|
||||
metropolis hastings samplings
|
||||
|
||||
Parameters
|
||||
----------
|
||||
axes : manim.mobject.graphing.coordinate_systems.Axes
|
||||
Manim 2D axes to plot the chain on
|
||||
log_prob_fn : function, optional
|
||||
Function to compute log-posterior, by default MultidmensionalGaussianPosterior
|
||||
prop_fn : function, optional
|
||||
Function to compute proposal location, by default gaussian_proposal
|
||||
initial_location : list, optional
|
||||
initial location for the markov chain, by default None
|
||||
iterations : int, optional
|
||||
number of iterations of the markov chain, by default 100
|
||||
|
||||
Returns
|
||||
-------
|
||||
animation : AnimationGroup
|
||||
animation for creating the markov chain
|
||||
"""
|
||||
# Compute the chain samples using a Metropolis Hastings Sampler
|
||||
mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler(
|
||||
log_prob_fn=log_prob_fn,
|
||||
prop_fn=prop_fn,
|
||||
**sampling_kwargs
|
||||
)
|
||||
print(f"MCMC samples: {mcmc_samples}")
|
||||
print(f"Candidate samples: {candidate_samples}")
|
||||
# Make the animation for visualizing the chain
|
||||
animations = []
|
||||
# Place the initial point
|
||||
current_point = mcmc_samples[0]
|
||||
current_point = Dot(
|
||||
self.axes.coords_to_point(current_point[0], current_point[1]),
|
||||
color=self.dot_color,
|
||||
radius=self.dot_radius
|
||||
)
|
||||
create_initial_point = Create(current_point)
|
||||
animations.append(create_initial_point)
|
||||
# Show the initial point's proposal distribution
|
||||
# NOTE: visualize the warm up and the iterations
|
||||
num_iterations = len(mcmc_samples) + len(warm_up_samples)
|
||||
for iteration in tqdm(range(1, num_iterations)):
|
||||
next_sample = mcmc_samples[iteration]
|
||||
print(f"Next sample: {next_sample}")
|
||||
candidate_sample = candidate_samples[iteration - 1]
|
||||
# Make the next point
|
||||
next_point = Dot(
|
||||
self.axes.coords_to_point(next_sample[0], next_sample[1]),
|
||||
color=self.dot_color,
|
||||
radius=self.dot_radius
|
||||
)
|
||||
candidate_point = Dot(
|
||||
self.axes.coords_to_point(candidate_sample[0], candidate_sample[1]),
|
||||
color=self.dot_color,
|
||||
radius=self.dot_radius
|
||||
)
|
||||
# Make a transition animation
|
||||
transition_animation = self.make_transition_animation(
|
||||
current_point, next_point, candidate_point
|
||||
)
|
||||
animations.append(transition_animation)
|
||||
# Setup for next iteration
|
||||
current_point = next_point
|
||||
# Make the final animation group
|
||||
animation_group = AnimationGroup(
|
||||
*animations,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
8
manim_ml/neural_network/activation_functions/__init__.py
Normal file
8
manim_ml/neural_network/activation_functions/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from manim_ml.neural_network.activation_functions.relu import ReLUFunction
|
||||
|
||||
name_to_activation_function_map = {
|
||||
"ReLU": ReLUFunction()
|
||||
}
|
||||
|
||||
def get_activation_function_by_name(name):
|
||||
return name_to_activation_function_map[name]
|
@ -0,0 +1,106 @@
|
||||
from manim import *
|
||||
from abc import ABC, abstractmethod
|
||||
import random
|
||||
|
||||
import manim_ml.neural_network.activation_functions.relu as relu
|
||||
|
||||
class ActivationFunction(ABC, VGroup):
|
||||
"""Abstract parent class for defining activation functions"""
|
||||
|
||||
def __init__(self, function_name=None, x_range=[-1, 1], y_range=[-1, 1],
|
||||
x_length=0.5, y_length=0.3, show_function_name=True, active_color=ORANGE,
|
||||
plot_color=BLUE, rectangle_color=WHITE):
|
||||
super(VGroup, self).__init__()
|
||||
self.function_name = function_name
|
||||
self.x_range = x_range
|
||||
self.y_range = y_range
|
||||
self.x_length = x_length
|
||||
self.y_length = y_length
|
||||
self.show_function_name = show_function_name
|
||||
self.active_color = active_color
|
||||
self.plot_color = plot_color
|
||||
self.rectangle_color = rectangle_color
|
||||
|
||||
self.construct_activation_function()
|
||||
|
||||
def construct_activation_function(self):
|
||||
"""Makes the activation function"""
|
||||
# Make an axis
|
||||
self.axes = Axes(
|
||||
x_range=self.x_range,
|
||||
y_range=self.y_range,
|
||||
x_length=self.x_length,
|
||||
y_length=self.y_length,
|
||||
tips=False,
|
||||
axis_config={
|
||||
"include_numbers": False,
|
||||
"stroke_width": 0.5,
|
||||
"include_ticks": False
|
||||
}
|
||||
)
|
||||
self.add(self.axes)
|
||||
# Surround the axis with a rounded rectangle.
|
||||
self.surrounding_rectangle = SurroundingRectangle(
|
||||
self.axes,
|
||||
corner_radius=0.05,
|
||||
buff=0.05,
|
||||
stroke_width=2.0,
|
||||
stroke_color=self.rectangle_color
|
||||
)
|
||||
self.add(self.surrounding_rectangle)
|
||||
# Plot function on axis by applying it and showing in given range
|
||||
self.graph = self.axes.plot(
|
||||
lambda x: self.apply_function(x),
|
||||
x_range=self.x_range,
|
||||
stroke_color=self.plot_color,
|
||||
stroke_width=2.0
|
||||
)
|
||||
self.add(self.graph)
|
||||
# Add the function name
|
||||
if self.show_function_name:
|
||||
function_name_text = Text(
|
||||
self.function_name,
|
||||
font_size=12,
|
||||
font="sans-serif"
|
||||
)
|
||||
function_name_text.next_to(self.axes, UP*0.5)
|
||||
self.add(function_name_text)
|
||||
|
||||
@abstractmethod
|
||||
def apply_function(self, x_val):
|
||||
"""Evaluates function at given x_val"""
|
||||
if x_val == None:
|
||||
x_val = random.uniform(self.x_range[0], self.x_range[1])
|
||||
|
||||
def make_evaluate_animation(self, x_val=None):
|
||||
"""Evaluates the function at a random point in the x_range"""
|
||||
# Highlight the graph
|
||||
# TODO: Evaluate the function at the x_val and show a highlighted dot
|
||||
animation_group = Succession(
|
||||
AnimationGroup(
|
||||
ApplyMethod(
|
||||
self.graph.set_color,
|
||||
self.active_color
|
||||
),
|
||||
ApplyMethod(
|
||||
self.surrounding_rectangle.set_stroke_color,
|
||||
self.active_color
|
||||
),
|
||||
lag_ratio=0.0
|
||||
),
|
||||
Wait(1),
|
||||
AnimationGroup(
|
||||
ApplyMethod(
|
||||
self.graph.set_color,
|
||||
self.plot_color
|
||||
),
|
||||
ApplyMethod(
|
||||
self.surrounding_rectangle.set_stroke_color,
|
||||
self.rectangle_color
|
||||
),
|
||||
lag_ratio=0.0
|
||||
),
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
15
manim_ml/neural_network/activation_functions/relu.py
Normal file
15
manim_ml/neural_network/activation_functions/relu.py
Normal file
@ -0,0 +1,15 @@
|
||||
from manim import *
|
||||
|
||||
from manim_ml.neural_network.activation_functions.activation_function import ActivationFunction
|
||||
|
||||
class ReLUFunction(ActivationFunction):
|
||||
"""Rectified Linear Unit Activation Function"""
|
||||
|
||||
def __init__(self, function_name="ReLU", x_range=[-1, 1], y_range=[-1, 1]):
|
||||
super().__init__(function_name, x_range, y_range)
|
||||
|
||||
def apply_function(self, x_val):
|
||||
if x_val < 0:
|
||||
return 0
|
||||
else:
|
||||
return x_val
|
@ -1,9 +1,11 @@
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_feed_forward import (
|
||||
Convolutional2DToFeedForward,
|
||||
)
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_max_pooling_2d import Convolutional2DToMaxPooling2D
|
||||
from manim_ml.neural_network.layers.image_to_convolutional_2d import (
|
||||
ImageToConvolutional2DLayer,
|
||||
)
|
||||
from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import MaxPooling2DToConvolutional2D
|
||||
from .convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D
|
||||
from .convolutional_2d import Convolutional2DLayer
|
||||
from .feed_forward_to_vector import FeedForwardToVector
|
||||
@ -37,4 +39,6 @@ connective_layers_list = (
|
||||
Convolutional2DToConvolutional2D,
|
||||
ImageToConvolutional2DLayer,
|
||||
Convolutional2DToFeedForward,
|
||||
Convolutional2DToMaxPooling2D,
|
||||
MaxPooling2DToConvolutional2D,
|
||||
)
|
||||
|
@ -1,4 +1,6 @@
|
||||
from typing import Union
|
||||
from manim_ml.neural_network.activation_functions import get_activation_function_by_name
|
||||
from manim_ml.neural_network.activation_functions.activation_function import ActivationFunction
|
||||
import numpy as np
|
||||
from manim import *
|
||||
|
||||
@ -24,6 +26,7 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
filter_color=ORANGE,
|
||||
stride=1,
|
||||
stroke_width=2.0,
|
||||
activation_function=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -44,6 +47,7 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
self.stride = stride
|
||||
self.stroke_width = stroke_width
|
||||
self.show_grid_lines = show_grid_lines
|
||||
self.activation_function = activation_function
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
@ -61,6 +65,19 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
about_point=self.get_center(),
|
||||
axis=ThreeDLayer.rotation_axis,
|
||||
)
|
||||
# Add the activation function
|
||||
if not self.activation_function is None:
|
||||
# Check if it is a string
|
||||
if isinstance(self.activation_function, str):
|
||||
activation_function = get_activation_function_by_name(
|
||||
self.activation_function
|
||||
)
|
||||
else:
|
||||
assert isinstance(self.activation_function, ActivationFunction)
|
||||
activation_function = self.activation_function
|
||||
# Plot the function above the rest of the layer
|
||||
self.activation_function = activation_function
|
||||
self.add(self.activation_function)
|
||||
|
||||
def construct_feature_maps(self):
|
||||
"""Creates the neural network layer"""
|
||||
@ -88,6 +105,19 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
|
||||
return VGroup(*feature_maps)
|
||||
|
||||
def highlight_and_unhighlight_feature_maps(self):
|
||||
"""Highlights then unhighlights feature maps"""
|
||||
return Succession(
|
||||
ApplyMethod(
|
||||
self.feature_maps.set_color,
|
||||
self.pulse_color
|
||||
),
|
||||
ApplyMethod(
|
||||
self.feature_maps.set_color,
|
||||
self.color
|
||||
)
|
||||
)
|
||||
|
||||
def make_forward_pass_animation(
|
||||
self, run_time=5, corner_pulses=False, layer_args={}, **kwargs
|
||||
):
|
||||
@ -112,7 +142,14 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
# filter_flashes
|
||||
)
|
||||
else:
|
||||
animation_group = AnimationGroup()
|
||||
if not self.activation_function is None:
|
||||
animation_group = AnimationGroup(
|
||||
self.activation_function.make_evaluate_animation(),
|
||||
self.highlight_and_unhighlight_feature_maps(),
|
||||
lag_ratio=0.0
|
||||
)
|
||||
else:
|
||||
animation_group = AnimationGroup()
|
||||
|
||||
return animation_group
|
||||
|
||||
@ -120,6 +157,15 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
self.cell_width *= scale_factor
|
||||
super().scale(scale_factor, **kwargs)
|
||||
|
||||
def get_center(self):
|
||||
"""Overrides function for getting center
|
||||
|
||||
The reason for this is so that the center calculation
|
||||
does not include the activation function.
|
||||
"""
|
||||
print("Getting center")
|
||||
return self.feature_maps.get_center()
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
return FadeIn(self.feature_maps)
|
||||
|
@ -1,3 +1,5 @@
|
||||
import numpy as np
|
||||
|
||||
from manim import *
|
||||
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
|
||||
@ -5,6 +7,26 @@ from manim_ml.gridded_rectangle import GriddedRectangle
|
||||
|
||||
from manim.utils.space_ops import rotation_matrix
|
||||
|
||||
def get_rotated_shift_vectors(input_layer, normalized=False):
|
||||
"""Rotates the shift vectors"""
|
||||
# Make base shift vectors
|
||||
right_shift = np.array([input_layer.cell_width, 0, 0])
|
||||
down_shift = np.array([0, -input_layer.cell_width, 0])
|
||||
# Make rotation matrix
|
||||
rot_mat = rotation_matrix(
|
||||
ThreeDLayer.rotation_angle,
|
||||
ThreeDLayer.rotation_axis
|
||||
)
|
||||
# Rotate the vectors
|
||||
right_shift = np.dot(right_shift, rot_mat.T)
|
||||
down_shift = np.dot(down_shift, rot_mat.T)
|
||||
# Normalize the vectors
|
||||
if normalized:
|
||||
right_shift = right_shift / np.linalg.norm(right_shift)
|
||||
down_shift = down_shift / np.linalg.norm(down_shift)
|
||||
|
||||
return right_shift, down_shift
|
||||
|
||||
class Filters(VGroup):
|
||||
"""Group for showing a collection of filters connecting two layers"""
|
||||
|
||||
@ -39,10 +61,9 @@ class Filters(VGroup):
|
||||
|
||||
def make_input_feature_map_rectangles(self):
|
||||
rectangles = []
|
||||
|
||||
rectangle_width = self.input_layer.filter_size[0] * self.input_layer.cell_width
|
||||
rectangle_height = self.input_layer.filter_size[1] * self.input_layer.cell_width
|
||||
filter_color = self.input_layer.filter_color
|
||||
rectangle_width = self.output_layer.filter_size[0] * self.output_layer.cell_width
|
||||
rectangle_height = self.output_layer.filter_size[1] * self.output_layer.cell_width
|
||||
filter_color = self.output_layer.filter_color
|
||||
|
||||
for index, feature_map in enumerate(self.input_layer.feature_maps):
|
||||
rectangle = GriddedRectangle(
|
||||
@ -87,7 +108,7 @@ class Filters(VGroup):
|
||||
filter_color = self.output_layer.filter_color
|
||||
|
||||
for index, feature_map in enumerate(self.output_layer.feature_maps):
|
||||
# Make sure current feature map is the right filte
|
||||
# Make sure current feature map is the right filter
|
||||
if not self.output_feature_map_to_connect is None:
|
||||
if index != self.output_feature_map_to_connect:
|
||||
continue
|
||||
@ -206,7 +227,7 @@ class Filters(VGroup):
|
||||
does not show up in the scene before the create animation.
|
||||
|
||||
Without this override the filters were shown at the beginning
|
||||
of the neural network forward pass animimation
|
||||
of the neural network forward pass animation
|
||||
instead of just when the filters were supposed to appear.
|
||||
I think this is a bug with Succession in the core
|
||||
Manim Community Library.
|
||||
@ -242,10 +263,8 @@ class Filters(VGroup):
|
||||
|
||||
return passing_flash
|
||||
|
||||
|
||||
class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
|
||||
input_class = Convolutional2DLayer
|
||||
output_class = Convolutional2DLayer
|
||||
|
||||
@ -265,18 +284,16 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
input_class=Convolutional2DLayer,
|
||||
output_class=Convolutional2DLayer,
|
||||
**kwargs,
|
||||
)
|
||||
self.color = color
|
||||
self.filter_color = self.input_layer.filter_color
|
||||
self.filter_size = self.input_layer.filter_size
|
||||
self.filter_color = self.output_layer.filter_color
|
||||
self.filter_size = self.output_layer.filter_size
|
||||
self.feature_map_size = self.input_layer.feature_map_size
|
||||
self.num_input_feature_maps = self.input_layer.num_feature_maps
|
||||
self.num_output_feature_maps = self.output_layer.num_feature_maps
|
||||
self.cell_width = self.input_layer.cell_width
|
||||
self.stride = self.input_layer.stride
|
||||
self.cell_width = self.output_layer.cell_width
|
||||
self.stride = self.output_layer.stride
|
||||
self.filter_opacity = filter_opacity
|
||||
self.cell_width = cell_width
|
||||
self.line_color = line_color
|
||||
@ -287,21 +304,6 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def get_rotated_shift_vectors(self):
|
||||
"""
|
||||
Rotates the shift vectors
|
||||
"""
|
||||
# Make base shift vectors
|
||||
right_shift = np.array([self.input_layer.cell_width, 0, 0])
|
||||
down_shift = np.array([0, -self.input_layer.cell_width, 0])
|
||||
# Make rotation matrix
|
||||
rot_mat = rotation_matrix(ThreeDLayer.rotation_angle, ThreeDLayer.rotation_axis)
|
||||
# Rotate the vectors
|
||||
right_shift = np.dot(right_shift, rot_mat.T)
|
||||
down_shift = np.dot(down_shift, rot_mat.T)
|
||||
|
||||
return right_shift, down_shift
|
||||
|
||||
def animate_filters_all_at_once(self, filters):
|
||||
"""Animates each of the filters all at once"""
|
||||
animations = []
|
||||
@ -316,7 +318,7 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
)
|
||||
animations.append(Create(filters))
|
||||
# Get the rotated shift vectors
|
||||
right_shift, down_shift = self.get_rotated_shift_vectors()
|
||||
right_shift, down_shift = get_rotated_shift_vectors(self.input_layer)
|
||||
left_shift = -1 * right_shift
|
||||
# Make the animation
|
||||
num_y_moves = int((self.feature_map_size[1] - self.filter_size[1]) / self.stride)
|
||||
@ -328,7 +330,6 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
shift_animation = ApplyMethod(filters.shift, self.stride * right_shift)
|
||||
# shift_animation = self.animate.shift(right_shift)
|
||||
animations.append(shift_animation)
|
||||
|
||||
# Go back left num_x_moves and down one
|
||||
shift_amount = (
|
||||
self.stride * num_x_moves * left_shift + self.stride * down_shift
|
||||
@ -346,7 +347,10 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
animations.append(FadeOut(filters))
|
||||
return Succession(*animations, lag_ratio=1.0)
|
||||
|
||||
def animate_filters_one_at_a_time(self, highlight_active_feature_map=False):
|
||||
def animate_filters_one_at_a_time(
|
||||
self,
|
||||
highlight_active_feature_map=True
|
||||
):
|
||||
"""Animates each of the filters one at a time"""
|
||||
animations = []
|
||||
output_feature_maps = self.output_layer.feature_maps
|
||||
@ -381,7 +385,7 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
AnimationGroup(*change_color_animations, lag_ratio=0.0)
|
||||
)
|
||||
# Get the rotated shift vectors
|
||||
right_shift, down_shift = self.get_rotated_shift_vectors()
|
||||
right_shift, down_shift = get_rotated_shift_vectors(self.input_layer)
|
||||
left_shift = -1 * right_shift
|
||||
# Make the animation
|
||||
num_y_moves = int(
|
||||
@ -414,12 +418,18 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
self.stride * num_x_moves * left_shift + self.stride * down_shift
|
||||
)
|
||||
# Make the animation
|
||||
shift_animation = ApplyMethod(filters.shift, shift_amount)
|
||||
shift_animation = ApplyMethod(
|
||||
filters.shift,
|
||||
shift_amount
|
||||
)
|
||||
animations.append(shift_animation)
|
||||
# Do last row move right
|
||||
for x_move in range(num_x_moves):
|
||||
# Shift right
|
||||
shift_animation = ApplyMethod(filters.shift, self.stride * right_shift)
|
||||
shift_animation = ApplyMethod(
|
||||
filters.shift,
|
||||
self.stride * right_shift
|
||||
)
|
||||
# shift_animation = self.animate.shift(right_shift)
|
||||
animations.append(shift_animation)
|
||||
# Remove the filters
|
||||
@ -430,14 +440,18 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
# Change the output feature map colors
|
||||
change_color_animations = []
|
||||
change_color_animations.append(
|
||||
ApplyMethod(feature_map.set_color, original_feature_map_color)
|
||||
ApplyMethod(
|
||||
feature_map.set_color,
|
||||
original_feature_map_color
|
||||
)
|
||||
)
|
||||
# Change the input feature map colors
|
||||
input_feature_maps = self.input_layer.feature_maps
|
||||
for input_feature_map in input_feature_maps:
|
||||
change_color_animations.append(
|
||||
ApplyMethod(
|
||||
input_feature_map.set_color, original_feature_map_color
|
||||
input_feature_map.set_color,
|
||||
original_feature_map_color
|
||||
)
|
||||
)
|
||||
# Combine the animations
|
||||
@ -451,7 +465,7 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
self,
|
||||
layer_args={},
|
||||
all_filters_at_once=False,
|
||||
highlight_active_feature_map=False,
|
||||
highlight_active_feature_map=True,
|
||||
run_time=10.5,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -20,8 +20,6 @@ class Convolutional2DToFeedForward(ConnectiveLayer, ThreeDLayer):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
input_class=Convolutional2DLayer,
|
||||
output_class=Convolutional2DLayer,
|
||||
**kwargs
|
||||
)
|
||||
self.passing_flash_color = passing_flash_color
|
||||
|
@ -0,0 +1,228 @@
|
||||
import random
|
||||
from manim import *
|
||||
from manim_ml.gridded_rectangle import GriddedRectangle
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_convolutional_2d import get_rotated_shift_vectors
|
||||
|
||||
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
|
||||
|
||||
class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
input_class = Convolutional2DLayer
|
||||
output_class = MaxPooling2DLayer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_layer: Convolutional2DLayer,
|
||||
output_layer: MaxPooling2DLayer,
|
||||
active_color=ORANGE,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
self.active_color = active_color
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: 'NeuralNetworkLayer',
|
||||
output_layer: 'NeuralNetworkLayer',
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(
|
||||
self,
|
||||
layer_args={},
|
||||
run_time=1.5,
|
||||
**kwargs
|
||||
):
|
||||
"""Forward pass animation from conv2d to max pooling"""
|
||||
cell_width = self.input_layer.cell_width
|
||||
feature_map_size = self.input_layer.feature_map_size
|
||||
kernel_size = self.output_layer.kernel_size
|
||||
feature_maps = self.input_layer.feature_maps
|
||||
grid_stroke_width = 1.0
|
||||
# Get the normalized shift vectors for the convolutional layer
|
||||
"""
|
||||
right_shift, down_shift = get_rotated_shift_vectors(
|
||||
self.input_layer,
|
||||
normalized=True
|
||||
)
|
||||
"""
|
||||
# Make all of the kernel gridded rectangles
|
||||
create_gridded_rectangle_animations = []
|
||||
create_and_remove_cell_animations = []
|
||||
move_and_resize_gridded_rectangle_animations = []
|
||||
remove_gridded_rectangle_animations = []
|
||||
|
||||
for feature_map_index, feature_map in enumerate(feature_maps):
|
||||
# 1. Draw gridded rectangle with kernel_size x kernel_size
|
||||
# box regions over the input feature maps.
|
||||
gridded_rectangle = GriddedRectangle(
|
||||
color=self.active_color,
|
||||
height=cell_width * feature_map_size[1],
|
||||
width=cell_width * feature_map_size[0],
|
||||
grid_xstep=cell_width * kernel_size,
|
||||
grid_ystep=cell_width * kernel_size,
|
||||
grid_stroke_width=grid_stroke_width,
|
||||
grid_stroke_color=self.active_color,
|
||||
show_grid_lines=True
|
||||
)
|
||||
# 2. Randomly highlight one of the cells in the kernel.
|
||||
highlighted_cells = []
|
||||
num_cells_in_kernel = kernel_size * kernel_size
|
||||
num_x_kernels = int(feature_map_size[0] / kernel_size)
|
||||
num_y_kernels = int(feature_map_size[1] / kernel_size)
|
||||
for kernel_x in range(0, num_x_kernels):
|
||||
for kernel_y in range(0, num_y_kernels):
|
||||
# Choose a random cell index
|
||||
cell_index = random.randint(0, num_cells_in_kernel - 1)
|
||||
# Make a rectangle in that cell
|
||||
cell_rectangle = GriddedRectangle(
|
||||
color=self.active_color,
|
||||
height=cell_width,
|
||||
width=cell_width,
|
||||
stroke_width=0.0,
|
||||
fill_opacity=0.7
|
||||
)
|
||||
# Move to the correct location
|
||||
kernel_shift_vector = [
|
||||
kernel_size * cell_width * kernel_x,
|
||||
-1 * kernel_size * cell_width * kernel_y,
|
||||
0
|
||||
]
|
||||
cell_shift_vector = [
|
||||
(cell_index % kernel_size) * cell_width,
|
||||
-1 * int(cell_index / kernel_size) * cell_width,
|
||||
0
|
||||
]
|
||||
cell_rectangle.next_to(
|
||||
gridded_rectangle.get_corners_dict()["top_left"],
|
||||
submobject_to_align=cell_rectangle.get_corners_dict()["top_left"],
|
||||
buff=0.0
|
||||
)
|
||||
cell_rectangle.shift(
|
||||
kernel_shift_vector
|
||||
)
|
||||
cell_rectangle.shift(
|
||||
cell_shift_vector
|
||||
)
|
||||
highlighted_cells.append(
|
||||
cell_rectangle
|
||||
)
|
||||
# Rotate the gridded rectangles so they match the angle
|
||||
# of the conv maps
|
||||
gridded_rectangle_group = VGroup(
|
||||
gridded_rectangle,
|
||||
*highlighted_cells
|
||||
)
|
||||
gridded_rectangle_group.rotate(
|
||||
ThreeDLayer.rotation_angle,
|
||||
about_point=gridded_rectangle.get_center(),
|
||||
axis=ThreeDLayer.rotation_axis,
|
||||
)
|
||||
gridded_rectangle.next_to(
|
||||
feature_map.get_corners_dict()["top_left"],
|
||||
submobject_to_align=gridded_rectangle.get_corners_dict()["top_left"],
|
||||
buff=0.0
|
||||
)
|
||||
# 3. Make a create gridded rectangle
|
||||
"""
|
||||
create_rectangle = Create(
|
||||
gridded_rectangle
|
||||
)
|
||||
create_gridded_rectangle_animations.append(
|
||||
create_rectangle
|
||||
)
|
||||
def add_grid_lines(rectangle):
|
||||
rectangle.color=self.active_color
|
||||
rectangle.height=cell_width * feature_map_size[1]
|
||||
rectangle.width=cell_width * feature_map_size[0]
|
||||
rectangle.grid_xstep=cell_width * kernel_size
|
||||
rectangle.grid_ystep=cell_width * kernel_size
|
||||
rectangle.grid_stroke_width=grid_stroke_width
|
||||
rectangle.grid_stroke_color=self.active_color
|
||||
rectangle.show_grid_lines=True
|
||||
|
||||
return rectangle
|
||||
|
||||
create_gridded_rectangle_animations.append(
|
||||
ApplyFunction(
|
||||
add_grid_lines,
|
||||
gridded_rectangle
|
||||
)
|
||||
)
|
||||
"""
|
||||
# 4. Create and fade out highlighted cells
|
||||
# highlighted_cells_group = VGroup()
|
||||
# NOTE: Another workaround that is hacky
|
||||
# See convolution_2d_to_convolution_2d Filters Create Override for
|
||||
# more information
|
||||
"""
|
||||
def add_highlighted_cells(object):
|
||||
for cell in highlighted_cells:
|
||||
object.add(
|
||||
cell
|
||||
)
|
||||
|
||||
return object
|
||||
|
||||
create_and_remove_cell_animation = Succession(
|
||||
ApplyFunction(add_highlighted_cells, highlighted_cells_group),
|
||||
Wait(0.5),
|
||||
FadeOut(highlighted_cells_group),
|
||||
lag_ratio=1.0
|
||||
)
|
||||
create_and_remove_cell_animations.append(
|
||||
create_and_remove_cell_animation
|
||||
)
|
||||
"""
|
||||
create_and_remove_cell_animations = Succession(
|
||||
Create(VGroup(*highlighted_cells)),
|
||||
Wait(0.5),
|
||||
Uncreate(VGroup(*highlighted_cells))
|
||||
)
|
||||
return create_and_remove_cell_animations
|
||||
# 5. Move and resize the gridded rectangle to the output
|
||||
# feature maps.
|
||||
resize_rectangle = Transform(
|
||||
gridded_rectangle,
|
||||
self.output_layer.feature_maps[feature_map_index]
|
||||
)
|
||||
move_rectangle = gridded_rectangle.animate.move_to(
|
||||
self.output_layer.feature_maps[feature_map_index]
|
||||
)
|
||||
move_and_resize = Succession(
|
||||
resize_rectangle,
|
||||
move_rectangle,
|
||||
lag_ratio=0.0
|
||||
)
|
||||
move_and_resize_gridded_rectangle_animations.append(
|
||||
move_and_resize
|
||||
)
|
||||
# 6. Make the gridded feature map(s) disappear.
|
||||
remove_gridded_rectangle_animations.append(
|
||||
Uncreate(
|
||||
gridded_rectangle_group
|
||||
)
|
||||
)
|
||||
|
||||
"""
|
||||
AnimationGroup(
|
||||
*move_and_resize_gridded_rectangle_animations
|
||||
),
|
||||
"""
|
||||
return Succession(
|
||||
# *create_gridded_rectangle_animations,
|
||||
create_and_remove_cell_animations,
|
||||
# AnimationGroup(
|
||||
# *remove_gridded_rectangle_animations
|
||||
# ),
|
||||
# lag_ratio=1.0
|
||||
lag_ratio=1.0
|
||||
)
|
@ -3,7 +3,6 @@ from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||
|
||||
|
||||
class EmbeddingToFeedForward(ConnectiveLayer):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
|
||||
@ -21,8 +20,6 @@ class EmbeddingToFeedForward(ConnectiveLayer):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
input_class=EmbeddingLayer,
|
||||
output_class=FeedForwardLayer,
|
||||
**kwargs
|
||||
)
|
||||
self.feed_forward_layer = output_layer
|
||||
|
@ -21,8 +21,6 @@ class FeedForwardToEmbedding(ConnectiveLayer):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
input_class=FeedForwardLayer,
|
||||
output_class=EmbeddingLayer,
|
||||
**kwargs
|
||||
)
|
||||
self.feed_forward_layer = input_layer
|
||||
|
@ -26,8 +26,6 @@ class FeedForwardToFeedForward(ConnectiveLayer):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
input_class=FeedForwardLayer,
|
||||
output_class=FeedForwardLayer,
|
||||
**kwargs
|
||||
)
|
||||
self.passing_flash = passing_flash
|
||||
|
@ -21,8 +21,6 @@ class FeedForwardToImage(ConnectiveLayer):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
input_class=FeedForwardLayer,
|
||||
output_class=ImageLayer,
|
||||
**kwargs
|
||||
)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
|
@ -21,8 +21,6 @@ class FeedForwardToVector(ConnectiveLayer):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
input_class=FeedForwardLayer,
|
||||
output_class=VectorLayer,
|
||||
**kwargs
|
||||
)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
|
@ -9,7 +9,6 @@ from manim_ml.neural_network.layers.parent_layers import (
|
||||
)
|
||||
from manim_ml.gridded_rectangle import GriddedRectangle
|
||||
|
||||
|
||||
class ImageToConvolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
"""Handles rendering a convolutional layer for a nn"""
|
||||
|
||||
@ -17,7 +16,10 @@ class ImageToConvolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
output_class = Convolutional2DLayer
|
||||
|
||||
def __init__(
|
||||
self, input_layer: ImageLayer, output_layer: Convolutional2DLayer, **kwargs
|
||||
self,
|
||||
input_layer: ImageLayer,
|
||||
output_layer: Convolutional2DLayer,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.input_layer = input_layer
|
||||
|
@ -21,8 +21,6 @@ class ImageToFeedForward(ConnectiveLayer):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
input_class=ImageLayer,
|
||||
output_class=FeedForwardLayer,
|
||||
**kwargs
|
||||
)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
|
@ -1,4 +1,5 @@
|
||||
from manim import *
|
||||
from manim_ml.gridded_rectangle import GriddedRectangle
|
||||
|
||||
from manim_ml.neural_network.layers.parent_layers import ThreeDLayer, VGroupNeuralNetworkLayer
|
||||
|
||||
@ -10,8 +11,18 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
to the 2 spatial dimensions of the convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, output_feature_map_size=(4, 4), kernel_size=2, stride=1,
|
||||
cell_highlight_color=ORANGE, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size=2,
|
||||
stride=1,
|
||||
cell_highlight_color=ORANGE,
|
||||
cell_width=0.2,
|
||||
filter_spacing=0.1,
|
||||
color=BLUE,
|
||||
show_grid_lines=False,
|
||||
stroke_width=2.0,
|
||||
**kwargs
|
||||
):
|
||||
"""Layer object for animating 2D Convolution Max Pooling
|
||||
|
||||
Parameters
|
||||
@ -22,20 +33,68 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
Stride of the max pooling operation, by default 1
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.output_feature_map_size = output_feature_map_size
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.cell_highlight_color = cell_highlight_color
|
||||
self.cell_width = cell_width
|
||||
self.filter_spacing = filter_spacing
|
||||
self.color = color
|
||||
self.show_grid_lines = show_grid_lines
|
||||
self.stroke_width = stroke_width
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
# Make the output feature maps
|
||||
feature_maps = self._make_output_feature_maps()
|
||||
self.add(feature_maps)
|
||||
self.feature_maps = self._make_output_feature_maps(
|
||||
input_layer.num_feature_maps,
|
||||
input_layer.feature_map_size
|
||||
)
|
||||
self.add(self.feature_maps)
|
||||
self.rotate(
|
||||
ThreeDLayer.rotation_angle,
|
||||
about_point=self.get_center(),
|
||||
axis=ThreeDLayer.rotation_axis,
|
||||
)
|
||||
self.feature_map_size = (
|
||||
input_layer.feature_map_size[0] / self.kernel_size,
|
||||
input_layer.feature_map_size[1] / self.kernel_size,
|
||||
)
|
||||
|
||||
def _make_output_feature_maps(self):
|
||||
def _make_output_feature_maps(
|
||||
self,
|
||||
num_input_feature_maps,
|
||||
input_feature_map_size
|
||||
):
|
||||
"""Makes a set of output feature maps"""
|
||||
# Compute the size of the feature maps
|
||||
pass
|
||||
output_feature_map_size = (
|
||||
input_feature_map_size[0] / self.kernel_size,
|
||||
input_feature_map_size[1] / self.kernel_size
|
||||
)
|
||||
# Draw rectangles that are filled in with opacity
|
||||
feature_maps = []
|
||||
for filter_index in range(num_input_feature_maps):
|
||||
rectangle = GriddedRectangle(
|
||||
color=self.color,
|
||||
height=output_feature_map_size[1] * self.cell_width,
|
||||
width=output_feature_map_size[0] * self.cell_width,
|
||||
fill_color=self.color,
|
||||
fill_opacity=0.2,
|
||||
stroke_color=self.color,
|
||||
stroke_width=self.stroke_width,
|
||||
grid_xstep=self.cell_width,
|
||||
grid_ystep=self.cell_width,
|
||||
grid_stroke_width=self.stroke_width / 2,
|
||||
grid_stroke_color=self.color,
|
||||
show_grid_lines=self.show_grid_lines,
|
||||
)
|
||||
# Move the feature map
|
||||
rectangle.move_to([0, 0, filter_index * self.filter_spacing])
|
||||
# rectangle.set_z_index(4)
|
||||
feature_maps.append(rectangle)
|
||||
|
||||
return VGroup(
|
||||
*feature_maps
|
||||
)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Makes forward pass of Max Pooling Layer.
|
||||
@ -45,13 +104,7 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
layer_args : dict, optional
|
||||
_description_, by default {}
|
||||
"""
|
||||
# 1. Draw gridded rectangle with kernel_size x kernel_size
|
||||
# box regions over the input feature map.
|
||||
# 2. Randomly highlight one of the cells in the kernel.
|
||||
# 3. Move and resize the gridded rectangle to the output
|
||||
# feature maps.
|
||||
# 4. Make the gridded feature map(s) disappear.
|
||||
pass
|
||||
return AnimationGroup()
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
|
@ -0,0 +1,53 @@
|
||||
import numpy as np
|
||||
from manim import *
|
||||
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D, Filters
|
||||
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
|
||||
|
||||
from manim.utils.space_ops import rotation_matrix
|
||||
|
||||
class MaxPooling2DToConvolutional2D(Convolutional2DToConvolutional2D):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
input_class = MaxPooling2DLayer
|
||||
output_class = Convolutional2DLayer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_layer: MaxPooling2DLayer,
|
||||
output_layer: Convolutional2DLayer,
|
||||
passing_flash_color=ORANGE,
|
||||
cell_width=1.0,
|
||||
stroke_width=2.0,
|
||||
show_grid_lines=False,
|
||||
**kwargs
|
||||
):
|
||||
input_layer.num_feature_maps = output_layer.num_feature_maps
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
self.passing_flash_color = passing_flash_color
|
||||
self.cell_width = cell_width
|
||||
self.stroke_width = stroke_width
|
||||
self.show_grid_lines = show_grid_lines
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: 'NeuralNetworkLayer',
|
||||
output_layer: 'NeuralNetworkLayer',
|
||||
**kwargs
|
||||
):
|
||||
"""Constructs the MaxPooling to Convolution3D layer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_layer : NeuralNetworkLayer
|
||||
input layer
|
||||
output_layer : NeuralNetworkLayer
|
||||
output layer
|
||||
"""
|
||||
pass
|
@ -21,8 +21,6 @@ class PairedQueryToFeedForward(ConnectiveLayer):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
input_class=PairedQueryLayer,
|
||||
output_class=FeedForwardLayer,
|
||||
**kwargs
|
||||
)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
|
@ -31,7 +31,7 @@ class NeuralNetworkLayer(ABC, Group):
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self):
|
||||
return AnimationGroup()
|
||||
return Succession()
|
||||
|
||||
def __repr__(self):
|
||||
return f"{type(self).__name__}"
|
||||
|
@ -21,8 +21,6 @@ class TripletToFeedForward(ConnectiveLayer):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
input_class=TripletLayer,
|
||||
output_class=FeedForwardLayer,
|
||||
**kwargs
|
||||
)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
|
@ -99,7 +99,7 @@ class NeuralNetwork(Group):
|
||||
for layer_index in range(1, len(self.input_layers)):
|
||||
previous_layer = self.input_layers[layer_index - 1]
|
||||
current_layer = self.input_layers[layer_index]
|
||||
current_layer.move_to(previous_layer)
|
||||
current_layer.move_to(previous_layer.get_center())
|
||||
# TODO Temp fix
|
||||
if isinstance(current_layer, EmbeddingLayer) or isinstance(
|
||||
previous_layer, EmbeddingLayer
|
||||
@ -156,6 +156,14 @@ class NeuralNetwork(Group):
|
||||
f"Unrecognized layout direction: {layout_direction}"
|
||||
)
|
||||
current_layer.shift(shift_vector)
|
||||
# Place activation function
|
||||
if hasattr(current_layer, "activation_function"):
|
||||
if not current_layer.activation_function is None:
|
||||
current_layer.activation_function.next_to(
|
||||
current_layer,
|
||||
direction=UP
|
||||
)
|
||||
self.add(current_layer.activation_function)
|
||||
|
||||
def _construct_connective_layers(self):
|
||||
"""Draws connecting lines between layers"""
|
||||
@ -220,8 +228,8 @@ class NeuralNetwork(Group):
|
||||
# Get the layer args
|
||||
if isinstance(layer, ConnectiveLayer):
|
||||
"""
|
||||
NOTE: By default a connective layer will get the combined
|
||||
layer_args of the layers it is connecting and itself.
|
||||
NOTE: By default a connective layer will get the combined
|
||||
layer_args of the layers it is connecting and itself.
|
||||
"""
|
||||
before_layer_args = {}
|
||||
current_layer_args = {}
|
||||
@ -244,11 +252,16 @@ class NeuralNetwork(Group):
|
||||
current_layer_args = layer_args[layer]
|
||||
# Perform the forward pass of the current layer
|
||||
layer_forward_pass = layer.make_forward_pass_animation(
|
||||
layer_args=current_layer_args, run_time=per_layer_runtime, **kwargs
|
||||
layer_args=current_layer_args,
|
||||
run_time=per_layer_runtime,
|
||||
**kwargs
|
||||
)
|
||||
all_animations.append(layer_forward_pass)
|
||||
# Make the animation group
|
||||
animation_group = Succession(*all_animations, lag_ratio=1.0)
|
||||
animation_group = Succession(
|
||||
*all_animations,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
|
@ -2,7 +2,6 @@ from manim import *
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
class GaussianDistribution(VGroup):
|
||||
"""Object for drawing a Gaussian distribution"""
|
||||
|
||||
@ -89,7 +88,7 @@ class GaussianDistribution(VGroup):
|
||||
height=ellipse_height,
|
||||
color=color,
|
||||
fill_opacity=opacity,
|
||||
stroke_width=0.0,
|
||||
stroke_width=2.0,
|
||||
)
|
||||
ellipse.move_to(mean)
|
||||
ellipse.rotate(rotation)
|
||||
|
38
tests/test_activation_function.py
Normal file
38
tests/test_activation_function.py
Normal file
@ -0,0 +1,38 @@
|
||||
from manim import *
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.image import ImageLayer
|
||||
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||
|
||||
# Make the specific scene
|
||||
config.pixel_height = 1200
|
||||
config.pixel_width = 1900
|
||||
config.frame_height = 6.0
|
||||
config.frame_width = 6.0
|
||||
|
||||
class CombinedScene(ThreeDScene):
|
||||
def construct(self):
|
||||
image = Image.open("../assets/mnist/digit.jpeg")
|
||||
numpy_image = np.asarray(image)
|
||||
# Make nn
|
||||
nn = NeuralNetwork([
|
||||
ImageLayer(numpy_image, height=1.5),
|
||||
Convolutional2DLayer(1, 7, filter_spacing=0.32),
|
||||
Convolutional2DLayer(3, 5, 3, filter_spacing=0.32, activation_function="ReLU"),
|
||||
FeedForwardLayer(3),
|
||||
],
|
||||
layer_spacing=0.25,
|
||||
)
|
||||
# Center the nn
|
||||
nn.move_to(ORIGIN)
|
||||
self.add(nn)
|
||||
# Play animation
|
||||
forward_pass = nn.make_forward_pass_animation(
|
||||
corner_pulses=False,
|
||||
all_filters_at_once=False
|
||||
)
|
||||
self.wait(1)
|
||||
self.play(forward_pass)
|
@ -34,10 +34,14 @@ class Simple3DConvScene(ThreeDScene):
|
||||
# Make nn
|
||||
layers = [
|
||||
Convolutional2DLayer(
|
||||
1, feature_map_size=3, filter_size=3
|
||||
num_feature_maps=1,
|
||||
feature_map_size=3,
|
||||
filter_size=3
|
||||
),
|
||||
Convolutional2DLayer(
|
||||
1, feature_map_size=3, filter_size=3
|
||||
num_feature_maps=1,
|
||||
feature_map_size=3,
|
||||
filter_size=3
|
||||
),
|
||||
]
|
||||
nn = NeuralNetwork(layers)
|
||||
@ -59,12 +63,11 @@ class CombinedScene(ThreeDScene):
|
||||
image = Image.open("../assets/mnist/digit.jpeg")
|
||||
numpy_image = np.asarray(image)
|
||||
# Make nn
|
||||
nn = NeuralNetwork(
|
||||
[
|
||||
nn = NeuralNetwork([
|
||||
ImageLayer(numpy_image, height=1.5),
|
||||
Convolutional2DLayer(1, 7, 3, filter_spacing=0.32),
|
||||
Convolutional2DLayer(1, 7, filter_spacing=0.32),
|
||||
Convolutional2DLayer(3, 5, 3, filter_spacing=0.32),
|
||||
Convolutional2DLayer(5, 3, 1, filter_spacing=0.18),
|
||||
Convolutional2DLayer(5, 3, 3, filter_spacing=0.18),
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(3),
|
||||
],
|
||||
|
@ -19,12 +19,36 @@ class CombinedScene(ThreeDScene):
|
||||
image = Image.open("../assets/mnist/digit.jpeg")
|
||||
numpy_image = np.asarray(image)
|
||||
# Make nn
|
||||
nn = NeuralNetwork(
|
||||
[
|
||||
nn = NeuralNetwork([
|
||||
ImageLayer(numpy_image, height=1.5),
|
||||
Convolutional2DLayer(1, 8, 8, 3, 3, filter_spacing=0.32),
|
||||
Convolutional2DLayer(1, 8, filter_spacing=0.32),
|
||||
MaxPooling2DLayer(kernel_size=2),
|
||||
Convolutional2DLayer(3, 3, 2, filter_spacing=0.32),
|
||||
],
|
||||
layer_spacing=0.25,
|
||||
)
|
||||
# Center the nn
|
||||
nn.move_to(ORIGIN)
|
||||
self.add(nn)
|
||||
self.wait(5)
|
||||
# Play animation
|
||||
forward_pass = nn.make_forward_pass_animation(
|
||||
corner_pulses=False, all_filters_at_once=False
|
||||
)
|
||||
print(forward_pass)
|
||||
print(forward_pass.animations)
|
||||
self.wait(1)
|
||||
self.play(forward_pass)
|
||||
|
||||
class SmallNetwork(ThreeDScene):
|
||||
def construct(self):
|
||||
image = Image.open("../assets/mnist/digit.jpeg")
|
||||
numpy_image = np.asarray(image)
|
||||
# Make nn
|
||||
nn = NeuralNetwork([
|
||||
ImageLayer(numpy_image, height=1.5),
|
||||
Convolutional2DLayer(1, 8, filter_spacing=0.32),
|
||||
MaxPooling2DLayer(kernel_size=2),
|
||||
Convolutional2DLayer(3, 5, 5, 3, 3, filter_spacing=0.32),
|
||||
],
|
||||
layer_spacing=0.25,
|
||||
)
|
||||
@ -36,4 +60,4 @@ class CombinedScene(ThreeDScene):
|
||||
corner_pulses=False, all_filters_at_once=False
|
||||
)
|
||||
self.wait(1)
|
||||
self.play(forward_pass)
|
||||
self.play(forward_pass)
|
19
tests/test_succession.py
Normal file
19
tests/test_succession.py
Normal file
@ -0,0 +1,19 @@
|
||||
from manim import *
|
||||
|
||||
class TestSuccession(Scene):
|
||||
|
||||
def construct(self):
|
||||
white_dot = Dot(color=WHITE)
|
||||
white_dot.shift(UP)
|
||||
|
||||
red_dot = Dot(color=RED)
|
||||
|
||||
self.play(
|
||||
Succession(
|
||||
Create(white_dot),
|
||||
white_dot.animate.shift(RIGHT),
|
||||
Create(red_dot),
|
||||
Wait(1),
|
||||
Uncreate(red_dot),
|
||||
)
|
||||
)
|
Reference in New Issue
Block a user