mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-26 10:07:33 +08:00
Added working convolutional layer.
This commit is contained in:
@ -1,25 +1,106 @@
|
||||
|
||||
from manim import *
|
||||
from torch import _fake_quantize_learnable_per_tensor_affine
|
||||
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
|
||||
|
||||
class ConvolutionalLayer(VGroupNeuralNetworkLayer):
|
||||
"""Handles rendering a convolutional layer for a nn"""
|
||||
|
||||
def __init__(self, num_filters, filter_width, **kwargs):
|
||||
def __init__(self, num_filters, filter_width, filter_height, filter_spacing=0.1, color=BLUE,
|
||||
pulse_color=ORANGE, **kwargs):
|
||||
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
|
||||
self.num_filters = num_filters
|
||||
self.filter_width = filter_width
|
||||
self.filter_height = filter_height
|
||||
self.filter_spacing = filter_spacing
|
||||
self.color = color
|
||||
self.pulse_color = pulse_color
|
||||
|
||||
self._construct_neural_network_layer()
|
||||
self._construct_layer(num_filters=self.num_filters, filter_width=self.filter_width, filter_height=self.filter_height)
|
||||
|
||||
def _construct_neural_network_layer(self):
|
||||
def _construct_layer(self, num_filters=5, filter_width=4, filter_height=4):
|
||||
"""Creates the neural network layer"""
|
||||
pass
|
||||
# Make axes, but hide the lines
|
||||
axes = ThreeDAxes(
|
||||
tips=False,
|
||||
x_length=1,
|
||||
y_length=1,
|
||||
x_axis_config={
|
||||
"include_ticks": False,
|
||||
"stroke_width": 0.0
|
||||
},
|
||||
y_axis_config={
|
||||
"include_ticks": False,
|
||||
"stroke_width": 0.0
|
||||
},
|
||||
z_axis_config={
|
||||
"include_ticks": False,
|
||||
"stroke_width": 0.0
|
||||
}
|
||||
)
|
||||
self.add(axes)
|
||||
# Set the camera angle so that the
|
||||
# self.set_camera_orientation(phi=75 * DEGREES, theta=30 * DEGREES)
|
||||
# Draw rectangles that are filled in with opacity
|
||||
self.rectangles = VGroup()
|
||||
for filter_index in range(num_filters):
|
||||
rectangle = Rectangle(
|
||||
color=self.color,
|
||||
height=filter_height,
|
||||
width=filter_width,
|
||||
fill_color=self.color,
|
||||
fill_opacity=0.2,
|
||||
stroke_color=WHITE,
|
||||
)
|
||||
rectangle.rotate_about_origin((80 - filter_index*0.5) * DEGREES, np.array([0, 1, 0])) # Rotate about z axis
|
||||
rectangle.rotate_about_origin(15 * DEGREES, np.array([1, 0, 0])) # Rotate about x axis
|
||||
rectangle.shift(np.array([filter_index*self.filter_spacing, filter_height*0.5, -3]))
|
||||
|
||||
self.rectangles.add(rectangle)
|
||||
|
||||
self.add(self.rectangles)
|
||||
|
||||
self.corner_lines = self.make_filter_corner_lines()
|
||||
self.add(self.corner_lines)
|
||||
|
||||
def make_filter_corner_lines(self):
|
||||
"""Make filter corner lines"""
|
||||
corner_lines = VGroup()
|
||||
|
||||
first_rectangle = self.rectangles[0]
|
||||
last_rectangle = self.rectangles[-1]
|
||||
first_vertices = first_rectangle.get_vertices()
|
||||
last_vertices = last_rectangle.get_vertices()
|
||||
for vertex_index in range(len(first_vertices)):
|
||||
# Make a line
|
||||
line = Line(
|
||||
start=first_vertices[vertex_index],
|
||||
end=last_vertices[vertex_index],
|
||||
color=WHITE,
|
||||
stroke_opacity=0.0
|
||||
)
|
||||
corner_lines.add(line)
|
||||
|
||||
return corner_lines
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
# make highlight animation
|
||||
return None
|
||||
"""Convolution forward pass animation"""
|
||||
animations = []
|
||||
for line in self.corner_lines:
|
||||
pulse = ShowPassingFlash(
|
||||
line.copy()
|
||||
.set_color(self.pulse_color)
|
||||
.set_stroke(opacity=1.0),
|
||||
time_width=0.5
|
||||
)
|
||||
animations.append(pulse)
|
||||
# Make animation group
|
||||
animation_group = AnimationGroup(
|
||||
*animations
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
pass
|
||||
return FadeIn(self.rectangles)
|
||||
|
Reference in New Issue
Block a user