Working max pooling visualization.

This commit is contained in:
Alec Helbling
2023-01-26 18:39:38 -05:00
parent 11d39a34e5
commit 46958ea293
4 changed files with 162 additions and 75 deletions

View File

@ -0,0 +1,72 @@
from manim import *
from PIL import Image
import numpy as np
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.image import ImageLayer
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
from manim_ml.neural_network.neural_network import NeuralNetwork
# Make the specific scene
config.pixel_height = 1200
config.pixel_width = 1900
config.frame_height = 6.0
config.frame_width = 6.0
def make_code_snippet():
code_str = """
# Make the neural network
nn = NeuralNetwork([
ImageLayer(image),
Convolutional2DLayer(1, 8),
MaxPooling2DLayer(kernel_size=2),
Convolutional2DLayer(3, 2, 3),
])
# Play the animation
self.play(nn.make_forward_pass_animation())
"""
code = Code(
code=code_str,
tab_width=4,
background_stroke_width=1,
background_stroke_color=WHITE,
insert_line_no=False,
style="monokai",
font="Monospace",
background="window",
language="py",
)
code.scale(0.4)
return code
class CombinedScene(ThreeDScene):
def construct(self):
image = Image.open("../../assets/mnist/digit.jpeg")
numpy_image = np.asarray(image)
# Make nn
nn = NeuralNetwork([
ImageLayer(numpy_image, height=1.5),
Convolutional2DLayer(1, 8, filter_spacing=0.32),
MaxPooling2DLayer(kernel_size=2),
Convolutional2DLayer(3, 2, 3, filter_spacing=0.32),
],
layer_spacing=0.25,
)
# Center the nn
nn.move_to(ORIGIN)
self.add(nn)
# Make code snippet
code = make_code_snippet()
code.next_to(nn, DOWN)
Group(code, nn).move_to(ORIGIN)
self.add(code)
self.wait(5)
# Play animation
forward_pass = nn.make_forward_pass_animation(
corner_pulses=False, all_filters_at_once=False
)
self.wait(1)
self.play(forward_pass)

View File

@ -45,6 +45,7 @@ class GriddedRectangle(VGroup):
stroke_width=stroke_width,
fill_color=color,
fill_opacity=fill_opacity,
shade_in_3d=True
)
self.add(self.rectangle)
# Make grid lines
@ -94,6 +95,7 @@ class GriddedRectangle(VGroup):
stroke_color=self.grid_stroke_color,
stroke_width=self.grid_stroke_width,
stroke_opacity=self.grid_stroke_opacity,
shade_in_3d=True
)
for i in range(1, count)
)

View File

@ -10,6 +10,22 @@ from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeD
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
class Uncreate(Create):
def __init__(
self,
mobject,
reverse_rate_function: bool = True,
introducer: bool = True,
remover: bool = True,
**kwargs,
) -> None:
super().__init__(
mobject,
reverse_rate_function=reverse_rate_function,
introducer=introducer,
remover=remover,
**kwargs,
)
class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
"""Feed Forward to Embedding Layer"""
@ -42,17 +58,10 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
kernel_size = self.output_layer.kernel_size
feature_maps = self.input_layer.feature_maps
grid_stroke_width = 1.0
# Get the normalized shift vectors for the convolutional layer
"""
right_shift, down_shift = get_rotated_shift_vectors(
self.input_layer,
normalized=True
)
"""
# Make all of the kernel gridded rectangles
create_gridded_rectangle_animations = []
create_and_remove_cell_animations = []
move_and_resize_gridded_rectangle_animations = []
transform_gridded_rectangle_animations = []
remove_gridded_rectangle_animations = []
for feature_map_index, feature_map in enumerate(feature_maps):
@ -68,6 +77,7 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
grid_stroke_color=self.active_color,
show_grid_lines=True,
)
gridded_rectangle.set_z_index(10)
# 2. Randomly highlight one of the cells in the kernel.
highlighted_cells = []
num_cells_in_kernel = kernel_size * kernel_size
@ -82,8 +92,9 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
color=self.active_color,
height=cell_width,
width=cell_width,
stroke_width=0.0,
fill_opacity=0.7,
stroke_width=0.0,
z_index=10
)
# Move to the correct location
kernel_shift_vector = [
@ -108,102 +119,103 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
highlighted_cells.append(cell_rectangle)
# Rotate the gridded rectangles so they match the angle
# of the conv maps
gridded_rectangle_group = VGroup(gridded_rectangle, *highlighted_cells)
gridded_rectangle_group = VGroup(
gridded_rectangle,
*highlighted_cells
)
gridded_rectangle_group.rotate(
ThreeDLayer.rotation_angle,
about_point=gridded_rectangle.get_center(),
axis=ThreeDLayer.rotation_axis,
)
gridded_rectangle.next_to(
gridded_rectangle_group.next_to(
feature_map.get_corners_dict()["top_left"],
submobject_to_align=gridded_rectangle.get_corners_dict()["top_left"],
buff=0.0,
)
# 3. Make a create gridded rectangle
"""
create_rectangle = Create(
gridded_rectangle
gridded_rectangle,
)
create_gridded_rectangle_animations.append(
create_rectangle
)
def add_grid_lines(rectangle):
rectangle.color=self.active_color
rectangle.height=cell_width * feature_map_size[1]
rectangle.width=cell_width * feature_map_size[0]
rectangle.grid_xstep=cell_width * kernel_size
rectangle.grid_ystep=cell_width * kernel_size
rectangle.grid_stroke_width=grid_stroke_width
rectangle.grid_stroke_color=self.active_color
rectangle.show_grid_lines=True
return rectangle
create_gridded_rectangle_animations.append(
ApplyFunction(
add_grid_lines,
gridded_rectangle
)
)
"""
# 4. Create and fade out highlighted cells
# highlighted_cells_group = VGroup()
# NOTE: Another workaround that is hacky
# See convolution_2d_to_convolution_2d Filters Create Override for
# more information
"""
def add_highlighted_cells(object):
for cell in highlighted_cells:
object.add(
cell
)
return object
create_and_remove_cell_animation = Succession(
ApplyFunction(add_highlighted_cells, highlighted_cells_group),
Wait(0.5),
FadeOut(highlighted_cells_group),
create_group = AnimationGroup(
*[Create(highlighted_cell) for highlighted_cell in highlighted_cells],
lag_ratio=1.0
)
uncreate_group = AnimationGroup(
*[Uncreate(highlighted_cell) for highlighted_cell in highlighted_cells],
lag_ratio=0.0
)
create_and_remove_cell_animation = Succession(
create_group,
Wait(1.0),
uncreate_group
)
create_and_remove_cell_animations.append(
create_and_remove_cell_animation
)
"""
create_and_remove_cell_animations = Succession(
Create(VGroup(*highlighted_cells)),
Wait(0.5),
Uncreate(VGroup(*highlighted_cells)),
)
return create_and_remove_cell_animations
# 5. Move and resize the gridded rectangle to the output
# feature maps.
resize_rectangle = Transform(
gridded_rectangle, self.output_layer.feature_maps[feature_map_index]
output_gridded_rectangle = GriddedRectangle(
color=self.active_color,
height=cell_width * feature_map_size[1] / 2,
width=cell_width * feature_map_size[0] / 2,
grid_xstep=cell_width,
grid_ystep=cell_width,
grid_stroke_width=grid_stroke_width,
grid_stroke_color=self.active_color,
show_grid_lines=True,
)
move_rectangle = gridded_rectangle.animate.move_to(
self.output_layer.feature_maps[feature_map_index]
output_gridded_rectangle.rotate(
ThreeDLayer.rotation_angle,
about_point=output_gridded_rectangle.get_center(),
axis=ThreeDLayer.rotation_axis,
)
move_and_resize = Succession(
resize_rectangle, move_rectangle, lag_ratio=0.0
output_gridded_rectangle.move_to(
self.output_layer.feature_maps[feature_map_index].copy()
)
move_and_resize_gridded_rectangle_animations.append(move_and_resize)
transform_rectangle = ReplacementTransform(
gridded_rectangle, output_gridded_rectangle,
introducer=True,
remover=True
)
transform_gridded_rectangle_animations.append(
transform_rectangle,
)
"""
Succession(
Uncreate(gridded_rectangle),
transform_rectangle,
lag_ratio=1.0
)
"""
# 6. Make the gridded feature map(s) disappear.
remove_gridded_rectangle_animations.append(
Uncreate(gridded_rectangle_group)
)
"""
AnimationGroup(
*move_and_resize_gridded_rectangle_animations
),
"""
create_gridded_rectangle_animation = AnimationGroup(
*create_gridded_rectangle_animations
)
create_and_remove_cell_animation = AnimationGroup(
*create_and_remove_cell_animations
)
transform_gridded_rectangle_animation = AnimationGroup(
*transform_gridded_rectangle_animations
)
remove_gridded_rectangle_animation = AnimationGroup(
*remove_gridded_rectangle_animations
)
return Succession(
# *create_gridded_rectangle_animations,
create_and_remove_cell_animations,
# AnimationGroup(
# *remove_gridded_rectangle_animations
# ),
# lag_ratio=1.0
create_gridded_rectangle_animation,
Wait(1),
create_and_remove_cell_animation,
transform_gridded_rectangle_animation,
Wait(1),
remove_gridded_rectangle_animation,
lag_ratio=1.0,
)

View File

@ -22,8 +22,9 @@ class CombinedScene(ThreeDScene):
nn = NeuralNetwork([
ImageLayer(numpy_image, height=1.5),
Convolutional2DLayer(1, 8, filter_spacing=0.32),
Convolutional2DLayer(3, 6, 3, filter_spacing=0.32),
MaxPooling2DLayer(kernel_size=2),
Convolutional2DLayer(3, 3, 2, filter_spacing=0.32),
Convolutional2DLayer(5, 2, 2, filter_spacing=0.32),
],
layer_spacing=0.25,
)