Bug fixes and linting for the activation functions addition.

This commit is contained in:
Alec Helbling
2023-01-25 08:40:32 -05:00
parent ce184af78e
commit f56620f047
42 changed files with 1275 additions and 387 deletions

View File

@ -8,7 +8,11 @@ class NeuralNetworkScene(Scene):
def construct(self):
# Make the Layer object
layers = [FeedForwardLayer(3), FeedForwardLayer(5), FeedForwardLayer(3)]
layers = [
FeedForwardLayer(3),
FeedForwardLayer(5),
FeedForwardLayer(3)
]
nn = NeuralNetwork(layers)
nn.scale(2)
nn.move_to(ORIGIN)

View File

@ -0,0 +1,75 @@
from pathlib import Path
from manim import *
from PIL import Image
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.image import ImageLayer
from manim_ml.neural_network.neural_network import NeuralNetwork
# Make the specific scene
config.pixel_height = 1200
config.pixel_width = 1900
config.frame_height = 7.0
config.frame_width = 7.0
ROOT_DIR = Path(__file__).parents[2]
def make_code_snippet():
code_str = """
# Make the neural network
nn = NeuralNetwork([
# ... Layers at start
Convolutional2DLayer(3, 5, 3, activation_function="ReLU"),
# ... Layers at end
])
# Play the animation
self.play(nn.make_forward_pass_animation())
"""
code = Code(
code=code_str,
tab_width=4,
background_stroke_width=1,
background_stroke_color=WHITE,
insert_line_no=False,
style="monokai",
font="Monospace",
background="window",
language="py",
)
code.scale(0.45)
return code
class CombinedScene(ThreeDScene):
def construct(self):
image = Image.open(ROOT_DIR / "assets/mnist/digit.jpeg")
numpy_image = np.asarray(image)
# Make nn
nn = NeuralNetwork(
[
ImageLayer(numpy_image, height=1.5),
Convolutional2DLayer(1, 7),
Convolutional2DLayer(3, 5, 3, activation_function="ReLU"),
Convolutional2DLayer(5, 3, 3, activation_function="ReLU"),
FeedForwardLayer(3),
FeedForwardLayer(1),
],
layer_spacing=0.25,
)
# nn.scale(0.7)
# Center the nn
nn.move_to(ORIGIN)
self.add(nn)
# Make code snippet
code = make_code_snippet()
code.next_to(nn, DOWN)
self.add(code)
nn.move_to(ORIGIN)
# Move everything up
Group(nn, code).move_to(ORIGIN)
# Play animation
forward_pass = nn.make_forward_pass_animation()
self.wait(1)
self.play(forward_pass)

View File

@ -15,7 +15,6 @@ config.frame_height = 7.0
config.frame_width = 7.0
ROOT_DIR = Path(__file__).parents[2]
def make_code_snippet():
code_str = """
# Make nn
@ -45,7 +44,6 @@ def make_code_snippet():
return code
class CombinedScene(ThreeDScene):
def construct(self):
image = Image.open(ROOT_DIR / "assets/mnist/digit.jpeg")

View File

@ -0,0 +1,90 @@
import numpy as np
from collections import deque
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import _tree as ctree
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
class AABB:
"""Axis-aligned bounding box"""
def __init__(self, n_features):
self.limits = np.array([[-np.inf, np.inf]] * n_features)
def split(self, f, v):
left = AABB(self.limits.shape[0])
right = AABB(self.limits.shape[0])
left.limits = self.limits.copy()
right.limits = self.limits.copy()
left.limits[f, 1] = v
right.limits[f, 0] = v
return left, right
def tree_bounds(tree, n_features=None):
"""Compute final decision rule for each node in tree"""
if n_features is None:
n_features = np.max(tree.feature) + 1
aabbs = [AABB(n_features) for _ in range(tree.node_count)]
queue = deque([0])
while queue:
i = queue.pop()
l = tree.children_left[i]
r = tree.children_right[i]
if l != ctree.TREE_LEAF:
aabbs[l], aabbs[r] = aabbs[i].split(tree.feature[i], tree.threshold[i])
queue.extend([l, r])
return aabbs
def decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
"""Extract decision areas.
tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier
maxrange: values to insert for [left, right, top, bottom] if the interval is open (+/-inf)
x: index of the feature that goes on the x axis
y: index of the feature that goes on the y axis
n_features: override autodetection of number of features
"""
tree = tree_classifier.tree_
aabbs = tree_bounds(tree, n_features)
maxrange = np.array(maxrange)
rectangles = []
for i in range(len(aabbs)):
if tree.children_left[i] != ctree.TREE_LEAF:
continue
l = aabbs[i].limits
r = [l[x, 0], l[x, 1], l[y, 0], l[y, 1], np.argmax(tree.value[i])]
# clip out of bounds indices
"""
if r[0] < maxrange[0]:
r[0] = maxrange[0]
if r[1] > maxrange[1]:
r[1] = maxrange[1]
if r[2] < maxrange[2]:
r[2] = maxrange[2]
if r[3] > maxrange[3]:
r[3] = maxrange[3]
print(r)
"""
rectangles.append(r)
rectangles = np.array(rectangles)
rectangles[:, [0, 2]] = np.maximum(rectangles[:, [0, 2]], maxrange[0::2])
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
return rectangles
def plot_areas(rectangles):
for rect in rectangles:
color = ["b", "r"][int(rect[4])]
print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
rp = Rectangle(
[rect[0], rect[2]],
rect[1] - rect[0],
rect[3] - rect[2],
color=color,
alpha=0.3,
)
plt.gca().add_artist(rp)

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

View File

@ -0,0 +1,690 @@
from sklearn import datasets
from decision_tree_surface import *
from manim import *
from sklearn.tree import DecisionTreeClassifier
from scipy.stats import entropy
import math
from PIL import Image
iris = datasets.load_iris()
font = "Source Han Sans"
font_scale = 0.75
images = [
Image.open("iris_dataset/SetosaFlower.jpeg"),
Image.open("iris_dataset/VeriscolorFlower.jpeg"),
Image.open("iris_dataset/VirginicaFlower.jpeg"),
]
def entropy(class_labels, base=2):
# compute the class counts
unique, counts = np.unique(class_labels, return_counts=True)
dictionary = dict(zip(unique, counts))
total = 0.0
num_samples = len(class_labels)
for class_index in range(0, 3):
if not class_index in dictionary:
continue
prob = dictionary[class_index] / num_samples
total += prob * math.log(prob, base)
# higher set
return -total
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
# get all polygons of each color
polygon_dict = {
str(BLUE).lower(): [],
str(GREEN).lower(): [],
str(ORANGE).lower(): [],
}
for polygon in all_polygons:
print(polygon_dict)
polygon_dict[str(polygon.color).lower()].append(polygon)
return_polygons = []
for color in colors:
color = str(color).lower()
polygons = polygon_dict[color]
points = set()
for polygon in polygons:
vertices = polygon.get_vertices().tolist()
vertices = [tuple(vert) for vert in vertices]
for pt in vertices:
if pt in points: # Shared vertice, remove it.
points.remove(pt)
else:
points.add(pt)
points = list(points)
sort_x = sorted(points)
sort_y = sorted(points, key=lambda x: x[1])
edges_h = {}
edges_v = {}
i = 0
while i < len(points):
curr_y = sort_y[i][1]
while i < len(points) and sort_y[i][1] == curr_y:
edges_h[sort_y[i]] = sort_y[i + 1]
edges_h[sort_y[i + 1]] = sort_y[i]
i += 2
i = 0
while i < len(points):
curr_x = sort_x[i][0]
while i < len(points) and sort_x[i][0] == curr_x:
edges_v[sort_x[i]] = sort_x[i + 1]
edges_v[sort_x[i + 1]] = sort_x[i]
i += 2
# Get all the polygons.
while edges_h:
# We can start with any point.
polygon = [(edges_h.popitem()[0], 0)]
while True:
curr, e = polygon[-1]
if e == 0:
next_vertex = edges_v.pop(curr)
polygon.append((next_vertex, 1))
else:
next_vertex = edges_h.pop(curr)
polygon.append((next_vertex, 0))
if polygon[-1] == polygon[0]:
# Closed polygon
polygon.pop()
break
# Remove implementation-markers from the polygon.
poly = [point for point, _ in polygon]
for vertex in poly:
if vertex in edges_h:
edges_h.pop(vertex)
if vertex in edges_v:
edges_v.pop(vertex)
polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0)
return_polygons.append(polygon)
return return_polygons
class IrisDatasetPlot(VGroup):
def __init__(self):
points = iris.data[:, 0:2]
labels = iris.feature_names
targets = iris.target
# Make points
self.point_group = self._make_point_group(points, targets)
# Make axes
self.axes_group = self._make_axes_group(points, labels)
# Make legend
self.legend_group = self._make_legend(
[BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
)
# Make title
# title_text = "Iris Dataset Plot"
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
# Make all group
self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
# scale the groups
self.point_group.scale(1.6)
self.point_group.match_x(self.axes_group)
self.point_group.match_y(self.axes_group)
self.point_group.shift([0.2, 0, 0])
self.axes_group.scale(0.7)
self.all_group.shift([0, 0.2, 0])
@override_animation(Create)
def create_animation(self):
animation_group = AnimationGroup(
# Perform the animations
Create(self.point_group, run_time=2),
Wait(0.5),
Create(self.axes_group, run_time=2),
# add title
# Create(self.title),
Create(self.legend_group),
)
return animation_group
def _make_point_group(self, points, targets, class_colors=[BLUE, ORANGE, GREEN]):
point_group = VGroup()
for point_index, point in enumerate(points):
# draw the dot
current_target = targets[point_index]
color = class_colors[current_target]
dot = Dot(point=np.array([point[0], point[1], 0])).set_color(color)
dot.scale(0.5)
point_group.add(dot)
return point_group
def _make_legend(self, class_colors, feature_labels, axes):
legend_group = VGroup()
# Make Text
setosa = Text("Setosa", color=BLUE)
verisicolor = Text("Verisicolor", color=ORANGE)
virginica = Text("Virginica", color=GREEN)
labels = VGroup(setosa, verisicolor, virginica).arrange(
direction=RIGHT, aligned_edge=LEFT, buff=2.0
)
labels.scale(0.5)
legend_group.add(labels)
# surrounding rectangle
surrounding_rectangle = SurroundingRectangle(labels, color=WHITE)
surrounding_rectangle.move_to(labels)
legend_group.add(surrounding_rectangle)
# shift the legend group
legend_group.move_to(axes)
legend_group.shift([0, -3.0, 0])
legend_group.match_x(axes[0][0])
return legend_group
def _make_axes_group(self, points, labels):
axes_group = VGroup()
# make the axes
x_range = [
np.amin(points, axis=0)[0] - 0.2,
np.amax(points, axis=0)[0] - 0.2,
0.5,
]
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
axes = Axes(
x_range=x_range,
y_range=y_range,
x_length=9,
y_length=6.5,
# axis_config={"number_scale_value":0.75, "include_numbers":True},
tips=False,
).shift([0.5, 0.25, 0])
axes_group.add(axes)
# make axis labels
# x_label
x_label = (
Text(labels[0], font=font)
.match_y(axes.get_axes()[0])
.shift([0.5, -0.75, 0])
.scale(font_scale)
)
axes_group.add(x_label)
# y_label
y_label = (
Text(labels[1], font=font)
.match_x(axes.get_axes()[1])
.shift([-0.75, 0, 0])
.rotate(np.pi / 2)
.scale(font_scale)
)
axes_group.add(y_label)
return axes_group
class DecisionTreeSurface(VGroup):
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
# take the tree and construct the surface from it
self.tree_clf = tree_clf
self.data = data
self.axes = axes
self.class_colors = class_colors
self.surface_rectangles = self.generate_surface_rectangles()
def generate_surface_rectangles(self):
# compute data bounds
left = np.amin(self.data[:, 0]) - 0.2
right = np.amax(self.data[:, 0]) - 0.2
top = np.amax(self.data[:, 1])
bottom = np.amin(self.data[:, 1]) - 0.2
maxrange = [left, right, bottom, top]
rectangles = decision_areas(self.tree_clf, maxrange, x=0, y=1, n_features=2)
# turn the rectangle objects into manim rectangles
def convert_rectangle_to_polygon(rect):
# get the points for the rectangle in the plot coordinate frame
bottom_left = [rect[0], rect[3]]
bottom_right = [rect[1], rect[3]]
top_right = [rect[1], rect[2]]
top_left = [rect[0], rect[2]]
# convert those points into the entire manim coordinates
bottom_left_coord = self.axes.coords_to_point(*bottom_left)
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
top_right_coord = self.axes.coords_to_point(*top_right)
top_left_coord = self.axes.coords_to_point(*top_left)
points = [
bottom_left_coord,
bottom_right_coord,
top_right_coord,
top_left_coord,
]
# construct a polygon object from those manim coordinates
rectangle = Polygon(
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
)
return rectangle
manim_rectangles = []
for rect in rectangles:
color = self.class_colors[int(rect[4])]
rectangle = convert_rectangle_to_polygon(rect)
manim_rectangles.append(rectangle)
manim_rectangles = merge_overlapping_polygons(
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
)
return manim_rectangles
@override_animation(Create)
def create_override(self):
# play a reveal of all of the surface rectangles
animations = []
for rectangle in self.surface_rectangles:
animations.append(Create(rectangle))
animation_group = AnimationGroup(*animations)
return animation_group
@override_animation(Uncreate)
def uncreate_override(self):
# play a reveal of all of the surface rectangles
animations = []
for rectangle in self.surface_rectangles:
animations.append(Uncreate(rectangle))
animation_group = AnimationGroup(*animations)
return animation_group
class DecisionTree:
"""
Draw a single tree node
"""
def _make_node(
self,
feature,
threshold,
values,
is_leaf=False,
depth=0,
leaf_colors=[BLUE, ORANGE, GREEN],
):
if not is_leaf:
node_text = f"{feature}\n <= {threshold:.3f} cm"
# draw decision text
decision_text = Text(node_text, color=WHITE)
# draw a box
bounding_box = SurroundingRectangle(decision_text, buff=0.3, color=WHITE)
node = VGroup()
node.add(bounding_box)
node.add(decision_text)
# return the node
else:
# plot the appropriate image
class_index = np.argmax(values)
# get image
pil_image = images[class_index]
leaf_group = Group()
node = ImageMobject(pil_image)
node.scale(1.5)
rectangle = Rectangle(
width=node.width + 0.05,
height=node.height + 0.05,
color=leaf_colors[class_index],
stroke_width=6,
)
rectangle.move_to(node.get_center())
rectangle.shift([-0.02, 0.02, 0])
leaf_group.add(rectangle)
leaf_group.add(node)
node = leaf_group
return node
def _make_connection(self, top, bottom, is_leaf=False):
top_node_bottom_location = top.get_center()
top_node_bottom_location[1] -= top.height / 2
bottom_node_top_location = bottom.get_center()
bottom_node_top_location[1] += bottom.height / 2
line = Line(top_node_bottom_location, bottom_node_top_location, color=WHITE)
return line
def _make_tree(self, tree, feature_names=["Sepal Length", "Sepal Width"]):
tree_group = Group()
max_depth = tree.max_depth
# make the base node
feature_name = feature_names[tree.feature[0]]
threshold = tree.threshold[0]
values = tree.value[0]
nodes_map = {}
root_node = self._make_node(feature_name, threshold, values, depth=0)
nodes_map[0] = root_node
tree_group.add(root_node)
# save some information
node_height = root_node.height
node_width = root_node.width
scale_factor = 1.0
edge_map = {}
# tree height
tree_height = scale_factor * node_height * max_depth
tree_width = scale_factor * 2**max_depth * node_width
# traverse tree
def recurse(node, depth, direction, parent_object, parent_node):
# make sure it is a valid node
# make the node object
is_leaf = tree.children_left[node] == tree.children_right[node]
feature_name = feature_names[tree.feature[node]]
threshold = tree.threshold[node]
values = tree.value[node]
node_object = self._make_node(
feature_name, threshold, values, depth=depth, is_leaf=is_leaf
)
nodes_map[node] = node_object
node_height = node_object.height
# set the node position
direction_factor = -1 if direction == "left" else 1
shift_right_amount = (
0.8 * direction_factor * scale_factor * tree_width / (2**depth) / 2
)
if is_leaf:
shift_down_amount = -1.0 * scale_factor * node_height
else:
shift_down_amount = -1.8 * scale_factor * node_height
node_object.match_x(parent_object).match_y(parent_object).shift(
[shift_right_amount, shift_down_amount, 0]
)
tree_group.add(node_object)
# make a connection
connection = self._make_connection(
parent_object, node_object, is_leaf=is_leaf
)
edge_name = str(parent_node) + "," + str(node)
edge_map[edge_name] = connection
tree_group.add(connection)
# recurse
if not is_leaf:
recurse(tree.children_left[node], depth + 1, "left", node_object, node)
recurse(
tree.children_right[node], depth + 1, "right", node_object, node
)
recurse(tree.children_left[0], 1, "left", root_node, 0)
recurse(tree.children_right[0], 1, "right", root_node, 0)
tree_group.scale(0.35)
return tree_group, nodes_map, edge_map
def color_example_path(
self, tree_group, nodes_map, tree, edge_map, example, color=YELLOW, thickness=2
):
# get decision path
decision_path = tree.decision_path(example)[0]
path_indices = decision_path.indices
# highlight edges
for node_index in range(0, len(path_indices) - 1):
current_val = path_indices[node_index]
next_val = path_indices[node_index + 1]
edge_str = str(current_val) + "," + str(next_val)
edge = edge_map[edge_str]
animation_two = AnimationGroup(
nodes_map[current_val].animate.set_color(color)
)
self.play(animation_two, run_time=0.5)
animation_one = AnimationGroup(
edge.animate.set_color(color),
# edge.animate.set_stroke_width(4),
)
self.play(animation_one, run_time=0.5)
# surround the bottom image
last_path_index = path_indices[-1]
last_path_rectangle = nodes_map[last_path_index][0]
self.play(last_path_rectangle.animate.set_color(color))
def create_sklearn_tree(self, max_tree_depth=1):
# learn the decision tree
iris = load_iris()
tree = learn_iris_decision_tree(iris, max_depth=max_tree_depth)
feature_names = iris.feature_names[0:2]
return tree.tree_
def make_tree(self, max_tree_depth=2):
sklearn_tree = self.create_sklearn_tree()
# make the tree
tree_group, nodes_map, edge_map = self._make_tree(
sklearn_tree.tree_, feature_names
)
tree_group.shift([0, 5.5, 0])
return tree_group
# self.add(tree_group)
# self.play(SpinInFromNothing(tree_group), run_time=3)
# self.color_example_path(tree_group, nodes_map, tree, edge_map, iris.data[None, 0, 0:2])
class DecisionTreeSplitScene(Scene):
def make_decision_tree_classifier(self, max_depth=4):
decision_tree = DecisionTreeClassifier(
random_state=1, max_depth=max_depth, max_leaf_nodes=8
)
decision_tree = decision_tree.fit(iris.data[:, :2], iris.target)
# output the decisioin tree in some format
return decision_tree
def make_split_animation(self, data, classes, data_labels, main_axes):
"""
def make_entropy_animation_and_plot(dim=0, num_entropy_values=50):
# calculate the entropy values
axes_group = VGroup()
# make axes
range_vals = [np.amin(data, axis=0)[dim], np.amax(data, axis=0)[dim]]
axes = Axes(x_range=range_vals,
y_range=[0, 1.0],
x_length=9,
y_length=4,
# axis_config={"number_scale_value":0.75, "include_numbers":True},
tips=False,
)
axes_group.add(axes)
# make axis labels
# x_label
x_label = Text(data_labels[dim], font=font) \
.match_y(axes.get_axes()[0]) \
.shift([0.5, -0.75, 0]) \
.scale(font_scale*1.2)
axes_group.add(x_label)
# y_label
y_label = Text("Information Gain", font=font) \
.match_x(axes.get_axes()[1]) \
.shift([-0.75, 0, 0]) \
.rotate(np.pi / 2) \
.scale(font_scale * 1.2)
axes_group.add(y_label)
# line animation
information_gains = []
def entropy_function(split_value):
# lower entropy
lower_set = np.nonzero(data[:, dim] <= split_value)[0]
lower_set = classes[lower_set]
lower_entropy = entropy(lower_set)
# higher entropy
higher_set = np.nonzero(data[:, dim] > split_value)[0]
higher_set = classes[higher_set]
higher_entropy = entropy(higher_set)
# calculate entropies
all_entropy = entropy(classes, base=2)
lower_entropy = entropy(lower_set, base=2)
higher_entropy = entropy(higher_set, base=2)
mean_entropy = (lower_entropy + higher_entropy) / 2
# calculate information gain
lower_prob = len(lower_set) / len(data[:, dim])
higher_prob = len(higher_set) / len(data[:, dim])
info_gain = all_entropy - (lower_prob * lower_entropy + higher_prob * higher_entropy)
information_gains.append((split_value, info_gain))
return info_gain
data_range = np.amin(data[:, dim]), np.amax(data[:, dim])
entropy_graph = axes.get_graph(
entropy_function,
# color=RED,
# x_range=data_range
)
axes_group.add(entropy_graph)
axes_group.shift([4.0, 2, 0])
axes_group.scale(0.5)
dot_animation = Dot(color=WHITE)
axes_group.add(dot_animation)
# make animations
animation_group = AnimationGroup(
Create(axes_group, run_time=2),
Wait(3),
MoveAlongPath(dot_animation, entropy_graph, run_time=20, rate_func=rate_functions.ease_in_out_quad),
Wait(2)
)
return axes_group, animation_group, information_gains
"""
def make_split_line_animation(dim=0):
# make a line along one of the dims and move it up and down
origin_coord = [
np.amin(data, axis=0)[0] - 0.2,
np.amin(data, axis=0)[1] - 0.2,
]
origin_point = main_axes.coords_to_point(*origin_coord)
top_left_coord = [origin_coord[0], np.amax(data, axis=0)[1]]
bottom_right_coord = [np.amax(data, axis=0)[0] - 0.2, origin_coord[1]]
if dim == 0:
other_coord = top_left_coord
moving_line_coord = bottom_right_coord
else:
other_coord = bottom_right_coord
moving_line_coord = top_left_coord
other_point = main_axes.coords_to_point(*other_coord)
moving_line_point = main_axes.coords_to_point(*moving_line_coord)
moving_line = Line(origin_point, other_point, color=RED)
movement_line = Line(origin_point, moving_line_point)
if dim == 0:
movement_line.shift([0, moving_line.height / 2, 0])
else:
movement_line.shift([moving_line.width / 2, 0, 0])
# move the moving line along the movement line
animation = MoveAlongPath(
moving_line,
movement_line,
run_time=20,
rate_func=rate_functions.ease_in_out_quad,
)
return animation, moving_line
# plot the line in white then make it invisible
# make an animation along the line
# make a
# axes_one_group, top_animation_group, info_gains = make_entropy_animation_and_plot(dim=0)
line_movement, first_moving_line = make_split_line_animation(dim=0)
# axes_two_group, bottom_animation_group, _ = make_entropy_animation_and_plot(dim=1)
second_line_movement, second_moving_line = make_split_line_animation(dim=1)
# axes_two_group.shift([0, -3, 0])
animation_group_one = AnimationGroup(
# top_animation_group,
line_movement,
)
animation_group_two = AnimationGroup(
# bottom_animation_group,
second_line_movement,
)
"""
both_axes_group = VGroup(
axes_one_group,
axes_two_group
)
"""
return (
animation_group_one,
animation_group_two,
first_moving_line,
second_moving_line,
None,
None,
)
# both_axes_group, \
# info_gains
def construct(self):
# make the points
iris_dataset_plot = IrisDatasetPlot()
iris_dataset_plot.all_group.scale(1.0)
iris_dataset_plot.all_group.shift([-3, 0.2, 0])
# make the entropy line graph
# entropy_line_graph = self.draw_entropy_line_graph()
# arrange the plots
# do animations
self.play(Create(iris_dataset_plot))
# make the decision tree classifier
decision_tree_classifier = self.make_decision_tree_classifier()
decision_tree_surface = DecisionTreeSurface(
decision_tree_classifier, iris.data, iris_dataset_plot.axes_group[0]
)
self.play(Create(decision_tree_surface))
self.wait(3)
self.play(Uncreate(decision_tree_surface))
main_axes = iris_dataset_plot.axes_group[0]
(
split_animation_one,
split_animation_two,
first_moving_line,
second_moving_line,
both_axes_group,
info_gains,
) = self.make_split_animation(
iris.data[:, 0:2], iris.target, iris.feature_names, main_axes
)
self.play(split_animation_one)
self.wait(0.1)
self.play(Uncreate(first_moving_line))
self.wait(3)
self.play(split_animation_two)
self.wait(0.1)
self.play(Uncreate(second_moving_line))
self.wait(0.1)
# highlight the maximum on top
# sort by second key
"""
highest_info_gain = sorted(info_gains, key=lambda x: x[1])[-1]
highest_info_gain_point = both_axes_group[0][0].coords_to_point(*highest_info_gain)
highlighted_peak = Dot(highest_info_gain_point, color=YELLOW)
# get location of highest info gain point
highest_info_gain_point_in_iris_graph = iris_dataset_plot.axes_group[0].coords_to_point(*[highest_info_gain[0], 0])
first_moving_line.start[0] = highest_info_gain_point_in_iris_graph[0]
first_moving_line.end[0] = highest_info_gain_point_in_iris_graph[0]
self.play(Create(highlighted_peak))
self.play(Create(first_moving_line))
text = Text("Highest Information Gain")
text.scale(0.4)
text.move_to(highlighted_peak)
text.shift([0, 0.5, 0])
self.play(Create(text))
"""
self.wait(1)
# draw the basic tree
decision_tree_classifier = self.make_decision_tree_classifier(max_depth=1)
decision_tree_surface = DecisionTreeSurface(
decision_tree_classifier, iris.data, iris_dataset_plot.axes_group[0]
)
decision_tree_graph, _, _ = DecisionTree()._make_tree(
decision_tree_classifier.tree_
)
decision_tree_graph.match_y(iris_dataset_plot.axes_group)
decision_tree_graph.shift([4, 0, 0])
self.play(Create(decision_tree_surface))
uncreate_animation = AnimationGroup(
# Uncreate(both_axes_group),
# Uncreate(highlighted_peak),
Uncreate(second_moving_line),
# Unwrite(text)
)
self.play(uncreate_animation)
self.wait(0.5)
self.play(FadeIn(decision_tree_graph))
# self.play(FadeIn(highlighted_peak))
self.wait(5)

View File

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="1920pt" height="1080pt" viewBox="0 0 1920 1080" version="1.1">
<g id="surface6">
</g>
</svg>

After

Width:  |  Height:  |  Size: 222 B

View File

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="1920pt" height="1080pt" viewBox="0 0 1920 1080" version="1.1">
<g id="surface1">
</g>
</svg>

After

Width:  |  Height:  |  Size: 222 B

View File

@ -3,6 +3,7 @@ import numpy as np
from collections import deque
from sklearn.tree import _tree as ctree
class AABB:
"""Axis-aligned bounding box"""
@ -19,6 +20,7 @@ class AABB:
return left, right
def tree_bounds(tree, n_features=None):
"""Compute final decision rule for each node in tree"""
if n_features is None:
@ -34,6 +36,7 @@ def tree_bounds(tree, n_features=None):
queue.extend([l, r])
return aabbs
def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
"""Extract decision areas.
@ -70,6 +73,7 @@ def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None)
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
return rectangles
def plot_areas(rectangles):
for rect in rectangles:
color = ["b", "r"][int(rect[4])]
@ -83,6 +87,7 @@ def plot_areas(rectangles):
)
plt.gca().add_artist(rp)
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
# get all polygons of each color
polygon_dict = {
@ -156,6 +161,7 @@ def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
return_polygons.append(polygon)
return return_polygons
class IrisDatasetPlot(VGroup):
def __init__(self, iris):
points = iris.data[:, 0:2]
@ -269,7 +275,6 @@ class IrisDatasetPlot(VGroup):
class DecisionTreeSurface(VGroup):
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
# take the tree and construct the surface from it
self.tree_clf = tree_clf
@ -346,8 +351,8 @@ class DecisionTreeSurface(VGroup):
def make_split_to_animation_map(self):
"""
Returns a dictionary mapping a given split
node to an animation to be played
Returns a dictionary mapping a given split
node to an animation to be played
"""
# Create an initial decision tree surface
# Go through each split node

View File

@ -26,6 +26,7 @@ def compute_node_depths(tree):
return node_depths
def compute_level_order_traversal(tree):
"""Computes level order traversal of a sklearn tree"""
@ -56,6 +57,7 @@ def compute_level_order_traversal(tree):
return sorted_inds
def compute_bfs_traversal(tree):
"""Traverses the tree in BFS order and returns the nodes in order"""
traversal_order = []
@ -73,10 +75,12 @@ def compute_bfs_traversal(tree):
return traversal_order
def compute_best_first_traversal(tree):
"""Traverses the tree according to the best split first order"""
pass
def compute_node_to_parent_mapping(tree):
"""Returns a hashmap mapping node indices to their parent indices"""
node_to_parent = {0: -1} # Root has no parent

View File

@ -9,6 +9,7 @@ from tqdm import tqdm
from manim_ml.probability import GaussianDistribution
def gaussian_proposal(x, sigma=0.2):
"""
Gaussian proposal distribution.
@ -39,7 +40,8 @@ def gaussian_proposal(x, sigma=0.2):
return (x_star, qxx)
class MultidimensionalGaussianPosterior():
class MultidimensionalGaussianPosterior:
"""
N-Dimensional Gaussian distribution with
@ -49,8 +51,7 @@ class MultidimensionalGaussianPosterior():
Prior on mean is U(-500, 500)
"""
def __init__(self, ndim=2, seed=12345, scale=3,
mu=None, var=None):
def __init__(self, ndim=2, seed=12345, scale=3, mu=None, var=None):
"""_summary_
Parameters
@ -71,10 +72,7 @@ class MultidimensionalGaussianPosterior():
self.var = var
if mu is None:
self.mu = scipy.stats.norm(
loc=0,
scale=self.scale
).rvs(ndim)
self.mu = scipy.stats.norm(loc=0, scale=self.scale).rvs(ndim)
else:
self.mu = mu
@ -84,20 +82,18 @@ class MultidimensionalGaussianPosterior():
"""
if np.all(x < 500) and np.all(x > -500):
return scipy.stats.multivariate_normal(
mean=self.mu,
cov=self.var
).logpdf(x)
return scipy.stats.multivariate_normal(mean=self.mu, cov=self.var).logpdf(x)
else:
return -1e6
def metropolis_hastings_sampler(
log_prob_fn=MultidimensionalGaussianPosterior(),
prop_fn=gaussian_proposal,
initial_location : np.ndarray = np.array([0, 0]),
log_prob_fn=MultidimensionalGaussianPosterior(),
prop_fn=gaussian_proposal,
initial_location: np.ndarray = np.array([0, 0]),
iterations=25,
warm_up=0,
ndim=2
ndim=2,
):
"""Samples using a Metropolis-Hastings sampler.
@ -158,17 +154,18 @@ def metropolis_hastings_sampler(
return chain, np.array([]), proposals
class MCMCAxes(Group):
"""Container object for visualizing MCMC on a 2D axis"""
def __init__(
self,
dot_color=BLUE,
self,
dot_color=BLUE,
dot_radius=0.05,
accept_line_color=GREEN,
reject_line_color=RED,
line_color=WHITE,
line_stroke_width=1
line_stroke_width=1,
):
super().__init__()
self.dot_color = dot_color
@ -176,7 +173,7 @@ class MCMCAxes(Group):
self.accept_line_color = accept_line_color
self.reject_line_color = reject_line_color
self.line_color = line_color
self.line_stroke_width=line_stroke_width
self.line_stroke_width = line_stroke_width
# Make the axes
self.axes = Axes(
x_range=[-3, 3],
@ -185,22 +182,16 @@ class MCMCAxes(Group):
y_length=12,
x_axis_config={"stroke_opacity": 0.0},
y_axis_config={"stroke_opacity": 0.0},
tips=False
tips=False,
)
self.add(self.axes)
@override_animation(Create)
def _create_override(self, **kwargs):
"""Overrides Create animation"""
return AnimationGroup(
Create(self.axes)
)
return AnimationGroup(Create(self.axes))
def visualize_gaussian_proposal_about_point(
self,
mean,
cov=None
) -> AnimationGroup:
def visualize_gaussian_proposal_about_point(self, mean, cov=None) -> AnimationGroup:
"""Creates a Gaussian distribution about a certain point
Parameters
@ -216,21 +207,14 @@ class MCMCAxes(Group):
animation of creating the proposal Gaussian distribution
"""
gaussian = GaussianDistribution(
axes=self.axes,
mean=mean,
cov=cov,
dist_theme="gaussian"
axes=self.axes, mean=mean, cov=cov, dist_theme="gaussian"
)
create_guassian = Create(gaussian)
return create_guassian
def make_transition_animation(
self,
start_point,
end_point,
candidate_point,
run_time=0.1
self, start_point, end_point, candidate_point, run_time=0.1
) -> AnimationGroup:
"""Makes an transition animation for a single point on a Markov Chain
@ -255,38 +239,27 @@ class MCMCAxes(Group):
if point_is_rejected:
return AnimationGroup()
else:
create_end_point = Create(
end_point
)
create_end_point = Create(end_point)
create_line = Create(
Line(
start_point,
end_point,
color=self.line_color,
stroke_width=self.line_stroke_width
stroke_width=self.line_stroke_width,
)
)
return AnimationGroup(
create_end_point,
create_line,
lag_ratio=1.0,
run_time=run_time
create_end_point, create_line, lag_ratio=1.0, run_time=run_time
)
def show_ground_truth_gaussian(self, distribution):
"""
"""
""" """
mean = distribution.mu
var = np.eye(2) * distribution.var
distribution_drawing = GaussianDistribution(
self.axes,
mean,
var,
dist_theme="gaussian"
self.axes, mean, var, dist_theme="gaussian"
).set_opacity(0.2)
return AnimationGroup(
Create(distribution_drawing)
)
return AnimationGroup(Create(distribution_drawing))
def visualize_metropolis_hastings_chain_sampling(
self,
@ -295,7 +268,7 @@ class MCMCAxes(Group):
sampling_kwargs={},
):
"""
Makes an animation for visualizing a 2D markov chain using
Makes an animation for visualizing a 2D markov chain using
metropolis hastings samplings
Parameters
@ -318,20 +291,18 @@ class MCMCAxes(Group):
"""
# Compute the chain samples using a Metropolis Hastings Sampler
mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler(
log_prob_fn=log_prob_fn,
prop_fn=prop_fn,
**sampling_kwargs
log_prob_fn=log_prob_fn, prop_fn=prop_fn, **sampling_kwargs
)
print(f"MCMC samples: {mcmc_samples}")
print(f"Candidate samples: {candidate_samples}")
# Make the animation for visualizing the chain
# Make the animation for visualizing the chain
animations = []
# Place the initial point
current_point = mcmc_samples[0]
current_point = Dot(
self.axes.coords_to_point(current_point[0], current_point[1]),
color=self.dot_color,
radius=self.dot_radius
radius=self.dot_radius,
)
create_initial_point = Create(current_point)
animations.append(create_initial_point)
@ -346,12 +317,12 @@ class MCMCAxes(Group):
next_point = Dot(
self.axes.coords_to_point(next_sample[0], next_sample[1]),
color=self.dot_color,
radius=self.dot_radius
radius=self.dot_radius,
)
candidate_point = Dot(
self.axes.coords_to_point(candidate_sample[0], candidate_sample[1]),
color=self.dot_color,
radius=self.dot_radius
radius=self.dot_radius,
)
# Make a transition animation
transition_animation = self.make_transition_animation(
@ -361,9 +332,6 @@ class MCMCAxes(Group):
# Setup for next iteration
current_point = next_point
# Make the final animation group
animation_group = AnimationGroup(
*animations,
lag_ratio=1.0
)
animation_group = AnimationGroup(*animations, lag_ratio=1.0)
return animation_group

View File

@ -1,8 +1,7 @@
from manim_ml.neural_network.activation_functions.relu import ReLUFunction
name_to_activation_function_map = {
"ReLU": ReLUFunction()
}
name_to_activation_function_map = {"ReLU": ReLUFunction}
def get_activation_function_by_name(name):
return name_to_activation_function_map[name]
return name_to_activation_function_map[name]

View File

@ -4,12 +4,22 @@ import random
import manim_ml.neural_network.activation_functions.relu as relu
class ActivationFunction(ABC, VGroup):
"""Abstract parent class for defining activation functions"""
def __init__(self, function_name=None, x_range=[-1, 1], y_range=[-1, 1],
x_length=0.5, y_length=0.3, show_function_name=True, active_color=ORANGE,
plot_color=BLUE, rectangle_color=WHITE):
def __init__(
self,
function_name=None,
x_range=[-1, 1],
y_range=[-1, 1],
x_length=0.5,
y_length=0.3,
show_function_name=True,
active_color=ORANGE,
plot_color=BLUE,
rectangle_color=WHITE,
):
super(VGroup, self).__init__()
self.function_name = function_name
self.x_range = x_range
@ -25,7 +35,7 @@ class ActivationFunction(ABC, VGroup):
def construct_activation_function(self):
"""Makes the activation function"""
# Make an axis
# Make an axis
self.axes = Axes(
x_range=self.x_range,
y_range=self.y_range,
@ -35,17 +45,17 @@ class ActivationFunction(ABC, VGroup):
axis_config={
"include_numbers": False,
"stroke_width": 0.5,
"include_ticks": False
}
"include_ticks": False,
},
)
self.add(self.axes)
# Surround the axis with a rounded rectangle.
# Surround the axis with a rounded rectangle.
self.surrounding_rectangle = SurroundingRectangle(
self.axes,
corner_radius=0.05,
buff=0.05,
stroke_width=2.0,
stroke_color=self.rectangle_color
stroke_color=self.rectangle_color,
)
self.add(self.surrounding_rectangle)
# Plot function on axis by applying it and showing in given range
@ -53,17 +63,15 @@ class ActivationFunction(ABC, VGroup):
lambda x: self.apply_function(x),
x_range=self.x_range,
stroke_color=self.plot_color,
stroke_width=2.0
stroke_width=2.0,
)
self.add(self.graph)
# Add the function name
if self.show_function_name:
function_name_text = Text(
self.function_name,
font_size=12,
font="sans-serif"
self.function_name, font_size=12, font="sans-serif"
)
function_name_text.next_to(self.axes, UP*0.5)
function_name_text.next_to(self.axes, UP * 0.5)
self.add(function_name_text)
@abstractmethod
@ -78,29 +86,21 @@ class ActivationFunction(ABC, VGroup):
# TODO: Evaluate the function at the x_val and show a highlighted dot
animation_group = Succession(
AnimationGroup(
ApplyMethod(self.graph.set_color, self.active_color),
ApplyMethod(
self.graph.set_color,
self.active_color
self.surrounding_rectangle.set_stroke_color, self.active_color
),
ApplyMethod(
self.surrounding_rectangle.set_stroke_color,
self.active_color
),
lag_ratio=0.0
lag_ratio=0.0,
),
Wait(1),
AnimationGroup(
ApplyMethod(self.graph.set_color, self.plot_color),
ApplyMethod(
self.graph.set_color,
self.plot_color
self.surrounding_rectangle.set_stroke_color, self.rectangle_color
),
ApplyMethod(
self.surrounding_rectangle.set_stroke_color,
self.rectangle_color
),
lag_ratio=0.0
lag_ratio=0.0,
),
lag_ratio=1.0
lag_ratio=1.0,
)
return animation_group
return animation_group

View File

@ -1,6 +1,9 @@
from manim import *
from manim_ml.neural_network.activation_functions.activation_function import ActivationFunction
from manim_ml.neural_network.activation_functions.activation_function import (
ActivationFunction,
)
class ReLUFunction(ActivationFunction):
"""Rectified Linear Unit Activation Function"""
@ -12,4 +15,4 @@ class ReLUFunction(ActivationFunction):
if x_val < 0:
return 0
else:
return x_val
return x_val

View File

@ -1,11 +1,15 @@
from manim_ml.neural_network.layers.convolutional_2d_to_feed_forward import (
Convolutional2DToFeedForward,
)
from manim_ml.neural_network.layers.convolutional_2d_to_max_pooling_2d import Convolutional2DToMaxPooling2D
from manim_ml.neural_network.layers.convolutional_2d_to_max_pooling_2d import (
Convolutional2DToMaxPooling2D,
)
from manim_ml.neural_network.layers.image_to_convolutional_2d import (
ImageToConvolutional2DLayer,
)
from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import MaxPooling2DToConvolutional2D
from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import (
MaxPooling2DToConvolutional2D,
)
from .convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D
from .convolutional_2d import Convolutional2DLayer
from .feed_forward_to_vector import FeedForwardToVector

View File

@ -1,6 +1,8 @@
from typing import Union
from manim_ml.neural_network.activation_functions import get_activation_function_by_name
from manim_ml.neural_network.activation_functions.activation_function import ActivationFunction
from manim_ml.neural_network.activation_functions.activation_function import (
ActivationFunction,
)
import numpy as np
from manim import *
@ -10,6 +12,7 @@ from manim_ml.neural_network.layers.parent_layers import (
)
from manim_ml.gridded_rectangle import GriddedRectangle
class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
"""Handles rendering a convolutional layer for a nn"""
@ -33,11 +36,11 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
self.num_feature_maps = num_feature_maps
self.filter_color = filter_color
if isinstance(feature_map_size, int):
self.feature_map_size = (feature_map_size, feature_map_size)
self.feature_map_size = (feature_map_size, feature_map_size)
else:
self.feature_map_size = feature_map_size
if isinstance(filter_size, int):
self.filter_size = (filter_size, filter_size)
self.filter_size = (filter_size, filter_size)
else:
self.filter_size = filter_size
self.cell_width = cell_width
@ -50,10 +53,10 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
self.activation_function = activation_function
def construct_layer(
self,
input_layer: 'NeuralNetworkLayer',
output_layer: 'NeuralNetworkLayer',
**kwargs
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs,
):
# Make the feature maps
self.feature_maps = self.construct_feature_maps()
@ -71,7 +74,7 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
if isinstance(self.activation_function, str):
activation_function = get_activation_function_by_name(
self.activation_function
)
)()
else:
assert isinstance(self.activation_function, ActivationFunction)
activation_function = self.activation_function
@ -108,14 +111,8 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
def highlight_and_unhighlight_feature_maps(self):
"""Highlights then unhighlights feature maps"""
return Succession(
ApplyMethod(
self.feature_maps.set_color,
self.pulse_color
),
ApplyMethod(
self.feature_maps.set_color,
self.color
)
ApplyMethod(self.feature_maps.set_color, self.pulse_color),
ApplyMethod(self.feature_maps.set_color, self.color),
)
def make_forward_pass_animation(
@ -146,7 +143,7 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
animation_group = AnimationGroup(
self.activation_function.make_evaluate_animation(),
self.highlight_and_unhighlight_feature_maps(),
lag_ratio=0.0
lag_ratio=0.0,
)
else:
animation_group = AnimationGroup()
@ -160,12 +157,19 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
def get_center(self):
"""Overrides function for getting center
The reason for this is so that the center calculation
does not include the activation function.
The reason for this is so that the center calculation
does not include the activation function.
"""
print("Getting center")
return self.feature_maps.get_center()
def get_width(self):
"""Overrides get width function"""
return self.feature_maps.length_over_dim(0)
def get_height(self):
"""Overrides get height function"""
return self.feature_maps.length_over_dim(1)
@override_animation(Create)
def _create_override(self, **kwargs):
return FadeIn(self.feature_maps)

View File

@ -7,16 +7,14 @@ from manim_ml.gridded_rectangle import GriddedRectangle
from manim.utils.space_ops import rotation_matrix
def get_rotated_shift_vectors(input_layer, normalized=False):
"""Rotates the shift vectors"""
# Make base shift vectors
right_shift = np.array([input_layer.cell_width, 0, 0])
down_shift = np.array([0, -input_layer.cell_width, 0])
# Make rotation matrix
rot_mat = rotation_matrix(
ThreeDLayer.rotation_angle,
ThreeDLayer.rotation_axis
)
rot_mat = rotation_matrix(ThreeDLayer.rotation_angle, ThreeDLayer.rotation_axis)
# Rotate the vectors
right_shift = np.dot(right_shift, rot_mat.T)
down_shift = np.dot(down_shift, rot_mat.T)
@ -27,6 +25,7 @@ def get_rotated_shift_vectors(input_layer, normalized=False):
return right_shift, down_shift
class Filters(VGroup):
"""Group for showing a collection of filters connecting two layers"""
@ -61,8 +60,12 @@ class Filters(VGroup):
def make_input_feature_map_rectangles(self):
rectangles = []
rectangle_width = self.output_layer.filter_size[0] * self.output_layer.cell_width
rectangle_height = self.output_layer.filter_size[1] * self.output_layer.cell_width
rectangle_width = (
self.output_layer.filter_size[0] * self.output_layer.cell_width
)
rectangle_height = (
self.output_layer.filter_size[1] * self.output_layer.cell_width
)
filter_color = self.output_layer.filter_color
for index, feature_map in enumerate(self.input_layer.feature_maps):
@ -263,8 +266,10 @@ class Filters(VGroup):
return passing_flash
class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
"""Feed Forward to Embedding Layer"""
input_class = Convolutional2DLayer
output_class = Convolutional2DLayer
@ -301,7 +306,12 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
self.show_grid_lines = show_grid_lines
self.highlight_color = highlight_color
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs,
):
return super().construct_layer(input_layer, output_layer, **kwargs)
def animate_filters_all_at_once(self, filters):
@ -321,8 +331,12 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
right_shift, down_shift = get_rotated_shift_vectors(self.input_layer)
left_shift = -1 * right_shift
# Make the animation
num_y_moves = int((self.feature_map_size[1] - self.filter_size[1]) / self.stride)
num_x_moves = int((self.feature_map_size[0] - self.filter_size[0]) / self.stride)
num_y_moves = int(
(self.feature_map_size[1] - self.filter_size[1]) / self.stride
)
num_x_moves = int(
(self.feature_map_size[0] - self.filter_size[0]) / self.stride
)
for y_move in range(num_y_moves):
# Go right num_x_moves
for x_move in range(num_x_moves):
@ -347,10 +361,7 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
animations.append(FadeOut(filters))
return Succession(*animations, lag_ratio=1.0)
def animate_filters_one_at_a_time(
self,
highlight_active_feature_map=True
):
def animate_filters_one_at_a_time(self, highlight_active_feature_map=True):
"""Animates each of the filters one at a time"""
animations = []
output_feature_maps = self.output_layer.feature_maps
@ -418,18 +429,12 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
self.stride * num_x_moves * left_shift + self.stride * down_shift
)
# Make the animation
shift_animation = ApplyMethod(
filters.shift,
shift_amount
)
shift_animation = ApplyMethod(filters.shift, shift_amount)
animations.append(shift_animation)
# Do last row move right
for x_move in range(num_x_moves):
# Shift right
shift_animation = ApplyMethod(
filters.shift,
self.stride * right_shift
)
shift_animation = ApplyMethod(filters.shift, self.stride * right_shift)
# shift_animation = self.animate.shift(right_shift)
animations.append(shift_animation)
# Remove the filters
@ -440,18 +445,14 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
# Change the output feature map colors
change_color_animations = []
change_color_animations.append(
ApplyMethod(
feature_map.set_color,
original_feature_map_color
)
ApplyMethod(feature_map.set_color, original_feature_map_color)
)
# Change the input feature map colors
input_feature_maps = self.input_layer.feature_maps
for input_feature_map in input_feature_maps:
change_color_animations.append(
ApplyMethod(
input_feature_map.set_color,
original_feature_map_color
input_feature_map.set_color, original_feature_map_color
)
)
# Combine the animations

View File

@ -17,14 +17,15 @@ class Convolutional2DToFeedForward(ConnectiveLayer, ThreeDLayer):
passing_flash_color=ORANGE,
**kwargs
):
super().__init__(
input_layer,
output_layer,
**kwargs
)
super().__init__(input_layer, output_layer, **kwargs)
self.passing_flash_color = passing_flash_color
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):

View File

@ -1,15 +1,19 @@
import random
from manim import *
from manim_ml.gridded_rectangle import GriddedRectangle
from manim_ml.neural_network.layers.convolutional_2d_to_convolutional_2d import get_rotated_shift_vectors
from manim_ml.neural_network.layers.convolutional_2d_to_convolutional_2d import (
get_rotated_shift_vectors,
)
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
"""Feed Forward to Embedding Layer"""
input_class = Convolutional2DLayer
output_class = MaxPooling2DLayer
@ -20,27 +24,18 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
active_color=ORANGE,
**kwargs
):
super().__init__(
input_layer,
output_layer,
**kwargs
)
super().__init__(input_layer, output_layer, **kwargs)
self.active_color = active_color
def construct_layer(
self,
input_layer: 'NeuralNetworkLayer',
output_layer: 'NeuralNetworkLayer',
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)
def make_forward_pass_animation(
self,
layer_args={},
run_time=1.5,
**kwargs
):
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
"""Forward pass animation from conv2d to max pooling"""
cell_width = self.input_layer.cell_width
feature_map_size = self.input_layer.feature_map_size
@ -61,7 +56,7 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
remove_gridded_rectangle_animations = []
for feature_map_index, feature_map in enumerate(feature_maps):
# 1. Draw gridded rectangle with kernel_size x kernel_size
# 1. Draw gridded rectangle with kernel_size x kernel_size
# box regions over the input feature maps.
gridded_rectangle = GriddedRectangle(
color=self.active_color,
@ -71,7 +66,7 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
grid_ystep=cell_width * kernel_size,
grid_stroke_width=grid_stroke_width,
grid_stroke_color=self.active_color,
show_grid_lines=True
show_grid_lines=True,
)
# 2. Randomly highlight one of the cells in the kernel.
highlighted_cells = []
@ -88,39 +83,32 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
height=cell_width,
width=cell_width,
stroke_width=0.0,
fill_opacity=0.7
fill_opacity=0.7,
)
# Move to the correct location
kernel_shift_vector = [
kernel_size * cell_width * kernel_x,
-1 * kernel_size * cell_width * kernel_y,
0
0,
]
cell_shift_vector = [
(cell_index % kernel_size) * cell_width,
-1 * int(cell_index / kernel_size) * cell_width,
0
0,
]
cell_rectangle.next_to(
gridded_rectangle.get_corners_dict()["top_left"],
submobject_to_align=cell_rectangle.get_corners_dict()["top_left"],
buff=0.0
submobject_to_align=cell_rectangle.get_corners_dict()[
"top_left"
],
buff=0.0,
)
cell_rectangle.shift(
kernel_shift_vector
)
cell_rectangle.shift(
cell_shift_vector
)
highlighted_cells.append(
cell_rectangle
)
# Rotate the gridded rectangles so they match the angle
cell_rectangle.shift(kernel_shift_vector)
cell_rectangle.shift(cell_shift_vector)
highlighted_cells.append(cell_rectangle)
# Rotate the gridded rectangles so they match the angle
# of the conv maps
gridded_rectangle_group = VGroup(
gridded_rectangle,
*highlighted_cells
)
gridded_rectangle_group = VGroup(gridded_rectangle, *highlighted_cells)
gridded_rectangle_group.rotate(
ThreeDLayer.rotation_angle,
about_point=gridded_rectangle.get_center(),
@ -129,9 +117,9 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
gridded_rectangle.next_to(
feature_map.get_corners_dict()["top_left"],
submobject_to_align=gridded_rectangle.get_corners_dict()["top_left"],
buff=0.0
buff=0.0,
)
# 3. Make a create gridded rectangle
# 3. Make a create gridded rectangle
"""
create_rectangle = Create(
gridded_rectangle
@ -185,31 +173,24 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
create_and_remove_cell_animations = Succession(
Create(VGroup(*highlighted_cells)),
Wait(0.5),
Uncreate(VGroup(*highlighted_cells))
Uncreate(VGroup(*highlighted_cells)),
)
return create_and_remove_cell_animations
# 5. Move and resize the gridded rectangle to the output
# feature maps.
# 5. Move and resize the gridded rectangle to the output
# feature maps.
resize_rectangle = Transform(
gridded_rectangle,
self.output_layer.feature_maps[feature_map_index]
gridded_rectangle, self.output_layer.feature_maps[feature_map_index]
)
move_rectangle = gridded_rectangle.animate.move_to(
self.output_layer.feature_maps[feature_map_index]
)
move_and_resize = Succession(
resize_rectangle,
move_rectangle,
lag_ratio=0.0
resize_rectangle, move_rectangle, lag_ratio=0.0
)
move_and_resize_gridded_rectangle_animations.append(
move_and_resize
)
# 6. Make the gridded feature map(s) disappear.
move_and_resize_gridded_rectangle_animations.append(move_and_resize)
# 6. Make the gridded feature map(s) disappear.
remove_gridded_rectangle_animations.append(
Uncreate(
gridded_rectangle_group
)
Uncreate(gridded_rectangle_group)
)
"""
@ -224,5 +205,5 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
# *remove_gridded_rectangle_animations
# ),
# lag_ratio=1.0
lag_ratio=1.0
)
lag_ratio=1.0,
)

View File

@ -25,9 +25,9 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
self.paired_query_mode = paired_query_mode
def construct_layer(
self,
input_layer: 'NeuralNetworkLayer',
output_layer: 'NeuralNetworkLayer',
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
self.axes = Axes(
@ -43,14 +43,13 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
self.axes.move_to(self.get_center())
# Make point cloud
self.point_cloud = self.construct_gaussian_point_cloud(
self.mean,
self.covariance
self.mean, self.covariance
)
self.add(self.point_cloud)
# Make latent distribution
self.latent_distribution = GaussianDistribution(
self.axes, mean=self.mean, cov=self.covariance
) # Use defaults
) # Use defaults
def add_gaussian_distribution(self, gaussian_distribution):
"""Adds given GaussianDistribution to the list"""

View File

@ -3,6 +3,7 @@ from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
class EmbeddingToFeedForward(ConnectiveLayer):
"""Feed Forward to Embedding Layer"""
@ -17,17 +18,18 @@ class EmbeddingToFeedForward(ConnectiveLayer):
dot_radius=0.03,
**kwargs
):
super().__init__(
input_layer,
output_layer,
**kwargs
)
super().__init__(input_layer, output_layer, **kwargs)
self.feed_forward_layer = output_layer
self.embedding_layer = input_layer
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):

View File

@ -35,7 +35,12 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
self.node_group = VGroup()
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
"""Creates the neural network layer"""
# Add Nodes
for node_number in range(self.num_nodes):

View File

@ -18,17 +18,18 @@ class FeedForwardToEmbedding(ConnectiveLayer):
dot_radius=0.03,
**kwargs
):
super().__init__(
input_layer,
output_layer,
**kwargs
)
super().__init__(input_layer, output_layer, **kwargs)
self.feed_forward_layer = input_layer
self.embedding_layer = output_layer
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):

View File

@ -5,6 +5,7 @@ from manim import *
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
class FeedForwardToFeedForward(ConnectiveLayer):
"""Layer for connecting FeedForward layer to FeedForwardLayer"""
@ -23,18 +24,19 @@ class FeedForwardToFeedForward(ConnectiveLayer):
camera=None,
**kwargs
):
super().__init__(
input_layer,
output_layer,
**kwargs
)
super().__init__(input_layer, output_layer, **kwargs)
self.passing_flash = passing_flash
self.edge_color = edge_color
self.dot_radius = dot_radius
self.animation_dot_color = animation_dot_color
self.edge_width = edge_width
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
self.edges = self.construct_edges()
self.add(self.edges)

View File

@ -18,18 +18,19 @@ class FeedForwardToImage(ConnectiveLayer):
dot_radius=0.05,
**kwargs
):
super().__init__(
input_layer,
output_layer,
**kwargs
)
super().__init__(input_layer, output_layer, **kwargs)
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
self.feed_forward_layer = input_layer
self.image_layer = output_layer
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)
def make_forward_pass_animation(self, layer_args={}, **kwargs):

View File

@ -18,18 +18,19 @@ class FeedForwardToVector(ConnectiveLayer):
dot_radius=0.05,
**kwargs
):
super().__init__(
input_layer,
output_layer,
**kwargs
)
super().__init__(input_layer, output_layer, **kwargs)
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
self.feed_forward_layer = input_layer
self.vector_layer = output_layer
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)
def make_forward_pass_animation(self, layer_args={}, **kwargs):

View File

@ -5,6 +5,7 @@ from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer
from PIL import Image
class ImageLayer(NeuralNetworkLayer):
"""Single Image Layer for Neural Network"""
@ -19,23 +20,22 @@ class ImageLayer(NeuralNetworkLayer):
Parameters
----------
input_layer :
input_layer :
Input layer
output_layer :
output_layer :
Output layer
"""
if len(np.shape(self.numpy_image)) == 2:
# Assumed Grayscale
self.num_channels = 1
self.image_mobject = GrayscaleImageMobject(
self.numpy_image,
height=self.image_height
self.numpy_image, height=self.image_height
)
elif len(np.shape(self.numpy_image)) == 3:
# Assumed RGB
self.num_channels = 3
self.image_mobject = ImageMobject(self.numpy_image).scale_to_fit_height(
height
self.image_height
)
self.add(self.image_mobject)
@ -63,10 +63,6 @@ class ImageLayer(NeuralNetworkLayer):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
return AnimationGroup()
# def move_to(self, location):
# """Override of move to"""
# self.image_mobject.move_to(location)
def get_right(self):
"""Override get right"""
return self.image_mobject.get_right()

View File

@ -9,6 +9,7 @@ from manim_ml.neural_network.layers.parent_layers import (
)
from manim_ml.gridded_rectangle import GriddedRectangle
class ImageToConvolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
"""Handles rendering a convolutional layer for a nn"""
@ -16,16 +17,18 @@ class ImageToConvolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
output_class = Convolutional2DLayer
def __init__(
self,
input_layer: ImageLayer,
output_layer: Convolutional2DLayer,
**kwargs
self, input_layer: ImageLayer, output_layer: Convolutional2DLayer, **kwargs
):
super().__init__(input_layer, output_layer, **kwargs)
self.input_layer = input_layer
self.output_layer = output_layer
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs,
):
return super().construct_layer(input_layer, output_layer, **kwargs)
def make_forward_pass_animation(self, run_time=5, layer_args={}, **kwargs):

View File

@ -18,18 +18,19 @@ class ImageToFeedForward(ConnectiveLayer):
dot_radius=0.05,
**kwargs
):
super().__init__(
input_layer,
output_layer,
**kwargs
)
super().__init__(input_layer, output_layer, **kwargs)
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
self.feed_forward_layer = output_layer
self.image_layer = input_layer
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)
def make_forward_pass_animation(self, layer_args={}, **kwargs):

View File

@ -1,27 +1,31 @@
from manim import *
from manim_ml.gridded_rectangle import GriddedRectangle
from manim_ml.neural_network.layers.parent_layers import ThreeDLayer, VGroupNeuralNetworkLayer
from manim_ml.neural_network.layers.parent_layers import (
ThreeDLayer,
VGroupNeuralNetworkLayer,
)
class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
"""Max pooling layer for Convolutional2DLayer
Note: This is for a Convolutional2DLayer even though
it is called MaxPooling2DLayer because the 2D corresponds
to the 2 spatial dimensions of the convolution.
to the 2 spatial dimensions of the convolution.
"""
def __init__(
self,
kernel_size=2,
stride=1,
cell_highlight_color=ORANGE,
cell_width=0.2,
filter_spacing=0.1,
color=BLUE,
show_grid_lines=False,
stroke_width=2.0,
**kwargs
self,
kernel_size=2,
stride=1,
cell_highlight_color=ORANGE,
cell_width=0.2,
filter_spacing=0.1,
color=BLUE,
show_grid_lines=False,
stroke_width=2.0,
**kwargs
):
"""Layer object for animating 2D Convolution Max Pooling
@ -42,11 +46,15 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
self.show_grid_lines = show_grid_lines
self.stroke_width = stroke_width
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
# Make the output feature maps
self.feature_maps = self._make_output_feature_maps(
input_layer.num_feature_maps,
input_layer.feature_map_size
input_layer.num_feature_maps, input_layer.feature_map_size
)
self.add(self.feature_maps)
self.rotate(
@ -58,17 +66,13 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
input_layer.feature_map_size[0] / self.kernel_size,
input_layer.feature_map_size[1] / self.kernel_size,
)
def _make_output_feature_maps(
self,
num_input_feature_maps,
input_feature_map_size
):
def _make_output_feature_maps(self, num_input_feature_maps, input_feature_map_size):
"""Makes a set of output feature maps"""
# Compute the size of the feature maps
# Compute the size of the feature maps
output_feature_map_size = (
input_feature_map_size[0] / self.kernel_size,
input_feature_map_size[1] / self.kernel_size
input_feature_map_size[1] / self.kernel_size,
)
# Draw rectangles that are filled in with opacity
feature_maps = []
@ -92,12 +96,10 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
# rectangle.set_z_index(4)
feature_maps.append(rectangle)
return VGroup(
*feature_maps
)
return VGroup(*feature_maps)
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Makes forward pass of Max Pooling Layer.
"""Makes forward pass of Max Pooling Layer.
Parameters
----------

View File

@ -1,7 +1,10 @@
import numpy as np
from manim import *
from manim_ml.neural_network.layers.convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D, Filters
from manim_ml.neural_network.layers.convolutional_2d_to_convolutional_2d import (
Convolutional2DToConvolutional2D,
Filters,
)
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
@ -9,8 +12,10 @@ from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
from manim.utils.space_ops import rotation_matrix
class MaxPooling2DToConvolutional2D(Convolutional2DToConvolutional2D):
"""Feed Forward to Embedding Layer"""
input_class = MaxPooling2DLayer
output_class = Convolutional2DLayer
@ -25,20 +30,16 @@ class MaxPooling2DToConvolutional2D(Convolutional2DToConvolutional2D):
**kwargs
):
input_layer.num_feature_maps = output_layer.num_feature_maps
super().__init__(
input_layer,
output_layer,
**kwargs
)
super().__init__(input_layer, output_layer, **kwargs)
self.passing_flash_color = passing_flash_color
self.cell_width = cell_width
self.stroke_width = stroke_width
self.show_grid_lines = show_grid_lines
def construct_layer(
self,
input_layer: 'NeuralNetworkLayer',
output_layer: 'NeuralNetworkLayer',
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
"""Constructs the MaxPooling to Convolution3D layer

View File

@ -22,7 +22,12 @@ class PairedQueryLayer(NeuralNetworkLayer):
self.add(self.assets)
self.add(self.title)
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)
@classmethod

View File

@ -18,18 +18,19 @@ class PairedQueryToFeedForward(ConnectiveLayer):
dot_radius=0.02,
**kwargs
):
super().__init__(
input_layer,
output_layer,
**kwargs
)
super().__init__(input_layer, output_layer, **kwargs)
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
self.paired_query_layer = input_layer
self.feed_forward_layer = output_layer
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)
def make_forward_pass_animation(self, layer_args={}, **kwargs):

View File

@ -1,6 +1,7 @@
from manim import *
from abc import ABC, abstractmethod
class NeuralNetworkLayer(ABC, Group):
"""Abstract Neural Network Layer class"""
@ -12,8 +13,12 @@ class NeuralNetworkLayer(ABC, Group):
# self.add(self.title)
@abstractmethod
def construct_layer(self, input_layer: 'NeuralNetworkLayer',
output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs,
):
"""Constructs the layer at network construction time
Parameters
@ -36,6 +41,7 @@ class NeuralNetworkLayer(ABC, Group):
def __repr__(self):
return f"{type(self).__name__}"
class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -49,6 +55,7 @@ class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
def _create_override(self):
return super()._create_override()
class ThreeDLayer(ABC):
"""Abstract class for 3D layers"""
@ -58,6 +65,7 @@ class ThreeDLayer(ABC):
rotation_angle = 60 * DEGREES
rotation_axis = [0.0, 0.9, 0.0]
class ConnectiveLayer(VGroupNeuralNetworkLayer):
"""Forward pass animation for a given pair of layers"""
@ -86,6 +94,7 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
+ ")"
)
class BlankConnective(ConnectiveLayer):
"""Connective layer to be used when the given pair of layers is undefined"""
@ -97,4 +106,4 @@ class BlankConnective(ConnectiveLayer):
@override_animation(Create)
def _create_override(self):
return super()._create_override()
return super()._create_override()

View File

@ -26,7 +26,12 @@ class TripletLayer(NeuralNetworkLayer):
self.stroke_width = stroke_width
self.font_size = font_size
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
# Make the assets
self.assets = self.make_assets()
self.add(self.assets)

View File

@ -18,18 +18,19 @@ class TripletToFeedForward(ConnectiveLayer):
dot_radius=0.02,
**kwargs
):
super().__init__(
input_layer,
output_layer,
**kwargs
)
super().__init__(input_layer, output_layer, **kwargs)
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
self.feed_forward_layer = output_layer
self.triplet_layer = input_layer
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
return super().construct_layer(input_layer, output_layer, **kwargs)
def make_forward_pass_animation(self, layer_args={}, **kwargs):

View File

@ -4,9 +4,10 @@ from manim import *
from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer
from manim_ml.neural_network.layers import connective_layers_list
def get_connective_layer(input_layer, output_layer):
"""
Deduces the relevant connective layer
Deduces the relevant connective layer
"""
connective_layer_class = None
for candidate_class in connective_layers_list:

View File

@ -12,7 +12,12 @@ class VectorLayer(VGroupNeuralNetworkLayer):
self.num_values = num_values
self.value_func = value_func
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs,
):
# Make the vector
self.vector_label = self.make_vector()
self.add(self.vector_label)

View File

@ -22,6 +22,7 @@ from manim_ml.neural_network.neural_network_transformations import (
RemoveLayer,
)
class NeuralNetwork(Group):
"""Neural Network Visualization Container Class"""
@ -53,17 +54,11 @@ class NeuralNetwork(Group):
# Construct all of the layers
self._construct_input_layers()
# Place the layers
self._place_layers(
layout=layout,
layout_direction=layout_direction
)
self._place_layers(layout=layout, layout_direction=layout_direction)
# Make the connective layers
self.connective_layers, self.all_layers = self._construct_connective_layers()
# Make overhead title
self.title = Text(
self.title_text,
font_size=DEFAULT_FONT_SIZE / 2
)
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 2)
self.title.next_to(self, UP, 1.0)
self.add(self.title)
# Place layers at correct z index
@ -76,7 +71,7 @@ class NeuralNetwork(Group):
print(repr(self))
def _construct_input_layers(self):
"""Constructs each of the input layers in context
"""Constructs each of the input layers in context
of their adjacent layers"""
prev_layer = None
next_layer = None
@ -105,64 +100,82 @@ class NeuralNetwork(Group):
previous_layer, EmbeddingLayer
):
if layout_direction == "left_to_right":
shift_vector = np.array([
(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
- 0.2
),
0,
0,
])
shift_vector = np.array(
[
(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
- 0.2
),
0,
0,
]
)
elif layout_direction == "top_to_bottom":
shift_vector = np.array([
0,
-(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
- 0.2
),
0,
])
shift_vector = np.array(
[
0,
-(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
- 0.2
),
0,
]
)
else:
raise Exception(
f"Unrecognized layout direction: {layout_direction}"
)
else:
if layout_direction == "left_to_right":
shift_vector = np.array([
(
shift_vector = np.array(
[
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
)
+ self.layer_spacing,
0,
0,
])
+ self.layer_spacing,
0,
0,
]
)
elif layout_direction == "top_to_bottom":
shift_vector = np.array([
0,
-(
(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
)
+ self.layer_spacing
),
0,
])
shift_vector = np.array(
[
0,
-(
(
previous_layer.get_width() / 2
+ current_layer.get_width() / 2
)
+ self.layer_spacing
),
0,
]
)
else:
raise Exception(
f"Unrecognized layout direction: {layout_direction}"
)
current_layer.shift(shift_vector)
# After all layers have been placed place their activation functions
for current_layer in self.input_layers:
# Place activation function
if hasattr(current_layer, "activation_function"):
if not current_layer.activation_function is None:
current_layer.activation_function.next_to(
current_layer,
direction=UP
up_movement = np.array(
[
0,
current_layer.get_height() / 2
+ current_layer.activation_function.get_height() / 2
+ 0.5 * self.layer_spacing,
0,
]
)
current_layer.activation_function.move_to(
current_layer,
)
current_layer.activation_function.shift(up_movement)
self.add(current_layer.activation_function)
def _construct_connective_layers(self):
@ -228,8 +241,8 @@ class NeuralNetwork(Group):
# Get the layer args
if isinstance(layer, ConnectiveLayer):
"""
NOTE: By default a connective layer will get the combined
layer_args of the layers it is connecting and itself.
NOTE: By default a connective layer will get the combined
layer_args of the layers it is connecting and itself.
"""
before_layer_args = {}
current_layer_args = {}
@ -252,16 +265,11 @@ class NeuralNetwork(Group):
current_layer_args = layer_args[layer]
# Perform the forward pass of the current layer
layer_forward_pass = layer.make_forward_pass_animation(
layer_args=current_layer_args,
run_time=per_layer_runtime,
**kwargs
layer_args=current_layer_args, run_time=per_layer_runtime, **kwargs
)
all_animations.append(layer_forward_pass)
# Make the animation group
animation_group = Succession(
*all_animations,
lag_ratio=1.0
)
animation_group = Succession(*all_animations, lag_ratio=1.0)
return animation_group
@ -332,6 +340,7 @@ class NeuralNetwork(Group):
string_repr = "NeuralNetwork([\n" + inner_string + "])"
return string_repr
class FeedForwardNeuralNetwork(NeuralNetwork):
"""NeuralNetwork with just feed forward layers"""

View File

@ -2,6 +2,7 @@ from manim import *
import numpy as np
import math
class GaussianDistribution(VGroup):
"""Object for drawing a Gaussian distribution"""