mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-19 12:49:18 +08:00
Made a workaround to make sure the filters in the CNN 2 CNN layers
don't appear at the beggining of a forward pass animation. I think this has to do with a bug in Succession in the core ManimCommunity library.
This commit is contained in:
@ -22,11 +22,11 @@ class Filters(VGroup):
|
||||
self.stroke_width = stroke_width
|
||||
# Make the filter
|
||||
self.input_rectangles = self.make_input_feature_map_rectangles()
|
||||
self.add(self.input_rectangles)
|
||||
# self.add(self.input_rectangles)
|
||||
self.output_rectangles = self.make_output_feature_map_rectangles()
|
||||
self.add(self.output_rectangles)
|
||||
# self.add(self.output_rectangles)
|
||||
self.connective_lines = self.make_connective_lines()
|
||||
self.add(self.connective_lines)
|
||||
# self.add(self.connective_lines)
|
||||
|
||||
def make_input_feature_map_rectangles(self):
|
||||
rectangles = []
|
||||
@ -185,6 +185,34 @@ class Filters(VGroup):
|
||||
|
||||
return connective_lines
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
"""
|
||||
NOTE This create override animation
|
||||
is a workaround to make sure that the filter
|
||||
does not show up in the scene before the create animation.
|
||||
|
||||
Without this override the filters were shown at the beginning
|
||||
of the neural network forward pass animimation
|
||||
instead of just when the filters were supposed to appear.
|
||||
I think this is a bug with Succession in the core
|
||||
Manim Community Library.
|
||||
|
||||
TODO Fix this
|
||||
"""
|
||||
def add_content(object):
|
||||
object.add(self.input_rectangles)
|
||||
object.add(self.connective_lines)
|
||||
object.add(self.output_rectangles)
|
||||
|
||||
return object
|
||||
|
||||
return ApplyFunction(
|
||||
add_content,
|
||||
self
|
||||
)
|
||||
|
||||
|
||||
class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
input_class = Convolutional3DLayer
|
||||
@ -208,9 +236,6 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
self.filter_opacity = filter_opacity
|
||||
self.line_color = line_color
|
||||
self.pulse_color = pulse_color
|
||||
# Make filters
|
||||
self.filters = Filters(self.input_layer, self.output_layer)
|
||||
self.add(self.filters)
|
||||
|
||||
def make_filter_propagation_animation(self):
|
||||
"""Make filter propagation animation"""
|
||||
@ -251,26 +276,17 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
|
||||
return right_shift, down_shift
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=10.5, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={},
|
||||
all_filters_at_once=False, run_time=10.5, **kwargs):
|
||||
"""Forward pass animation from conv2d to conv2d"""
|
||||
|
||||
animations = []
|
||||
# Rotate given three_d_phi and three_d_theta
|
||||
# Rotate about center
|
||||
# filters.rotate(110 * DEGREES, about_point=filters.get_center(), axis=[0, 0, 1])
|
||||
"""
|
||||
self.filters.rotate(
|
||||
ThreeDLayer.three_d_x_rotation,
|
||||
about_point=self.filters.get_center(),
|
||||
axis=[1, 0, 0]
|
||||
# Make filters
|
||||
filters = Filters(self.input_layer, self.output_layer)
|
||||
filters.set_z_index(self.input_layer.feature_maps[0].get_z_index() + 1)
|
||||
# self.add(filters)
|
||||
animations.append(
|
||||
Create(filters)
|
||||
)
|
||||
|
||||
self.filters.rotate(
|
||||
ThreeDLayer.three_d_y_rotation,
|
||||
about_point=self.filters.get_center(),
|
||||
axis=[0, 1, 0]
|
||||
)
|
||||
"""
|
||||
# Get shift vectors
|
||||
right_shift, down_shift = self.get_rotated_shift_vectors()
|
||||
left_shift = -1 * right_shift
|
||||
@ -279,9 +295,6 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
# Make animations for creating the filters, output_nodes, and filter_lines
|
||||
# TODO decide if I want to create the filters at the start of a conv
|
||||
# animation or have them there by default
|
||||
# animations.append(
|
||||
# Create(filters)
|
||||
# )
|
||||
# Rotate the base shift vectors
|
||||
# Make filter shifting animations
|
||||
num_y_moves = int((self.feature_map_height - self.filter_height) / self.stride)
|
||||
@ -289,10 +302,9 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
for y_move in range(num_y_moves):
|
||||
# Go right num_x_moves
|
||||
for x_move in range(num_x_moves):
|
||||
print(right_shift)
|
||||
# Shift right
|
||||
shift_animation = ApplyMethod(
|
||||
self.filters.shift,
|
||||
filters.shift,
|
||||
self.stride * right_shift
|
||||
)
|
||||
# shift_animation = self.animate.shift(right_shift)
|
||||
@ -302,7 +314,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
shift_amount = self.stride * num_x_moves * left_shift + self.stride * down_shift
|
||||
# Make the animation
|
||||
shift_animation = ApplyMethod(
|
||||
self.filters.shift,
|
||||
filters.shift,
|
||||
shift_amount
|
||||
)
|
||||
animations.append(shift_animation)
|
||||
@ -310,12 +322,15 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
for x_move in range(num_x_moves):
|
||||
# Shift right
|
||||
shift_animation = ApplyMethod(
|
||||
self.filters.shift,
|
||||
filters.shift,
|
||||
self.stride * right_shift
|
||||
)
|
||||
# shift_animation = self.animate.shift(right_shift)
|
||||
animations.append(shift_animation)
|
||||
|
||||
# Remove the filters
|
||||
animations.append(
|
||||
FadeOut(filters)
|
||||
)
|
||||
# Remove filters
|
||||
return Succession(
|
||||
*animations,
|
||||
@ -332,4 +347,4 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
return AnimationGroup()
|
||||
return Succession()
|
||||
|
Reference in New Issue
Block a user