mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-17 19:49:18 +08:00
475 lines
19 KiB
Python
475 lines
19 KiB
Python
from manim import *
|
|
from manim_ml.neural_network.layers.convolutional3d import Convolutional3DLayer
|
|
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
|
|
from manim_ml.gridded_rectangle import GriddedRectangle
|
|
|
|
from manim.utils.space_ops import rotation_matrix
|
|
|
|
|
|
class Filters(VGroup):
|
|
"""Group for showing a collection of filters connecting two layers"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_layer,
|
|
output_layer,
|
|
line_color=ORANGE,
|
|
cell_width=1.0,
|
|
stroke_width=2.0,
|
|
show_grid_lines=False,
|
|
output_feature_map_to_connect=None, # None means all at once
|
|
):
|
|
super().__init__()
|
|
self.input_layer = input_layer
|
|
self.output_layer = output_layer
|
|
self.line_color = line_color
|
|
self.cell_width = cell_width
|
|
self.stroke_width = stroke_width
|
|
self.show_grid_lines = show_grid_lines
|
|
self.output_feature_map_to_connect = output_feature_map_to_connect
|
|
# Make the filter
|
|
self.input_rectangles = self.make_input_feature_map_rectangles()
|
|
# self.input_rectangles.set_z_index(5)
|
|
# self.add(self.input_rectangles)
|
|
self.output_rectangles = self.make_output_feature_map_rectangles()
|
|
# self.output_rectangles.set_z_index(5)
|
|
# self.add(self.output_rectangles)
|
|
self.connective_lines = self.make_connective_lines()
|
|
# self.connective_lines.set_z_index(5)
|
|
# self.add(self.connective_lines)
|
|
|
|
def make_input_feature_map_rectangles(self):
|
|
rectangles = []
|
|
|
|
rectangle_width = self.input_layer.filter_width * self.input_layer.cell_width
|
|
rectangle_height = self.input_layer.filter_height * self.input_layer.cell_width
|
|
filter_color = self.input_layer.filter_color
|
|
|
|
for index, feature_map in enumerate(self.input_layer.feature_maps):
|
|
rectangle = GriddedRectangle(
|
|
width=rectangle_width,
|
|
height=rectangle_height,
|
|
fill_color=filter_color,
|
|
stroke_color=filter_color,
|
|
fill_opacity=0.2,
|
|
stroke_width=self.stroke_width,
|
|
grid_xstep=self.cell_width,
|
|
grid_ystep=self.cell_width,
|
|
grid_stroke_width=self.stroke_width / 2,
|
|
grid_stroke_color=filter_color,
|
|
show_grid_lines=self.show_grid_lines,
|
|
)
|
|
# normal_vector = rectangle.get_normal_vector()
|
|
rectangle.rotate(
|
|
ThreeDLayer.rotation_angle,
|
|
about_point=rectangle.get_center(),
|
|
axis=ThreeDLayer.rotation_axis,
|
|
)
|
|
# Move the rectangle to the corner of the feature map
|
|
rectangle.next_to(
|
|
feature_map.get_corners_dict()["top_left"],
|
|
submobject_to_align=rectangle.get_corners_dict()["top_left"],
|
|
buff=0.0
|
|
# aligned_edge=feature_map.get_corners_dict()["top_left"].get_center()
|
|
)
|
|
rectangle.set_z_index(5)
|
|
|
|
rectangles.append(rectangle)
|
|
|
|
feature_map_rectangles = VGroup(*rectangles)
|
|
|
|
return feature_map_rectangles
|
|
|
|
def make_output_feature_map_rectangles(self):
|
|
rectangles = []
|
|
|
|
rectangle_width = self.output_layer.cell_width
|
|
rectangle_height = self.output_layer.cell_width
|
|
filter_color = self.output_layer.filter_color
|
|
|
|
for index, feature_map in enumerate(self.output_layer.feature_maps):
|
|
# Make sure current feature map is the right filte
|
|
if not self.output_feature_map_to_connect is None:
|
|
if index != self.output_feature_map_to_connect:
|
|
continue
|
|
# Make the rectangle
|
|
rectangle = GriddedRectangle(
|
|
width=rectangle_width,
|
|
height=rectangle_height,
|
|
fill_color=filter_color,
|
|
fill_opacity=0.2,
|
|
stroke_color=filter_color,
|
|
stroke_width=self.stroke_width,
|
|
grid_xstep=self.cell_width,
|
|
grid_ystep=self.cell_width,
|
|
grid_stroke_width=self.stroke_width / 2,
|
|
grid_stroke_color=filter_color,
|
|
show_grid_lines=self.show_grid_lines,
|
|
)
|
|
# Rotate the rectangle
|
|
rectangle.rotate(
|
|
ThreeDLayer.rotation_angle,
|
|
about_point=rectangle.get_center(),
|
|
axis=ThreeDLayer.rotation_axis,
|
|
)
|
|
# Move the rectangle to the corner location
|
|
rectangle.next_to(
|
|
feature_map.get_corners_dict()["top_left"],
|
|
submobject_to_align=rectangle.get_corners_dict()["top_left"],
|
|
buff=0.0
|
|
# aligned_edge=feature_map.get_corners_dict()["top_left"].get_center()
|
|
)
|
|
rectangles.append(rectangle)
|
|
|
|
feature_map_rectangles = VGroup(*rectangles)
|
|
|
|
return feature_map_rectangles
|
|
|
|
def make_connective_lines(self):
|
|
"""Lines connecting input filter with output node"""
|
|
|
|
corner_names = ["top_left", "bottom_left", "top_right", "bottom_right"]
|
|
|
|
def make_input_connective_lines():
|
|
"""Makes connective lines between the corners of the input filters"""
|
|
first_input_rectangle = self.input_rectangles[0]
|
|
last_input_rectangle = self.input_rectangles[-1]
|
|
# Get the corner dots for each rectangle
|
|
first_input_corners = first_input_rectangle.get_corners_dict()
|
|
last_input_corners = last_input_rectangle.get_corners_dict()
|
|
# Iterate through each corner and make the lines
|
|
lines = []
|
|
for corner_name in corner_names:
|
|
line = Line(
|
|
first_input_corners[corner_name].get_center(),
|
|
last_input_corners[corner_name].get_center(),
|
|
color=self.line_color,
|
|
stroke_width=self.stroke_width,
|
|
)
|
|
lines.append(line)
|
|
|
|
return VGroup(*lines)
|
|
|
|
def make_output_connective_lines():
|
|
"""Makes connective lines between the corners of the output filters"""
|
|
first_output_rectangle = self.output_rectangles[0]
|
|
last_output_rectangle = self.output_rectangles[-1]
|
|
# Get the corner dots for each rectangle
|
|
first_output_corners = first_output_rectangle.get_corners_dict()
|
|
last_output_corners = last_output_rectangle.get_corners_dict()
|
|
# Iterate through each corner and make the lines
|
|
lines = []
|
|
for corner_name in corner_names:
|
|
line = Line(
|
|
first_output_corners[corner_name].get_center(),
|
|
last_output_corners[corner_name].get_center(),
|
|
color=self.line_color,
|
|
stroke_width=self.stroke_width,
|
|
)
|
|
lines.append(line)
|
|
|
|
return VGroup(*lines)
|
|
|
|
def make_input_to_output_connective_lines():
|
|
"""Make connective lines between last input filter and first output filter"""
|
|
# Choose the correct feature map to link to
|
|
input_rectangle = self.input_rectangles[-1]
|
|
output_rectangle = self.output_rectangles[0]
|
|
# Get the corner dots for each rectangle
|
|
input_corners = input_rectangle.get_corners_dict()
|
|
output_corners = output_rectangle.get_corners_dict()
|
|
# Iterate through each corner and make the lines
|
|
lines = []
|
|
for corner_name in corner_names:
|
|
line = Line(
|
|
input_corners[corner_name].get_center(),
|
|
output_corners[corner_name].get_center(),
|
|
color=self.line_color,
|
|
stroke_width=self.stroke_width,
|
|
)
|
|
lines.append(line)
|
|
|
|
return VGroup(*lines)
|
|
|
|
input_lines = make_input_connective_lines()
|
|
output_lines = make_output_connective_lines()
|
|
input_output_lines = make_input_to_output_connective_lines()
|
|
|
|
connective_lines = VGroup(*input_lines, *output_lines, *input_output_lines)
|
|
|
|
return connective_lines
|
|
|
|
@override_animation(Create)
|
|
def _create_override(self, **kwargs):
|
|
"""
|
|
NOTE This create override animation
|
|
is a workaround to make sure that the filter
|
|
does not show up in the scene before the create animation.
|
|
|
|
Without this override the filters were shown at the beginning
|
|
of the neural network forward pass animimation
|
|
instead of just when the filters were supposed to appear.
|
|
I think this is a bug with Succession in the core
|
|
Manim Community Library.
|
|
|
|
TODO Fix this
|
|
"""
|
|
|
|
def add_content(object):
|
|
object.add(self.input_rectangles)
|
|
object.add(self.connective_lines)
|
|
object.add(self.output_rectangles)
|
|
|
|
return object
|
|
|
|
return ApplyFunction(add_content, self)
|
|
return AnimationGroup(
|
|
Create(self.input_rectangles),
|
|
Create(self.connective_lines),
|
|
Create(self.output_rectangles),
|
|
lag_ratio=0.0,
|
|
)
|
|
|
|
def make_pulse_animation(self, shift_amount):
|
|
"""Make animation of the filter pulsing"""
|
|
passing_flash = ShowPassingFlash(
|
|
self.connective_lines.shift(shift_amount).set_stroke_width(
|
|
self.stroke_width * 1.5
|
|
),
|
|
time_width=0.2,
|
|
color=RED,
|
|
z_index=10,
|
|
)
|
|
|
|
return passing_flash
|
|
|
|
|
|
class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|
"""Feed Forward to Embedding Layer"""
|
|
|
|
input_class = Convolutional3DLayer
|
|
output_class = Convolutional3DLayer
|
|
|
|
def __init__(
|
|
self,
|
|
input_layer: Convolutional3DLayer,
|
|
output_layer: Convolutional3DLayer,
|
|
color=ORANGE,
|
|
filter_opacity=0.3,
|
|
line_color=ORANGE,
|
|
pulse_color=ORANGE,
|
|
cell_width=0.2,
|
|
show_grid_lines=True,
|
|
highlight_color=ORANGE,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
input_layer,
|
|
output_layer,
|
|
input_class=Convolutional3DLayer,
|
|
output_class=Convolutional3DLayer,
|
|
**kwargs,
|
|
)
|
|
self.color = color
|
|
self.filter_color = self.input_layer.filter_color
|
|
self.filter_width = self.input_layer.filter_width
|
|
self.filter_height = self.input_layer.filter_height
|
|
self.feature_map_width = self.input_layer.feature_map_width
|
|
self.feature_map_height = self.input_layer.feature_map_height
|
|
self.num_input_feature_maps = self.input_layer.num_feature_maps
|
|
self.num_output_feature_maps = self.output_layer.num_feature_maps
|
|
self.cell_width = self.input_layer.cell_width
|
|
self.stride = self.input_layer.stride
|
|
self.filter_opacity = filter_opacity
|
|
self.cell_width = cell_width
|
|
self.line_color = line_color
|
|
self.pulse_color = pulse_color
|
|
self.show_grid_lines = show_grid_lines
|
|
self.highlight_color = highlight_color
|
|
|
|
def get_rotated_shift_vectors(self):
|
|
"""
|
|
Rotates the shift vectors
|
|
"""
|
|
# Make base shift vectors
|
|
right_shift = np.array([self.input_layer.cell_width, 0, 0])
|
|
down_shift = np.array([0, -self.input_layer.cell_width, 0])
|
|
# Make rotation matrix
|
|
rot_mat = rotation_matrix(ThreeDLayer.rotation_angle, ThreeDLayer.rotation_axis)
|
|
# Rotate the vectors
|
|
right_shift = np.dot(right_shift, rot_mat.T)
|
|
down_shift = np.dot(down_shift, rot_mat.T)
|
|
|
|
return right_shift, down_shift
|
|
|
|
def animate_filters_all_at_once(self, filters):
|
|
"""Animates each of the filters all at once"""
|
|
animations = []
|
|
# Make filters
|
|
filters = Filters(
|
|
self.input_layer,
|
|
self.output_layer,
|
|
line_color=self.color,
|
|
cell_width=self.cell_width,
|
|
show_grid_lines=self.show_grid_lines,
|
|
output_feature_map_to_connect=None, # None means all at once
|
|
)
|
|
animations.append(Create(filters))
|
|
# Get the rotated shift vectors
|
|
right_shift, down_shift = self.get_rotated_shift_vectors()
|
|
left_shift = -1 * right_shift
|
|
# Make the animation
|
|
num_y_moves = int((self.feature_map_height - self.filter_height) / self.stride)
|
|
num_x_moves = int((self.feature_map_width - self.filter_width) / self.stride)
|
|
for y_move in range(num_y_moves):
|
|
# Go right num_x_moves
|
|
for x_move in range(num_x_moves):
|
|
# Shift right
|
|
shift_animation = ApplyMethod(filters.shift, self.stride * right_shift)
|
|
# shift_animation = self.animate.shift(right_shift)
|
|
animations.append(shift_animation)
|
|
|
|
# Go back left num_x_moves and down one
|
|
shift_amount = (
|
|
self.stride * num_x_moves * left_shift + self.stride * down_shift
|
|
)
|
|
# Make the animation
|
|
shift_animation = ApplyMethod(filters.shift, shift_amount)
|
|
animations.append(shift_animation)
|
|
# Do last row move right
|
|
for x_move in range(num_x_moves):
|
|
# Shift right
|
|
shift_animation = ApplyMethod(filters.shift, self.stride * right_shift)
|
|
# shift_animation = self.animate.shift(right_shift)
|
|
animations.append(shift_animation)
|
|
# Remove the filters
|
|
animations.append(FadeOut(filters))
|
|
return Succession(*animations, lag_ratio=1.0)
|
|
|
|
def animate_filters_one_at_a_time(self, highlight_active_feature_map=False):
|
|
"""Animates each of the filters one at a time"""
|
|
animations = []
|
|
output_feature_maps = self.output_layer.feature_maps
|
|
for feature_map_index in range(len(output_feature_maps)):
|
|
# Make filters
|
|
filters = Filters(
|
|
self.input_layer,
|
|
self.output_layer,
|
|
line_color=self.color,
|
|
cell_width=self.cell_width,
|
|
show_grid_lines=self.show_grid_lines,
|
|
output_feature_map_to_connect=feature_map_index, # None means all at once
|
|
)
|
|
animations.append(Create(filters))
|
|
# Highlight the feature map
|
|
if highlight_active_feature_map:
|
|
feature_map = output_feature_maps[feature_map_index]
|
|
original_feature_map_color = feature_map.color
|
|
# Change the output feature map colors
|
|
change_color_animations = []
|
|
change_color_animations.append(
|
|
ApplyMethod(feature_map.set_color, self.highlight_color)
|
|
)
|
|
# Change the input feature map colors
|
|
input_feature_maps = self.input_layer.feature_maps
|
|
for input_feature_map in input_feature_maps:
|
|
change_color_animations.append(
|
|
ApplyMethod(input_feature_map.set_color, self.highlight_color)
|
|
)
|
|
# Combine the animations
|
|
animations.append(
|
|
AnimationGroup(*change_color_animations, lag_ratio=0.0)
|
|
)
|
|
# Get the rotated shift vectors
|
|
right_shift, down_shift = self.get_rotated_shift_vectors()
|
|
left_shift = -1 * right_shift
|
|
# Make the animation
|
|
num_y_moves = int(
|
|
(self.feature_map_height - self.filter_height) / self.stride
|
|
)
|
|
num_x_moves = int(
|
|
(self.feature_map_width - self.filter_width) / self.stride
|
|
)
|
|
for y_move in range(num_y_moves):
|
|
# Go right num_x_moves
|
|
for x_move in range(num_x_moves):
|
|
# Make a pulse animation for the corners
|
|
"""
|
|
pulse_animation = filters.make_pulse_animation(
|
|
shift_amount=shift_amount
|
|
)
|
|
animations.append(pulse_animation)
|
|
"""
|
|
z_index_animation = ApplyMethod(filters.set_z_index, 5)
|
|
animations.append(z_index_animation)
|
|
# Shift right
|
|
shift_animation = ApplyMethod(
|
|
filters.shift, self.stride * right_shift
|
|
)
|
|
# shift_animation = self.animate.shift(right_shift)
|
|
animations.append(shift_animation)
|
|
|
|
# Go back left num_x_moves and down one
|
|
shift_amount = (
|
|
self.stride * num_x_moves * left_shift + self.stride * down_shift
|
|
)
|
|
# Make the animation
|
|
shift_animation = ApplyMethod(filters.shift, shift_amount)
|
|
animations.append(shift_animation)
|
|
# Do last row move right
|
|
for x_move in range(num_x_moves):
|
|
# Shift right
|
|
shift_animation = ApplyMethod(filters.shift, self.stride * right_shift)
|
|
# shift_animation = self.animate.shift(right_shift)
|
|
animations.append(shift_animation)
|
|
# Remove the filters
|
|
animations.append(FadeOut(filters))
|
|
# Un-highlight the feature map
|
|
if highlight_active_feature_map:
|
|
feature_map = output_feature_maps[feature_map_index]
|
|
# Change the output feature map colors
|
|
change_color_animations = []
|
|
change_color_animations.append(
|
|
ApplyMethod(feature_map.set_color, original_feature_map_color)
|
|
)
|
|
# Change the input feature map colors
|
|
input_feature_maps = self.input_layer.feature_maps
|
|
for input_feature_map in input_feature_maps:
|
|
change_color_animations.append(
|
|
ApplyMethod(
|
|
input_feature_map.set_color, original_feature_map_color
|
|
)
|
|
)
|
|
# Combine the animations
|
|
animations.append(
|
|
AnimationGroup(*change_color_animations, lag_ratio=0.0)
|
|
)
|
|
|
|
return Succession(*animations, lag_ratio=1.0)
|
|
|
|
def make_forward_pass_animation(
|
|
self,
|
|
layer_args={},
|
|
all_filters_at_once=False,
|
|
highlight_active_feature_map=False,
|
|
run_time=10.5,
|
|
**kwargs,
|
|
):
|
|
"""Forward pass animation from conv2d to conv2d"""
|
|
print(f"All filters at once: {all_filters_at_once}")
|
|
# Make filter shifting animations
|
|
if all_filters_at_once:
|
|
return self.animate_filters_all_at_once()
|
|
else:
|
|
return self.animate_filters_one_at_a_time(
|
|
highlight_active_feature_map=highlight_active_feature_map
|
|
)
|
|
|
|
def scale(self, scale_factor, **kwargs):
|
|
self.cell_width *= scale_factor
|
|
super().scale(scale_factor, **kwargs)
|
|
|
|
@override_animation(Create)
|
|
def _create_override(self, **kwargs):
|
|
return Succession()
|