mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-25 00:40:54 +08:00
Bug fixes and linting for the activation functions addition.
This commit is contained in:
@ -8,7 +8,11 @@ class NeuralNetworkScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
# Make the Layer object
|
||||
layers = [FeedForwardLayer(3), FeedForwardLayer(5), FeedForwardLayer(3)]
|
||||
layers = [
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(5),
|
||||
FeedForwardLayer(3)
|
||||
]
|
||||
nn = NeuralNetwork(layers)
|
||||
nn.scale(2)
|
||||
nn.move_to(ORIGIN)
|
75
examples/cnn/activation_functions.py
Normal file
75
examples/cnn/activation_functions.py
Normal file
@ -0,0 +1,75 @@
|
||||
from pathlib import Path
|
||||
|
||||
from manim import *
|
||||
from PIL import Image
|
||||
|
||||
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.image import ImageLayer
|
||||
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||
|
||||
# Make the specific scene
|
||||
config.pixel_height = 1200
|
||||
config.pixel_width = 1900
|
||||
config.frame_height = 7.0
|
||||
config.frame_width = 7.0
|
||||
ROOT_DIR = Path(__file__).parents[2]
|
||||
|
||||
def make_code_snippet():
|
||||
code_str = """
|
||||
# Make the neural network
|
||||
nn = NeuralNetwork([
|
||||
# ... Layers at start
|
||||
Convolutional2DLayer(3, 5, 3, activation_function="ReLU"),
|
||||
# ... Layers at end
|
||||
])
|
||||
# Play the animation
|
||||
self.play(nn.make_forward_pass_animation())
|
||||
"""
|
||||
|
||||
code = Code(
|
||||
code=code_str,
|
||||
tab_width=4,
|
||||
background_stroke_width=1,
|
||||
background_stroke_color=WHITE,
|
||||
insert_line_no=False,
|
||||
style="monokai",
|
||||
font="Monospace",
|
||||
background="window",
|
||||
language="py",
|
||||
)
|
||||
code.scale(0.45)
|
||||
|
||||
return code
|
||||
|
||||
class CombinedScene(ThreeDScene):
|
||||
def construct(self):
|
||||
image = Image.open(ROOT_DIR / "assets/mnist/digit.jpeg")
|
||||
numpy_image = np.asarray(image)
|
||||
# Make nn
|
||||
nn = NeuralNetwork(
|
||||
[
|
||||
ImageLayer(numpy_image, height=1.5),
|
||||
Convolutional2DLayer(1, 7),
|
||||
Convolutional2DLayer(3, 5, 3, activation_function="ReLU"),
|
||||
Convolutional2DLayer(5, 3, 3, activation_function="ReLU"),
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(1),
|
||||
],
|
||||
layer_spacing=0.25,
|
||||
)
|
||||
# nn.scale(0.7)
|
||||
# Center the nn
|
||||
nn.move_to(ORIGIN)
|
||||
self.add(nn)
|
||||
# Make code snippet
|
||||
code = make_code_snippet()
|
||||
code.next_to(nn, DOWN)
|
||||
self.add(code)
|
||||
nn.move_to(ORIGIN)
|
||||
# Move everything up
|
||||
Group(nn, code).move_to(ORIGIN)
|
||||
# Play animation
|
||||
forward_pass = nn.make_forward_pass_animation()
|
||||
self.wait(1)
|
||||
self.play(forward_pass)
|
@ -15,7 +15,6 @@ config.frame_height = 7.0
|
||||
config.frame_width = 7.0
|
||||
ROOT_DIR = Path(__file__).parents[2]
|
||||
|
||||
|
||||
def make_code_snippet():
|
||||
code_str = """
|
||||
# Make nn
|
||||
@ -45,7 +44,6 @@ def make_code_snippet():
|
||||
|
||||
return code
|
||||
|
||||
|
||||
class CombinedScene(ThreeDScene):
|
||||
def construct(self):
|
||||
image = Image.open(ROOT_DIR / "assets/mnist/digit.jpeg")
|
||||
|
90
examples/decision_tree/decision_tree_surface.py
Normal file
90
examples/decision_tree/decision_tree_surface.py
Normal file
@ -0,0 +1,90 @@
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.tree import _tree as ctree
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Rectangle
|
||||
|
||||
|
||||
class AABB:
|
||||
"""Axis-aligned bounding box"""
|
||||
|
||||
def __init__(self, n_features):
|
||||
self.limits = np.array([[-np.inf, np.inf]] * n_features)
|
||||
|
||||
def split(self, f, v):
|
||||
left = AABB(self.limits.shape[0])
|
||||
right = AABB(self.limits.shape[0])
|
||||
left.limits = self.limits.copy()
|
||||
right.limits = self.limits.copy()
|
||||
left.limits[f, 1] = v
|
||||
right.limits[f, 0] = v
|
||||
|
||||
return left, right
|
||||
|
||||
|
||||
def tree_bounds(tree, n_features=None):
|
||||
"""Compute final decision rule for each node in tree"""
|
||||
if n_features is None:
|
||||
n_features = np.max(tree.feature) + 1
|
||||
aabbs = [AABB(n_features) for _ in range(tree.node_count)]
|
||||
queue = deque([0])
|
||||
while queue:
|
||||
i = queue.pop()
|
||||
l = tree.children_left[i]
|
||||
r = tree.children_right[i]
|
||||
if l != ctree.TREE_LEAF:
|
||||
aabbs[l], aabbs[r] = aabbs[i].split(tree.feature[i], tree.threshold[i])
|
||||
queue.extend([l, r])
|
||||
return aabbs
|
||||
|
||||
|
||||
def decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
|
||||
"""Extract decision areas.
|
||||
|
||||
tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier
|
||||
maxrange: values to insert for [left, right, top, bottom] if the interval is open (+/-inf)
|
||||
x: index of the feature that goes on the x axis
|
||||
y: index of the feature that goes on the y axis
|
||||
n_features: override autodetection of number of features
|
||||
"""
|
||||
tree = tree_classifier.tree_
|
||||
aabbs = tree_bounds(tree, n_features)
|
||||
maxrange = np.array(maxrange)
|
||||
rectangles = []
|
||||
for i in range(len(aabbs)):
|
||||
if tree.children_left[i] != ctree.TREE_LEAF:
|
||||
continue
|
||||
l = aabbs[i].limits
|
||||
r = [l[x, 0], l[x, 1], l[y, 0], l[y, 1], np.argmax(tree.value[i])]
|
||||
# clip out of bounds indices
|
||||
"""
|
||||
if r[0] < maxrange[0]:
|
||||
r[0] = maxrange[0]
|
||||
if r[1] > maxrange[1]:
|
||||
r[1] = maxrange[1]
|
||||
if r[2] < maxrange[2]:
|
||||
r[2] = maxrange[2]
|
||||
if r[3] > maxrange[3]:
|
||||
r[3] = maxrange[3]
|
||||
print(r)
|
||||
"""
|
||||
rectangles.append(r)
|
||||
rectangles = np.array(rectangles)
|
||||
rectangles[:, [0, 2]] = np.maximum(rectangles[:, [0, 2]], maxrange[0::2])
|
||||
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
|
||||
return rectangles
|
||||
|
||||
|
||||
def plot_areas(rectangles):
|
||||
for rect in rectangles:
|
||||
color = ["b", "r"][int(rect[4])]
|
||||
print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
|
||||
rp = Rectangle(
|
||||
[rect[0], rect[2]],
|
||||
rect[1] - rect[0],
|
||||
rect[3] - rect[2],
|
||||
color=color,
|
||||
alpha=0.3,
|
||||
)
|
||||
plt.gca().add_artist(rp)
|
BIN
examples/decision_tree/iris_dataset/SetosaFlower.jpeg
Normal file
BIN
examples/decision_tree/iris_dataset/SetosaFlower.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 23 KiB |
BIN
examples/decision_tree/iris_dataset/VeriscolorFlower.jpeg
Normal file
BIN
examples/decision_tree/iris_dataset/VeriscolorFlower.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
BIN
examples/decision_tree/iris_dataset/VirginicaFlower.jpeg
Normal file
BIN
examples/decision_tree/iris_dataset/VirginicaFlower.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 30 KiB |
690
examples/decision_tree/split_scene.py
Normal file
690
examples/decision_tree/split_scene.py
Normal file
@ -0,0 +1,690 @@
|
||||
from sklearn import datasets
|
||||
from decision_tree_surface import *
|
||||
from manim import *
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from scipy.stats import entropy
|
||||
import math
|
||||
from PIL import Image
|
||||
|
||||
iris = datasets.load_iris()
|
||||
font = "Source Han Sans"
|
||||
font_scale = 0.75
|
||||
|
||||
images = [
|
||||
Image.open("iris_dataset/SetosaFlower.jpeg"),
|
||||
Image.open("iris_dataset/VeriscolorFlower.jpeg"),
|
||||
Image.open("iris_dataset/VirginicaFlower.jpeg"),
|
||||
]
|
||||
|
||||
|
||||
def entropy(class_labels, base=2):
|
||||
# compute the class counts
|
||||
unique, counts = np.unique(class_labels, return_counts=True)
|
||||
dictionary = dict(zip(unique, counts))
|
||||
total = 0.0
|
||||
num_samples = len(class_labels)
|
||||
for class_index in range(0, 3):
|
||||
if not class_index in dictionary:
|
||||
continue
|
||||
prob = dictionary[class_index] / num_samples
|
||||
total += prob * math.log(prob, base)
|
||||
# higher set
|
||||
return -total
|
||||
|
||||
|
||||
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
||||
# get all polygons of each color
|
||||
polygon_dict = {
|
||||
str(BLUE).lower(): [],
|
||||
str(GREEN).lower(): [],
|
||||
str(ORANGE).lower(): [],
|
||||
}
|
||||
for polygon in all_polygons:
|
||||
print(polygon_dict)
|
||||
polygon_dict[str(polygon.color).lower()].append(polygon)
|
||||
|
||||
return_polygons = []
|
||||
for color in colors:
|
||||
color = str(color).lower()
|
||||
polygons = polygon_dict[color]
|
||||
points = set()
|
||||
for polygon in polygons:
|
||||
vertices = polygon.get_vertices().tolist()
|
||||
vertices = [tuple(vert) for vert in vertices]
|
||||
for pt in vertices:
|
||||
if pt in points: # Shared vertice, remove it.
|
||||
points.remove(pt)
|
||||
else:
|
||||
points.add(pt)
|
||||
points = list(points)
|
||||
sort_x = sorted(points)
|
||||
sort_y = sorted(points, key=lambda x: x[1])
|
||||
|
||||
edges_h = {}
|
||||
edges_v = {}
|
||||
|
||||
i = 0
|
||||
while i < len(points):
|
||||
curr_y = sort_y[i][1]
|
||||
while i < len(points) and sort_y[i][1] == curr_y:
|
||||
edges_h[sort_y[i]] = sort_y[i + 1]
|
||||
edges_h[sort_y[i + 1]] = sort_y[i]
|
||||
i += 2
|
||||
i = 0
|
||||
while i < len(points):
|
||||
curr_x = sort_x[i][0]
|
||||
while i < len(points) and sort_x[i][0] == curr_x:
|
||||
edges_v[sort_x[i]] = sort_x[i + 1]
|
||||
edges_v[sort_x[i + 1]] = sort_x[i]
|
||||
i += 2
|
||||
|
||||
# Get all the polygons.
|
||||
while edges_h:
|
||||
# We can start with any point.
|
||||
polygon = [(edges_h.popitem()[0], 0)]
|
||||
while True:
|
||||
curr, e = polygon[-1]
|
||||
if e == 0:
|
||||
next_vertex = edges_v.pop(curr)
|
||||
polygon.append((next_vertex, 1))
|
||||
else:
|
||||
next_vertex = edges_h.pop(curr)
|
||||
polygon.append((next_vertex, 0))
|
||||
if polygon[-1] == polygon[0]:
|
||||
# Closed polygon
|
||||
polygon.pop()
|
||||
break
|
||||
# Remove implementation-markers from the polygon.
|
||||
poly = [point for point, _ in polygon]
|
||||
for vertex in poly:
|
||||
if vertex in edges_h:
|
||||
edges_h.pop(vertex)
|
||||
if vertex in edges_v:
|
||||
edges_v.pop(vertex)
|
||||
polygon = Polygon(*poly, color=color, fill_opacity=0.3, stroke_opacity=1.0)
|
||||
return_polygons.append(polygon)
|
||||
return return_polygons
|
||||
|
||||
|
||||
class IrisDatasetPlot(VGroup):
|
||||
def __init__(self):
|
||||
points = iris.data[:, 0:2]
|
||||
labels = iris.feature_names
|
||||
targets = iris.target
|
||||
# Make points
|
||||
self.point_group = self._make_point_group(points, targets)
|
||||
# Make axes
|
||||
self.axes_group = self._make_axes_group(points, labels)
|
||||
# Make legend
|
||||
self.legend_group = self._make_legend(
|
||||
[BLUE, ORANGE, GREEN], iris.target_names, self.axes_group
|
||||
)
|
||||
# Make title
|
||||
# title_text = "Iris Dataset Plot"
|
||||
# self.title = Text(title_text).match_y(self.axes_group).shift([0.5, self.axes_group.height / 2 + 0.5, 0])
|
||||
# Make all group
|
||||
self.all_group = Group(self.point_group, self.axes_group, self.legend_group)
|
||||
# scale the groups
|
||||
self.point_group.scale(1.6)
|
||||
self.point_group.match_x(self.axes_group)
|
||||
self.point_group.match_y(self.axes_group)
|
||||
self.point_group.shift([0.2, 0, 0])
|
||||
self.axes_group.scale(0.7)
|
||||
self.all_group.shift([0, 0.2, 0])
|
||||
|
||||
@override_animation(Create)
|
||||
def create_animation(self):
|
||||
animation_group = AnimationGroup(
|
||||
# Perform the animations
|
||||
Create(self.point_group, run_time=2),
|
||||
Wait(0.5),
|
||||
Create(self.axes_group, run_time=2),
|
||||
# add title
|
||||
# Create(self.title),
|
||||
Create(self.legend_group),
|
||||
)
|
||||
return animation_group
|
||||
|
||||
def _make_point_group(self, points, targets, class_colors=[BLUE, ORANGE, GREEN]):
|
||||
point_group = VGroup()
|
||||
for point_index, point in enumerate(points):
|
||||
# draw the dot
|
||||
current_target = targets[point_index]
|
||||
color = class_colors[current_target]
|
||||
dot = Dot(point=np.array([point[0], point[1], 0])).set_color(color)
|
||||
dot.scale(0.5)
|
||||
point_group.add(dot)
|
||||
return point_group
|
||||
|
||||
def _make_legend(self, class_colors, feature_labels, axes):
|
||||
legend_group = VGroup()
|
||||
# Make Text
|
||||
setosa = Text("Setosa", color=BLUE)
|
||||
verisicolor = Text("Verisicolor", color=ORANGE)
|
||||
virginica = Text("Virginica", color=GREEN)
|
||||
labels = VGroup(setosa, verisicolor, virginica).arrange(
|
||||
direction=RIGHT, aligned_edge=LEFT, buff=2.0
|
||||
)
|
||||
labels.scale(0.5)
|
||||
legend_group.add(labels)
|
||||
# surrounding rectangle
|
||||
surrounding_rectangle = SurroundingRectangle(labels, color=WHITE)
|
||||
surrounding_rectangle.move_to(labels)
|
||||
legend_group.add(surrounding_rectangle)
|
||||
# shift the legend group
|
||||
legend_group.move_to(axes)
|
||||
legend_group.shift([0, -3.0, 0])
|
||||
legend_group.match_x(axes[0][0])
|
||||
|
||||
return legend_group
|
||||
|
||||
def _make_axes_group(self, points, labels):
|
||||
axes_group = VGroup()
|
||||
# make the axes
|
||||
x_range = [
|
||||
np.amin(points, axis=0)[0] - 0.2,
|
||||
np.amax(points, axis=0)[0] - 0.2,
|
||||
0.5,
|
||||
]
|
||||
y_range = [np.amin(points, axis=0)[1] - 0.2, np.amax(points, axis=0)[1], 0.5]
|
||||
axes = Axes(
|
||||
x_range=x_range,
|
||||
y_range=y_range,
|
||||
x_length=9,
|
||||
y_length=6.5,
|
||||
# axis_config={"number_scale_value":0.75, "include_numbers":True},
|
||||
tips=False,
|
||||
).shift([0.5, 0.25, 0])
|
||||
axes_group.add(axes)
|
||||
# make axis labels
|
||||
# x_label
|
||||
x_label = (
|
||||
Text(labels[0], font=font)
|
||||
.match_y(axes.get_axes()[0])
|
||||
.shift([0.5, -0.75, 0])
|
||||
.scale(font_scale)
|
||||
)
|
||||
axes_group.add(x_label)
|
||||
# y_label
|
||||
y_label = (
|
||||
Text(labels[1], font=font)
|
||||
.match_x(axes.get_axes()[1])
|
||||
.shift([-0.75, 0, 0])
|
||||
.rotate(np.pi / 2)
|
||||
.scale(font_scale)
|
||||
)
|
||||
axes_group.add(y_label)
|
||||
|
||||
return axes_group
|
||||
|
||||
|
||||
class DecisionTreeSurface(VGroup):
|
||||
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
|
||||
# take the tree and construct the surface from it
|
||||
self.tree_clf = tree_clf
|
||||
self.data = data
|
||||
self.axes = axes
|
||||
self.class_colors = class_colors
|
||||
self.surface_rectangles = self.generate_surface_rectangles()
|
||||
|
||||
def generate_surface_rectangles(self):
|
||||
# compute data bounds
|
||||
left = np.amin(self.data[:, 0]) - 0.2
|
||||
right = np.amax(self.data[:, 0]) - 0.2
|
||||
top = np.amax(self.data[:, 1])
|
||||
bottom = np.amin(self.data[:, 1]) - 0.2
|
||||
maxrange = [left, right, bottom, top]
|
||||
rectangles = decision_areas(self.tree_clf, maxrange, x=0, y=1, n_features=2)
|
||||
# turn the rectangle objects into manim rectangles
|
||||
def convert_rectangle_to_polygon(rect):
|
||||
# get the points for the rectangle in the plot coordinate frame
|
||||
bottom_left = [rect[0], rect[3]]
|
||||
bottom_right = [rect[1], rect[3]]
|
||||
top_right = [rect[1], rect[2]]
|
||||
top_left = [rect[0], rect[2]]
|
||||
# convert those points into the entire manim coordinates
|
||||
bottom_left_coord = self.axes.coords_to_point(*bottom_left)
|
||||
bottom_right_coord = self.axes.coords_to_point(*bottom_right)
|
||||
top_right_coord = self.axes.coords_to_point(*top_right)
|
||||
top_left_coord = self.axes.coords_to_point(*top_left)
|
||||
points = [
|
||||
bottom_left_coord,
|
||||
bottom_right_coord,
|
||||
top_right_coord,
|
||||
top_left_coord,
|
||||
]
|
||||
# construct a polygon object from those manim coordinates
|
||||
rectangle = Polygon(
|
||||
*points, color=color, fill_opacity=0.3, stroke_opacity=0.0
|
||||
)
|
||||
return rectangle
|
||||
|
||||
manim_rectangles = []
|
||||
for rect in rectangles:
|
||||
color = self.class_colors[int(rect[4])]
|
||||
rectangle = convert_rectangle_to_polygon(rect)
|
||||
manim_rectangles.append(rectangle)
|
||||
|
||||
manim_rectangles = merge_overlapping_polygons(
|
||||
manim_rectangles, colors=[BLUE, GREEN, ORANGE]
|
||||
)
|
||||
|
||||
return manim_rectangles
|
||||
|
||||
@override_animation(Create)
|
||||
def create_override(self):
|
||||
# play a reveal of all of the surface rectangles
|
||||
animations = []
|
||||
for rectangle in self.surface_rectangles:
|
||||
animations.append(Create(rectangle))
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
return animation_group
|
||||
|
||||
@override_animation(Uncreate)
|
||||
def uncreate_override(self):
|
||||
# play a reveal of all of the surface rectangles
|
||||
animations = []
|
||||
for rectangle in self.surface_rectangles:
|
||||
animations.append(Uncreate(rectangle))
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
return animation_group
|
||||
|
||||
|
||||
class DecisionTree:
|
||||
"""
|
||||
Draw a single tree node
|
||||
"""
|
||||
|
||||
def _make_node(
|
||||
self,
|
||||
feature,
|
||||
threshold,
|
||||
values,
|
||||
is_leaf=False,
|
||||
depth=0,
|
||||
leaf_colors=[BLUE, ORANGE, GREEN],
|
||||
):
|
||||
if not is_leaf:
|
||||
node_text = f"{feature}\n <= {threshold:.3f} cm"
|
||||
# draw decision text
|
||||
decision_text = Text(node_text, color=WHITE)
|
||||
# draw a box
|
||||
bounding_box = SurroundingRectangle(decision_text, buff=0.3, color=WHITE)
|
||||
node = VGroup()
|
||||
node.add(bounding_box)
|
||||
node.add(decision_text)
|
||||
# return the node
|
||||
else:
|
||||
# plot the appropriate image
|
||||
class_index = np.argmax(values)
|
||||
# get image
|
||||
pil_image = images[class_index]
|
||||
leaf_group = Group()
|
||||
node = ImageMobject(pil_image)
|
||||
node.scale(1.5)
|
||||
rectangle = Rectangle(
|
||||
width=node.width + 0.05,
|
||||
height=node.height + 0.05,
|
||||
color=leaf_colors[class_index],
|
||||
stroke_width=6,
|
||||
)
|
||||
rectangle.move_to(node.get_center())
|
||||
rectangle.shift([-0.02, 0.02, 0])
|
||||
leaf_group.add(rectangle)
|
||||
leaf_group.add(node)
|
||||
node = leaf_group
|
||||
|
||||
return node
|
||||
|
||||
def _make_connection(self, top, bottom, is_leaf=False):
|
||||
top_node_bottom_location = top.get_center()
|
||||
top_node_bottom_location[1] -= top.height / 2
|
||||
bottom_node_top_location = bottom.get_center()
|
||||
bottom_node_top_location[1] += bottom.height / 2
|
||||
line = Line(top_node_bottom_location, bottom_node_top_location, color=WHITE)
|
||||
return line
|
||||
|
||||
def _make_tree(self, tree, feature_names=["Sepal Length", "Sepal Width"]):
|
||||
tree_group = Group()
|
||||
max_depth = tree.max_depth
|
||||
# make the base node
|
||||
feature_name = feature_names[tree.feature[0]]
|
||||
threshold = tree.threshold[0]
|
||||
values = tree.value[0]
|
||||
nodes_map = {}
|
||||
root_node = self._make_node(feature_name, threshold, values, depth=0)
|
||||
nodes_map[0] = root_node
|
||||
tree_group.add(root_node)
|
||||
# save some information
|
||||
node_height = root_node.height
|
||||
node_width = root_node.width
|
||||
scale_factor = 1.0
|
||||
edge_map = {}
|
||||
# tree height
|
||||
tree_height = scale_factor * node_height * max_depth
|
||||
tree_width = scale_factor * 2**max_depth * node_width
|
||||
# traverse tree
|
||||
def recurse(node, depth, direction, parent_object, parent_node):
|
||||
# make sure it is a valid node
|
||||
# make the node object
|
||||
is_leaf = tree.children_left[node] == tree.children_right[node]
|
||||
feature_name = feature_names[tree.feature[node]]
|
||||
threshold = tree.threshold[node]
|
||||
values = tree.value[node]
|
||||
node_object = self._make_node(
|
||||
feature_name, threshold, values, depth=depth, is_leaf=is_leaf
|
||||
)
|
||||
nodes_map[node] = node_object
|
||||
node_height = node_object.height
|
||||
# set the node position
|
||||
direction_factor = -1 if direction == "left" else 1
|
||||
shift_right_amount = (
|
||||
0.8 * direction_factor * scale_factor * tree_width / (2**depth) / 2
|
||||
)
|
||||
if is_leaf:
|
||||
shift_down_amount = -1.0 * scale_factor * node_height
|
||||
else:
|
||||
shift_down_amount = -1.8 * scale_factor * node_height
|
||||
node_object.match_x(parent_object).match_y(parent_object).shift(
|
||||
[shift_right_amount, shift_down_amount, 0]
|
||||
)
|
||||
tree_group.add(node_object)
|
||||
# make a connection
|
||||
connection = self._make_connection(
|
||||
parent_object, node_object, is_leaf=is_leaf
|
||||
)
|
||||
edge_name = str(parent_node) + "," + str(node)
|
||||
edge_map[edge_name] = connection
|
||||
tree_group.add(connection)
|
||||
# recurse
|
||||
if not is_leaf:
|
||||
recurse(tree.children_left[node], depth + 1, "left", node_object, node)
|
||||
recurse(
|
||||
tree.children_right[node], depth + 1, "right", node_object, node
|
||||
)
|
||||
|
||||
recurse(tree.children_left[0], 1, "left", root_node, 0)
|
||||
recurse(tree.children_right[0], 1, "right", root_node, 0)
|
||||
|
||||
tree_group.scale(0.35)
|
||||
return tree_group, nodes_map, edge_map
|
||||
|
||||
def color_example_path(
|
||||
self, tree_group, nodes_map, tree, edge_map, example, color=YELLOW, thickness=2
|
||||
):
|
||||
# get decision path
|
||||
decision_path = tree.decision_path(example)[0]
|
||||
path_indices = decision_path.indices
|
||||
# highlight edges
|
||||
for node_index in range(0, len(path_indices) - 1):
|
||||
current_val = path_indices[node_index]
|
||||
next_val = path_indices[node_index + 1]
|
||||
edge_str = str(current_val) + "," + str(next_val)
|
||||
edge = edge_map[edge_str]
|
||||
animation_two = AnimationGroup(
|
||||
nodes_map[current_val].animate.set_color(color)
|
||||
)
|
||||
self.play(animation_two, run_time=0.5)
|
||||
animation_one = AnimationGroup(
|
||||
edge.animate.set_color(color),
|
||||
# edge.animate.set_stroke_width(4),
|
||||
)
|
||||
self.play(animation_one, run_time=0.5)
|
||||
# surround the bottom image
|
||||
last_path_index = path_indices[-1]
|
||||
last_path_rectangle = nodes_map[last_path_index][0]
|
||||
self.play(last_path_rectangle.animate.set_color(color))
|
||||
|
||||
def create_sklearn_tree(self, max_tree_depth=1):
|
||||
# learn the decision tree
|
||||
iris = load_iris()
|
||||
tree = learn_iris_decision_tree(iris, max_depth=max_tree_depth)
|
||||
feature_names = iris.feature_names[0:2]
|
||||
return tree.tree_
|
||||
|
||||
def make_tree(self, max_tree_depth=2):
|
||||
sklearn_tree = self.create_sklearn_tree()
|
||||
# make the tree
|
||||
tree_group, nodes_map, edge_map = self._make_tree(
|
||||
sklearn_tree.tree_, feature_names
|
||||
)
|
||||
tree_group.shift([0, 5.5, 0])
|
||||
return tree_group
|
||||
# self.add(tree_group)
|
||||
# self.play(SpinInFromNothing(tree_group), run_time=3)
|
||||
# self.color_example_path(tree_group, nodes_map, tree, edge_map, iris.data[None, 0, 0:2])
|
||||
|
||||
|
||||
class DecisionTreeSplitScene(Scene):
|
||||
def make_decision_tree_classifier(self, max_depth=4):
|
||||
decision_tree = DecisionTreeClassifier(
|
||||
random_state=1, max_depth=max_depth, max_leaf_nodes=8
|
||||
)
|
||||
decision_tree = decision_tree.fit(iris.data[:, :2], iris.target)
|
||||
# output the decisioin tree in some format
|
||||
return decision_tree
|
||||
|
||||
def make_split_animation(self, data, classes, data_labels, main_axes):
|
||||
|
||||
"""
|
||||
def make_entropy_animation_and_plot(dim=0, num_entropy_values=50):
|
||||
# calculate the entropy values
|
||||
axes_group = VGroup()
|
||||
# make axes
|
||||
range_vals = [np.amin(data, axis=0)[dim], np.amax(data, axis=0)[dim]]
|
||||
axes = Axes(x_range=range_vals,
|
||||
y_range=[0, 1.0],
|
||||
x_length=9,
|
||||
y_length=4,
|
||||
# axis_config={"number_scale_value":0.75, "include_numbers":True},
|
||||
tips=False,
|
||||
)
|
||||
axes_group.add(axes)
|
||||
# make axis labels
|
||||
# x_label
|
||||
x_label = Text(data_labels[dim], font=font) \
|
||||
.match_y(axes.get_axes()[0]) \
|
||||
.shift([0.5, -0.75, 0]) \
|
||||
.scale(font_scale*1.2)
|
||||
axes_group.add(x_label)
|
||||
# y_label
|
||||
y_label = Text("Information Gain", font=font) \
|
||||
.match_x(axes.get_axes()[1]) \
|
||||
.shift([-0.75, 0, 0]) \
|
||||
.rotate(np.pi / 2) \
|
||||
.scale(font_scale * 1.2)
|
||||
|
||||
axes_group.add(y_label)
|
||||
# line animation
|
||||
information_gains = []
|
||||
def entropy_function(split_value):
|
||||
# lower entropy
|
||||
lower_set = np.nonzero(data[:, dim] <= split_value)[0]
|
||||
lower_set = classes[lower_set]
|
||||
lower_entropy = entropy(lower_set)
|
||||
# higher entropy
|
||||
higher_set = np.nonzero(data[:, dim] > split_value)[0]
|
||||
higher_set = classes[higher_set]
|
||||
higher_entropy = entropy(higher_set)
|
||||
# calculate entropies
|
||||
all_entropy = entropy(classes, base=2)
|
||||
lower_entropy = entropy(lower_set, base=2)
|
||||
higher_entropy = entropy(higher_set, base=2)
|
||||
mean_entropy = (lower_entropy + higher_entropy) / 2
|
||||
# calculate information gain
|
||||
lower_prob = len(lower_set) / len(data[:, dim])
|
||||
higher_prob = len(higher_set) / len(data[:, dim])
|
||||
info_gain = all_entropy - (lower_prob * lower_entropy + higher_prob * higher_entropy)
|
||||
information_gains.append((split_value, info_gain))
|
||||
return info_gain
|
||||
|
||||
data_range = np.amin(data[:, dim]), np.amax(data[:, dim])
|
||||
|
||||
entropy_graph = axes.get_graph(
|
||||
entropy_function,
|
||||
# color=RED,
|
||||
# x_range=data_range
|
||||
)
|
||||
axes_group.add(entropy_graph)
|
||||
axes_group.shift([4.0, 2, 0])
|
||||
axes_group.scale(0.5)
|
||||
dot_animation = Dot(color=WHITE)
|
||||
axes_group.add(dot_animation)
|
||||
# make animations
|
||||
animation_group = AnimationGroup(
|
||||
Create(axes_group, run_time=2),
|
||||
Wait(3),
|
||||
MoveAlongPath(dot_animation, entropy_graph, run_time=20, rate_func=rate_functions.ease_in_out_quad),
|
||||
Wait(2)
|
||||
)
|
||||
|
||||
return axes_group, animation_group, information_gains
|
||||
"""
|
||||
|
||||
def make_split_line_animation(dim=0):
|
||||
# make a line along one of the dims and move it up and down
|
||||
origin_coord = [
|
||||
np.amin(data, axis=0)[0] - 0.2,
|
||||
np.amin(data, axis=0)[1] - 0.2,
|
||||
]
|
||||
origin_point = main_axes.coords_to_point(*origin_coord)
|
||||
top_left_coord = [origin_coord[0], np.amax(data, axis=0)[1]]
|
||||
bottom_right_coord = [np.amax(data, axis=0)[0] - 0.2, origin_coord[1]]
|
||||
if dim == 0:
|
||||
other_coord = top_left_coord
|
||||
moving_line_coord = bottom_right_coord
|
||||
else:
|
||||
other_coord = bottom_right_coord
|
||||
moving_line_coord = top_left_coord
|
||||
other_point = main_axes.coords_to_point(*other_coord)
|
||||
moving_line_point = main_axes.coords_to_point(*moving_line_coord)
|
||||
moving_line = Line(origin_point, other_point, color=RED)
|
||||
movement_line = Line(origin_point, moving_line_point)
|
||||
if dim == 0:
|
||||
movement_line.shift([0, moving_line.height / 2, 0])
|
||||
else:
|
||||
movement_line.shift([moving_line.width / 2, 0, 0])
|
||||
# move the moving line along the movement line
|
||||
animation = MoveAlongPath(
|
||||
moving_line,
|
||||
movement_line,
|
||||
run_time=20,
|
||||
rate_func=rate_functions.ease_in_out_quad,
|
||||
)
|
||||
return animation, moving_line
|
||||
|
||||
# plot the line in white then make it invisible
|
||||
# make an animation along the line
|
||||
# make a
|
||||
# axes_one_group, top_animation_group, info_gains = make_entropy_animation_and_plot(dim=0)
|
||||
line_movement, first_moving_line = make_split_line_animation(dim=0)
|
||||
# axes_two_group, bottom_animation_group, _ = make_entropy_animation_and_plot(dim=1)
|
||||
second_line_movement, second_moving_line = make_split_line_animation(dim=1)
|
||||
# axes_two_group.shift([0, -3, 0])
|
||||
animation_group_one = AnimationGroup(
|
||||
# top_animation_group,
|
||||
line_movement,
|
||||
)
|
||||
|
||||
animation_group_two = AnimationGroup(
|
||||
# bottom_animation_group,
|
||||
second_line_movement,
|
||||
)
|
||||
"""
|
||||
both_axes_group = VGroup(
|
||||
axes_one_group,
|
||||
axes_two_group
|
||||
)
|
||||
"""
|
||||
|
||||
return (
|
||||
animation_group_one,
|
||||
animation_group_two,
|
||||
first_moving_line,
|
||||
second_moving_line,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
# both_axes_group, \
|
||||
# info_gains
|
||||
|
||||
def construct(self):
|
||||
# make the points
|
||||
iris_dataset_plot = IrisDatasetPlot()
|
||||
iris_dataset_plot.all_group.scale(1.0)
|
||||
iris_dataset_plot.all_group.shift([-3, 0.2, 0])
|
||||
# make the entropy line graph
|
||||
# entropy_line_graph = self.draw_entropy_line_graph()
|
||||
# arrange the plots
|
||||
# do animations
|
||||
self.play(Create(iris_dataset_plot))
|
||||
# make the decision tree classifier
|
||||
decision_tree_classifier = self.make_decision_tree_classifier()
|
||||
decision_tree_surface = DecisionTreeSurface(
|
||||
decision_tree_classifier, iris.data, iris_dataset_plot.axes_group[0]
|
||||
)
|
||||
self.play(Create(decision_tree_surface))
|
||||
self.wait(3)
|
||||
self.play(Uncreate(decision_tree_surface))
|
||||
main_axes = iris_dataset_plot.axes_group[0]
|
||||
(
|
||||
split_animation_one,
|
||||
split_animation_two,
|
||||
first_moving_line,
|
||||
second_moving_line,
|
||||
both_axes_group,
|
||||
info_gains,
|
||||
) = self.make_split_animation(
|
||||
iris.data[:, 0:2], iris.target, iris.feature_names, main_axes
|
||||
)
|
||||
self.play(split_animation_one)
|
||||
self.wait(0.1)
|
||||
self.play(Uncreate(first_moving_line))
|
||||
self.wait(3)
|
||||
self.play(split_animation_two)
|
||||
self.wait(0.1)
|
||||
self.play(Uncreate(second_moving_line))
|
||||
self.wait(0.1)
|
||||
# highlight the maximum on top
|
||||
# sort by second key
|
||||
"""
|
||||
highest_info_gain = sorted(info_gains, key=lambda x: x[1])[-1]
|
||||
highest_info_gain_point = both_axes_group[0][0].coords_to_point(*highest_info_gain)
|
||||
highlighted_peak = Dot(highest_info_gain_point, color=YELLOW)
|
||||
# get location of highest info gain point
|
||||
highest_info_gain_point_in_iris_graph = iris_dataset_plot.axes_group[0].coords_to_point(*[highest_info_gain[0], 0])
|
||||
first_moving_line.start[0] = highest_info_gain_point_in_iris_graph[0]
|
||||
first_moving_line.end[0] = highest_info_gain_point_in_iris_graph[0]
|
||||
self.play(Create(highlighted_peak))
|
||||
self.play(Create(first_moving_line))
|
||||
text = Text("Highest Information Gain")
|
||||
text.scale(0.4)
|
||||
text.move_to(highlighted_peak)
|
||||
text.shift([0, 0.5, 0])
|
||||
self.play(Create(text))
|
||||
"""
|
||||
self.wait(1)
|
||||
# draw the basic tree
|
||||
decision_tree_classifier = self.make_decision_tree_classifier(max_depth=1)
|
||||
decision_tree_surface = DecisionTreeSurface(
|
||||
decision_tree_classifier, iris.data, iris_dataset_plot.axes_group[0]
|
||||
)
|
||||
decision_tree_graph, _, _ = DecisionTree()._make_tree(
|
||||
decision_tree_classifier.tree_
|
||||
)
|
||||
decision_tree_graph.match_y(iris_dataset_plot.axes_group)
|
||||
decision_tree_graph.shift([4, 0, 0])
|
||||
self.play(Create(decision_tree_surface))
|
||||
uncreate_animation = AnimationGroup(
|
||||
# Uncreate(both_axes_group),
|
||||
# Uncreate(highlighted_peak),
|
||||
Uncreate(second_moving_line),
|
||||
# Unwrite(text)
|
||||
)
|
||||
self.play(uncreate_animation)
|
||||
self.wait(0.5)
|
||||
self.play(FadeIn(decision_tree_graph))
|
||||
# self.play(FadeIn(highlighted_peak))
|
||||
self.wait(5)
|
5
examples/media/texts/6bb4864e46e7c499.svg
Normal file
5
examples/media/texts/6bb4864e46e7c499.svg
Normal file
@ -0,0 +1,5 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="1920pt" height="1080pt" viewBox="0 0 1920 1080" version="1.1">
|
||||
<g id="surface6">
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 222 B |
5
examples/media/texts/762a31f3ee2af27b.svg
Normal file
5
examples/media/texts/762a31f3ee2af27b.svg
Normal file
@ -0,0 +1,5 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="1920pt" height="1080pt" viewBox="0 0 1920 1080" version="1.1">
|
||||
<g id="surface1">
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 222 B |
@ -3,6 +3,7 @@ import numpy as np
|
||||
from collections import deque
|
||||
from sklearn.tree import _tree as ctree
|
||||
|
||||
|
||||
class AABB:
|
||||
"""Axis-aligned bounding box"""
|
||||
|
||||
@ -19,6 +20,7 @@ class AABB:
|
||||
|
||||
return left, right
|
||||
|
||||
|
||||
def tree_bounds(tree, n_features=None):
|
||||
"""Compute final decision rule for each node in tree"""
|
||||
if n_features is None:
|
||||
@ -34,6 +36,7 @@ def tree_bounds(tree, n_features=None):
|
||||
queue.extend([l, r])
|
||||
return aabbs
|
||||
|
||||
|
||||
def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
|
||||
"""Extract decision areas.
|
||||
|
||||
@ -70,6 +73,7 @@ def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None)
|
||||
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
|
||||
return rectangles
|
||||
|
||||
|
||||
def plot_areas(rectangles):
|
||||
for rect in rectangles:
|
||||
color = ["b", "r"][int(rect[4])]
|
||||
@ -83,6 +87,7 @@ def plot_areas(rectangles):
|
||||
)
|
||||
plt.gca().add_artist(rp)
|
||||
|
||||
|
||||
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
||||
# get all polygons of each color
|
||||
polygon_dict = {
|
||||
@ -156,6 +161,7 @@ def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
||||
return_polygons.append(polygon)
|
||||
return return_polygons
|
||||
|
||||
|
||||
class IrisDatasetPlot(VGroup):
|
||||
def __init__(self, iris):
|
||||
points = iris.data[:, 0:2]
|
||||
@ -269,7 +275,6 @@ class IrisDatasetPlot(VGroup):
|
||||
|
||||
|
||||
class DecisionTreeSurface(VGroup):
|
||||
|
||||
def __init__(self, tree_clf, data, axes, class_colors=[BLUE, ORANGE, GREEN]):
|
||||
# take the tree and construct the surface from it
|
||||
self.tree_clf = tree_clf
|
||||
@ -346,8 +351,8 @@ class DecisionTreeSurface(VGroup):
|
||||
|
||||
def make_split_to_animation_map(self):
|
||||
"""
|
||||
Returns a dictionary mapping a given split
|
||||
node to an animation to be played
|
||||
Returns a dictionary mapping a given split
|
||||
node to an animation to be played
|
||||
"""
|
||||
# Create an initial decision tree surface
|
||||
# Go through each split node
|
||||
|
@ -26,6 +26,7 @@ def compute_node_depths(tree):
|
||||
|
||||
return node_depths
|
||||
|
||||
|
||||
def compute_level_order_traversal(tree):
|
||||
"""Computes level order traversal of a sklearn tree"""
|
||||
|
||||
@ -56,6 +57,7 @@ def compute_level_order_traversal(tree):
|
||||
|
||||
return sorted_inds
|
||||
|
||||
|
||||
def compute_bfs_traversal(tree):
|
||||
"""Traverses the tree in BFS order and returns the nodes in order"""
|
||||
traversal_order = []
|
||||
@ -73,10 +75,12 @@ def compute_bfs_traversal(tree):
|
||||
|
||||
return traversal_order
|
||||
|
||||
|
||||
def compute_best_first_traversal(tree):
|
||||
"""Traverses the tree according to the best split first order"""
|
||||
pass
|
||||
|
||||
|
||||
def compute_node_to_parent_mapping(tree):
|
||||
"""Returns a hashmap mapping node indices to their parent indices"""
|
||||
node_to_parent = {0: -1} # Root has no parent
|
||||
|
@ -9,6 +9,7 @@ from tqdm import tqdm
|
||||
|
||||
from manim_ml.probability import GaussianDistribution
|
||||
|
||||
|
||||
def gaussian_proposal(x, sigma=0.2):
|
||||
"""
|
||||
Gaussian proposal distribution.
|
||||
@ -39,7 +40,8 @@ def gaussian_proposal(x, sigma=0.2):
|
||||
|
||||
return (x_star, qxx)
|
||||
|
||||
class MultidimensionalGaussianPosterior():
|
||||
|
||||
class MultidimensionalGaussianPosterior:
|
||||
"""
|
||||
N-Dimensional Gaussian distribution with
|
||||
|
||||
@ -49,8 +51,7 @@ class MultidimensionalGaussianPosterior():
|
||||
Prior on mean is U(-500, 500)
|
||||
"""
|
||||
|
||||
def __init__(self, ndim=2, seed=12345, scale=3,
|
||||
mu=None, var=None):
|
||||
def __init__(self, ndim=2, seed=12345, scale=3, mu=None, var=None):
|
||||
"""_summary_
|
||||
|
||||
Parameters
|
||||
@ -71,10 +72,7 @@ class MultidimensionalGaussianPosterior():
|
||||
self.var = var
|
||||
|
||||
if mu is None:
|
||||
self.mu = scipy.stats.norm(
|
||||
loc=0,
|
||||
scale=self.scale
|
||||
).rvs(ndim)
|
||||
self.mu = scipy.stats.norm(loc=0, scale=self.scale).rvs(ndim)
|
||||
else:
|
||||
self.mu = mu
|
||||
|
||||
@ -84,20 +82,18 @@ class MultidimensionalGaussianPosterior():
|
||||
"""
|
||||
|
||||
if np.all(x < 500) and np.all(x > -500):
|
||||
return scipy.stats.multivariate_normal(
|
||||
mean=self.mu,
|
||||
cov=self.var
|
||||
).logpdf(x)
|
||||
return scipy.stats.multivariate_normal(mean=self.mu, cov=self.var).logpdf(x)
|
||||
else:
|
||||
return -1e6
|
||||
|
||||
|
||||
def metropolis_hastings_sampler(
|
||||
log_prob_fn=MultidimensionalGaussianPosterior(),
|
||||
prop_fn=gaussian_proposal,
|
||||
initial_location : np.ndarray = np.array([0, 0]),
|
||||
log_prob_fn=MultidimensionalGaussianPosterior(),
|
||||
prop_fn=gaussian_proposal,
|
||||
initial_location: np.ndarray = np.array([0, 0]),
|
||||
iterations=25,
|
||||
warm_up=0,
|
||||
ndim=2
|
||||
ndim=2,
|
||||
):
|
||||
"""Samples using a Metropolis-Hastings sampler.
|
||||
|
||||
@ -158,17 +154,18 @@ def metropolis_hastings_sampler(
|
||||
|
||||
return chain, np.array([]), proposals
|
||||
|
||||
|
||||
class MCMCAxes(Group):
|
||||
"""Container object for visualizing MCMC on a 2D axis"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dot_color=BLUE,
|
||||
self,
|
||||
dot_color=BLUE,
|
||||
dot_radius=0.05,
|
||||
accept_line_color=GREEN,
|
||||
reject_line_color=RED,
|
||||
line_color=WHITE,
|
||||
line_stroke_width=1
|
||||
line_stroke_width=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.dot_color = dot_color
|
||||
@ -176,7 +173,7 @@ class MCMCAxes(Group):
|
||||
self.accept_line_color = accept_line_color
|
||||
self.reject_line_color = reject_line_color
|
||||
self.line_color = line_color
|
||||
self.line_stroke_width=line_stroke_width
|
||||
self.line_stroke_width = line_stroke_width
|
||||
# Make the axes
|
||||
self.axes = Axes(
|
||||
x_range=[-3, 3],
|
||||
@ -185,22 +182,16 @@ class MCMCAxes(Group):
|
||||
y_length=12,
|
||||
x_axis_config={"stroke_opacity": 0.0},
|
||||
y_axis_config={"stroke_opacity": 0.0},
|
||||
tips=False
|
||||
tips=False,
|
||||
)
|
||||
self.add(self.axes)
|
||||
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
"""Overrides Create animation"""
|
||||
return AnimationGroup(
|
||||
Create(self.axes)
|
||||
)
|
||||
return AnimationGroup(Create(self.axes))
|
||||
|
||||
def visualize_gaussian_proposal_about_point(
|
||||
self,
|
||||
mean,
|
||||
cov=None
|
||||
) -> AnimationGroup:
|
||||
def visualize_gaussian_proposal_about_point(self, mean, cov=None) -> AnimationGroup:
|
||||
"""Creates a Gaussian distribution about a certain point
|
||||
|
||||
Parameters
|
||||
@ -216,21 +207,14 @@ class MCMCAxes(Group):
|
||||
animation of creating the proposal Gaussian distribution
|
||||
"""
|
||||
gaussian = GaussianDistribution(
|
||||
axes=self.axes,
|
||||
mean=mean,
|
||||
cov=cov,
|
||||
dist_theme="gaussian"
|
||||
axes=self.axes, mean=mean, cov=cov, dist_theme="gaussian"
|
||||
)
|
||||
|
||||
create_guassian = Create(gaussian)
|
||||
return create_guassian
|
||||
|
||||
|
||||
def make_transition_animation(
|
||||
self,
|
||||
start_point,
|
||||
end_point,
|
||||
candidate_point,
|
||||
run_time=0.1
|
||||
self, start_point, end_point, candidate_point, run_time=0.1
|
||||
) -> AnimationGroup:
|
||||
"""Makes an transition animation for a single point on a Markov Chain
|
||||
|
||||
@ -255,38 +239,27 @@ class MCMCAxes(Group):
|
||||
if point_is_rejected:
|
||||
return AnimationGroup()
|
||||
else:
|
||||
create_end_point = Create(
|
||||
end_point
|
||||
)
|
||||
create_end_point = Create(end_point)
|
||||
create_line = Create(
|
||||
Line(
|
||||
start_point,
|
||||
end_point,
|
||||
color=self.line_color,
|
||||
stroke_width=self.line_stroke_width
|
||||
stroke_width=self.line_stroke_width,
|
||||
)
|
||||
)
|
||||
return AnimationGroup(
|
||||
create_end_point,
|
||||
create_line,
|
||||
lag_ratio=1.0,
|
||||
run_time=run_time
|
||||
create_end_point, create_line, lag_ratio=1.0, run_time=run_time
|
||||
)
|
||||
|
||||
def show_ground_truth_gaussian(self, distribution):
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
mean = distribution.mu
|
||||
var = np.eye(2) * distribution.var
|
||||
distribution_drawing = GaussianDistribution(
|
||||
self.axes,
|
||||
mean,
|
||||
var,
|
||||
dist_theme="gaussian"
|
||||
self.axes, mean, var, dist_theme="gaussian"
|
||||
).set_opacity(0.2)
|
||||
return AnimationGroup(
|
||||
Create(distribution_drawing)
|
||||
)
|
||||
return AnimationGroup(Create(distribution_drawing))
|
||||
|
||||
def visualize_metropolis_hastings_chain_sampling(
|
||||
self,
|
||||
@ -295,7 +268,7 @@ class MCMCAxes(Group):
|
||||
sampling_kwargs={},
|
||||
):
|
||||
"""
|
||||
Makes an animation for visualizing a 2D markov chain using
|
||||
Makes an animation for visualizing a 2D markov chain using
|
||||
metropolis hastings samplings
|
||||
|
||||
Parameters
|
||||
@ -318,20 +291,18 @@ class MCMCAxes(Group):
|
||||
"""
|
||||
# Compute the chain samples using a Metropolis Hastings Sampler
|
||||
mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler(
|
||||
log_prob_fn=log_prob_fn,
|
||||
prop_fn=prop_fn,
|
||||
**sampling_kwargs
|
||||
log_prob_fn=log_prob_fn, prop_fn=prop_fn, **sampling_kwargs
|
||||
)
|
||||
print(f"MCMC samples: {mcmc_samples}")
|
||||
print(f"Candidate samples: {candidate_samples}")
|
||||
# Make the animation for visualizing the chain
|
||||
# Make the animation for visualizing the chain
|
||||
animations = []
|
||||
# Place the initial point
|
||||
current_point = mcmc_samples[0]
|
||||
current_point = Dot(
|
||||
self.axes.coords_to_point(current_point[0], current_point[1]),
|
||||
color=self.dot_color,
|
||||
radius=self.dot_radius
|
||||
radius=self.dot_radius,
|
||||
)
|
||||
create_initial_point = Create(current_point)
|
||||
animations.append(create_initial_point)
|
||||
@ -346,12 +317,12 @@ class MCMCAxes(Group):
|
||||
next_point = Dot(
|
||||
self.axes.coords_to_point(next_sample[0], next_sample[1]),
|
||||
color=self.dot_color,
|
||||
radius=self.dot_radius
|
||||
radius=self.dot_radius,
|
||||
)
|
||||
candidate_point = Dot(
|
||||
self.axes.coords_to_point(candidate_sample[0], candidate_sample[1]),
|
||||
color=self.dot_color,
|
||||
radius=self.dot_radius
|
||||
radius=self.dot_radius,
|
||||
)
|
||||
# Make a transition animation
|
||||
transition_animation = self.make_transition_animation(
|
||||
@ -361,9 +332,6 @@ class MCMCAxes(Group):
|
||||
# Setup for next iteration
|
||||
current_point = next_point
|
||||
# Make the final animation group
|
||||
animation_group = AnimationGroup(
|
||||
*animations,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
animation_group = AnimationGroup(*animations, lag_ratio=1.0)
|
||||
|
||||
return animation_group
|
||||
|
@ -1,8 +1,7 @@
|
||||
from manim_ml.neural_network.activation_functions.relu import ReLUFunction
|
||||
|
||||
name_to_activation_function_map = {
|
||||
"ReLU": ReLUFunction()
|
||||
}
|
||||
name_to_activation_function_map = {"ReLU": ReLUFunction}
|
||||
|
||||
|
||||
def get_activation_function_by_name(name):
|
||||
return name_to_activation_function_map[name]
|
||||
return name_to_activation_function_map[name]
|
||||
|
@ -4,12 +4,22 @@ import random
|
||||
|
||||
import manim_ml.neural_network.activation_functions.relu as relu
|
||||
|
||||
|
||||
class ActivationFunction(ABC, VGroup):
|
||||
"""Abstract parent class for defining activation functions"""
|
||||
|
||||
def __init__(self, function_name=None, x_range=[-1, 1], y_range=[-1, 1],
|
||||
x_length=0.5, y_length=0.3, show_function_name=True, active_color=ORANGE,
|
||||
plot_color=BLUE, rectangle_color=WHITE):
|
||||
def __init__(
|
||||
self,
|
||||
function_name=None,
|
||||
x_range=[-1, 1],
|
||||
y_range=[-1, 1],
|
||||
x_length=0.5,
|
||||
y_length=0.3,
|
||||
show_function_name=True,
|
||||
active_color=ORANGE,
|
||||
plot_color=BLUE,
|
||||
rectangle_color=WHITE,
|
||||
):
|
||||
super(VGroup, self).__init__()
|
||||
self.function_name = function_name
|
||||
self.x_range = x_range
|
||||
@ -25,7 +35,7 @@ class ActivationFunction(ABC, VGroup):
|
||||
|
||||
def construct_activation_function(self):
|
||||
"""Makes the activation function"""
|
||||
# Make an axis
|
||||
# Make an axis
|
||||
self.axes = Axes(
|
||||
x_range=self.x_range,
|
||||
y_range=self.y_range,
|
||||
@ -35,17 +45,17 @@ class ActivationFunction(ABC, VGroup):
|
||||
axis_config={
|
||||
"include_numbers": False,
|
||||
"stroke_width": 0.5,
|
||||
"include_ticks": False
|
||||
}
|
||||
"include_ticks": False,
|
||||
},
|
||||
)
|
||||
self.add(self.axes)
|
||||
# Surround the axis with a rounded rectangle.
|
||||
# Surround the axis with a rounded rectangle.
|
||||
self.surrounding_rectangle = SurroundingRectangle(
|
||||
self.axes,
|
||||
corner_radius=0.05,
|
||||
buff=0.05,
|
||||
stroke_width=2.0,
|
||||
stroke_color=self.rectangle_color
|
||||
stroke_color=self.rectangle_color,
|
||||
)
|
||||
self.add(self.surrounding_rectangle)
|
||||
# Plot function on axis by applying it and showing in given range
|
||||
@ -53,17 +63,15 @@ class ActivationFunction(ABC, VGroup):
|
||||
lambda x: self.apply_function(x),
|
||||
x_range=self.x_range,
|
||||
stroke_color=self.plot_color,
|
||||
stroke_width=2.0
|
||||
stroke_width=2.0,
|
||||
)
|
||||
self.add(self.graph)
|
||||
# Add the function name
|
||||
if self.show_function_name:
|
||||
function_name_text = Text(
|
||||
self.function_name,
|
||||
font_size=12,
|
||||
font="sans-serif"
|
||||
self.function_name, font_size=12, font="sans-serif"
|
||||
)
|
||||
function_name_text.next_to(self.axes, UP*0.5)
|
||||
function_name_text.next_to(self.axes, UP * 0.5)
|
||||
self.add(function_name_text)
|
||||
|
||||
@abstractmethod
|
||||
@ -78,29 +86,21 @@ class ActivationFunction(ABC, VGroup):
|
||||
# TODO: Evaluate the function at the x_val and show a highlighted dot
|
||||
animation_group = Succession(
|
||||
AnimationGroup(
|
||||
ApplyMethod(self.graph.set_color, self.active_color),
|
||||
ApplyMethod(
|
||||
self.graph.set_color,
|
||||
self.active_color
|
||||
self.surrounding_rectangle.set_stroke_color, self.active_color
|
||||
),
|
||||
ApplyMethod(
|
||||
self.surrounding_rectangle.set_stroke_color,
|
||||
self.active_color
|
||||
),
|
||||
lag_ratio=0.0
|
||||
lag_ratio=0.0,
|
||||
),
|
||||
Wait(1),
|
||||
AnimationGroup(
|
||||
ApplyMethod(self.graph.set_color, self.plot_color),
|
||||
ApplyMethod(
|
||||
self.graph.set_color,
|
||||
self.plot_color
|
||||
self.surrounding_rectangle.set_stroke_color, self.rectangle_color
|
||||
),
|
||||
ApplyMethod(
|
||||
self.surrounding_rectangle.set_stroke_color,
|
||||
self.rectangle_color
|
||||
),
|
||||
lag_ratio=0.0
|
||||
lag_ratio=0.0,
|
||||
),
|
||||
lag_ratio=1.0
|
||||
lag_ratio=1.0,
|
||||
)
|
||||
|
||||
return animation_group
|
||||
return animation_group
|
||||
|
@ -1,6 +1,9 @@
|
||||
from manim import *
|
||||
|
||||
from manim_ml.neural_network.activation_functions.activation_function import ActivationFunction
|
||||
from manim_ml.neural_network.activation_functions.activation_function import (
|
||||
ActivationFunction,
|
||||
)
|
||||
|
||||
|
||||
class ReLUFunction(ActivationFunction):
|
||||
"""Rectified Linear Unit Activation Function"""
|
||||
@ -12,4 +15,4 @@ class ReLUFunction(ActivationFunction):
|
||||
if x_val < 0:
|
||||
return 0
|
||||
else:
|
||||
return x_val
|
||||
return x_val
|
||||
|
@ -1,11 +1,15 @@
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_feed_forward import (
|
||||
Convolutional2DToFeedForward,
|
||||
)
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_max_pooling_2d import Convolutional2DToMaxPooling2D
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_max_pooling_2d import (
|
||||
Convolutional2DToMaxPooling2D,
|
||||
)
|
||||
from manim_ml.neural_network.layers.image_to_convolutional_2d import (
|
||||
ImageToConvolutional2DLayer,
|
||||
)
|
||||
from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import MaxPooling2DToConvolutional2D
|
||||
from manim_ml.neural_network.layers.max_pooling_2d_to_convolutional_2d import (
|
||||
MaxPooling2DToConvolutional2D,
|
||||
)
|
||||
from .convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D
|
||||
from .convolutional_2d import Convolutional2DLayer
|
||||
from .feed_forward_to_vector import FeedForwardToVector
|
||||
|
@ -1,6 +1,8 @@
|
||||
from typing import Union
|
||||
from manim_ml.neural_network.activation_functions import get_activation_function_by_name
|
||||
from manim_ml.neural_network.activation_functions.activation_function import ActivationFunction
|
||||
from manim_ml.neural_network.activation_functions.activation_function import (
|
||||
ActivationFunction,
|
||||
)
|
||||
import numpy as np
|
||||
from manim import *
|
||||
|
||||
@ -10,6 +12,7 @@ from manim_ml.neural_network.layers.parent_layers import (
|
||||
)
|
||||
from manim_ml.gridded_rectangle import GriddedRectangle
|
||||
|
||||
|
||||
class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
"""Handles rendering a convolutional layer for a nn"""
|
||||
|
||||
@ -33,11 +36,11 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
self.num_feature_maps = num_feature_maps
|
||||
self.filter_color = filter_color
|
||||
if isinstance(feature_map_size, int):
|
||||
self.feature_map_size = (feature_map_size, feature_map_size)
|
||||
self.feature_map_size = (feature_map_size, feature_map_size)
|
||||
else:
|
||||
self.feature_map_size = feature_map_size
|
||||
if isinstance(filter_size, int):
|
||||
self.filter_size = (filter_size, filter_size)
|
||||
self.filter_size = (filter_size, filter_size)
|
||||
else:
|
||||
self.filter_size = filter_size
|
||||
self.cell_width = cell_width
|
||||
@ -50,10 +53,10 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
self.activation_function = activation_function
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: 'NeuralNetworkLayer',
|
||||
output_layer: 'NeuralNetworkLayer',
|
||||
**kwargs
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs,
|
||||
):
|
||||
# Make the feature maps
|
||||
self.feature_maps = self.construct_feature_maps()
|
||||
@ -71,7 +74,7 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
if isinstance(self.activation_function, str):
|
||||
activation_function = get_activation_function_by_name(
|
||||
self.activation_function
|
||||
)
|
||||
)()
|
||||
else:
|
||||
assert isinstance(self.activation_function, ActivationFunction)
|
||||
activation_function = self.activation_function
|
||||
@ -108,14 +111,8 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
def highlight_and_unhighlight_feature_maps(self):
|
||||
"""Highlights then unhighlights feature maps"""
|
||||
return Succession(
|
||||
ApplyMethod(
|
||||
self.feature_maps.set_color,
|
||||
self.pulse_color
|
||||
),
|
||||
ApplyMethod(
|
||||
self.feature_maps.set_color,
|
||||
self.color
|
||||
)
|
||||
ApplyMethod(self.feature_maps.set_color, self.pulse_color),
|
||||
ApplyMethod(self.feature_maps.set_color, self.color),
|
||||
)
|
||||
|
||||
def make_forward_pass_animation(
|
||||
@ -146,7 +143,7 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
animation_group = AnimationGroup(
|
||||
self.activation_function.make_evaluate_animation(),
|
||||
self.highlight_and_unhighlight_feature_maps(),
|
||||
lag_ratio=0.0
|
||||
lag_ratio=0.0,
|
||||
)
|
||||
else:
|
||||
animation_group = AnimationGroup()
|
||||
@ -160,12 +157,19 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
def get_center(self):
|
||||
"""Overrides function for getting center
|
||||
|
||||
The reason for this is so that the center calculation
|
||||
does not include the activation function.
|
||||
The reason for this is so that the center calculation
|
||||
does not include the activation function.
|
||||
"""
|
||||
print("Getting center")
|
||||
return self.feature_maps.get_center()
|
||||
|
||||
def get_width(self):
|
||||
"""Overrides get width function"""
|
||||
return self.feature_maps.length_over_dim(0)
|
||||
|
||||
def get_height(self):
|
||||
"""Overrides get height function"""
|
||||
return self.feature_maps.length_over_dim(1)
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
return FadeIn(self.feature_maps)
|
||||
|
@ -7,16 +7,14 @@ from manim_ml.gridded_rectangle import GriddedRectangle
|
||||
|
||||
from manim.utils.space_ops import rotation_matrix
|
||||
|
||||
|
||||
def get_rotated_shift_vectors(input_layer, normalized=False):
|
||||
"""Rotates the shift vectors"""
|
||||
# Make base shift vectors
|
||||
right_shift = np.array([input_layer.cell_width, 0, 0])
|
||||
down_shift = np.array([0, -input_layer.cell_width, 0])
|
||||
# Make rotation matrix
|
||||
rot_mat = rotation_matrix(
|
||||
ThreeDLayer.rotation_angle,
|
||||
ThreeDLayer.rotation_axis
|
||||
)
|
||||
rot_mat = rotation_matrix(ThreeDLayer.rotation_angle, ThreeDLayer.rotation_axis)
|
||||
# Rotate the vectors
|
||||
right_shift = np.dot(right_shift, rot_mat.T)
|
||||
down_shift = np.dot(down_shift, rot_mat.T)
|
||||
@ -27,6 +25,7 @@ def get_rotated_shift_vectors(input_layer, normalized=False):
|
||||
|
||||
return right_shift, down_shift
|
||||
|
||||
|
||||
class Filters(VGroup):
|
||||
"""Group for showing a collection of filters connecting two layers"""
|
||||
|
||||
@ -61,8 +60,12 @@ class Filters(VGroup):
|
||||
|
||||
def make_input_feature_map_rectangles(self):
|
||||
rectangles = []
|
||||
rectangle_width = self.output_layer.filter_size[0] * self.output_layer.cell_width
|
||||
rectangle_height = self.output_layer.filter_size[1] * self.output_layer.cell_width
|
||||
rectangle_width = (
|
||||
self.output_layer.filter_size[0] * self.output_layer.cell_width
|
||||
)
|
||||
rectangle_height = (
|
||||
self.output_layer.filter_size[1] * self.output_layer.cell_width
|
||||
)
|
||||
filter_color = self.output_layer.filter_color
|
||||
|
||||
for index, feature_map in enumerate(self.input_layer.feature_maps):
|
||||
@ -263,8 +266,10 @@ class Filters(VGroup):
|
||||
|
||||
return passing_flash
|
||||
|
||||
|
||||
class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
|
||||
input_class = Convolutional2DLayer
|
||||
output_class = Convolutional2DLayer
|
||||
|
||||
@ -301,7 +306,12 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
self.show_grid_lines = show_grid_lines
|
||||
self.highlight_color = highlight_color
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs,
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def animate_filters_all_at_once(self, filters):
|
||||
@ -321,8 +331,12 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
right_shift, down_shift = get_rotated_shift_vectors(self.input_layer)
|
||||
left_shift = -1 * right_shift
|
||||
# Make the animation
|
||||
num_y_moves = int((self.feature_map_size[1] - self.filter_size[1]) / self.stride)
|
||||
num_x_moves = int((self.feature_map_size[0] - self.filter_size[0]) / self.stride)
|
||||
num_y_moves = int(
|
||||
(self.feature_map_size[1] - self.filter_size[1]) / self.stride
|
||||
)
|
||||
num_x_moves = int(
|
||||
(self.feature_map_size[0] - self.filter_size[0]) / self.stride
|
||||
)
|
||||
for y_move in range(num_y_moves):
|
||||
# Go right num_x_moves
|
||||
for x_move in range(num_x_moves):
|
||||
@ -347,10 +361,7 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
animations.append(FadeOut(filters))
|
||||
return Succession(*animations, lag_ratio=1.0)
|
||||
|
||||
def animate_filters_one_at_a_time(
|
||||
self,
|
||||
highlight_active_feature_map=True
|
||||
):
|
||||
def animate_filters_one_at_a_time(self, highlight_active_feature_map=True):
|
||||
"""Animates each of the filters one at a time"""
|
||||
animations = []
|
||||
output_feature_maps = self.output_layer.feature_maps
|
||||
@ -418,18 +429,12 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
self.stride * num_x_moves * left_shift + self.stride * down_shift
|
||||
)
|
||||
# Make the animation
|
||||
shift_animation = ApplyMethod(
|
||||
filters.shift,
|
||||
shift_amount
|
||||
)
|
||||
shift_animation = ApplyMethod(filters.shift, shift_amount)
|
||||
animations.append(shift_animation)
|
||||
# Do last row move right
|
||||
for x_move in range(num_x_moves):
|
||||
# Shift right
|
||||
shift_animation = ApplyMethod(
|
||||
filters.shift,
|
||||
self.stride * right_shift
|
||||
)
|
||||
shift_animation = ApplyMethod(filters.shift, self.stride * right_shift)
|
||||
# shift_animation = self.animate.shift(right_shift)
|
||||
animations.append(shift_animation)
|
||||
# Remove the filters
|
||||
@ -440,18 +445,14 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer, ThreeDLayer):
|
||||
# Change the output feature map colors
|
||||
change_color_animations = []
|
||||
change_color_animations.append(
|
||||
ApplyMethod(
|
||||
feature_map.set_color,
|
||||
original_feature_map_color
|
||||
)
|
||||
ApplyMethod(feature_map.set_color, original_feature_map_color)
|
||||
)
|
||||
# Change the input feature map colors
|
||||
input_feature_maps = self.input_layer.feature_maps
|
||||
for input_feature_map in input_feature_maps:
|
||||
change_color_animations.append(
|
||||
ApplyMethod(
|
||||
input_feature_map.set_color,
|
||||
original_feature_map_color
|
||||
input_feature_map.set_color, original_feature_map_color
|
||||
)
|
||||
)
|
||||
# Combine the animations
|
||||
|
@ -17,14 +17,15 @@ class Convolutional2DToFeedForward(ConnectiveLayer, ThreeDLayer):
|
||||
passing_flash_color=ORANGE,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.passing_flash_color = passing_flash_color
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
|
||||
|
@ -1,15 +1,19 @@
|
||||
import random
|
||||
from manim import *
|
||||
from manim_ml.gridded_rectangle import GriddedRectangle
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_convolutional_2d import get_rotated_shift_vectors
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_convolutional_2d import (
|
||||
get_rotated_shift_vectors,
|
||||
)
|
||||
|
||||
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
|
||||
|
||||
|
||||
class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
|
||||
input_class = Convolutional2DLayer
|
||||
output_class = MaxPooling2DLayer
|
||||
|
||||
@ -20,27 +24,18 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
active_color=ORANGE,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.active_color = active_color
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: 'NeuralNetworkLayer',
|
||||
output_layer: 'NeuralNetworkLayer',
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(
|
||||
self,
|
||||
layer_args={},
|
||||
run_time=1.5,
|
||||
**kwargs
|
||||
):
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
|
||||
"""Forward pass animation from conv2d to max pooling"""
|
||||
cell_width = self.input_layer.cell_width
|
||||
feature_map_size = self.input_layer.feature_map_size
|
||||
@ -61,7 +56,7 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
remove_gridded_rectangle_animations = []
|
||||
|
||||
for feature_map_index, feature_map in enumerate(feature_maps):
|
||||
# 1. Draw gridded rectangle with kernel_size x kernel_size
|
||||
# 1. Draw gridded rectangle with kernel_size x kernel_size
|
||||
# box regions over the input feature maps.
|
||||
gridded_rectangle = GriddedRectangle(
|
||||
color=self.active_color,
|
||||
@ -71,7 +66,7 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
grid_ystep=cell_width * kernel_size,
|
||||
grid_stroke_width=grid_stroke_width,
|
||||
grid_stroke_color=self.active_color,
|
||||
show_grid_lines=True
|
||||
show_grid_lines=True,
|
||||
)
|
||||
# 2. Randomly highlight one of the cells in the kernel.
|
||||
highlighted_cells = []
|
||||
@ -88,39 +83,32 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
height=cell_width,
|
||||
width=cell_width,
|
||||
stroke_width=0.0,
|
||||
fill_opacity=0.7
|
||||
fill_opacity=0.7,
|
||||
)
|
||||
# Move to the correct location
|
||||
kernel_shift_vector = [
|
||||
kernel_size * cell_width * kernel_x,
|
||||
-1 * kernel_size * cell_width * kernel_y,
|
||||
0
|
||||
0,
|
||||
]
|
||||
cell_shift_vector = [
|
||||
(cell_index % kernel_size) * cell_width,
|
||||
-1 * int(cell_index / kernel_size) * cell_width,
|
||||
0
|
||||
0,
|
||||
]
|
||||
cell_rectangle.next_to(
|
||||
gridded_rectangle.get_corners_dict()["top_left"],
|
||||
submobject_to_align=cell_rectangle.get_corners_dict()["top_left"],
|
||||
buff=0.0
|
||||
submobject_to_align=cell_rectangle.get_corners_dict()[
|
||||
"top_left"
|
||||
],
|
||||
buff=0.0,
|
||||
)
|
||||
cell_rectangle.shift(
|
||||
kernel_shift_vector
|
||||
)
|
||||
cell_rectangle.shift(
|
||||
cell_shift_vector
|
||||
)
|
||||
highlighted_cells.append(
|
||||
cell_rectangle
|
||||
)
|
||||
# Rotate the gridded rectangles so they match the angle
|
||||
cell_rectangle.shift(kernel_shift_vector)
|
||||
cell_rectangle.shift(cell_shift_vector)
|
||||
highlighted_cells.append(cell_rectangle)
|
||||
# Rotate the gridded rectangles so they match the angle
|
||||
# of the conv maps
|
||||
gridded_rectangle_group = VGroup(
|
||||
gridded_rectangle,
|
||||
*highlighted_cells
|
||||
)
|
||||
gridded_rectangle_group = VGroup(gridded_rectangle, *highlighted_cells)
|
||||
gridded_rectangle_group.rotate(
|
||||
ThreeDLayer.rotation_angle,
|
||||
about_point=gridded_rectangle.get_center(),
|
||||
@ -129,9 +117,9 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
gridded_rectangle.next_to(
|
||||
feature_map.get_corners_dict()["top_left"],
|
||||
submobject_to_align=gridded_rectangle.get_corners_dict()["top_left"],
|
||||
buff=0.0
|
||||
buff=0.0,
|
||||
)
|
||||
# 3. Make a create gridded rectangle
|
||||
# 3. Make a create gridded rectangle
|
||||
"""
|
||||
create_rectangle = Create(
|
||||
gridded_rectangle
|
||||
@ -185,31 +173,24 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
create_and_remove_cell_animations = Succession(
|
||||
Create(VGroup(*highlighted_cells)),
|
||||
Wait(0.5),
|
||||
Uncreate(VGroup(*highlighted_cells))
|
||||
Uncreate(VGroup(*highlighted_cells)),
|
||||
)
|
||||
return create_and_remove_cell_animations
|
||||
# 5. Move and resize the gridded rectangle to the output
|
||||
# feature maps.
|
||||
# 5. Move and resize the gridded rectangle to the output
|
||||
# feature maps.
|
||||
resize_rectangle = Transform(
|
||||
gridded_rectangle,
|
||||
self.output_layer.feature_maps[feature_map_index]
|
||||
gridded_rectangle, self.output_layer.feature_maps[feature_map_index]
|
||||
)
|
||||
move_rectangle = gridded_rectangle.animate.move_to(
|
||||
self.output_layer.feature_maps[feature_map_index]
|
||||
)
|
||||
move_and_resize = Succession(
|
||||
resize_rectangle,
|
||||
move_rectangle,
|
||||
lag_ratio=0.0
|
||||
resize_rectangle, move_rectangle, lag_ratio=0.0
|
||||
)
|
||||
move_and_resize_gridded_rectangle_animations.append(
|
||||
move_and_resize
|
||||
)
|
||||
# 6. Make the gridded feature map(s) disappear.
|
||||
move_and_resize_gridded_rectangle_animations.append(move_and_resize)
|
||||
# 6. Make the gridded feature map(s) disappear.
|
||||
remove_gridded_rectangle_animations.append(
|
||||
Uncreate(
|
||||
gridded_rectangle_group
|
||||
)
|
||||
Uncreate(gridded_rectangle_group)
|
||||
)
|
||||
|
||||
"""
|
||||
@ -224,5 +205,5 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
# *remove_gridded_rectangle_animations
|
||||
# ),
|
||||
# lag_ratio=1.0
|
||||
lag_ratio=1.0
|
||||
)
|
||||
lag_ratio=1.0,
|
||||
)
|
||||
|
@ -25,9 +25,9 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
self.paired_query_mode = paired_query_mode
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: 'NeuralNetworkLayer',
|
||||
output_layer: 'NeuralNetworkLayer',
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
self.axes = Axes(
|
||||
@ -43,14 +43,13 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
self.axes.move_to(self.get_center())
|
||||
# Make point cloud
|
||||
self.point_cloud = self.construct_gaussian_point_cloud(
|
||||
self.mean,
|
||||
self.covariance
|
||||
self.mean, self.covariance
|
||||
)
|
||||
self.add(self.point_cloud)
|
||||
# Make latent distribution
|
||||
self.latent_distribution = GaussianDistribution(
|
||||
self.axes, mean=self.mean, cov=self.covariance
|
||||
) # Use defaults
|
||||
) # Use defaults
|
||||
|
||||
def add_gaussian_distribution(self, gaussian_distribution):
|
||||
"""Adds given GaussianDistribution to the list"""
|
||||
|
@ -3,6 +3,7 @@ from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||
|
||||
|
||||
class EmbeddingToFeedForward(ConnectiveLayer):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
|
||||
@ -17,17 +18,18 @@ class EmbeddingToFeedForward(ConnectiveLayer):
|
||||
dot_radius=0.03,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.feed_forward_layer = output_layer
|
||||
self.embedding_layer = input_layer
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
|
||||
|
@ -35,7 +35,12 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
|
||||
|
||||
self.node_group = VGroup()
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
"""Creates the neural network layer"""
|
||||
# Add Nodes
|
||||
for node_number in range(self.num_nodes):
|
||||
|
@ -18,17 +18,18 @@ class FeedForwardToEmbedding(ConnectiveLayer):
|
||||
dot_radius=0.03,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.feed_forward_layer = input_layer
|
||||
self.embedding_layer = output_layer
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
|
||||
|
@ -5,6 +5,7 @@ from manim import *
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
||||
|
||||
|
||||
class FeedForwardToFeedForward(ConnectiveLayer):
|
||||
"""Layer for connecting FeedForward layer to FeedForwardLayer"""
|
||||
|
||||
@ -23,18 +24,19 @@ class FeedForwardToFeedForward(ConnectiveLayer):
|
||||
camera=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.passing_flash = passing_flash
|
||||
self.edge_color = edge_color
|
||||
self.dot_radius = dot_radius
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.edge_width = edge_width
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
self.edges = self.construct_edges()
|
||||
self.add(self.edges)
|
||||
|
||||
|
@ -18,18 +18,19 @@ class FeedForwardToImage(ConnectiveLayer):
|
||||
dot_radius=0.05,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
self.feed_forward_layer = input_layer
|
||||
self.image_layer = output_layer
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
|
@ -18,18 +18,19 @@ class FeedForwardToVector(ConnectiveLayer):
|
||||
dot_radius=0.05,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
self.feed_forward_layer = input_layer
|
||||
self.vector_layer = output_layer
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
|
@ -5,6 +5,7 @@ from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageLayer(NeuralNetworkLayer):
|
||||
"""Single Image Layer for Neural Network"""
|
||||
|
||||
@ -19,23 +20,22 @@ class ImageLayer(NeuralNetworkLayer):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_layer :
|
||||
input_layer :
|
||||
Input layer
|
||||
output_layer :
|
||||
output_layer :
|
||||
Output layer
|
||||
"""
|
||||
if len(np.shape(self.numpy_image)) == 2:
|
||||
# Assumed Grayscale
|
||||
self.num_channels = 1
|
||||
self.image_mobject = GrayscaleImageMobject(
|
||||
self.numpy_image,
|
||||
height=self.image_height
|
||||
self.numpy_image, height=self.image_height
|
||||
)
|
||||
elif len(np.shape(self.numpy_image)) == 3:
|
||||
# Assumed RGB
|
||||
self.num_channels = 3
|
||||
self.image_mobject = ImageMobject(self.numpy_image).scale_to_fit_height(
|
||||
height
|
||||
self.image_height
|
||||
)
|
||||
self.add(self.image_mobject)
|
||||
|
||||
@ -63,10 +63,6 @@ class ImageLayer(NeuralNetworkLayer):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
return AnimationGroup()
|
||||
|
||||
# def move_to(self, location):
|
||||
# """Override of move to"""
|
||||
# self.image_mobject.move_to(location)
|
||||
|
||||
def get_right(self):
|
||||
"""Override get right"""
|
||||
return self.image_mobject.get_right()
|
||||
|
@ -9,6 +9,7 @@ from manim_ml.neural_network.layers.parent_layers import (
|
||||
)
|
||||
from manim_ml.gridded_rectangle import GriddedRectangle
|
||||
|
||||
|
||||
class ImageToConvolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
"""Handles rendering a convolutional layer for a nn"""
|
||||
|
||||
@ -16,16 +17,18 @@ class ImageToConvolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
output_class = Convolutional2DLayer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_layer: ImageLayer,
|
||||
output_layer: Convolutional2DLayer,
|
||||
**kwargs
|
||||
self, input_layer: ImageLayer, output_layer: Convolutional2DLayer, **kwargs
|
||||
):
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.input_layer = input_layer
|
||||
self.output_layer = output_layer
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs,
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(self, run_time=5, layer_args={}, **kwargs):
|
||||
|
@ -18,18 +18,19 @@ class ImageToFeedForward(ConnectiveLayer):
|
||||
dot_radius=0.05,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
self.feed_forward_layer = output_layer
|
||||
self.image_layer = input_layer
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
|
@ -1,27 +1,31 @@
|
||||
from manim import *
|
||||
from manim_ml.gridded_rectangle import GriddedRectangle
|
||||
|
||||
from manim_ml.neural_network.layers.parent_layers import ThreeDLayer, VGroupNeuralNetworkLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import (
|
||||
ThreeDLayer,
|
||||
VGroupNeuralNetworkLayer,
|
||||
)
|
||||
|
||||
|
||||
class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
"""Max pooling layer for Convolutional2DLayer
|
||||
|
||||
Note: This is for a Convolutional2DLayer even though
|
||||
it is called MaxPooling2DLayer because the 2D corresponds
|
||||
to the 2 spatial dimensions of the convolution.
|
||||
to the 2 spatial dimensions of the convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size=2,
|
||||
stride=1,
|
||||
cell_highlight_color=ORANGE,
|
||||
cell_width=0.2,
|
||||
filter_spacing=0.1,
|
||||
color=BLUE,
|
||||
show_grid_lines=False,
|
||||
stroke_width=2.0,
|
||||
**kwargs
|
||||
self,
|
||||
kernel_size=2,
|
||||
stride=1,
|
||||
cell_highlight_color=ORANGE,
|
||||
cell_width=0.2,
|
||||
filter_spacing=0.1,
|
||||
color=BLUE,
|
||||
show_grid_lines=False,
|
||||
stroke_width=2.0,
|
||||
**kwargs
|
||||
):
|
||||
"""Layer object for animating 2D Convolution Max Pooling
|
||||
|
||||
@ -42,11 +46,15 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
self.show_grid_lines = show_grid_lines
|
||||
self.stroke_width = stroke_width
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
# Make the output feature maps
|
||||
self.feature_maps = self._make_output_feature_maps(
|
||||
input_layer.num_feature_maps,
|
||||
input_layer.feature_map_size
|
||||
input_layer.num_feature_maps, input_layer.feature_map_size
|
||||
)
|
||||
self.add(self.feature_maps)
|
||||
self.rotate(
|
||||
@ -58,17 +66,13 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
input_layer.feature_map_size[0] / self.kernel_size,
|
||||
input_layer.feature_map_size[1] / self.kernel_size,
|
||||
)
|
||||
|
||||
def _make_output_feature_maps(
|
||||
self,
|
||||
num_input_feature_maps,
|
||||
input_feature_map_size
|
||||
):
|
||||
|
||||
def _make_output_feature_maps(self, num_input_feature_maps, input_feature_map_size):
|
||||
"""Makes a set of output feature maps"""
|
||||
# Compute the size of the feature maps
|
||||
# Compute the size of the feature maps
|
||||
output_feature_map_size = (
|
||||
input_feature_map_size[0] / self.kernel_size,
|
||||
input_feature_map_size[1] / self.kernel_size
|
||||
input_feature_map_size[1] / self.kernel_size,
|
||||
)
|
||||
# Draw rectangles that are filled in with opacity
|
||||
feature_maps = []
|
||||
@ -92,12 +96,10 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
# rectangle.set_z_index(4)
|
||||
feature_maps.append(rectangle)
|
||||
|
||||
return VGroup(
|
||||
*feature_maps
|
||||
)
|
||||
return VGroup(*feature_maps)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Makes forward pass of Max Pooling Layer.
|
||||
"""Makes forward pass of Max Pooling Layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -1,7 +1,10 @@
|
||||
import numpy as np
|
||||
from manim import *
|
||||
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_convolutional_2d import Convolutional2DToConvolutional2D, Filters
|
||||
from manim_ml.neural_network.layers.convolutional_2d_to_convolutional_2d import (
|
||||
Convolutional2DToConvolutional2D,
|
||||
Filters,
|
||||
)
|
||||
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
@ -9,8 +12,10 @@ from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
|
||||
|
||||
from manim.utils.space_ops import rotation_matrix
|
||||
|
||||
|
||||
class MaxPooling2DToConvolutional2D(Convolutional2DToConvolutional2D):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
|
||||
input_class = MaxPooling2DLayer
|
||||
output_class = Convolutional2DLayer
|
||||
|
||||
@ -25,20 +30,16 @@ class MaxPooling2DToConvolutional2D(Convolutional2DToConvolutional2D):
|
||||
**kwargs
|
||||
):
|
||||
input_layer.num_feature_maps = output_layer.num_feature_maps
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.passing_flash_color = passing_flash_color
|
||||
self.cell_width = cell_width
|
||||
self.stroke_width = stroke_width
|
||||
self.show_grid_lines = show_grid_lines
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: 'NeuralNetworkLayer',
|
||||
output_layer: 'NeuralNetworkLayer',
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
"""Constructs the MaxPooling to Convolution3D layer
|
||||
|
@ -22,7 +22,12 @@ class PairedQueryLayer(NeuralNetworkLayer):
|
||||
self.add(self.assets)
|
||||
self.add(self.title)
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
@ -18,18 +18,19 @@ class PairedQueryToFeedForward(ConnectiveLayer):
|
||||
dot_radius=0.02,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
self.paired_query_layer = input_layer
|
||||
self.feed_forward_layer = output_layer
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
|
@ -1,6 +1,7 @@
|
||||
from manim import *
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class NeuralNetworkLayer(ABC, Group):
|
||||
"""Abstract Neural Network Layer class"""
|
||||
|
||||
@ -12,8 +13,12 @@ class NeuralNetworkLayer(ABC, Group):
|
||||
# self.add(self.title)
|
||||
|
||||
@abstractmethod
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer',
|
||||
output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs,
|
||||
):
|
||||
"""Constructs the layer at network construction time
|
||||
|
||||
Parameters
|
||||
@ -36,6 +41,7 @@ class NeuralNetworkLayer(ABC, Group):
|
||||
def __repr__(self):
|
||||
return f"{type(self).__name__}"
|
||||
|
||||
|
||||
class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -49,6 +55,7 @@ class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
|
||||
def _create_override(self):
|
||||
return super()._create_override()
|
||||
|
||||
|
||||
class ThreeDLayer(ABC):
|
||||
"""Abstract class for 3D layers"""
|
||||
|
||||
@ -58,6 +65,7 @@ class ThreeDLayer(ABC):
|
||||
rotation_angle = 60 * DEGREES
|
||||
rotation_axis = [0.0, 0.9, 0.0]
|
||||
|
||||
|
||||
class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
||||
"""Forward pass animation for a given pair of layers"""
|
||||
|
||||
@ -86,6 +94,7 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
||||
+ ")"
|
||||
)
|
||||
|
||||
|
||||
class BlankConnective(ConnectiveLayer):
|
||||
"""Connective layer to be used when the given pair of layers is undefined"""
|
||||
|
||||
@ -97,4 +106,4 @@ class BlankConnective(ConnectiveLayer):
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self):
|
||||
return super()._create_override()
|
||||
return super()._create_override()
|
||||
|
@ -26,7 +26,12 @@ class TripletLayer(NeuralNetworkLayer):
|
||||
self.stroke_width = stroke_width
|
||||
self.font_size = font_size
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
# Make the assets
|
||||
self.assets = self.make_assets()
|
||||
self.add(self.assets)
|
||||
|
@ -18,18 +18,19 @@ class TripletToFeedForward(ConnectiveLayer):
|
||||
dot_radius=0.02,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
input_layer,
|
||||
output_layer,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
self.feed_forward_layer = output_layer
|
||||
self.triplet_layer = input_layer
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
|
@ -4,9 +4,10 @@ from manim import *
|
||||
from manim_ml.neural_network.layers.parent_layers import BlankConnective, ThreeDLayer
|
||||
from manim_ml.neural_network.layers import connective_layers_list
|
||||
|
||||
|
||||
def get_connective_layer(input_layer, output_layer):
|
||||
"""
|
||||
Deduces the relevant connective layer
|
||||
Deduces the relevant connective layer
|
||||
"""
|
||||
connective_layer_class = None
|
||||
for candidate_class in connective_layers_list:
|
||||
|
@ -12,7 +12,12 @@ class VectorLayer(VGroupNeuralNetworkLayer):
|
||||
self.num_values = num_values
|
||||
self.value_func = value_func
|
||||
|
||||
def construct_layer(self, input_layer: 'NeuralNetworkLayer', output_layer: 'NeuralNetworkLayer', **kwargs):
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs,
|
||||
):
|
||||
# Make the vector
|
||||
self.vector_label = self.make_vector()
|
||||
self.add(self.vector_label)
|
||||
|
@ -22,6 +22,7 @@ from manim_ml.neural_network.neural_network_transformations import (
|
||||
RemoveLayer,
|
||||
)
|
||||
|
||||
|
||||
class NeuralNetwork(Group):
|
||||
"""Neural Network Visualization Container Class"""
|
||||
|
||||
@ -53,17 +54,11 @@ class NeuralNetwork(Group):
|
||||
# Construct all of the layers
|
||||
self._construct_input_layers()
|
||||
# Place the layers
|
||||
self._place_layers(
|
||||
layout=layout,
|
||||
layout_direction=layout_direction
|
||||
)
|
||||
self._place_layers(layout=layout, layout_direction=layout_direction)
|
||||
# Make the connective layers
|
||||
self.connective_layers, self.all_layers = self._construct_connective_layers()
|
||||
# Make overhead title
|
||||
self.title = Text(
|
||||
self.title_text,
|
||||
font_size=DEFAULT_FONT_SIZE / 2
|
||||
)
|
||||
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 2)
|
||||
self.title.next_to(self, UP, 1.0)
|
||||
self.add(self.title)
|
||||
# Place layers at correct z index
|
||||
@ -76,7 +71,7 @@ class NeuralNetwork(Group):
|
||||
print(repr(self))
|
||||
|
||||
def _construct_input_layers(self):
|
||||
"""Constructs each of the input layers in context
|
||||
"""Constructs each of the input layers in context
|
||||
of their adjacent layers"""
|
||||
prev_layer = None
|
||||
next_layer = None
|
||||
@ -105,64 +100,82 @@ class NeuralNetwork(Group):
|
||||
previous_layer, EmbeddingLayer
|
||||
):
|
||||
if layout_direction == "left_to_right":
|
||||
shift_vector = np.array([
|
||||
(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
- 0.2
|
||||
),
|
||||
0,
|
||||
0,
|
||||
])
|
||||
shift_vector = np.array(
|
||||
[
|
||||
(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
- 0.2
|
||||
),
|
||||
0,
|
||||
0,
|
||||
]
|
||||
)
|
||||
elif layout_direction == "top_to_bottom":
|
||||
shift_vector = np.array([
|
||||
0,
|
||||
-(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
- 0.2
|
||||
),
|
||||
0,
|
||||
])
|
||||
shift_vector = np.array(
|
||||
[
|
||||
0,
|
||||
-(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
- 0.2
|
||||
),
|
||||
0,
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unrecognized layout direction: {layout_direction}"
|
||||
)
|
||||
else:
|
||||
if layout_direction == "left_to_right":
|
||||
shift_vector = np.array([
|
||||
(
|
||||
shift_vector = np.array(
|
||||
[
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
)
|
||||
+ self.layer_spacing,
|
||||
0,
|
||||
0,
|
||||
])
|
||||
+ self.layer_spacing,
|
||||
0,
|
||||
0,
|
||||
]
|
||||
)
|
||||
elif layout_direction == "top_to_bottom":
|
||||
shift_vector = np.array([
|
||||
0,
|
||||
-(
|
||||
(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
)
|
||||
+ self.layer_spacing
|
||||
),
|
||||
0,
|
||||
])
|
||||
shift_vector = np.array(
|
||||
[
|
||||
0,
|
||||
-(
|
||||
(
|
||||
previous_layer.get_width() / 2
|
||||
+ current_layer.get_width() / 2
|
||||
)
|
||||
+ self.layer_spacing
|
||||
),
|
||||
0,
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unrecognized layout direction: {layout_direction}"
|
||||
)
|
||||
current_layer.shift(shift_vector)
|
||||
|
||||
# After all layers have been placed place their activation functions
|
||||
for current_layer in self.input_layers:
|
||||
# Place activation function
|
||||
if hasattr(current_layer, "activation_function"):
|
||||
if not current_layer.activation_function is None:
|
||||
current_layer.activation_function.next_to(
|
||||
current_layer,
|
||||
direction=UP
|
||||
up_movement = np.array(
|
||||
[
|
||||
0,
|
||||
current_layer.get_height() / 2
|
||||
+ current_layer.activation_function.get_height() / 2
|
||||
+ 0.5 * self.layer_spacing,
|
||||
0,
|
||||
]
|
||||
)
|
||||
current_layer.activation_function.move_to(
|
||||
current_layer,
|
||||
)
|
||||
current_layer.activation_function.shift(up_movement)
|
||||
self.add(current_layer.activation_function)
|
||||
|
||||
def _construct_connective_layers(self):
|
||||
@ -228,8 +241,8 @@ class NeuralNetwork(Group):
|
||||
# Get the layer args
|
||||
if isinstance(layer, ConnectiveLayer):
|
||||
"""
|
||||
NOTE: By default a connective layer will get the combined
|
||||
layer_args of the layers it is connecting and itself.
|
||||
NOTE: By default a connective layer will get the combined
|
||||
layer_args of the layers it is connecting and itself.
|
||||
"""
|
||||
before_layer_args = {}
|
||||
current_layer_args = {}
|
||||
@ -252,16 +265,11 @@ class NeuralNetwork(Group):
|
||||
current_layer_args = layer_args[layer]
|
||||
# Perform the forward pass of the current layer
|
||||
layer_forward_pass = layer.make_forward_pass_animation(
|
||||
layer_args=current_layer_args,
|
||||
run_time=per_layer_runtime,
|
||||
**kwargs
|
||||
layer_args=current_layer_args, run_time=per_layer_runtime, **kwargs
|
||||
)
|
||||
all_animations.append(layer_forward_pass)
|
||||
# Make the animation group
|
||||
animation_group = Succession(
|
||||
*all_animations,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
animation_group = Succession(*all_animations, lag_ratio=1.0)
|
||||
|
||||
return animation_group
|
||||
|
||||
@ -332,6 +340,7 @@ class NeuralNetwork(Group):
|
||||
string_repr = "NeuralNetwork([\n" + inner_string + "])"
|
||||
return string_repr
|
||||
|
||||
|
||||
class FeedForwardNeuralNetwork(NeuralNetwork):
|
||||
"""NeuralNetwork with just feed forward layers"""
|
||||
|
||||
|
@ -2,6 +2,7 @@ from manim import *
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
class GaussianDistribution(VGroup):
|
||||
"""Object for drawing a Gaussian distribution"""
|
||||
|
||||
|
Reference in New Issue
Block a user