mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-20 03:57:40 +08:00
67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
from cProfile import run
|
|
from manim import *
|
|
from manim_ml.neural_network.layers.convolutional_3d import Convolutional3DLayer
|
|
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
|
|
|
class Convolutional3DToConvolutional3D(ConnectiveLayer):
|
|
"""Feed Forward to Embedding Layer"""
|
|
input_class = Convolutional3DLayer
|
|
output_class = Convolutional3DLayer
|
|
|
|
def __init__(self, input_layer, output_layer, color=WHITE, pulse_color=RED,
|
|
**kwargs):
|
|
super().__init__(input_layer, output_layer, input_class=Convolutional3DLayer, output_class=Convolutional3DLayer,
|
|
**kwargs)
|
|
self.color = color
|
|
self.pulse_color = pulse_color
|
|
|
|
self.lines = self.make_lines()
|
|
self.add(self.lines)
|
|
|
|
def make_lines(self):
|
|
"""Make lines connecting the input and output layers"""
|
|
lines = VGroup()
|
|
# Get the first and last rectangle
|
|
input_rectangle = self.input_layer.rectangles[-1]
|
|
output_rectangle = self.output_layer.rectangles[0]
|
|
input_vertices = input_rectangle.get_vertices()
|
|
output_vertices = output_rectangle.get_vertices()
|
|
# Go through each vertex
|
|
for vertex_index in range(len(input_vertices)):
|
|
# Make a line
|
|
line = Line(
|
|
start=input_vertices[vertex_index],
|
|
end=output_vertices[vertex_index],
|
|
color=self.color,
|
|
stroke_opacity=0.0
|
|
)
|
|
lines.add(line)
|
|
|
|
return lines
|
|
|
|
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
|
|
"""Forward pass animation from conv to conv"""
|
|
animations = []
|
|
# Go thorugh the lines
|
|
for line in self.lines:
|
|
pulse = ShowPassingFlash(
|
|
line.copy()
|
|
.set_color(self.pulse_color)
|
|
.set_stroke(opacity=1.0),
|
|
time_width=0.5,
|
|
run_time=run_time
|
|
)
|
|
animations.append(pulse)
|
|
# Make animation group
|
|
animation_group = AnimationGroup(
|
|
*animations,
|
|
run_time=run_time
|
|
)
|
|
|
|
return animation_group
|
|
|
|
@override_animation(Create)
|
|
def _create_override(self, **kwargs):
|
|
return AnimationGroup()
|
|
|