mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-25 08:56:23 +08:00
107 lines
3.8 KiB
Python
107 lines
3.8 KiB
Python
|
|
from manim import *
|
|
from torch import _fake_quantize_learnable_per_tensor_affine
|
|
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
|
|
|
|
class ConvolutionalLayer(VGroupNeuralNetworkLayer):
|
|
"""Handles rendering a convolutional layer for a nn"""
|
|
|
|
def __init__(self, num_filters, filter_width, filter_height, filter_spacing=0.1, color=BLUE,
|
|
pulse_color=ORANGE, **kwargs):
|
|
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
|
|
self.num_filters = num_filters
|
|
self.filter_width = filter_width
|
|
self.filter_height = filter_height
|
|
self.filter_spacing = filter_spacing
|
|
self.color = color
|
|
self.pulse_color = pulse_color
|
|
|
|
self._construct_layer(num_filters=self.num_filters, filter_width=self.filter_width, filter_height=self.filter_height)
|
|
|
|
def _construct_layer(self, num_filters=5, filter_width=4, filter_height=4):
|
|
"""Creates the neural network layer"""
|
|
# Make axes, but hide the lines
|
|
axes = ThreeDAxes(
|
|
tips=False,
|
|
x_length=1,
|
|
y_length=1,
|
|
x_axis_config={
|
|
"include_ticks": False,
|
|
"stroke_width": 0.0
|
|
},
|
|
y_axis_config={
|
|
"include_ticks": False,
|
|
"stroke_width": 0.0
|
|
},
|
|
z_axis_config={
|
|
"include_ticks": False,
|
|
"stroke_width": 0.0
|
|
}
|
|
)
|
|
self.add(axes)
|
|
# Set the camera angle so that the
|
|
# self.set_camera_orientation(phi=75 * DEGREES, theta=30 * DEGREES)
|
|
# Draw rectangles that are filled in with opacity
|
|
self.rectangles = VGroup()
|
|
for filter_index in range(num_filters):
|
|
rectangle = Rectangle(
|
|
color=self.color,
|
|
height=filter_height,
|
|
width=filter_width,
|
|
fill_color=self.color,
|
|
fill_opacity=0.2,
|
|
stroke_color=WHITE,
|
|
)
|
|
rectangle.rotate_about_origin((80 - filter_index*0.5) * DEGREES, np.array([0, 1, 0])) # Rotate about z axis
|
|
rectangle.rotate_about_origin(15 * DEGREES, np.array([1, 0, 0])) # Rotate about x axis
|
|
rectangle.shift(np.array([filter_index*self.filter_spacing, filter_height*0.5, -3]))
|
|
|
|
self.rectangles.add(rectangle)
|
|
|
|
self.add(self.rectangles)
|
|
|
|
self.corner_lines = self.make_filter_corner_lines()
|
|
self.add(self.corner_lines)
|
|
|
|
def make_filter_corner_lines(self):
|
|
"""Make filter corner lines"""
|
|
corner_lines = VGroup()
|
|
|
|
first_rectangle = self.rectangles[0]
|
|
last_rectangle = self.rectangles[-1]
|
|
first_vertices = first_rectangle.get_vertices()
|
|
last_vertices = last_rectangle.get_vertices()
|
|
for vertex_index in range(len(first_vertices)):
|
|
# Make a line
|
|
line = Line(
|
|
start=first_vertices[vertex_index],
|
|
end=last_vertices[vertex_index],
|
|
color=WHITE,
|
|
stroke_opacity=0.0
|
|
)
|
|
corner_lines.add(line)
|
|
|
|
return corner_lines
|
|
|
|
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
|
"""Convolution forward pass animation"""
|
|
animations = []
|
|
for line in self.corner_lines:
|
|
pulse = ShowPassingFlash(
|
|
line.copy()
|
|
.set_color(self.pulse_color)
|
|
.set_stroke(opacity=1.0),
|
|
time_width=0.5
|
|
)
|
|
animations.append(pulse)
|
|
# Make animation group
|
|
animation_group = AnimationGroup(
|
|
*animations
|
|
)
|
|
|
|
return animation_group
|
|
|
|
@override_animation(Create)
|
|
def _create_override(self, **kwargs):
|
|
return FadeIn(self.rectangles)
|