Working initial visualization of a CNN.

This commit is contained in:
Alec Helbling
2022-12-29 14:09:16 -05:00
parent 330ba170a0
commit 8cee86e884
18 changed files with 384 additions and 236 deletions

View File

@ -1,12 +1,20 @@
from manim import *
from manim_ml.neural_network.layers.convolutional3d import Convolutional3DLayer
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
from manim_ml.gridded_rectangle import GriddedRectangle, CornersRectangle
from manim_ml.gridded_rectangle import GriddedRectangle
from manim.utils.space_ops import rotation_matrix
class Filters(VGroup):
"""Group for showing a collection of filters connecting two layers"""
def __init__(self, input_layer, output_layer, line_color=ORANGE, stroke_width=2.0):
def __init__(
self,
input_layer,
output_layer,
line_color=ORANGE,
stroke_width=2.0,
):
super().__init__()
self.input_layer = input_layer
self.output_layer = output_layer
@ -29,7 +37,6 @@ class Filters(VGroup):
for index, feature_map in enumerate(self.input_layer.feature_maps):
rectangle = GriddedRectangle(
center=feature_map.get_center(),
width=rectangle_width,
height=rectangle_height,
fill_color=filter_color,
@ -38,21 +45,20 @@ class Filters(VGroup):
z_index=2,
stroke_width=self.stroke_width,
)
# Center on feature map
# rectangle.move_to(feature_map.get_center())
# Rotate so it is in the yz plane
rectangle.rotate(
90 * DEGREES,
ThreeDLayer.three_d_x_rotation,
about_point=rectangle.get_center(),
axis=[1, 0, 0]
)
rectangle.rotate(
ThreeDLayer.three_d_y_rotation,
about_point=rectangle.get_center(),
axis=[0, 1, 0]
)
# Get the feature map top left corner
feature_map_top_left = feature_map.get_corners_dict(inner_rectangle=True)["top_left"]
rectangle_top_left = rectangle.get_corners_dict()["top_left"]
# Move the rectangle to the corner location
rectangle.next_to(
feature_map_top_left,
submobject_to_align=rectangle_top_left,
buff=0.0
# Move the rectangle to the corner of the feature map
rectangle.move_to(
feature_map,
aligned_edge=np.array([-1, 1, 0])
)
rectangles.append(rectangle)
@ -70,7 +76,6 @@ class Filters(VGroup):
for index, feature_map in enumerate(self.output_layer.feature_maps):
rectangle = GriddedRectangle(
center=feature_map.get_center(),
width=rectangle_width,
height=rectangle_height,
fill_color=filter_color,
@ -81,21 +86,22 @@ class Filters(VGroup):
)
# Center on feature map
# rectangle.move_to(feature_map.get_center())
# Rotate so it is in the yz plane
# Rotate the rectangle
rectangle.rotate(
90 * DEGREES,
ThreeDLayer.three_d_x_rotation,
about_point=rectangle.get_center(),
axis=[1, 0, 0]
)
rectangle.rotate(
ThreeDLayer.three_d_y_rotation,
about_point=rectangle.get_center(),
axis=[0, 1, 0]
)
# Get the feature map top left corner
feature_map_top_left = feature_map.get_corners_dict(inner_rectangle=True)["top_left"]
rectangle_top_left = rectangle.get_corners_dict()["top_left"]
# Move the rectangle to the corner location
rectangle.next_to(
feature_map_top_left,
submobject_to_align=rectangle_top_left,
buff=0.0
rectangle.move_to(
feature_map,
aligned_edge=np.array([-1, 1, 0])
)
rectangles.append(rectangle)
feature_map_rectangles = VGroup(*rectangles)
@ -202,6 +208,9 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
self.filter_opacity = filter_opacity
self.line_color = line_color
self.pulse_color = pulse_color
# Make filters
self.filters = Filters(self.input_layer, self.output_layer)
self.add(self.filters)
def make_filter_propagation_animation(self):
"""Make filter propagation animation"""
@ -219,32 +228,71 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
return animation_group
def get_rotated_shift_vectors(self):
"""
Rotates the shift vectors
"""
x_rot_mat = rotation_matrix(
ThreeDLayer.three_d_x_rotation,
[1, 0, 0]
)
y_rot_mat = rotation_matrix(
ThreeDLayer.three_d_y_rotation,
[0, 1, 0]
)
# Make base shift vectors
right_shift = np.array([self.input_layer.cell_width, 0, 0])
down_shift = np.array([0, -self.input_layer.cell_width, 0])
# Rotate the vectors
right_shift = np.dot(right_shift, x_rot_mat.T)
right_shift = np.dot(right_shift, y_rot_mat.T)
down_shift = np.dot(down_shift, x_rot_mat.T)
down_shift = np.dot(down_shift, y_rot_mat.T)
return right_shift, down_shift
def make_forward_pass_animation(self, layer_args={}, run_time=10.5, **kwargs):
"""Forward pass animation from conv2d to conv2d"""
animations = []
# Create the filters, output nodes (feature map square), and lines
filters = Filters(self.input_layer, self.output_layer)
self.add(filters)
# Rotate given three_d_phi and three_d_theta
# Rotate about center
# filters.rotate(110 * DEGREES, about_point=filters.get_center(), axis=[0, 0, 1])
"""
self.filters.rotate(
ThreeDLayer.three_d_x_rotation,
about_point=self.filters.get_center(),
axis=[1, 0, 0]
)
self.filters.rotate(
ThreeDLayer.three_d_y_rotation,
about_point=self.filters.get_center(),
axis=[0, 1, 0]
)
"""
# Get shift vectors
right_shift, down_shift = self.get_rotated_shift_vectors()
left_shift = -1 * right_shift
# filters.rotate(ThreeDLayer.three_d_theta, axis=[0, 0, 1])
# filters.rotate(ThreeDLayer.three_d_phi, axis=-filters.get_center())
# Make animations for creating the filters, output_nodes, and filter_lines
# TODO decide if I want to create the filters at the start of a conv
# animation or have them there by default
# animations.append(
# Create(filters)
# )
# Make shift amounts
right_shift = np.array([0, self.input_layer.cell_width, 0])# * 1.55
left_shift = np.array([0, -1*self.input_layer.cell_width, 0])# * 1.55
up_shift = np.array([0, 0, -1*self.input_layer.cell_width])# * 1.55
down_shift = np.array([0, 0, self.input_layer.cell_width])# * 1.55
# Rotate the base shift vectors
# Make filter shifting animations
num_y_moves = int((self.feature_map_height - self.filter_height) / self.stride)
num_x_moves = int((self.feature_map_width - self.filter_width) / self.stride)
for y_move in range(num_y_moves):
# Go right num_x_moves
for x_move in range(num_x_moves):
print(right_shift)
# Shift right
shift_animation = ApplyMethod(
filters.shift,
self.filters.shift,
self.stride * right_shift
)
# shift_animation = self.animate.shift(right_shift)
@ -254,7 +302,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
shift_amount = self.stride * num_x_moves * left_shift + self.stride * down_shift
# Make the animation
shift_animation = ApplyMethod(
filters.shift,
self.filters.shift,
shift_amount
)
animations.append(shift_animation)
@ -262,7 +310,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
for x_move in range(num_x_moves):
# Shift right
shift_animation = ApplyMethod(
filters.shift,
self.filters.shift,
self.stride * right_shift
)
# shift_animation = self.animate.shift(right_shift)