Files
ManimML/manim_ml/neural_network/layers/convolutional3d_to_convolutional3d.py
2022-05-15 13:42:21 -04:00

67 lines
2.3 KiB
Python

from cProfile import run
from manim import *
from manim_ml.neural_network.layers.convolutional_3d import Convolutional3DLayer
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
class Convolutional3DToConvolutional3D(ConnectiveLayer):
"""Feed Forward to Embedding Layer"""
input_class = Convolutional3DLayer
output_class = Convolutional3DLayer
def __init__(self, input_layer, output_layer, color=WHITE, pulse_color=RED,
**kwargs):
super().__init__(input_layer, output_layer, input_class=Convolutional3DLayer, output_class=Convolutional3DLayer,
**kwargs)
self.color = color
self.pulse_color = pulse_color
self.lines = self.make_lines()
self.add(self.lines)
def make_lines(self):
"""Make lines connecting the input and output layers"""
lines = VGroup()
# Get the first and last rectangle
input_rectangle = self.input_layer.rectangles[-1]
output_rectangle = self.output_layer.rectangles[0]
input_vertices = input_rectangle.get_vertices()
output_vertices = output_rectangle.get_vertices()
# Go through each vertex
for vertex_index in range(len(input_vertices)):
# Make a line
line = Line(
start=input_vertices[vertex_index],
end=output_vertices[vertex_index],
color=self.color,
stroke_opacity=0.0
)
lines.add(line)
return lines
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
"""Forward pass animation from conv to conv"""
animations = []
# Go thorugh the lines
for line in self.lines:
pulse = ShowPassingFlash(
line.copy()
.set_color(self.pulse_color)
.set_stroke(opacity=1.0),
time_width=0.5,
run_time=run_time
)
animations.append(pulse)
# Make animation group
animation_group = AnimationGroup(
*animations,
run_time=run_time
)
return animation_group
@override_animation(Create)
def _create_override(self, **kwargs):
return AnimationGroup()