from manim import * from manim_ml.neural_network.layers.convolutional import ConvolutionalLayer from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer class ConvolutionalToConvolutional(ConnectiveLayer): """Feed Forward to Embedding Layer""" input_class = ConvolutionalLayer output_class = ConvolutionalLayer def __init__(self, input_layer, output_layer, color=WHITE, pulse_color=RED, **kwargs): super().__init__(input_layer, output_layer, input_class=ConvolutionalLayer, output_class=ConvolutionalLayer, **kwargs) self.color = color self.pulse_color = pulse_color self.lines = self.make_lines() self.add(self.lines) def make_lines(self): """Make lines connecting the input and output layers""" lines = VGroup() # Get the first and last rectangle input_rectangle = self.input_layer.rectangles[-1] output_rectangle = self.output_layer.rectangles[0] input_vertices = input_rectangle.get_vertices() output_vertices = output_rectangle.get_vertices() # Go through each vertex for vertex_index in range(len(input_vertices)): # Make a line line = Line( start=input_vertices[vertex_index], end=output_vertices[vertex_index], color=self.color, stroke_opacity=0.0 ) lines.add(line) return lines def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs): """Forward pass animation from conv to conv""" animations = [] # Go thorugh the lines for line in self.lines: pulse = ShowPassingFlash( line.copy() .set_color(self.pulse_color) .set_stroke(opacity=1.0), time_width=0.5 ) animations.append(pulse) # Make animation group animation_group = AnimationGroup( *animations ) return animation_group @override_animation(Create) def _create_override(self, **kwargs): return AnimationGroup()