mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-18 12:07:46 +08:00
Working initial visualization of a CNN.
This commit is contained in:
@ -1,12 +1,20 @@
|
||||
from manim import *
|
||||
from manim_ml.neural_network.layers.convolutional3d import Convolutional3DLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
|
||||
from manim_ml.gridded_rectangle import GriddedRectangle, CornersRectangle
|
||||
from manim_ml.gridded_rectangle import GriddedRectangle
|
||||
|
||||
from manim.utils.space_ops import rotation_matrix
|
||||
|
||||
class Filters(VGroup):
|
||||
"""Group for showing a collection of filters connecting two layers"""
|
||||
|
||||
def __init__(self, input_layer, output_layer, line_color=ORANGE, stroke_width=2.0):
|
||||
def __init__(
|
||||
self,
|
||||
input_layer,
|
||||
output_layer,
|
||||
line_color=ORANGE,
|
||||
stroke_width=2.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_layer = input_layer
|
||||
self.output_layer = output_layer
|
||||
@ -29,7 +37,6 @@ class Filters(VGroup):
|
||||
|
||||
for index, feature_map in enumerate(self.input_layer.feature_maps):
|
||||
rectangle = GriddedRectangle(
|
||||
center=feature_map.get_center(),
|
||||
width=rectangle_width,
|
||||
height=rectangle_height,
|
||||
fill_color=filter_color,
|
||||
@ -38,21 +45,20 @@ class Filters(VGroup):
|
||||
z_index=2,
|
||||
stroke_width=self.stroke_width,
|
||||
)
|
||||
# Center on feature map
|
||||
# rectangle.move_to(feature_map.get_center())
|
||||
# Rotate so it is in the yz plane
|
||||
rectangle.rotate(
|
||||
90 * DEGREES,
|
||||
ThreeDLayer.three_d_x_rotation,
|
||||
about_point=rectangle.get_center(),
|
||||
axis=[1, 0, 0]
|
||||
)
|
||||
rectangle.rotate(
|
||||
ThreeDLayer.three_d_y_rotation,
|
||||
about_point=rectangle.get_center(),
|
||||
axis=[0, 1, 0]
|
||||
)
|
||||
# Get the feature map top left corner
|
||||
feature_map_top_left = feature_map.get_corners_dict(inner_rectangle=True)["top_left"]
|
||||
rectangle_top_left = rectangle.get_corners_dict()["top_left"]
|
||||
# Move the rectangle to the corner location
|
||||
rectangle.next_to(
|
||||
feature_map_top_left,
|
||||
submobject_to_align=rectangle_top_left,
|
||||
buff=0.0
|
||||
# Move the rectangle to the corner of the feature map
|
||||
rectangle.move_to(
|
||||
feature_map,
|
||||
aligned_edge=np.array([-1, 1, 0])
|
||||
)
|
||||
|
||||
rectangles.append(rectangle)
|
||||
@ -70,7 +76,6 @@ class Filters(VGroup):
|
||||
|
||||
for index, feature_map in enumerate(self.output_layer.feature_maps):
|
||||
rectangle = GriddedRectangle(
|
||||
center=feature_map.get_center(),
|
||||
width=rectangle_width,
|
||||
height=rectangle_height,
|
||||
fill_color=filter_color,
|
||||
@ -81,21 +86,22 @@ class Filters(VGroup):
|
||||
)
|
||||
# Center on feature map
|
||||
# rectangle.move_to(feature_map.get_center())
|
||||
# Rotate so it is in the yz plane
|
||||
# Rotate the rectangle
|
||||
rectangle.rotate(
|
||||
90 * DEGREES,
|
||||
ThreeDLayer.three_d_x_rotation,
|
||||
about_point=rectangle.get_center(),
|
||||
axis=[1, 0, 0]
|
||||
)
|
||||
rectangle.rotate(
|
||||
ThreeDLayer.three_d_y_rotation,
|
||||
about_point=rectangle.get_center(),
|
||||
axis=[0, 1, 0]
|
||||
)
|
||||
# Get the feature map top left corner
|
||||
feature_map_top_left = feature_map.get_corners_dict(inner_rectangle=True)["top_left"]
|
||||
rectangle_top_left = rectangle.get_corners_dict()["top_left"]
|
||||
# Move the rectangle to the corner location
|
||||
rectangle.next_to(
|
||||
feature_map_top_left,
|
||||
submobject_to_align=rectangle_top_left,
|
||||
buff=0.0
|
||||
rectangle.move_to(
|
||||
feature_map,
|
||||
aligned_edge=np.array([-1, 1, 0])
|
||||
)
|
||||
|
||||
rectangles.append(rectangle)
|
||||
|
||||
feature_map_rectangles = VGroup(*rectangles)
|
||||
@ -202,6 +208,9 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
self.filter_opacity = filter_opacity
|
||||
self.line_color = line_color
|
||||
self.pulse_color = pulse_color
|
||||
# Make filters
|
||||
self.filters = Filters(self.input_layer, self.output_layer)
|
||||
self.add(self.filters)
|
||||
|
||||
def make_filter_propagation_animation(self):
|
||||
"""Make filter propagation animation"""
|
||||
@ -219,32 +228,71 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
|
||||
return animation_group
|
||||
|
||||
def get_rotated_shift_vectors(self):
|
||||
"""
|
||||
Rotates the shift vectors
|
||||
"""
|
||||
x_rot_mat = rotation_matrix(
|
||||
ThreeDLayer.three_d_x_rotation,
|
||||
[1, 0, 0]
|
||||
)
|
||||
y_rot_mat = rotation_matrix(
|
||||
ThreeDLayer.three_d_y_rotation,
|
||||
[0, 1, 0]
|
||||
)
|
||||
# Make base shift vectors
|
||||
right_shift = np.array([self.input_layer.cell_width, 0, 0])
|
||||
down_shift = np.array([0, -self.input_layer.cell_width, 0])
|
||||
# Rotate the vectors
|
||||
right_shift = np.dot(right_shift, x_rot_mat.T)
|
||||
right_shift = np.dot(right_shift, y_rot_mat.T)
|
||||
down_shift = np.dot(down_shift, x_rot_mat.T)
|
||||
down_shift = np.dot(down_shift, y_rot_mat.T)
|
||||
|
||||
return right_shift, down_shift
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=10.5, **kwargs):
|
||||
"""Forward pass animation from conv2d to conv2d"""
|
||||
|
||||
animations = []
|
||||
# Create the filters, output nodes (feature map square), and lines
|
||||
filters = Filters(self.input_layer, self.output_layer)
|
||||
self.add(filters)
|
||||
# Rotate given three_d_phi and three_d_theta
|
||||
# Rotate about center
|
||||
# filters.rotate(110 * DEGREES, about_point=filters.get_center(), axis=[0, 0, 1])
|
||||
"""
|
||||
self.filters.rotate(
|
||||
ThreeDLayer.three_d_x_rotation,
|
||||
about_point=self.filters.get_center(),
|
||||
axis=[1, 0, 0]
|
||||
)
|
||||
|
||||
self.filters.rotate(
|
||||
ThreeDLayer.three_d_y_rotation,
|
||||
about_point=self.filters.get_center(),
|
||||
axis=[0, 1, 0]
|
||||
)
|
||||
"""
|
||||
# Get shift vectors
|
||||
right_shift, down_shift = self.get_rotated_shift_vectors()
|
||||
left_shift = -1 * right_shift
|
||||
# filters.rotate(ThreeDLayer.three_d_theta, axis=[0, 0, 1])
|
||||
# filters.rotate(ThreeDLayer.three_d_phi, axis=-filters.get_center())
|
||||
# Make animations for creating the filters, output_nodes, and filter_lines
|
||||
# TODO decide if I want to create the filters at the start of a conv
|
||||
# animation or have them there by default
|
||||
# animations.append(
|
||||
# Create(filters)
|
||||
# )
|
||||
# Make shift amounts
|
||||
right_shift = np.array([0, self.input_layer.cell_width, 0])# * 1.55
|
||||
left_shift = np.array([0, -1*self.input_layer.cell_width, 0])# * 1.55
|
||||
up_shift = np.array([0, 0, -1*self.input_layer.cell_width])# * 1.55
|
||||
down_shift = np.array([0, 0, self.input_layer.cell_width])# * 1.55
|
||||
# Rotate the base shift vectors
|
||||
# Make filter shifting animations
|
||||
num_y_moves = int((self.feature_map_height - self.filter_height) / self.stride)
|
||||
num_x_moves = int((self.feature_map_width - self.filter_width) / self.stride)
|
||||
for y_move in range(num_y_moves):
|
||||
# Go right num_x_moves
|
||||
for x_move in range(num_x_moves):
|
||||
print(right_shift)
|
||||
# Shift right
|
||||
shift_animation = ApplyMethod(
|
||||
filters.shift,
|
||||
self.filters.shift,
|
||||
self.stride * right_shift
|
||||
)
|
||||
# shift_animation = self.animate.shift(right_shift)
|
||||
@ -254,7 +302,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
shift_amount = self.stride * num_x_moves * left_shift + self.stride * down_shift
|
||||
# Make the animation
|
||||
shift_animation = ApplyMethod(
|
||||
filters.shift,
|
||||
self.filters.shift,
|
||||
shift_amount
|
||||
)
|
||||
animations.append(shift_animation)
|
||||
@ -262,7 +310,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
for x_move in range(num_x_moves):
|
||||
# Shift right
|
||||
shift_animation = ApplyMethod(
|
||||
filters.shift,
|
||||
self.filters.shift,
|
||||
self.stride * right_shift
|
||||
)
|
||||
# shift_animation = self.animate.shift(right_shift)
|
||||
|
Reference in New Issue
Block a user