Files
ManimML/manim_ml/neural_network/layers/convolutional_to_convolutional.py
2022-04-29 14:36:14 -04:00

64 lines
2.2 KiB
Python

from manim import *
from manim_ml.neural_network.layers.convolutional import ConvolutionalLayer
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
class ConvolutionalToConvolutional(ConnectiveLayer):
"""Feed Forward to Embedding Layer"""
input_class = ConvolutionalLayer
output_class = ConvolutionalLayer
def __init__(self, input_layer, output_layer, color=WHITE, pulse_color=RED,
**kwargs):
super().__init__(input_layer, output_layer, input_class=ConvolutionalLayer, output_class=ConvolutionalLayer,
**kwargs)
self.color = color
self.pulse_color = pulse_color
self.lines = self.make_lines()
self.add(self.lines)
def make_lines(self):
"""Make lines connecting the input and output layers"""
lines = VGroup()
# Get the first and last rectangle
input_rectangle = self.input_layer.rectangles[-1]
output_rectangle = self.output_layer.rectangles[0]
input_vertices = input_rectangle.get_vertices()
output_vertices = output_rectangle.get_vertices()
# Go through each vertex
for vertex_index in range(len(input_vertices)):
# Make a line
line = Line(
start=input_vertices[vertex_index],
end=output_vertices[vertex_index],
color=self.color,
stroke_opacity=0.0
)
lines.add(line)
return lines
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
"""Forward pass animation from conv to conv"""
animations = []
# Go thorugh the lines
for line in self.lines:
pulse = ShowPassingFlash(
line.copy()
.set_color(self.pulse_color)
.set_stroke(opacity=1.0),
time_width=0.5
)
animations.append(pulse)
# Make animation group
animation_group = AnimationGroup(
*animations
)
return animation_group
@override_animation(Create)
def _create_override(self, **kwargs):
return AnimationGroup()