from manim import * from manim_ml.neural_network.layers.convolutional3d import Convolutional3DLayer from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer from manim_ml.gridded_rectangle import GriddedRectangle from manim.utils.space_ops import rotation_matrix class Filters(VGroup): """Group for showing a collection of filters connecting two layers""" def __init__( self, input_layer, output_layer, line_color=ORANGE, cell_width=1.0, stroke_width=2.0, show_grid_lines=False, output_feature_map_to_connect=None, # None means all at once ): super().__init__() self.input_layer = input_layer self.output_layer = output_layer self.line_color = line_color self.cell_width = cell_width self.stroke_width = stroke_width self.show_grid_lines = show_grid_lines self.output_feature_map_to_connect = output_feature_map_to_connect # Make the filter self.input_rectangles = self.make_input_feature_map_rectangles() # self.input_rectangles.set_z_index(5) # self.add(self.input_rectangles) self.output_rectangles = self.make_output_feature_map_rectangles() # self.output_rectangles.set_z_index(5) # self.add(self.output_rectangles) self.connective_lines = self.make_connective_lines() # self.connective_lines.set_z_index(5) # self.add(self.connective_lines) def make_input_feature_map_rectangles(self): rectangles = [] rectangle_width = self.input_layer.filter_width * self.input_layer.cell_width rectangle_height = self.input_layer.filter_height * self.input_layer.cell_width filter_color = self.input_layer.filter_color for index, feature_map in enumerate(self.input_layer.feature_maps): rectangle = GriddedRectangle( width=rectangle_width, height=rectangle_height, fill_color=filter_color, stroke_color=filter_color, fill_opacity=0.2, stroke_width=self.stroke_width, grid_xstep=self.cell_width, grid_ystep=self.cell_width, grid_stroke_width=self.stroke_width / 2, grid_stroke_color=filter_color, show_grid_lines=self.show_grid_lines, ) # normal_vector = rectangle.get_normal_vector() rectangle.rotate( ThreeDLayer.rotation_angle, about_point=rectangle.get_center(), axis=ThreeDLayer.rotation_axis, ) # Move the rectangle to the corner of the feature map rectangle.next_to( feature_map.get_corners_dict()["top_left"], submobject_to_align=rectangle.get_corners_dict()["top_left"], buff=0.0 # aligned_edge=feature_map.get_corners_dict()["top_left"].get_center() ) rectangle.set_z_index(5) rectangles.append(rectangle) feature_map_rectangles = VGroup(*rectangles) return feature_map_rectangles def make_output_feature_map_rectangles(self): rectangles = [] rectangle_width = self.output_layer.cell_width rectangle_height = self.output_layer.cell_width filter_color = self.output_layer.filter_color for index, feature_map in enumerate(self.output_layer.feature_maps): # Make sure current feature map is the right filte if not self.output_feature_map_to_connect is None: if index != self.output_feature_map_to_connect: continue # Make the rectangle rectangle = GriddedRectangle( width=rectangle_width, height=rectangle_height, fill_color=filter_color, fill_opacity=0.2, stroke_color=filter_color, stroke_width=self.stroke_width, grid_xstep=self.cell_width, grid_ystep=self.cell_width, grid_stroke_width=self.stroke_width / 2, grid_stroke_color=filter_color, show_grid_lines=self.show_grid_lines, ) # Rotate the rectangle rectangle.rotate( ThreeDLayer.rotation_angle, about_point=rectangle.get_center(), axis=ThreeDLayer.rotation_axis, ) # Move the rectangle to the corner location rectangle.next_to( feature_map.get_corners_dict()["top_left"], submobject_to_align=rectangle.get_corners_dict()["top_left"], buff=0.0 # aligned_edge=feature_map.get_corners_dict()["top_left"].get_center() ) rectangles.append(rectangle) feature_map_rectangles = VGroup(*rectangles) return feature_map_rectangles def make_connective_lines(self): """Lines connecting input filter with output node""" corner_names = ["top_left", "bottom_left", "top_right", "bottom_right"] def make_input_connective_lines(): """Makes connective lines between the corners of the input filters""" first_input_rectangle = self.input_rectangles[0] last_input_rectangle = self.input_rectangles[-1] # Get the corner dots for each rectangle first_input_corners = first_input_rectangle.get_corners_dict() last_input_corners = last_input_rectangle.get_corners_dict() # Iterate through each corner and make the lines lines = [] for corner_name in corner_names: line = Line( first_input_corners[corner_name].get_center(), last_input_corners[corner_name].get_center(), color=self.line_color, stroke_width=self.stroke_width, ) lines.append(line) return VGroup(*lines) def make_output_connective_lines(): """Makes connective lines between the corners of the output filters""" first_output_rectangle = self.output_rectangles[0] last_output_rectangle = self.output_rectangles[-1] # Get the corner dots for each rectangle first_output_corners = first_output_rectangle.get_corners_dict() last_output_corners = last_output_rectangle.get_corners_dict() # Iterate through each corner and make the lines lines = [] for corner_name in corner_names: line = Line( first_output_corners[corner_name].get_center(), last_output_corners[corner_name].get_center(), color=self.line_color, stroke_width=self.stroke_width, ) lines.append(line) return VGroup(*lines) def make_input_to_output_connective_lines(): """Make connective lines between last input filter and first output filter""" # Choose the correct feature map to link to input_rectangle = self.input_rectangles[-1] output_rectangle = self.output_rectangles[0] # Get the corner dots for each rectangle input_corners = input_rectangle.get_corners_dict() output_corners = output_rectangle.get_corners_dict() # Iterate through each corner and make the lines lines = [] for corner_name in corner_names: line = Line( input_corners[corner_name].get_center(), output_corners[corner_name].get_center(), color=self.line_color, stroke_width=self.stroke_width, ) lines.append(line) return VGroup(*lines) input_lines = make_input_connective_lines() output_lines = make_output_connective_lines() input_output_lines = make_input_to_output_connective_lines() connective_lines = VGroup(*input_lines, *output_lines, *input_output_lines) return connective_lines @override_animation(Create) def _create_override(self, **kwargs): """ NOTE This create override animation is a workaround to make sure that the filter does not show up in the scene before the create animation. Without this override the filters were shown at the beginning of the neural network forward pass animimation instead of just when the filters were supposed to appear. I think this is a bug with Succession in the core Manim Community Library. TODO Fix this """ def add_content(object): object.add(self.input_rectangles) object.add(self.connective_lines) object.add(self.output_rectangles) return object return ApplyFunction(add_content, self) return AnimationGroup( Create(self.input_rectangles), Create(self.connective_lines), Create(self.output_rectangles), lag_ratio=0.0, ) def make_pulse_animation(self, shift_amount): """Make animation of the filter pulsing""" passing_flash = ShowPassingFlash( self.connective_lines.shift(shift_amount).set_stroke_width( self.stroke_width * 1.5 ), time_width=0.2, color=RED, z_index=10, ) return passing_flash class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer): """Feed Forward to Embedding Layer""" input_class = Convolutional3DLayer output_class = Convolutional3DLayer def __init__( self, input_layer: Convolutional3DLayer, output_layer: Convolutional3DLayer, color=ORANGE, filter_opacity=0.3, line_color=ORANGE, pulse_color=ORANGE, cell_width=0.2, show_grid_lines=True, highlight_color=ORANGE, **kwargs, ): super().__init__( input_layer, output_layer, input_class=Convolutional3DLayer, output_class=Convolutional3DLayer, **kwargs, ) self.color = color self.filter_color = self.input_layer.filter_color self.filter_width = self.input_layer.filter_width self.filter_height = self.input_layer.filter_height self.feature_map_width = self.input_layer.feature_map_width self.feature_map_height = self.input_layer.feature_map_height self.num_input_feature_maps = self.input_layer.num_feature_maps self.num_output_feature_maps = self.output_layer.num_feature_maps self.cell_width = self.input_layer.cell_width self.stride = self.input_layer.stride self.filter_opacity = filter_opacity self.cell_width = cell_width self.line_color = line_color self.pulse_color = pulse_color self.show_grid_lines = show_grid_lines self.highlight_color = highlight_color def get_rotated_shift_vectors(self): """ Rotates the shift vectors """ # Make base shift vectors right_shift = np.array([self.input_layer.cell_width, 0, 0]) down_shift = np.array([0, -self.input_layer.cell_width, 0]) # Make rotation matrix rot_mat = rotation_matrix(ThreeDLayer.rotation_angle, ThreeDLayer.rotation_axis) # Rotate the vectors right_shift = np.dot(right_shift, rot_mat.T) down_shift = np.dot(down_shift, rot_mat.T) return right_shift, down_shift def animate_filters_all_at_once(self, filters): """Animates each of the filters all at once""" animations = [] # Make filters filters = Filters( self.input_layer, self.output_layer, line_color=self.color, cell_width=self.cell_width, show_grid_lines=self.show_grid_lines, output_feature_map_to_connect=None, # None means all at once ) animations.append(Create(filters)) # Get the rotated shift vectors right_shift, down_shift = self.get_rotated_shift_vectors() left_shift = -1 * right_shift # Make the animation num_y_moves = int((self.feature_map_height - self.filter_height) / self.stride) num_x_moves = int((self.feature_map_width - self.filter_width) / self.stride) for y_move in range(num_y_moves): # Go right num_x_moves for x_move in range(num_x_moves): # Shift right shift_animation = ApplyMethod(filters.shift, self.stride * right_shift) # shift_animation = self.animate.shift(right_shift) animations.append(shift_animation) # Go back left num_x_moves and down one shift_amount = ( self.stride * num_x_moves * left_shift + self.stride * down_shift ) # Make the animation shift_animation = ApplyMethod(filters.shift, shift_amount) animations.append(shift_animation) # Do last row move right for x_move in range(num_x_moves): # Shift right shift_animation = ApplyMethod(filters.shift, self.stride * right_shift) # shift_animation = self.animate.shift(right_shift) animations.append(shift_animation) # Remove the filters animations.append(FadeOut(filters)) return Succession(*animations, lag_ratio=1.0) def animate_filters_one_at_a_time(self, highlight_active_feature_map=False): """Animates each of the filters one at a time""" animations = [] output_feature_maps = self.output_layer.feature_maps for feature_map_index in range(len(output_feature_maps)): # Make filters filters = Filters( self.input_layer, self.output_layer, line_color=self.color, cell_width=self.cell_width, show_grid_lines=self.show_grid_lines, output_feature_map_to_connect=feature_map_index, # None means all at once ) animations.append(Create(filters)) # Highlight the feature map if highlight_active_feature_map: feature_map = output_feature_maps[feature_map_index] original_feature_map_color = feature_map.color # Change the output feature map colors change_color_animations = [] change_color_animations.append( ApplyMethod(feature_map.set_color, self.highlight_color) ) # Change the input feature map colors input_feature_maps = self.input_layer.feature_maps for input_feature_map in input_feature_maps: change_color_animations.append( ApplyMethod(input_feature_map.set_color, self.highlight_color) ) # Combine the animations animations.append( AnimationGroup(*change_color_animations, lag_ratio=0.0) ) # Get the rotated shift vectors right_shift, down_shift = self.get_rotated_shift_vectors() left_shift = -1 * right_shift # Make the animation num_y_moves = int( (self.feature_map_height - self.filter_height) / self.stride ) num_x_moves = int( (self.feature_map_width - self.filter_width) / self.stride ) for y_move in range(num_y_moves): # Go right num_x_moves for x_move in range(num_x_moves): # Make a pulse animation for the corners """ pulse_animation = filters.make_pulse_animation( shift_amount=shift_amount ) animations.append(pulse_animation) """ z_index_animation = ApplyMethod(filters.set_z_index, 5) animations.append(z_index_animation) # Shift right shift_animation = ApplyMethod( filters.shift, self.stride * right_shift ) # shift_animation = self.animate.shift(right_shift) animations.append(shift_animation) # Go back left num_x_moves and down one shift_amount = ( self.stride * num_x_moves * left_shift + self.stride * down_shift ) # Make the animation shift_animation = ApplyMethod(filters.shift, shift_amount) animations.append(shift_animation) # Do last row move right for x_move in range(num_x_moves): # Shift right shift_animation = ApplyMethod(filters.shift, self.stride * right_shift) # shift_animation = self.animate.shift(right_shift) animations.append(shift_animation) # Remove the filters animations.append(FadeOut(filters)) # Un-highlight the feature map if highlight_active_feature_map: feature_map = output_feature_maps[feature_map_index] # Change the output feature map colors change_color_animations = [] change_color_animations.append( ApplyMethod(feature_map.set_color, original_feature_map_color) ) # Change the input feature map colors input_feature_maps = self.input_layer.feature_maps for input_feature_map in input_feature_maps: change_color_animations.append( ApplyMethod( input_feature_map.set_color, original_feature_map_color ) ) # Combine the animations animations.append( AnimationGroup(*change_color_animations, lag_ratio=0.0) ) return Succession(*animations, lag_ratio=1.0) def make_forward_pass_animation( self, layer_args={}, all_filters_at_once=False, highlight_active_feature_map=False, run_time=10.5, **kwargs, ): """Forward pass animation from conv2d to conv2d""" print(f"All filters at once: {all_filters_at_once}") # Make filter shifting animations if all_filters_at_once: return self.animate_filters_all_at_once() else: return self.animate_filters_one_at_a_time( highlight_active_feature_map=highlight_active_feature_map ) def scale(self, scale_factor, **kwargs): self.cell_width *= scale_factor super().scale(scale_factor, **kwargs) @override_animation(Create) def _create_override(self, **kwargs): return Succession()