mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-20 03:57:40 +08:00
51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
from manim import *
|
|
from matplotlib import animation
|
|
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
|
|
|
|
|
|
class Convolutional2DLayer(VGroupNeuralNetworkLayer):
|
|
def __init__(
|
|
self,
|
|
feature_map_height,
|
|
feature_map_width,
|
|
filter_width,
|
|
filter_height,
|
|
stride=1,
|
|
cell_width=0.5,
|
|
feature_map_color=BLUE,
|
|
filter_color=ORANGE,
|
|
**kwargs
|
|
):
|
|
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
|
|
self.feature_map_height = feature_map_height
|
|
self.feature_map_width = feature_map_width
|
|
self.filter_width = filter_width
|
|
self.filter_height = filter_height
|
|
self.feature_map_color = feature_map_color
|
|
self.filter_color = filter_color
|
|
self.stride = stride
|
|
self.cell_width = cell_width
|
|
# Construct the input
|
|
self.construct_feature_map()
|
|
|
|
def construct_feature_map(self):
|
|
"""Makes feature map"""
|
|
# Make feature map rectangle
|
|
self.feature_map = Rectangle(
|
|
width=self.feature_map_width * self.cell_width,
|
|
height=self.feature_map_height * self.cell_width,
|
|
color=self.feature_map_color,
|
|
grid_xstep=self.cell_width,
|
|
grid_ystep=self.cell_width,
|
|
)
|
|
|
|
self.add(self.feature_map)
|
|
|
|
@override_animation(Create)
|
|
def _create_override(self, **kwargs):
|
|
return FadeIn(self.feature_map)
|
|
|
|
def make_forward_pass_animation(self, **kwargs):
|
|
"""Make feed forward animation"""
|
|
return AnimationGroup()
|