mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-26 02:05:18 +08:00
Added working convolutional layer.
This commit is contained in:
@ -0,0 +1,63 @@
|
||||
from manim import *
|
||||
from manim_ml.neural_network.layers.convolutional import ConvolutionalLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
||||
|
||||
class ConvolutionalToConvolutional(ConnectiveLayer):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
input_class = ConvolutionalLayer
|
||||
output_class = ConvolutionalLayer
|
||||
|
||||
def __init__(self, input_layer, output_layer, color=WHITE, pulse_color=RED,
|
||||
**kwargs):
|
||||
super().__init__(input_layer, output_layer, input_class=ConvolutionalLayer, output_class=ConvolutionalLayer,
|
||||
**kwargs)
|
||||
self.color = color
|
||||
self.pulse_color = pulse_color
|
||||
|
||||
self.lines = self.make_lines()
|
||||
self.add(self.lines)
|
||||
|
||||
def make_lines(self):
|
||||
"""Make lines connecting the input and output layers"""
|
||||
lines = VGroup()
|
||||
# Get the first and last rectangle
|
||||
input_rectangle = self.input_layer.rectangles[-1]
|
||||
output_rectangle = self.output_layer.rectangles[0]
|
||||
input_vertices = input_rectangle.get_vertices()
|
||||
output_vertices = output_rectangle.get_vertices()
|
||||
# Go through each vertex
|
||||
for vertex_index in range(len(input_vertices)):
|
||||
# Make a line
|
||||
line = Line(
|
||||
start=input_vertices[vertex_index],
|
||||
end=output_vertices[vertex_index],
|
||||
color=self.color,
|
||||
stroke_opacity=0.0
|
||||
)
|
||||
lines.add(line)
|
||||
|
||||
return lines
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
|
||||
"""Forward pass animation from conv to conv"""
|
||||
animations = []
|
||||
# Go thorugh the lines
|
||||
for line in self.lines:
|
||||
pulse = ShowPassingFlash(
|
||||
line.copy()
|
||||
.set_color(self.pulse_color)
|
||||
.set_stroke(opacity=1.0),
|
||||
time_width=0.5
|
||||
)
|
||||
animations.append(pulse)
|
||||
# Make animation group
|
||||
animation_group = AnimationGroup(
|
||||
*animations
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
return AnimationGroup()
|
||||
|
||||
|
Reference in New Issue
Block a user