from manim import * from manim_ml.neural_network.layers.parent_layers import ThreeDLayer, VGroupNeuralNetworkLayer from manim_ml.gridded_rectangle import GriddedRectangle import numpy as np class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer): """Handles rendering a convolutional layer for a nn""" def __init__(self, num_feature_maps, feature_map_width, feature_map_height, filter_width, filter_height, cell_width=0.2, filter_spacing=0.1, color=BLUE, pulse_color=ORANGE, filter_color=ORANGE, stride=1, stroke_width=2.0, **kwargs): super().__init__(**kwargs) self.num_feature_maps = num_feature_maps self.feature_map_height = feature_map_height self.filter_color = filter_color self.feature_map_width = feature_map_width self.filter_width = filter_width self.filter_height = filter_height self.cell_width = cell_width self.filter_spacing = filter_spacing self.color = color self.pulse_color = pulse_color self.stride = stride self.stroke_width = stroke_width # Make the feature maps self.feature_maps = self.construct_feature_maps() self.add(self.feature_maps) # Rotate stuff properly self.rotate( ThreeDLayer.three_d_x_rotation, about_point=self.get_center(), axis=[1, 0, 0] ) self.rotate( ThreeDLayer.three_d_y_rotation, about_point=self.get_center(), axis=[0, 1, 0] ) def construct_feature_maps(self): """Creates the neural network layer""" # Draw rectangles that are filled in with opacity feature_maps = [] for filter_index in range(self.num_feature_maps): rectangle = GriddedRectangle( color=self.color, height=self.feature_map_height * self.cell_width, width=self.feature_map_width * self.cell_width, fill_color=self.color, fill_opacity=0.2, stroke_color=self.color, stroke_width=self.stroke_width, # grid_xstep=self.cell_width, # grid_ystep=self.cell_width, # grid_stroke_width=DEFAULT_STROKE_WIDTH/2 ) # Move the feature map rectangle.move_to( [0, 0, filter_index * self.filter_spacing] ) feature_maps.append(rectangle) return VGroup(*feature_maps) def make_forward_pass_animation( self, run_time=5, corner_pulses=False, layer_args={}, **kwargs ): """Convolution forward pass animation""" # Note: most of this animation is done in the Convolution3DToConvolution3D layer print(f"Corner pulses: {corner_pulses}") if corner_pulses: raise NotImplementedError() passing_flashes = [] for line in self.corner_lines: pulse = ShowPassingFlash( line.copy() .set_color(self.pulse_color) .set_stroke(opacity=1.0), time_width=0.5, run_time=run_time, rate_func=rate_functions.linear ) passing_flashes.append(pulse) # per_filter_run_time = run_time / len(self.feature_maps) # Make animation group animation_group = AnimationGroup( *passing_flashes, # filter_flashes ) else: animation_group = AnimationGroup() return animation_group def scale(self, scale_factor, **kwargs): self.cell_width *= scale_factor super().scale(scale_factor, **kwargs) @override_animation(Create) def _create_override(self, **kwargs): return FadeIn(self.feature_maps)