mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-24 02:20:20 +08:00
Working max pooling visualization.
This commit is contained in:
72
examples/cnn/cnn_max_pool.py
Normal file
72
examples/cnn/cnn_max_pool.py
Normal file
@ -0,0 +1,72 @@
|
||||
from manim import *
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.image import ImageLayer
|
||||
from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer
|
||||
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||
|
||||
# Make the specific scene
|
||||
config.pixel_height = 1200
|
||||
config.pixel_width = 1900
|
||||
config.frame_height = 6.0
|
||||
config.frame_width = 6.0
|
||||
|
||||
def make_code_snippet():
|
||||
code_str = """
|
||||
# Make the neural network
|
||||
nn = NeuralNetwork([
|
||||
ImageLayer(image),
|
||||
Convolutional2DLayer(1, 8),
|
||||
MaxPooling2DLayer(kernel_size=2),
|
||||
Convolutional2DLayer(3, 2, 3),
|
||||
])
|
||||
# Play the animation
|
||||
self.play(nn.make_forward_pass_animation())
|
||||
"""
|
||||
|
||||
code = Code(
|
||||
code=code_str,
|
||||
tab_width=4,
|
||||
background_stroke_width=1,
|
||||
background_stroke_color=WHITE,
|
||||
insert_line_no=False,
|
||||
style="monokai",
|
||||
font="Monospace",
|
||||
background="window",
|
||||
language="py",
|
||||
)
|
||||
code.scale(0.4)
|
||||
|
||||
return code
|
||||
|
||||
class CombinedScene(ThreeDScene):
|
||||
def construct(self):
|
||||
image = Image.open("../../assets/mnist/digit.jpeg")
|
||||
numpy_image = np.asarray(image)
|
||||
# Make nn
|
||||
nn = NeuralNetwork([
|
||||
ImageLayer(numpy_image, height=1.5),
|
||||
Convolutional2DLayer(1, 8, filter_spacing=0.32),
|
||||
MaxPooling2DLayer(kernel_size=2),
|
||||
Convolutional2DLayer(3, 2, 3, filter_spacing=0.32),
|
||||
],
|
||||
layer_spacing=0.25,
|
||||
)
|
||||
# Center the nn
|
||||
nn.move_to(ORIGIN)
|
||||
self.add(nn)
|
||||
# Make code snippet
|
||||
code = make_code_snippet()
|
||||
code.next_to(nn, DOWN)
|
||||
Group(code, nn).move_to(ORIGIN)
|
||||
self.add(code)
|
||||
self.wait(5)
|
||||
# Play animation
|
||||
forward_pass = nn.make_forward_pass_animation(
|
||||
corner_pulses=False, all_filters_at_once=False
|
||||
)
|
||||
self.wait(1)
|
||||
self.play(forward_pass)
|
@ -45,6 +45,7 @@ class GriddedRectangle(VGroup):
|
||||
stroke_width=stroke_width,
|
||||
fill_color=color,
|
||||
fill_opacity=fill_opacity,
|
||||
shade_in_3d=True
|
||||
)
|
||||
self.add(self.rectangle)
|
||||
# Make grid lines
|
||||
@ -94,6 +95,7 @@ class GriddedRectangle(VGroup):
|
||||
stroke_color=self.grid_stroke_color,
|
||||
stroke_width=self.grid_stroke_width,
|
||||
stroke_opacity=self.grid_stroke_opacity,
|
||||
shade_in_3d=True
|
||||
)
|
||||
for i in range(1, count)
|
||||
)
|
||||
|
@ -10,6 +10,22 @@ from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeD
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
|
||||
|
||||
class Uncreate(Create):
|
||||
def __init__(
|
||||
self,
|
||||
mobject,
|
||||
reverse_rate_function: bool = True,
|
||||
introducer: bool = True,
|
||||
remover: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
mobject,
|
||||
reverse_rate_function=reverse_rate_function,
|
||||
introducer=introducer,
|
||||
remover=remover,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
@ -42,17 +58,10 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
kernel_size = self.output_layer.kernel_size
|
||||
feature_maps = self.input_layer.feature_maps
|
||||
grid_stroke_width = 1.0
|
||||
# Get the normalized shift vectors for the convolutional layer
|
||||
"""
|
||||
right_shift, down_shift = get_rotated_shift_vectors(
|
||||
self.input_layer,
|
||||
normalized=True
|
||||
)
|
||||
"""
|
||||
# Make all of the kernel gridded rectangles
|
||||
create_gridded_rectangle_animations = []
|
||||
create_and_remove_cell_animations = []
|
||||
move_and_resize_gridded_rectangle_animations = []
|
||||
transform_gridded_rectangle_animations = []
|
||||
remove_gridded_rectangle_animations = []
|
||||
|
||||
for feature_map_index, feature_map in enumerate(feature_maps):
|
||||
@ -68,6 +77,7 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
grid_stroke_color=self.active_color,
|
||||
show_grid_lines=True,
|
||||
)
|
||||
gridded_rectangle.set_z_index(10)
|
||||
# 2. Randomly highlight one of the cells in the kernel.
|
||||
highlighted_cells = []
|
||||
num_cells_in_kernel = kernel_size * kernel_size
|
||||
@ -82,8 +92,9 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
color=self.active_color,
|
||||
height=cell_width,
|
||||
width=cell_width,
|
||||
stroke_width=0.0,
|
||||
fill_opacity=0.7,
|
||||
stroke_width=0.0,
|
||||
z_index=10
|
||||
)
|
||||
# Move to the correct location
|
||||
kernel_shift_vector = [
|
||||
@ -108,102 +119,103 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer):
|
||||
highlighted_cells.append(cell_rectangle)
|
||||
# Rotate the gridded rectangles so they match the angle
|
||||
# of the conv maps
|
||||
gridded_rectangle_group = VGroup(gridded_rectangle, *highlighted_cells)
|
||||
gridded_rectangle_group = VGroup(
|
||||
gridded_rectangle,
|
||||
*highlighted_cells
|
||||
)
|
||||
gridded_rectangle_group.rotate(
|
||||
ThreeDLayer.rotation_angle,
|
||||
about_point=gridded_rectangle.get_center(),
|
||||
axis=ThreeDLayer.rotation_axis,
|
||||
)
|
||||
gridded_rectangle.next_to(
|
||||
gridded_rectangle_group.next_to(
|
||||
feature_map.get_corners_dict()["top_left"],
|
||||
submobject_to_align=gridded_rectangle.get_corners_dict()["top_left"],
|
||||
buff=0.0,
|
||||
)
|
||||
# 3. Make a create gridded rectangle
|
||||
"""
|
||||
create_rectangle = Create(
|
||||
gridded_rectangle
|
||||
gridded_rectangle,
|
||||
)
|
||||
create_gridded_rectangle_animations.append(
|
||||
create_rectangle
|
||||
)
|
||||
def add_grid_lines(rectangle):
|
||||
rectangle.color=self.active_color
|
||||
rectangle.height=cell_width * feature_map_size[1]
|
||||
rectangle.width=cell_width * feature_map_size[0]
|
||||
rectangle.grid_xstep=cell_width * kernel_size
|
||||
rectangle.grid_ystep=cell_width * kernel_size
|
||||
rectangle.grid_stroke_width=grid_stroke_width
|
||||
rectangle.grid_stroke_color=self.active_color
|
||||
rectangle.show_grid_lines=True
|
||||
|
||||
return rectangle
|
||||
|
||||
create_gridded_rectangle_animations.append(
|
||||
ApplyFunction(
|
||||
add_grid_lines,
|
||||
gridded_rectangle
|
||||
)
|
||||
)
|
||||
"""
|
||||
# 4. Create and fade out highlighted cells
|
||||
# highlighted_cells_group = VGroup()
|
||||
# NOTE: Another workaround that is hacky
|
||||
# See convolution_2d_to_convolution_2d Filters Create Override for
|
||||
# more information
|
||||
"""
|
||||
def add_highlighted_cells(object):
|
||||
for cell in highlighted_cells:
|
||||
object.add(
|
||||
cell
|
||||
)
|
||||
|
||||
return object
|
||||
|
||||
create_and_remove_cell_animation = Succession(
|
||||
ApplyFunction(add_highlighted_cells, highlighted_cells_group),
|
||||
Wait(0.5),
|
||||
FadeOut(highlighted_cells_group),
|
||||
create_group = AnimationGroup(
|
||||
*[Create(highlighted_cell) for highlighted_cell in highlighted_cells],
|
||||
lag_ratio=1.0
|
||||
)
|
||||
uncreate_group = AnimationGroup(
|
||||
*[Uncreate(highlighted_cell) for highlighted_cell in highlighted_cells],
|
||||
lag_ratio=0.0
|
||||
)
|
||||
create_and_remove_cell_animation = Succession(
|
||||
create_group,
|
||||
Wait(1.0),
|
||||
uncreate_group
|
||||
)
|
||||
create_and_remove_cell_animations.append(
|
||||
create_and_remove_cell_animation
|
||||
)
|
||||
"""
|
||||
create_and_remove_cell_animations = Succession(
|
||||
Create(VGroup(*highlighted_cells)),
|
||||
Wait(0.5),
|
||||
Uncreate(VGroup(*highlighted_cells)),
|
||||
)
|
||||
return create_and_remove_cell_animations
|
||||
# 5. Move and resize the gridded rectangle to the output
|
||||
# feature maps.
|
||||
resize_rectangle = Transform(
|
||||
gridded_rectangle, self.output_layer.feature_maps[feature_map_index]
|
||||
output_gridded_rectangle = GriddedRectangle(
|
||||
color=self.active_color,
|
||||
height=cell_width * feature_map_size[1] / 2,
|
||||
width=cell_width * feature_map_size[0] / 2,
|
||||
grid_xstep=cell_width,
|
||||
grid_ystep=cell_width,
|
||||
grid_stroke_width=grid_stroke_width,
|
||||
grid_stroke_color=self.active_color,
|
||||
show_grid_lines=True,
|
||||
)
|
||||
move_rectangle = gridded_rectangle.animate.move_to(
|
||||
self.output_layer.feature_maps[feature_map_index]
|
||||
output_gridded_rectangle.rotate(
|
||||
ThreeDLayer.rotation_angle,
|
||||
about_point=output_gridded_rectangle.get_center(),
|
||||
axis=ThreeDLayer.rotation_axis,
|
||||
)
|
||||
move_and_resize = Succession(
|
||||
resize_rectangle, move_rectangle, lag_ratio=0.0
|
||||
output_gridded_rectangle.move_to(
|
||||
self.output_layer.feature_maps[feature_map_index].copy()
|
||||
)
|
||||
move_and_resize_gridded_rectangle_animations.append(move_and_resize)
|
||||
transform_rectangle = ReplacementTransform(
|
||||
gridded_rectangle, output_gridded_rectangle,
|
||||
introducer=True,
|
||||
remover=True
|
||||
)
|
||||
transform_gridded_rectangle_animations.append(
|
||||
transform_rectangle,
|
||||
)
|
||||
"""
|
||||
Succession(
|
||||
Uncreate(gridded_rectangle),
|
||||
transform_rectangle,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
"""
|
||||
# 6. Make the gridded feature map(s) disappear.
|
||||
remove_gridded_rectangle_animations.append(
|
||||
Uncreate(gridded_rectangle_group)
|
||||
)
|
||||
|
||||
"""
|
||||
AnimationGroup(
|
||||
*move_and_resize_gridded_rectangle_animations
|
||||
),
|
||||
"""
|
||||
create_gridded_rectangle_animation = AnimationGroup(
|
||||
*create_gridded_rectangle_animations
|
||||
)
|
||||
create_and_remove_cell_animation = AnimationGroup(
|
||||
*create_and_remove_cell_animations
|
||||
)
|
||||
transform_gridded_rectangle_animation = AnimationGroup(
|
||||
*transform_gridded_rectangle_animations
|
||||
)
|
||||
remove_gridded_rectangle_animation = AnimationGroup(
|
||||
*remove_gridded_rectangle_animations
|
||||
)
|
||||
|
||||
return Succession(
|
||||
# *create_gridded_rectangle_animations,
|
||||
create_and_remove_cell_animations,
|
||||
# AnimationGroup(
|
||||
# *remove_gridded_rectangle_animations
|
||||
# ),
|
||||
# lag_ratio=1.0
|
||||
create_gridded_rectangle_animation,
|
||||
Wait(1),
|
||||
create_and_remove_cell_animation,
|
||||
transform_gridded_rectangle_animation,
|
||||
Wait(1),
|
||||
remove_gridded_rectangle_animation,
|
||||
lag_ratio=1.0,
|
||||
)
|
||||
|
@ -22,8 +22,9 @@ class CombinedScene(ThreeDScene):
|
||||
nn = NeuralNetwork([
|
||||
ImageLayer(numpy_image, height=1.5),
|
||||
Convolutional2DLayer(1, 8, filter_spacing=0.32),
|
||||
Convolutional2DLayer(3, 6, 3, filter_spacing=0.32),
|
||||
MaxPooling2DLayer(kernel_size=2),
|
||||
Convolutional2DLayer(3, 3, 2, filter_spacing=0.32),
|
||||
Convolutional2DLayer(5, 2, 2, filter_spacing=0.32),
|
||||
],
|
||||
layer_spacing=0.25,
|
||||
)
|
||||
|
Reference in New Issue
Block a user