Files
2022-04-29 14:36:14 -04:00

107 lines
3.8 KiB
Python

from manim import *
from torch import _fake_quantize_learnable_per_tensor_affine
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
class ConvolutionalLayer(VGroupNeuralNetworkLayer):
"""Handles rendering a convolutional layer for a nn"""
def __init__(self, num_filters, filter_width, filter_height, filter_spacing=0.1, color=BLUE,
pulse_color=ORANGE, **kwargs):
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
self.num_filters = num_filters
self.filter_width = filter_width
self.filter_height = filter_height
self.filter_spacing = filter_spacing
self.color = color
self.pulse_color = pulse_color
self._construct_layer(num_filters=self.num_filters, filter_width=self.filter_width, filter_height=self.filter_height)
def _construct_layer(self, num_filters=5, filter_width=4, filter_height=4):
"""Creates the neural network layer"""
# Make axes, but hide the lines
axes = ThreeDAxes(
tips=False,
x_length=1,
y_length=1,
x_axis_config={
"include_ticks": False,
"stroke_width": 0.0
},
y_axis_config={
"include_ticks": False,
"stroke_width": 0.0
},
z_axis_config={
"include_ticks": False,
"stroke_width": 0.0
}
)
self.add(axes)
# Set the camera angle so that the
# self.set_camera_orientation(phi=75 * DEGREES, theta=30 * DEGREES)
# Draw rectangles that are filled in with opacity
self.rectangles = VGroup()
for filter_index in range(num_filters):
rectangle = Rectangle(
color=self.color,
height=filter_height,
width=filter_width,
fill_color=self.color,
fill_opacity=0.2,
stroke_color=WHITE,
)
rectangle.rotate_about_origin((80 - filter_index*0.5) * DEGREES, np.array([0, 1, 0])) # Rotate about z axis
rectangle.rotate_about_origin(15 * DEGREES, np.array([1, 0, 0])) # Rotate about x axis
rectangle.shift(np.array([filter_index*self.filter_spacing, filter_height*0.5, -3]))
self.rectangles.add(rectangle)
self.add(self.rectangles)
self.corner_lines = self.make_filter_corner_lines()
self.add(self.corner_lines)
def make_filter_corner_lines(self):
"""Make filter corner lines"""
corner_lines = VGroup()
first_rectangle = self.rectangles[0]
last_rectangle = self.rectangles[-1]
first_vertices = first_rectangle.get_vertices()
last_vertices = last_rectangle.get_vertices()
for vertex_index in range(len(first_vertices)):
# Make a line
line = Line(
start=first_vertices[vertex_index],
end=last_vertices[vertex_index],
color=WHITE,
stroke_opacity=0.0
)
corner_lines.add(line)
return corner_lines
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Convolution forward pass animation"""
animations = []
for line in self.corner_lines:
pulse = ShowPassingFlash(
line.copy()
.set_color(self.pulse_color)
.set_stroke(opacity=1.0),
time_width=0.5
)
animations.append(pulse)
# Make animation group
animation_group = AnimationGroup(
*animations
)
return animation_group
@override_animation(Create)
def _create_override(self, **kwargs):
return FadeIn(self.rectangles)