mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-19 04:41:57 +08:00
Reformatted the code using black, allowd for different orientation NNs, made an option for highlighting the active filter in a CNN forward pass.
This commit is contained in:
@ -260,6 +260,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
pulse_color=ORANGE,
|
||||
cell_width=0.2,
|
||||
show_grid_lines=True,
|
||||
highlight_color=ORANGE,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -284,6 +285,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
self.line_color = line_color
|
||||
self.pulse_color = pulse_color
|
||||
self.show_grid_lines = show_grid_lines
|
||||
self.highlight_color = highlight_color
|
||||
|
||||
def get_rotated_shift_vectors(self):
|
||||
"""
|
||||
@ -344,11 +346,11 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
animations.append(FadeOut(filters))
|
||||
return Succession(*animations, lag_ratio=1.0)
|
||||
|
||||
def animate_filters_one_at_a_time(self):
|
||||
def animate_filters_one_at_a_time(self, highlight_active_feature_map=False):
|
||||
"""Animates each of the filters one at a time"""
|
||||
animations = []
|
||||
output_feature_maps = self.output_layer.feature_maps
|
||||
for filter_index in range(len(output_feature_maps)):
|
||||
for feature_map_index in range(len(output_feature_maps)):
|
||||
# Make filters
|
||||
filters = Filters(
|
||||
self.input_layer,
|
||||
@ -356,9 +358,28 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
line_color=self.color,
|
||||
cell_width=self.cell_width,
|
||||
show_grid_lines=self.show_grid_lines,
|
||||
output_feature_map_to_connect=filter_index, # None means all at once
|
||||
output_feature_map_to_connect=feature_map_index, # None means all at once
|
||||
)
|
||||
animations.append(Create(filters))
|
||||
# Highlight the feature map
|
||||
if highlight_active_feature_map:
|
||||
feature_map = output_feature_maps[feature_map_index]
|
||||
original_feature_map_color = feature_map.color
|
||||
# Change the output feature map colors
|
||||
change_color_animations = []
|
||||
change_color_animations.append(
|
||||
ApplyMethod(feature_map.set_color, self.highlight_color)
|
||||
)
|
||||
# Change the input feature map colors
|
||||
input_feature_maps = self.input_layer.feature_maps
|
||||
for input_feature_map in input_feature_maps:
|
||||
change_color_animations.append(
|
||||
ApplyMethod(input_feature_map.set_color, self.highlight_color)
|
||||
)
|
||||
# Combine the animations
|
||||
animations.append(
|
||||
AnimationGroup(*change_color_animations, lag_ratio=0.0)
|
||||
)
|
||||
# Get the rotated shift vectors
|
||||
right_shift, down_shift = self.get_rotated_shift_vectors()
|
||||
left_shift = -1 * right_shift
|
||||
@ -403,11 +424,36 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
animations.append(shift_animation)
|
||||
# Remove the filters
|
||||
animations.append(FadeOut(filters))
|
||||
# Un-highlight the feature map
|
||||
if highlight_active_feature_map:
|
||||
feature_map = output_feature_maps[feature_map_index]
|
||||
# Change the output feature map colors
|
||||
change_color_animations = []
|
||||
change_color_animations.append(
|
||||
ApplyMethod(feature_map.set_color, original_feature_map_color)
|
||||
)
|
||||
# Change the input feature map colors
|
||||
input_feature_maps = self.input_layer.feature_maps
|
||||
for input_feature_map in input_feature_maps:
|
||||
change_color_animations.append(
|
||||
ApplyMethod(
|
||||
input_feature_map.set_color, original_feature_map_color
|
||||
)
|
||||
)
|
||||
# Combine the animations
|
||||
animations.append(
|
||||
AnimationGroup(*change_color_animations, lag_ratio=0.0)
|
||||
)
|
||||
|
||||
return Succession(*animations, lag_ratio=1.0)
|
||||
|
||||
def make_forward_pass_animation(
|
||||
self, layer_args={}, all_filters_at_once=False, run_time=10.5, **kwargs
|
||||
self,
|
||||
layer_args={},
|
||||
all_filters_at_once=False,
|
||||
highlight_active_feature_map=False,
|
||||
run_time=10.5,
|
||||
**kwargs,
|
||||
):
|
||||
"""Forward pass animation from conv2d to conv2d"""
|
||||
print(f"All filters at once: {all_filters_at_once}")
|
||||
@ -415,7 +461,9 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
if all_filters_at_once:
|
||||
return self.animate_filters_all_at_once()
|
||||
else:
|
||||
return self.animate_filters_one_at_a_time()
|
||||
return self.animate_filters_one_at_a_time(
|
||||
highlight_active_feature_map=highlight_active_feature_map
|
||||
)
|
||||
|
||||
def scale(self, scale_factor, **kwargs):
|
||||
self.cell_width *= scale_factor
|
||||
|
Reference in New Issue
Block a user