mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-26 13:22:08 +08:00
Convolutional Layers
This commit is contained in:
@ -0,0 +1,66 @@
|
||||
from cProfile import run
|
||||
from manim import *
|
||||
from manim_ml.neural_network.layers.convolutional_3d import Convolutional3DLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
||||
|
||||
class Convolutional3DToConvolutional3D(ConnectiveLayer):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
input_class = Convolutional3DLayer
|
||||
output_class = Convolutional3DLayer
|
||||
|
||||
def __init__(self, input_layer, output_layer, color=WHITE, pulse_color=RED,
|
||||
**kwargs):
|
||||
super().__init__(input_layer, output_layer, input_class=Convolutional3DLayer, output_class=Convolutional3DLayer,
|
||||
**kwargs)
|
||||
self.color = color
|
||||
self.pulse_color = pulse_color
|
||||
|
||||
self.lines = self.make_lines()
|
||||
self.add(self.lines)
|
||||
|
||||
def make_lines(self):
|
||||
"""Make lines connecting the input and output layers"""
|
||||
lines = VGroup()
|
||||
# Get the first and last rectangle
|
||||
input_rectangle = self.input_layer.rectangles[-1]
|
||||
output_rectangle = self.output_layer.rectangles[0]
|
||||
input_vertices = input_rectangle.get_vertices()
|
||||
output_vertices = output_rectangle.get_vertices()
|
||||
# Go through each vertex
|
||||
for vertex_index in range(len(input_vertices)):
|
||||
# Make a line
|
||||
line = Line(
|
||||
start=input_vertices[vertex_index],
|
||||
end=output_vertices[vertex_index],
|
||||
color=self.color,
|
||||
stroke_opacity=0.0
|
||||
)
|
||||
lines.add(line)
|
||||
|
||||
return lines
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
|
||||
"""Forward pass animation from conv to conv"""
|
||||
animations = []
|
||||
# Go thorugh the lines
|
||||
for line in self.lines:
|
||||
pulse = ShowPassingFlash(
|
||||
line.copy()
|
||||
.set_color(self.pulse_color)
|
||||
.set_stroke(opacity=1.0),
|
||||
time_width=0.5,
|
||||
run_time=run_time
|
||||
)
|
||||
animations.append(pulse)
|
||||
# Make animation group
|
||||
animation_group = AnimationGroup(
|
||||
*animations,
|
||||
run_time=run_time
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
return AnimationGroup()
|
||||
|
Reference in New Issue
Block a user