mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-20 12:05:58 +08:00
62 lines
2.2 KiB
Python
62 lines
2.2 KiB
Python
from manim import *
|
|
|
|
from manim_ml.neural_network.layers.parent_layers import ThreeDLayer, VGroupNeuralNetworkLayer
|
|
|
|
class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|
"""Max pooling layer for Convolutional2DLayer
|
|
|
|
Note: This is for a Convolutional2DLayer even though
|
|
it is called MaxPooling2DLayer because the 2D corresponds
|
|
to the 2 spatial dimensions of the convolution.
|
|
"""
|
|
|
|
def __init__(self, output_feature_map_size=(4, 4), kernel_size=2, stride=1,
|
|
cell_highlight_color=ORANGE, **kwargs):
|
|
"""Layer object for animating 2D Convolution Max Pooling
|
|
|
|
Parameters
|
|
----------
|
|
kernel_size : int or tuple, optional
|
|
Width/Height of max pooling kernel, by default 2
|
|
stride : int, optional
|
|
Stride of the max pooling operation, by default 1
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self.output_feature_map_size = output_feature_map_size
|
|
self.kernel_size = kernel_size
|
|
self.stride = stride
|
|
self.cell_highlight_color = cell_highlight_color
|
|
# Make the output feature maps
|
|
feature_maps = self._make_output_feature_maps()
|
|
self.add(feature_maps)
|
|
|
|
def construct_layer(self, input_layer, output_layer):
|
|
"""Constructs the layer in the context of adjacent layers"""
|
|
pass
|
|
|
|
def _make_output_feature_maps(self):
|
|
"""Makes a set of output feature maps"""
|
|
# Compute the size of the feature maps
|
|
pass
|
|
|
|
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
|
"""Makes forward pass of Max Pooling Layer.
|
|
|
|
Parameters
|
|
----------
|
|
layer_args : dict, optional
|
|
_description_, by default {}
|
|
"""
|
|
# 1. Draw gridded rectangle with kernel_size x kernel_size
|
|
# box regions over the input feature map.
|
|
# 2. Randomly highlight one of the cells in the kernel.
|
|
# 3. Move and resize the gridded rectangle to the output
|
|
# feature maps.
|
|
# 4. Make the gridded feature map(s) disappear.
|
|
pass
|
|
|
|
@override_animation(Create)
|
|
def _create_override(self, **kwargs):
|
|
"""Create animation for the MaxPooling operation"""
|
|
pass
|