diff --git a/examples/cnn/cnn_max_pool.py b/examples/cnn/cnn_max_pool.py new file mode 100644 index 0000000..3faf358 --- /dev/null +++ b/examples/cnn/cnn_max_pool.py @@ -0,0 +1,72 @@ +from manim import * +from PIL import Image +import numpy as np + +from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer +from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer +from manim_ml.neural_network.layers.image import ImageLayer +from manim_ml.neural_network.layers.max_pooling_2d import MaxPooling2DLayer +from manim_ml.neural_network.neural_network import NeuralNetwork + +# Make the specific scene +config.pixel_height = 1200 +config.pixel_width = 1900 +config.frame_height = 6.0 +config.frame_width = 6.0 + +def make_code_snippet(): + code_str = """ + # Make the neural network + nn = NeuralNetwork([ + ImageLayer(image), + Convolutional2DLayer(1, 8), + MaxPooling2DLayer(kernel_size=2), + Convolutional2DLayer(3, 2, 3), + ]) + # Play the animation + self.play(nn.make_forward_pass_animation()) + """ + + code = Code( + code=code_str, + tab_width=4, + background_stroke_width=1, + background_stroke_color=WHITE, + insert_line_no=False, + style="monokai", + font="Monospace", + background="window", + language="py", + ) + code.scale(0.4) + + return code + +class CombinedScene(ThreeDScene): + def construct(self): + image = Image.open("../../assets/mnist/digit.jpeg") + numpy_image = np.asarray(image) + # Make nn + nn = NeuralNetwork([ + ImageLayer(numpy_image, height=1.5), + Convolutional2DLayer(1, 8, filter_spacing=0.32), + MaxPooling2DLayer(kernel_size=2), + Convolutional2DLayer(3, 2, 3, filter_spacing=0.32), + ], + layer_spacing=0.25, + ) + # Center the nn + nn.move_to(ORIGIN) + self.add(nn) + # Make code snippet + code = make_code_snippet() + code.next_to(nn, DOWN) + Group(code, nn).move_to(ORIGIN) + self.add(code) + self.wait(5) + # Play animation + forward_pass = nn.make_forward_pass_animation( + corner_pulses=False, all_filters_at_once=False + ) + self.wait(1) + self.play(forward_pass) \ No newline at end of file diff --git a/manim_ml/gridded_rectangle.py b/manim_ml/gridded_rectangle.py index 35228da..29123bf 100644 --- a/manim_ml/gridded_rectangle.py +++ b/manim_ml/gridded_rectangle.py @@ -45,6 +45,7 @@ class GriddedRectangle(VGroup): stroke_width=stroke_width, fill_color=color, fill_opacity=fill_opacity, + shade_in_3d=True ) self.add(self.rectangle) # Make grid lines @@ -94,6 +95,7 @@ class GriddedRectangle(VGroup): stroke_color=self.grid_stroke_color, stroke_width=self.grid_stroke_width, stroke_opacity=self.grid_stroke_opacity, + shade_in_3d=True ) for i in range(1, count) ) diff --git a/manim_ml/neural_network/layers/convolutional_2d_to_max_pooling_2d.py b/manim_ml/neural_network/layers/convolutional_2d_to_max_pooling_2d.py index 3e4e4ec..adc40d9 100644 --- a/manim_ml/neural_network/layers/convolutional_2d_to_max_pooling_2d.py +++ b/manim_ml/neural_network/layers/convolutional_2d_to_max_pooling_2d.py @@ -10,6 +10,22 @@ from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeD from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer +class Uncreate(Create): + def __init__( + self, + mobject, + reverse_rate_function: bool = True, + introducer: bool = True, + remover: bool = True, + **kwargs, + ) -> None: + super().__init__( + mobject, + reverse_rate_function=reverse_rate_function, + introducer=introducer, + remover=remover, + **kwargs, + ) class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer): """Feed Forward to Embedding Layer""" @@ -42,17 +58,10 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer): kernel_size = self.output_layer.kernel_size feature_maps = self.input_layer.feature_maps grid_stroke_width = 1.0 - # Get the normalized shift vectors for the convolutional layer - """ - right_shift, down_shift = get_rotated_shift_vectors( - self.input_layer, - normalized=True - ) - """ # Make all of the kernel gridded rectangles create_gridded_rectangle_animations = [] create_and_remove_cell_animations = [] - move_and_resize_gridded_rectangle_animations = [] + transform_gridded_rectangle_animations = [] remove_gridded_rectangle_animations = [] for feature_map_index, feature_map in enumerate(feature_maps): @@ -68,6 +77,7 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer): grid_stroke_color=self.active_color, show_grid_lines=True, ) + gridded_rectangle.set_z_index(10) # 2. Randomly highlight one of the cells in the kernel. highlighted_cells = [] num_cells_in_kernel = kernel_size * kernel_size @@ -82,8 +92,9 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer): color=self.active_color, height=cell_width, width=cell_width, - stroke_width=0.0, fill_opacity=0.7, + stroke_width=0.0, + z_index=10 ) # Move to the correct location kernel_shift_vector = [ @@ -108,102 +119,103 @@ class Convolutional2DToMaxPooling2D(ConnectiveLayer, ThreeDLayer): highlighted_cells.append(cell_rectangle) # Rotate the gridded rectangles so they match the angle # of the conv maps - gridded_rectangle_group = VGroup(gridded_rectangle, *highlighted_cells) + gridded_rectangle_group = VGroup( + gridded_rectangle, + *highlighted_cells + ) gridded_rectangle_group.rotate( ThreeDLayer.rotation_angle, about_point=gridded_rectangle.get_center(), axis=ThreeDLayer.rotation_axis, ) - gridded_rectangle.next_to( + gridded_rectangle_group.next_to( feature_map.get_corners_dict()["top_left"], submobject_to_align=gridded_rectangle.get_corners_dict()["top_left"], buff=0.0, ) # 3. Make a create gridded rectangle - """ create_rectangle = Create( - gridded_rectangle + gridded_rectangle, ) create_gridded_rectangle_animations.append( create_rectangle ) - def add_grid_lines(rectangle): - rectangle.color=self.active_color - rectangle.height=cell_width * feature_map_size[1] - rectangle.width=cell_width * feature_map_size[0] - rectangle.grid_xstep=cell_width * kernel_size - rectangle.grid_ystep=cell_width * kernel_size - rectangle.grid_stroke_width=grid_stroke_width - rectangle.grid_stroke_color=self.active_color - rectangle.show_grid_lines=True - - return rectangle - - create_gridded_rectangle_animations.append( - ApplyFunction( - add_grid_lines, - gridded_rectangle - ) - ) - """ # 4. Create and fade out highlighted cells - # highlighted_cells_group = VGroup() - # NOTE: Another workaround that is hacky - # See convolution_2d_to_convolution_2d Filters Create Override for - # more information - """ - def add_highlighted_cells(object): - for cell in highlighted_cells: - object.add( - cell - ) - - return object - - create_and_remove_cell_animation = Succession( - ApplyFunction(add_highlighted_cells, highlighted_cells_group), - Wait(0.5), - FadeOut(highlighted_cells_group), + create_group = AnimationGroup( + *[Create(highlighted_cell) for highlighted_cell in highlighted_cells], lag_ratio=1.0 ) + uncreate_group = AnimationGroup( + *[Uncreate(highlighted_cell) for highlighted_cell in highlighted_cells], + lag_ratio=0.0 + ) + create_and_remove_cell_animation = Succession( + create_group, + Wait(1.0), + uncreate_group + ) create_and_remove_cell_animations.append( create_and_remove_cell_animation ) - """ - create_and_remove_cell_animations = Succession( - Create(VGroup(*highlighted_cells)), - Wait(0.5), - Uncreate(VGroup(*highlighted_cells)), - ) - return create_and_remove_cell_animations # 5. Move and resize the gridded rectangle to the output # feature maps. - resize_rectangle = Transform( - gridded_rectangle, self.output_layer.feature_maps[feature_map_index] + output_gridded_rectangle = GriddedRectangle( + color=self.active_color, + height=cell_width * feature_map_size[1] / 2, + width=cell_width * feature_map_size[0] / 2, + grid_xstep=cell_width, + grid_ystep=cell_width, + grid_stroke_width=grid_stroke_width, + grid_stroke_color=self.active_color, + show_grid_lines=True, ) - move_rectangle = gridded_rectangle.animate.move_to( - self.output_layer.feature_maps[feature_map_index] + output_gridded_rectangle.rotate( + ThreeDLayer.rotation_angle, + about_point=output_gridded_rectangle.get_center(), + axis=ThreeDLayer.rotation_axis, ) - move_and_resize = Succession( - resize_rectangle, move_rectangle, lag_ratio=0.0 + output_gridded_rectangle.move_to( + self.output_layer.feature_maps[feature_map_index].copy() ) - move_and_resize_gridded_rectangle_animations.append(move_and_resize) + transform_rectangle = ReplacementTransform( + gridded_rectangle, output_gridded_rectangle, + introducer=True, + remover=True + ) + transform_gridded_rectangle_animations.append( + transform_rectangle, + ) + """ + Succession( + Uncreate(gridded_rectangle), + transform_rectangle, + lag_ratio=1.0 + ) + """ # 6. Make the gridded feature map(s) disappear. remove_gridded_rectangle_animations.append( Uncreate(gridded_rectangle_group) ) - """ - AnimationGroup( - *move_and_resize_gridded_rectangle_animations - ), - """ + create_gridded_rectangle_animation = AnimationGroup( + *create_gridded_rectangle_animations + ) + create_and_remove_cell_animation = AnimationGroup( + *create_and_remove_cell_animations + ) + transform_gridded_rectangle_animation = AnimationGroup( + *transform_gridded_rectangle_animations + ) + remove_gridded_rectangle_animation = AnimationGroup( + *remove_gridded_rectangle_animations + ) + return Succession( - # *create_gridded_rectangle_animations, - create_and_remove_cell_animations, - # AnimationGroup( - # *remove_gridded_rectangle_animations - # ), - # lag_ratio=1.0 + create_gridded_rectangle_animation, + Wait(1), + create_and_remove_cell_animation, + transform_gridded_rectangle_animation, + Wait(1), + remove_gridded_rectangle_animation, lag_ratio=1.0, ) diff --git a/tests/test_max_pool.py b/tests/test_max_pool.py index 7c51b34..db97eea 100644 --- a/tests/test_max_pool.py +++ b/tests/test_max_pool.py @@ -22,8 +22,9 @@ class CombinedScene(ThreeDScene): nn = NeuralNetwork([ ImageLayer(numpy_image, height=1.5), Convolutional2DLayer(1, 8, filter_spacing=0.32), + Convolutional2DLayer(3, 6, 3, filter_spacing=0.32), MaxPooling2DLayer(kernel_size=2), - Convolutional2DLayer(3, 3, 2, filter_spacing=0.32), + Convolutional2DLayer(5, 2, 2, filter_spacing=0.32), ], layer_spacing=0.25, )