mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-20 20:16:32 +08:00
Added padding.
Fixed a bug with ImageLayerToConvolutional2D Padding example
This commit is contained in:
@ -12,6 +12,89 @@ from manim_ml.neural_network.layers.parent_layers import (
|
||||
)
|
||||
from manim_ml.gridded_rectangle import GriddedRectangle
|
||||
|
||||
class FeatureMap(VGroup):
|
||||
"""Class for making a feature map"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
color=ORANGE,
|
||||
feature_map_size=None,
|
||||
fill_color=ORANGE,
|
||||
fill_opacity=0.2,
|
||||
cell_width=0.2,
|
||||
padding=(0, 0),
|
||||
stroke_width=2.0,
|
||||
show_grid_lines=False,
|
||||
padding_dashed=False
|
||||
):
|
||||
super().__init__()
|
||||
self.color = color
|
||||
self.feature_map_size = feature_map_size
|
||||
self.fill_color = fill_color
|
||||
self.fill_opacity = fill_opacity
|
||||
self.cell_width = cell_width
|
||||
self.padding = padding
|
||||
self.stroke_width = stroke_width
|
||||
self.show_grid_lines = show_grid_lines
|
||||
self.padding_dashed = padding_dashed
|
||||
# Check if we have non-zero padding
|
||||
if padding[0] > 0 or padding[1] > 0:
|
||||
# Make the exterior rectangle dashed
|
||||
width_with_padding = (self.feature_map_size[0] + self.padding[0] * 2) * self.cell_width
|
||||
height_with_padding = (self.feature_map_size[1] + self.padding[1] * 2) * self.cell_width
|
||||
self.untransformed_width = width_with_padding
|
||||
self.untransformed_height = height_with_padding
|
||||
|
||||
self.exterior_rectangle = GriddedRectangle(
|
||||
color=self.color,
|
||||
width=width_with_padding,
|
||||
height=height_with_padding,
|
||||
fill_color=self.color,
|
||||
fill_opacity=self.fill_opacity,
|
||||
stroke_color=self.color,
|
||||
stroke_width=self.stroke_width,
|
||||
grid_xstep=self.cell_width,
|
||||
grid_ystep=self.cell_width,
|
||||
grid_stroke_width=self.stroke_width / 2,
|
||||
grid_stroke_color=self.color,
|
||||
show_grid_lines=self.show_grid_lines,
|
||||
dotted_lines=self.padding_dashed
|
||||
)
|
||||
self.add(self.exterior_rectangle)
|
||||
# Add an interior rectangle with no fill color
|
||||
self.interior_rectangle = GriddedRectangle(
|
||||
color=self.color,
|
||||
fill_opacity=0.0,
|
||||
width=self.feature_map_size[0] * self.cell_width,
|
||||
height=self.feature_map_size[1] * self.cell_width,
|
||||
stroke_width=self.stroke_width
|
||||
)
|
||||
self.add(self.interior_rectangle)
|
||||
else:
|
||||
# Just make an exterior rectangle with no dashes.
|
||||
self.untransformed_height = self.feature_map_size[1] * self.cell_width,
|
||||
self.untransformed_width = self.feature_map_size[0] * self.cell_width,
|
||||
# Make the exterior rectangle
|
||||
self.exterior_rectangle = GriddedRectangle(
|
||||
color=self.color,
|
||||
height=self.feature_map_size[1] * self.cell_width,
|
||||
width=self.feature_map_size[0] * self.cell_width,
|
||||
fill_color=self.color,
|
||||
fill_opacity=self.fill_opacity,
|
||||
stroke_color=self.color,
|
||||
stroke_width=self.stroke_width,
|
||||
grid_xstep=self.cell_width,
|
||||
grid_ystep=self.cell_width,
|
||||
grid_stroke_width=self.stroke_width / 2,
|
||||
grid_stroke_color=self.color,
|
||||
show_grid_lines=self.show_grid_lines,
|
||||
)
|
||||
self.add(self.exterior_rectangle)
|
||||
|
||||
def get_corners_dict(self):
|
||||
"""Returns a dictionary of the corners"""
|
||||
# Sort points through clockwise rotation of a vector in the xy plane
|
||||
return self.exterior_rectangle.get_corners_dict()
|
||||
|
||||
class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
"""Handles rendering a convolutional layer for a nn"""
|
||||
@ -24,33 +107,48 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
cell_width=0.2,
|
||||
filter_spacing=0.1,
|
||||
color=BLUE,
|
||||
pulse_color=ORANGE,
|
||||
show_grid_lines=False,
|
||||
active_color=ORANGE,
|
||||
filter_color=ORANGE,
|
||||
show_grid_lines=False,
|
||||
fill_opacity=0.3,
|
||||
stride=1,
|
||||
stroke_width=2.0,
|
||||
activation_function=None,
|
||||
padding=0,
|
||||
padding_dashed=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.num_feature_maps = num_feature_maps
|
||||
self.filter_color = filter_color
|
||||
if isinstance(padding, tuple):
|
||||
assert len(padding) == 2
|
||||
self.padding = padding
|
||||
elif isinstance(padding, int):
|
||||
self.padding = (padding, padding)
|
||||
else:
|
||||
raise Exception(f"Unrecognized type for padding: {type(padding)}")
|
||||
|
||||
if isinstance(feature_map_size, int):
|
||||
self.feature_map_size = (feature_map_size, feature_map_size)
|
||||
else:
|
||||
self.feature_map_size = feature_map_size
|
||||
|
||||
if isinstance(filter_size, int):
|
||||
self.filter_size = (filter_size, filter_size)
|
||||
else:
|
||||
self.filter_size = filter_size
|
||||
|
||||
self.cell_width = cell_width
|
||||
self.filter_spacing = filter_spacing
|
||||
self.color = color
|
||||
self.pulse_color = pulse_color
|
||||
self.active_color = active_color
|
||||
self.stride = stride
|
||||
self.stroke_width = stroke_width
|
||||
self.show_grid_lines = show_grid_lines
|
||||
self.activation_function = activation_function
|
||||
self.fill_opacity = fill_opacity
|
||||
self.padding_dashed = padding_dashed
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
@ -92,12 +190,14 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
# Draw rectangles that are filled in with opacity
|
||||
feature_maps = []
|
||||
for filter_index in range(self.num_feature_maps):
|
||||
rectangle = GriddedRectangle(
|
||||
# Check if we need to add padding
|
||||
"""
|
||||
feature_map = GriddedRectangle(
|
||||
color=self.color,
|
||||
height=self.feature_map_size[1] * self.cell_width,
|
||||
width=self.feature_map_size[0] * self.cell_width,
|
||||
fill_color=self.color,
|
||||
fill_opacity=0.2,
|
||||
fill_opacity=self.fill_opacity,
|
||||
stroke_color=self.color,
|
||||
stroke_width=self.stroke_width,
|
||||
grid_xstep=self.cell_width,
|
||||
@ -106,52 +206,44 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
grid_stroke_color=self.color,
|
||||
show_grid_lines=self.show_grid_lines,
|
||||
)
|
||||
"""
|
||||
# feature_map = GriddedRectangle()
|
||||
feature_map = FeatureMap(
|
||||
color=self.color,
|
||||
feature_map_size=self.feature_map_size,
|
||||
cell_width=self.cell_width,
|
||||
fill_color=self.color,
|
||||
fill_opacity=self.fill_opacity,
|
||||
padding=self.padding,
|
||||
padding_dashed=self.padding_dashed
|
||||
)
|
||||
# Move the feature map
|
||||
rectangle.move_to([0, 0, filter_index * self.filter_spacing])
|
||||
feature_map.move_to([0, 0, filter_index * self.filter_spacing])
|
||||
# rectangle.set_z_index(4)
|
||||
feature_maps.append(rectangle)
|
||||
feature_maps.append(feature_map)
|
||||
|
||||
return VGroup(*feature_maps)
|
||||
|
||||
def highlight_and_unhighlight_feature_maps(self):
|
||||
"""Highlights then unhighlights feature maps"""
|
||||
return Succession(
|
||||
ApplyMethod(self.feature_maps.set_color, self.pulse_color),
|
||||
ApplyMethod(self.feature_maps.set_color, self.active_color),
|
||||
ApplyMethod(self.feature_maps.set_color, self.color),
|
||||
)
|
||||
|
||||
def make_forward_pass_animation(
|
||||
self, run_time=5, corner_pulses=False, layer_args={}, **kwargs
|
||||
self, run_time=5, layer_args={}, **kwargs
|
||||
):
|
||||
"""Convolution forward pass animation"""
|
||||
# Note: most of this animation is done in the Convolution3DToConvolution3D layer
|
||||
if corner_pulses:
|
||||
raise NotImplementedError()
|
||||
passing_flashes = []
|
||||
for line in self.corner_lines:
|
||||
pulse = ShowPassingFlash(
|
||||
line.copy().set_color(self.pulse_color).set_stroke(opacity=1.0),
|
||||
time_width=0.5,
|
||||
run_time=run_time,
|
||||
rate_func=rate_functions.linear,
|
||||
)
|
||||
passing_flashes.append(pulse)
|
||||
|
||||
# per_filter_run_time = run_time / len(self.feature_maps)
|
||||
# Make animation group
|
||||
if not self.activation_function is None:
|
||||
animation_group = AnimationGroup(
|
||||
*passing_flashes,
|
||||
# filter_flashes
|
||||
self.activation_function.make_evaluate_animation(),
|
||||
self.highlight_and_unhighlight_feature_maps(),
|
||||
lag_ratio=0.0,
|
||||
)
|
||||
else:
|
||||
if not self.activation_function is None:
|
||||
animation_group = AnimationGroup(
|
||||
self.activation_function.make_evaluate_animation(),
|
||||
self.highlight_and_unhighlight_feature_maps(),
|
||||
lag_ratio=0.0,
|
||||
)
|
||||
else:
|
||||
animation_group = AnimationGroup()
|
||||
animation_group = AnimationGroup()
|
||||
|
||||
return animation_group
|
||||
|
||||
|
Reference in New Issue
Block a user