mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-18 12:07:46 +08:00
Overall working 3D convolution visualization.
This commit is contained in:
@ -8,7 +8,7 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
|
||||
def __init__(self, num_feature_maps, feature_map_width, feature_map_height,
|
||||
filter_width, filter_height, cell_width=0.2, filter_spacing=0.1, color=BLUE,
|
||||
pulse_color=ORANGE, filter_color=ORANGE, stride=1, stroke_width=2.0, **kwargs):
|
||||
pulse_color=ORANGE, show_grid_lines=False, filter_color=ORANGE, stride=1, stroke_width=2.0, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.num_feature_maps = num_feature_maps
|
||||
self.feature_map_height = feature_map_height
|
||||
@ -22,20 +22,24 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
self.pulse_color = pulse_color
|
||||
self.stride = stride
|
||||
self.stroke_width = stroke_width
|
||||
self.show_grid_lines = show_grid_lines
|
||||
# Make the feature maps
|
||||
self.feature_maps = self.construct_feature_maps()
|
||||
self.add(self.feature_maps)
|
||||
# Rotate stuff properly
|
||||
# normal_vector = self.feature_maps[0].get_normal_vector()
|
||||
self.rotate(
|
||||
ThreeDLayer.three_d_x_rotation,
|
||||
ThreeDLayer.rotation_angle,
|
||||
about_point=self.get_center(),
|
||||
axis=[1, 0, 0]
|
||||
axis=ThreeDLayer.rotation_axis
|
||||
)
|
||||
"""
|
||||
self.rotate(
|
||||
ThreeDLayer.three_d_y_rotation,
|
||||
about_point=self.get_center(),
|
||||
axis=[0, 1, 0]
|
||||
)
|
||||
"""
|
||||
|
||||
def construct_feature_maps(self):
|
||||
"""Creates the neural network layer"""
|
||||
@ -50,14 +54,17 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
fill_opacity=0.2,
|
||||
stroke_color=self.color,
|
||||
stroke_width=self.stroke_width,
|
||||
# grid_xstep=self.cell_width,
|
||||
# grid_ystep=self.cell_width,
|
||||
# grid_stroke_width=DEFAULT_STROKE_WIDTH/2
|
||||
grid_xstep=self.cell_width,
|
||||
grid_ystep=self.cell_width,
|
||||
grid_stroke_width=self.stroke_width/2,
|
||||
grid_stroke_color=self.color,
|
||||
show_grid_lines=self.show_grid_lines,
|
||||
)
|
||||
# Move the feature map
|
||||
rectangle.move_to(
|
||||
[0, 0, filter_index * self.filter_spacing]
|
||||
)
|
||||
rectangle.set_z_index(4)
|
||||
feature_maps.append(rectangle)
|
||||
|
||||
return VGroup(*feature_maps)
|
||||
|
Reference in New Issue
Block a user