mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-09-19 22:34:28 +08:00
Overall working 3D convolution visualization.
This commit is contained in:
@ -7,8 +7,8 @@ class GriddedRectangle(VGroup):
|
|||||||
def __init__(self, color=ORANGE, height=2.0, width=4.0,
|
def __init__(self, color=ORANGE, height=2.0, width=4.0,
|
||||||
mark_paths_closed=True, close_new_points=True,
|
mark_paths_closed=True, close_new_points=True,
|
||||||
grid_xstep=None, grid_ystep=None, grid_stroke_width=0.0, #DEFAULT_STROKE_WIDTH/2,
|
grid_xstep=None, grid_ystep=None, grid_stroke_width=0.0, #DEFAULT_STROKE_WIDTH/2,
|
||||||
grid_stroke_color=None, grid_stroke_opacity=None,
|
grid_stroke_color=ORANGE, grid_stroke_opacity=1.0,
|
||||||
stroke_width=2.0, fill_opacity=0.2, **kwargs):
|
stroke_width=2.0, fill_opacity=0.2, show_grid_lines=False, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Fields
|
# Fields
|
||||||
self.mark_paths_closed = mark_paths_closed
|
self.mark_paths_closed = mark_paths_closed
|
||||||
@ -17,11 +17,12 @@ class GriddedRectangle(VGroup):
|
|||||||
self.grid_ystep = grid_ystep
|
self.grid_ystep = grid_ystep
|
||||||
self.grid_stroke_width = grid_stroke_width
|
self.grid_stroke_width = grid_stroke_width
|
||||||
self.grid_stroke_color = grid_stroke_color
|
self.grid_stroke_color = grid_stroke_color
|
||||||
self.grid_stroke_opacity = grid_stroke_opacity
|
self.grid_stroke_opacity = grid_stroke_opacity if show_grid_lines else 0.0
|
||||||
self.stroke_width = stroke_width
|
self.stroke_width = stroke_width
|
||||||
self.rotation_angles = [0, 0, 0]
|
self.rotation_angles = [0, 0, 0]
|
||||||
self.rectangle_width = width
|
self.rectangle_width = width
|
||||||
self.rectangle_height = height
|
self.rectangle_height = height
|
||||||
|
self.show_grid_lines = show_grid_lines
|
||||||
# Make rectangle
|
# Make rectangle
|
||||||
self.rectangle = Rectangle(
|
self.rectangle = Rectangle(
|
||||||
width=width,
|
width=width,
|
||||||
@ -29,37 +30,64 @@ class GriddedRectangle(VGroup):
|
|||||||
color=color,
|
color=color,
|
||||||
stroke_width=stroke_width,
|
stroke_width=stroke_width,
|
||||||
fill_color=color,
|
fill_color=color,
|
||||||
fill_opacity=fill_opacity
|
fill_opacity=fill_opacity,
|
||||||
)
|
)
|
||||||
self.add(self.rectangle)
|
self.add(self.rectangle)
|
||||||
|
# Make grid lines
|
||||||
|
grid_lines = self.make_grid_lines()
|
||||||
|
self.add(grid_lines)
|
||||||
|
# Make corner rectangles
|
||||||
|
self.corners_dict = self.make_corners_dict()
|
||||||
|
self.add(*self.corners_dict.values())
|
||||||
|
|
||||||
|
def make_corners_dict(self):
|
||||||
|
"""Make corners dictionary"""
|
||||||
|
corners_dict = {
|
||||||
|
"top_right": Dot(
|
||||||
|
self.rectangle.get_corner([1, 1, 0]),
|
||||||
|
fill_opacity=0.0,
|
||||||
|
radius=0.0
|
||||||
|
),
|
||||||
|
"top_left": Dot(
|
||||||
|
self.rectangle.get_corner([-1, 1, 0]),
|
||||||
|
fill_opacity=0.0,
|
||||||
|
radius=0.0
|
||||||
|
),
|
||||||
|
"bottom_left": Dot(
|
||||||
|
self.rectangle.get_corner([-1, -1, 0]),
|
||||||
|
fill_opacity=0.0,
|
||||||
|
radius=0.0
|
||||||
|
),
|
||||||
|
"bottom_right": Dot(
|
||||||
|
self.rectangle.get_corner([1, -1, 0]),
|
||||||
|
fill_opacity=0.0,
|
||||||
|
radius=0.0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return corners_dict
|
||||||
|
|
||||||
def get_corners_dict(self):
|
def get_corners_dict(self):
|
||||||
"""Returns a dictionary of the corners"""
|
"""Returns a dictionary of the corners"""
|
||||||
# Sort points through clockwise rotation of a vector in the xy plane
|
# Sort points through clockwise rotation of a vector in the xy plane
|
||||||
return{
|
return self.corners_dict
|
||||||
"top_right": Dot(self.rectangle.get_corner([1, 1, 0])),
|
|
||||||
"top_left": Dot(self.rectangle.get_corner([-1, 1, 0])),
|
|
||||||
"bottom_left": Dot(self.rectangle.get_corner([-1, -1, 0])),
|
|
||||||
"bottom_right": Dot(self.rectangle.get_corner([1, -1, 0])),
|
|
||||||
}
|
|
||||||
|
|
||||||
def make_grid_lines(self):
|
def make_grid_lines(self):
|
||||||
"""Make grid lines in rectangle"""
|
"""Make grid lines in rectangle"""
|
||||||
grid_lines = VGroup()
|
grid_lines = VGroup()
|
||||||
width = self.width
|
|
||||||
height = self.width
|
|
||||||
|
|
||||||
v = self.inner_rectangle.get_vertices()
|
v = self.rectangle.get_vertices()
|
||||||
if self.grid_xstep is not None:
|
if self.grid_xstep is not None:
|
||||||
grid_xstep = abs(self.grid_xstep)
|
grid_xstep = abs(self.grid_xstep)
|
||||||
count = int(width / grid_xstep)
|
count = int(self.width / grid_xstep)
|
||||||
grid = VGroup(
|
grid = VGroup(
|
||||||
*(
|
*(
|
||||||
Line(
|
Line(
|
||||||
v[1] + i * grid_xstep * RIGHT,
|
v[1] + i * grid_xstep * RIGHT,
|
||||||
v[1] + i * grid_xstep * RIGHT + height * DOWN,
|
v[1] + i * grid_xstep * RIGHT + self.height * DOWN,
|
||||||
color=self.color,
|
stroke_color=self.grid_stroke_color,
|
||||||
stroke_width=self.grid_stroke_width
|
stroke_width=self.grid_stroke_width,
|
||||||
|
stroke_opacity = self.grid_stroke_opacity
|
||||||
)
|
)
|
||||||
for i in range(1, count)
|
for i in range(1, count)
|
||||||
)
|
)
|
||||||
@ -68,14 +96,15 @@ class GriddedRectangle(VGroup):
|
|||||||
|
|
||||||
if self.grid_ystep is not None:
|
if self.grid_ystep is not None:
|
||||||
grid_ystep = abs(self.grid_ystep)
|
grid_ystep = abs(self.grid_ystep)
|
||||||
count = int(height / grid_ystep)
|
count = int(self.height / grid_ystep)
|
||||||
grid = VGroup(
|
grid = VGroup(
|
||||||
*(
|
*(
|
||||||
Line(
|
Line(
|
||||||
v[1] + i * grid_ystep * DOWN,
|
v[1] + i * grid_ystep * DOWN,
|
||||||
v[1] + i * grid_ystep * DOWN + width * RIGHT,
|
v[1] + i * grid_ystep * DOWN + self.width * RIGHT,
|
||||||
color=self.color,
|
stroke_color=self.grid_stroke_color,
|
||||||
stroke_width = self.grid_stroke_width
|
stroke_width = self.grid_stroke_width,
|
||||||
|
stroke_opacity = self.grid_stroke_opacity
|
||||||
)
|
)
|
||||||
for i in range(1, count)
|
for i in range(1, count)
|
||||||
)
|
)
|
||||||
@ -86,3 +115,12 @@ class GriddedRectangle(VGroup):
|
|||||||
|
|
||||||
def get_center(self):
|
def get_center(self):
|
||||||
return self.rectangle.get_center()
|
return self.rectangle.get_center()
|
||||||
|
|
||||||
|
def get_normal_vector(self):
|
||||||
|
vertex_1 = self.rectangle.get_vertices()[0]
|
||||||
|
vertex_2 = self.rectangle.get_vertices()[1]
|
||||||
|
vertex_3 = self.rectangle.get_vertices()[2]
|
||||||
|
# First vector
|
||||||
|
normal_vector = np.cross((vertex_1 - vertex_2), (vertex_1 - vertex_3))
|
||||||
|
|
||||||
|
return normal_vector
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from manim_ml.neural_network.layers.convolutional_3d_to_feed_forward import Convolutional3DToFeedForward
|
||||||
from manim_ml.neural_network.layers.image_to_convolutional3d import ImageToConvolutional3DLayer
|
from manim_ml.neural_network.layers.image_to_convolutional3d import ImageToConvolutional3DLayer
|
||||||
from .convolutional3d_to_convolutional3d import Convolutional3DToConvolutional3D
|
from .convolutional3d_to_convolutional3d import Convolutional3DToConvolutional3D
|
||||||
from .convolutional2d_to_convolutional2d import Convolutional2DToConvolutional2D
|
from .convolutional2d_to_convolutional2d import Convolutional2DToConvolutional2D
|
||||||
@ -32,4 +33,5 @@ connective_layers_list = (
|
|||||||
Convolutional3DToConvolutional3D,
|
Convolutional3DToConvolutional3D,
|
||||||
Convolutional2DToConvolutional2D,
|
Convolutional2DToConvolutional2D,
|
||||||
ImageToConvolutional3DLayer,
|
ImageToConvolutional3DLayer,
|
||||||
|
Convolutional3DToFeedForward
|
||||||
)
|
)
|
||||||
|
@ -62,8 +62,7 @@ class Convolutional2DToConvolutional2D(ConnectiveLayer):
|
|||||||
|
|
||||||
def make_filter_propagation_animation(self):
|
def make_filter_propagation_animation(self):
|
||||||
"""Make filter propagation animation"""
|
"""Make filter propagation animation"""
|
||||||
old_z_index = self.filter_lines.z_index
|
lines_copy = self.filter_lines.copy().set_color(ORANGE)
|
||||||
lines_copy = self.filter_lines.copy().set_color(ORANGE).set_z_index(old_z_index + 1)
|
|
||||||
animation_group = AnimationGroup(
|
animation_group = AnimationGroup(
|
||||||
Create(lines_copy, lag_ratio=0.0),
|
Create(lines_copy, lag_ratio=0.0),
|
||||||
# FadeOut(self.filter_lines),
|
# FadeOut(self.filter_lines),
|
||||||
|
@ -8,7 +8,7 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
|
|
||||||
def __init__(self, num_feature_maps, feature_map_width, feature_map_height,
|
def __init__(self, num_feature_maps, feature_map_width, feature_map_height,
|
||||||
filter_width, filter_height, cell_width=0.2, filter_spacing=0.1, color=BLUE,
|
filter_width, filter_height, cell_width=0.2, filter_spacing=0.1, color=BLUE,
|
||||||
pulse_color=ORANGE, filter_color=ORANGE, stride=1, stroke_width=2.0, **kwargs):
|
pulse_color=ORANGE, show_grid_lines=False, filter_color=ORANGE, stride=1, stroke_width=2.0, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.num_feature_maps = num_feature_maps
|
self.num_feature_maps = num_feature_maps
|
||||||
self.feature_map_height = feature_map_height
|
self.feature_map_height = feature_map_height
|
||||||
@ -22,20 +22,24 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
self.pulse_color = pulse_color
|
self.pulse_color = pulse_color
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.stroke_width = stroke_width
|
self.stroke_width = stroke_width
|
||||||
|
self.show_grid_lines = show_grid_lines
|
||||||
# Make the feature maps
|
# Make the feature maps
|
||||||
self.feature_maps = self.construct_feature_maps()
|
self.feature_maps = self.construct_feature_maps()
|
||||||
self.add(self.feature_maps)
|
self.add(self.feature_maps)
|
||||||
# Rotate stuff properly
|
# Rotate stuff properly
|
||||||
|
# normal_vector = self.feature_maps[0].get_normal_vector()
|
||||||
self.rotate(
|
self.rotate(
|
||||||
ThreeDLayer.three_d_x_rotation,
|
ThreeDLayer.rotation_angle,
|
||||||
about_point=self.get_center(),
|
about_point=self.get_center(),
|
||||||
axis=[1, 0, 0]
|
axis=ThreeDLayer.rotation_axis
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
self.rotate(
|
self.rotate(
|
||||||
ThreeDLayer.three_d_y_rotation,
|
ThreeDLayer.three_d_y_rotation,
|
||||||
about_point=self.get_center(),
|
about_point=self.get_center(),
|
||||||
axis=[0, 1, 0]
|
axis=[0, 1, 0]
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
def construct_feature_maps(self):
|
def construct_feature_maps(self):
|
||||||
"""Creates the neural network layer"""
|
"""Creates the neural network layer"""
|
||||||
@ -50,14 +54,17 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
fill_opacity=0.2,
|
fill_opacity=0.2,
|
||||||
stroke_color=self.color,
|
stroke_color=self.color,
|
||||||
stroke_width=self.stroke_width,
|
stroke_width=self.stroke_width,
|
||||||
# grid_xstep=self.cell_width,
|
grid_xstep=self.cell_width,
|
||||||
# grid_ystep=self.cell_width,
|
grid_ystep=self.cell_width,
|
||||||
# grid_stroke_width=DEFAULT_STROKE_WIDTH/2
|
grid_stroke_width=self.stroke_width/2,
|
||||||
|
grid_stroke_color=self.color,
|
||||||
|
show_grid_lines=self.show_grid_lines,
|
||||||
)
|
)
|
||||||
# Move the feature map
|
# Move the feature map
|
||||||
rectangle.move_to(
|
rectangle.move_to(
|
||||||
[0, 0, filter_index * self.filter_spacing]
|
[0, 0, filter_index * self.filter_spacing]
|
||||||
)
|
)
|
||||||
|
rectangle.set_z_index(4)
|
||||||
feature_maps.append(rectangle)
|
feature_maps.append(rectangle)
|
||||||
|
|
||||||
return VGroup(*feature_maps)
|
return VGroup(*feature_maps)
|
||||||
|
@ -13,19 +13,28 @@ class Filters(VGroup):
|
|||||||
input_layer,
|
input_layer,
|
||||||
output_layer,
|
output_layer,
|
||||||
line_color=ORANGE,
|
line_color=ORANGE,
|
||||||
|
cell_width=1.0,
|
||||||
stroke_width=2.0,
|
stroke_width=2.0,
|
||||||
|
show_grid_lines=False,
|
||||||
|
output_feature_map_to_connect=None # None means all at once
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_layer = input_layer
|
self.input_layer = input_layer
|
||||||
self.output_layer = output_layer
|
self.output_layer = output_layer
|
||||||
self.line_color = line_color
|
self.line_color = line_color
|
||||||
|
self.cell_width = cell_width
|
||||||
self.stroke_width = stroke_width
|
self.stroke_width = stroke_width
|
||||||
|
self.show_grid_lines = show_grid_lines
|
||||||
|
self.output_feature_map_to_connect = output_feature_map_to_connect
|
||||||
# Make the filter
|
# Make the filter
|
||||||
self.input_rectangles = self.make_input_feature_map_rectangles()
|
self.input_rectangles = self.make_input_feature_map_rectangles()
|
||||||
|
# self.input_rectangles.set_z_index(5)
|
||||||
# self.add(self.input_rectangles)
|
# self.add(self.input_rectangles)
|
||||||
self.output_rectangles = self.make_output_feature_map_rectangles()
|
self.output_rectangles = self.make_output_feature_map_rectangles()
|
||||||
|
# self.output_rectangles.set_z_index(5)
|
||||||
# self.add(self.output_rectangles)
|
# self.add(self.output_rectangles)
|
||||||
self.connective_lines = self.make_connective_lines()
|
self.connective_lines = self.make_connective_lines()
|
||||||
|
# self.connective_lines.set_z_index(5)
|
||||||
# self.add(self.connective_lines)
|
# self.add(self.connective_lines)
|
||||||
|
|
||||||
def make_input_feature_map_rectangles(self):
|
def make_input_feature_map_rectangles(self):
|
||||||
@ -42,24 +51,27 @@ class Filters(VGroup):
|
|||||||
fill_color=filter_color,
|
fill_color=filter_color,
|
||||||
stroke_color=filter_color,
|
stroke_color=filter_color,
|
||||||
fill_opacity=0.2,
|
fill_opacity=0.2,
|
||||||
z_index=2,
|
|
||||||
stroke_width=self.stroke_width,
|
stroke_width=self.stroke_width,
|
||||||
|
grid_xstep=self.cell_width,
|
||||||
|
grid_ystep=self.cell_width,
|
||||||
|
grid_stroke_width=self.stroke_width / 2,
|
||||||
|
grid_stroke_color=filter_color,
|
||||||
|
show_grid_lines=self.show_grid_lines,
|
||||||
)
|
)
|
||||||
|
# normal_vector = rectangle.get_normal_vector()
|
||||||
rectangle.rotate(
|
rectangle.rotate(
|
||||||
ThreeDLayer.three_d_x_rotation,
|
ThreeDLayer.rotation_angle,
|
||||||
about_point=rectangle.get_center(),
|
about_point=rectangle.get_center(),
|
||||||
axis=[1, 0, 0]
|
axis=ThreeDLayer.rotation_axis
|
||||||
)
|
|
||||||
rectangle.rotate(
|
|
||||||
ThreeDLayer.three_d_y_rotation,
|
|
||||||
about_point=rectangle.get_center(),
|
|
||||||
axis=[0, 1, 0]
|
|
||||||
)
|
)
|
||||||
# Move the rectangle to the corner of the feature map
|
# Move the rectangle to the corner of the feature map
|
||||||
rectangle.move_to(
|
rectangle.next_to(
|
||||||
feature_map,
|
feature_map.get_corners_dict()["top_left"],
|
||||||
aligned_edge=np.array([-1, 1, 0])
|
submobject_to_align=rectangle.get_corners_dict()["top_left"],
|
||||||
|
buff=0.0
|
||||||
|
# aligned_edge=feature_map.get_corners_dict()["top_left"].get_center()
|
||||||
)
|
)
|
||||||
|
rectangle.set_z_index(5)
|
||||||
|
|
||||||
rectangles.append(rectangle)
|
rectangles.append(rectangle)
|
||||||
|
|
||||||
@ -75,32 +87,36 @@ class Filters(VGroup):
|
|||||||
filter_color = self.output_layer.filter_color
|
filter_color = self.output_layer.filter_color
|
||||||
|
|
||||||
for index, feature_map in enumerate(self.output_layer.feature_maps):
|
for index, feature_map in enumerate(self.output_layer.feature_maps):
|
||||||
|
# Make sure current feature map is the right filte
|
||||||
|
if not self.output_feature_map_to_connect is None:
|
||||||
|
if index != self.output_feature_map_to_connect:
|
||||||
|
continue
|
||||||
|
# Make the rectangle
|
||||||
rectangle = GriddedRectangle(
|
rectangle = GriddedRectangle(
|
||||||
width=rectangle_width,
|
width=rectangle_width,
|
||||||
height=rectangle_height,
|
height=rectangle_height,
|
||||||
fill_color=filter_color,
|
fill_color=filter_color,
|
||||||
stroke_color=filter_color,
|
|
||||||
fill_opacity=0.2,
|
fill_opacity=0.2,
|
||||||
|
stroke_color=filter_color,
|
||||||
stroke_width=self.stroke_width,
|
stroke_width=self.stroke_width,
|
||||||
z_index=2,
|
grid_xstep=self.cell_width,
|
||||||
|
grid_ystep=self.cell_width,
|
||||||
|
grid_stroke_width=self.stroke_width/2,
|
||||||
|
grid_stroke_color=filter_color,
|
||||||
|
show_grid_lines=self.show_grid_lines,
|
||||||
)
|
)
|
||||||
# Center on feature map
|
|
||||||
# rectangle.move_to(feature_map.get_center())
|
|
||||||
# Rotate the rectangle
|
# Rotate the rectangle
|
||||||
rectangle.rotate(
|
rectangle.rotate(
|
||||||
ThreeDLayer.three_d_x_rotation,
|
ThreeDLayer.rotation_angle,
|
||||||
about_point=rectangle.get_center(),
|
about_point=rectangle.get_center(),
|
||||||
axis=[1, 0, 0]
|
axis=ThreeDLayer.rotation_axis
|
||||||
)
|
|
||||||
rectangle.rotate(
|
|
||||||
ThreeDLayer.three_d_y_rotation,
|
|
||||||
about_point=rectangle.get_center(),
|
|
||||||
axis=[0, 1, 0]
|
|
||||||
)
|
)
|
||||||
# Move the rectangle to the corner location
|
# Move the rectangle to the corner location
|
||||||
rectangle.move_to(
|
rectangle.next_to(
|
||||||
feature_map,
|
feature_map.get_corners_dict()["top_left"],
|
||||||
aligned_edge=np.array([-1, 1, 0])
|
submobject_to_align=rectangle.get_corners_dict()["top_left"],
|
||||||
|
buff=0.0
|
||||||
|
# aligned_edge=feature_map.get_corners_dict()["top_left"].get_center()
|
||||||
)
|
)
|
||||||
rectangles.append(rectangle)
|
rectangles.append(rectangle)
|
||||||
|
|
||||||
@ -127,7 +143,7 @@ class Filters(VGroup):
|
|||||||
first_input_corners[corner_name].get_center(),
|
first_input_corners[corner_name].get_center(),
|
||||||
last_input_corners[corner_name].get_center(),
|
last_input_corners[corner_name].get_center(),
|
||||||
color=self.line_color,
|
color=self.line_color,
|
||||||
stroke_width=self.stroke_width
|
stroke_width=self.stroke_width,
|
||||||
)
|
)
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
|
|
||||||
@ -147,7 +163,7 @@ class Filters(VGroup):
|
|||||||
first_output_corners[corner_name].get_center(),
|
first_output_corners[corner_name].get_center(),
|
||||||
last_output_corners[corner_name].get_center(),
|
last_output_corners[corner_name].get_center(),
|
||||||
color=self.line_color,
|
color=self.line_color,
|
||||||
stroke_width=self.stroke_width
|
stroke_width=self.stroke_width,
|
||||||
)
|
)
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
|
|
||||||
@ -155,19 +171,20 @@ class Filters(VGroup):
|
|||||||
|
|
||||||
def make_input_to_output_connective_lines():
|
def make_input_to_output_connective_lines():
|
||||||
"""Make connective lines between last input filter and first output filter"""
|
"""Make connective lines between last input filter and first output filter"""
|
||||||
last_input_rectangle = self.input_rectangles[-1]
|
# Choose the correct feature map to link to
|
||||||
first_output_rectangle = self.output_rectangles[0]
|
input_rectangle = self.input_rectangles[-1]
|
||||||
|
output_rectangle = self.output_rectangles[0]
|
||||||
# Get the corner dots for each rectangle
|
# Get the corner dots for each rectangle
|
||||||
last_input_corners = last_input_rectangle.get_corners_dict()
|
input_corners = input_rectangle.get_corners_dict()
|
||||||
first_output_corners = first_output_rectangle.get_corners_dict()
|
output_corners = output_rectangle.get_corners_dict()
|
||||||
# Iterate through each corner and make the lines
|
# Iterate through each corner and make the lines
|
||||||
lines = []
|
lines = []
|
||||||
for corner_name in corner_names:
|
for corner_name in corner_names:
|
||||||
line = Line(
|
line = Line(
|
||||||
last_input_corners[corner_name].get_center(),
|
input_corners[corner_name].get_center(),
|
||||||
first_output_corners[corner_name].get_center(),
|
output_corners[corner_name].get_center(),
|
||||||
color=self.line_color,
|
color=self.line_color,
|
||||||
stroke_width=self.stroke_width
|
stroke_width=self.stroke_width,
|
||||||
)
|
)
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
|
|
||||||
@ -211,7 +228,23 @@ class Filters(VGroup):
|
|||||||
add_content,
|
add_content,
|
||||||
self
|
self
|
||||||
)
|
)
|
||||||
|
return AnimationGroup(
|
||||||
|
Create(self.input_rectangles),
|
||||||
|
Create(self.connective_lines),
|
||||||
|
Create(self.output_rectangles),
|
||||||
|
lag_ratio=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_pulse_animation(self, shift_amount):
|
||||||
|
"""Make animation of the filter pulsing"""
|
||||||
|
passing_flash = ShowPassingFlash(
|
||||||
|
self.connective_lines.shift(shift_amount).set_stroke_width(self.stroke_width*1.5),
|
||||||
|
time_width=0.2,
|
||||||
|
color=RED,
|
||||||
|
z_index=10
|
||||||
|
)
|
||||||
|
|
||||||
|
return passing_flash
|
||||||
|
|
||||||
class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||||
"""Feed Forward to Embedding Layer"""
|
"""Feed Forward to Embedding Layer"""
|
||||||
@ -219,8 +252,8 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
output_class = Convolutional3DLayer
|
output_class = Convolutional3DLayer
|
||||||
|
|
||||||
def __init__(self, input_layer: Convolutional3DLayer, output_layer: Convolutional3DLayer,
|
def __init__(self, input_layer: Convolutional3DLayer, output_layer: Convolutional3DLayer,
|
||||||
color=WHITE, filter_opacity=0.3, line_color=WHITE,
|
color=ORANGE, filter_opacity=0.3, line_color=ORANGE,
|
||||||
pulse_color=ORANGE, **kwargs):
|
pulse_color=ORANGE, cell_width=0.2, show_grid_lines=True, **kwargs):
|
||||||
super().__init__(input_layer, output_layer, input_class=Convolutional3DLayer,
|
super().__init__(input_layer, output_layer, input_class=Convolutional3DLayer,
|
||||||
output_class=Convolutional3DLayer, **kwargs)
|
output_class=Convolutional3DLayer, **kwargs)
|
||||||
self.color = color
|
self.color = color
|
||||||
@ -234,69 +267,48 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
self.cell_width = self.input_layer.cell_width
|
self.cell_width = self.input_layer.cell_width
|
||||||
self.stride = self.input_layer.stride
|
self.stride = self.input_layer.stride
|
||||||
self.filter_opacity = filter_opacity
|
self.filter_opacity = filter_opacity
|
||||||
|
self.cell_width = cell_width
|
||||||
self.line_color = line_color
|
self.line_color = line_color
|
||||||
self.pulse_color = pulse_color
|
self.pulse_color = pulse_color
|
||||||
|
self.show_grid_lines = show_grid_lines
|
||||||
def make_filter_propagation_animation(self):
|
|
||||||
"""Make filter propagation animation"""
|
|
||||||
# TODO implement this
|
|
||||||
raise NotImplementedError()
|
|
||||||
# Deprecated code
|
|
||||||
old_z_index = self.filter_lines.z_index
|
|
||||||
lines_copy = self.filter_lines.copy().set_color(ORANGE).set_z_index(old_z_index + 1)
|
|
||||||
animation_group = AnimationGroup(
|
|
||||||
Create(lines_copy, lag_ratio=0.0),
|
|
||||||
# FadeOut(self.filter_lines),
|
|
||||||
FadeOut(lines_copy),
|
|
||||||
lag_ratio=1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
return animation_group
|
|
||||||
|
|
||||||
def get_rotated_shift_vectors(self):
|
def get_rotated_shift_vectors(self):
|
||||||
"""
|
"""
|
||||||
Rotates the shift vectors
|
Rotates the shift vectors
|
||||||
"""
|
"""
|
||||||
x_rot_mat = rotation_matrix(
|
|
||||||
ThreeDLayer.three_d_x_rotation,
|
|
||||||
[1, 0, 0]
|
|
||||||
)
|
|
||||||
y_rot_mat = rotation_matrix(
|
|
||||||
ThreeDLayer.three_d_y_rotation,
|
|
||||||
[0, 1, 0]
|
|
||||||
)
|
|
||||||
# Make base shift vectors
|
# Make base shift vectors
|
||||||
right_shift = np.array([self.input_layer.cell_width, 0, 0])
|
right_shift = np.array([self.input_layer.cell_width, 0, 0])
|
||||||
down_shift = np.array([0, -self.input_layer.cell_width, 0])
|
down_shift = np.array([0, -self.input_layer.cell_width, 0])
|
||||||
|
# Make rotation matrix
|
||||||
|
rot_mat = rotation_matrix(
|
||||||
|
ThreeDLayer.rotation_angle,
|
||||||
|
ThreeDLayer.rotation_axis
|
||||||
|
)
|
||||||
# Rotate the vectors
|
# Rotate the vectors
|
||||||
right_shift = np.dot(right_shift, x_rot_mat.T)
|
right_shift = np.dot(right_shift, rot_mat.T)
|
||||||
right_shift = np.dot(right_shift, y_rot_mat.T)
|
down_shift = np.dot(down_shift, rot_mat.T)
|
||||||
down_shift = np.dot(down_shift, x_rot_mat.T)
|
|
||||||
down_shift = np.dot(down_shift, y_rot_mat.T)
|
|
||||||
|
|
||||||
return right_shift, down_shift
|
return right_shift, down_shift
|
||||||
|
|
||||||
def make_forward_pass_animation(self, layer_args={},
|
def animate_filters_all_at_once(self, filters):
|
||||||
all_filters_at_once=False, run_time=10.5, **kwargs):
|
"""Animates each of the filters all at once"""
|
||||||
"""Forward pass animation from conv2d to conv2d"""
|
|
||||||
animations = []
|
animations = []
|
||||||
# Make filters
|
# Make filters
|
||||||
filters = Filters(self.input_layer, self.output_layer)
|
filters = Filters(
|
||||||
filters.set_z_index(self.input_layer.feature_maps[0].get_z_index() + 1)
|
self.input_layer,
|
||||||
# self.add(filters)
|
self.output_layer,
|
||||||
|
line_color=self.color,
|
||||||
|
cell_width=self.cell_width,
|
||||||
|
show_grid_lines=self.show_grid_lines,
|
||||||
|
output_feature_map_to_connect=None # None means all at once
|
||||||
|
)
|
||||||
animations.append(
|
animations.append(
|
||||||
Create(filters)
|
Create(filters)
|
||||||
)
|
)
|
||||||
# Get shift vectors
|
# Get the rotated shift vectors
|
||||||
right_shift, down_shift = self.get_rotated_shift_vectors()
|
right_shift, down_shift = self.get_rotated_shift_vectors()
|
||||||
left_shift = -1 * right_shift
|
left_shift = -1 * right_shift
|
||||||
# filters.rotate(ThreeDLayer.three_d_theta, axis=[0, 0, 1])
|
# Make the animation
|
||||||
# filters.rotate(ThreeDLayer.three_d_phi, axis=-filters.get_center())
|
|
||||||
# Make animations for creating the filters, output_nodes, and filter_lines
|
|
||||||
# TODO decide if I want to create the filters at the start of a conv
|
|
||||||
# animation or have them there by default
|
|
||||||
# Rotate the base shift vectors
|
|
||||||
# Make filter shifting animations
|
|
||||||
num_y_moves = int((self.feature_map_height - self.filter_height) / self.stride)
|
num_y_moves = int((self.feature_map_height - self.filter_height) / self.stride)
|
||||||
num_x_moves = int((self.feature_map_width - self.filter_width) / self.stride)
|
num_x_moves = int((self.feature_map_width - self.filter_width) / self.stride)
|
||||||
for y_move in range(num_y_moves):
|
for y_move in range(num_y_moves):
|
||||||
@ -331,15 +343,93 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
animations.append(
|
animations.append(
|
||||||
FadeOut(filters)
|
FadeOut(filters)
|
||||||
)
|
)
|
||||||
# Remove filters
|
|
||||||
return Succession(
|
return Succession(
|
||||||
*animations,
|
*animations,
|
||||||
lag_ratio=1.0
|
lag_ratio=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_z_index(self, z_index, family=False):
|
def animate_filters_one_at_a_time(self):
|
||||||
"""Override set_z_index"""
|
"""Animates each of the filters one at a time"""
|
||||||
super().set_z_index(4)
|
animations = []
|
||||||
|
output_feature_maps = self.output_layer.feature_maps
|
||||||
|
for filter_index in range(len(output_feature_maps)):
|
||||||
|
# Make filters
|
||||||
|
filters = Filters(
|
||||||
|
self.input_layer,
|
||||||
|
self.output_layer,
|
||||||
|
line_color=self.color,
|
||||||
|
cell_width=self.cell_width,
|
||||||
|
show_grid_lines=self.show_grid_lines,
|
||||||
|
output_feature_map_to_connect=filter_index # None means all at once
|
||||||
|
)
|
||||||
|
animations.append(
|
||||||
|
Create(filters)
|
||||||
|
)
|
||||||
|
# Get the rotated shift vectors
|
||||||
|
right_shift, down_shift = self.get_rotated_shift_vectors()
|
||||||
|
left_shift = -1 * right_shift
|
||||||
|
# Make the animation
|
||||||
|
num_y_moves = int((self.feature_map_height - self.filter_height) / self.stride)
|
||||||
|
num_x_moves = int((self.feature_map_width - self.filter_width) / self.stride)
|
||||||
|
for y_move in range(num_y_moves):
|
||||||
|
# Go right num_x_moves
|
||||||
|
for x_move in range(num_x_moves):
|
||||||
|
# Make a pulse animation for the corners
|
||||||
|
"""
|
||||||
|
pulse_animation = filters.make_pulse_animation(
|
||||||
|
shift_amount=shift_amount
|
||||||
|
)
|
||||||
|
animations.append(pulse_animation)
|
||||||
|
"""
|
||||||
|
z_index_animation = ApplyMethod(
|
||||||
|
filters.set_z_index,
|
||||||
|
5
|
||||||
|
)
|
||||||
|
animations.append(z_index_animation)
|
||||||
|
# Shift right
|
||||||
|
shift_animation = ApplyMethod(
|
||||||
|
filters.shift,
|
||||||
|
self.stride * right_shift
|
||||||
|
)
|
||||||
|
# shift_animation = self.animate.shift(right_shift)
|
||||||
|
animations.append(shift_animation)
|
||||||
|
|
||||||
|
# Go back left num_x_moves and down one
|
||||||
|
shift_amount = self.stride * num_x_moves * left_shift + self.stride * down_shift
|
||||||
|
# Make the animation
|
||||||
|
shift_animation = ApplyMethod(
|
||||||
|
filters.shift,
|
||||||
|
shift_amount
|
||||||
|
)
|
||||||
|
animations.append(shift_animation)
|
||||||
|
# Do last row move right
|
||||||
|
for x_move in range(num_x_moves):
|
||||||
|
# Shift right
|
||||||
|
shift_animation = ApplyMethod(
|
||||||
|
filters.shift,
|
||||||
|
self.stride * right_shift
|
||||||
|
)
|
||||||
|
# shift_animation = self.animate.shift(right_shift)
|
||||||
|
animations.append(shift_animation)
|
||||||
|
# Remove the filters
|
||||||
|
animations.append(
|
||||||
|
FadeOut(filters)
|
||||||
|
)
|
||||||
|
|
||||||
|
return Succession(
|
||||||
|
*animations,
|
||||||
|
lag_ratio=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_forward_pass_animation(self, layer_args={},
|
||||||
|
all_filters_at_once=False, run_time=10.5, **kwargs):
|
||||||
|
"""Forward pass animation from conv2d to conv2d"""
|
||||||
|
print(f"All filters at once: {all_filters_at_once}")
|
||||||
|
# Make filter shifting animations
|
||||||
|
if all_filters_at_once:
|
||||||
|
return self.animate_filters_all_at_once()
|
||||||
|
else:
|
||||||
|
return self.animate_filters_one_at_a_time()
|
||||||
|
|
||||||
def scale(self, scale_factor, **kwargs):
|
def scale(self, scale_factor, **kwargs):
|
||||||
self.cell_width *= scale_factor
|
self.cell_width *= scale_factor
|
||||||
|
@ -0,0 +1,37 @@
|
|||||||
|
from manim import *
|
||||||
|
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer, ThreeDLayer
|
||||||
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||||
|
from manim_ml.neural_network.layers.convolutional3d import Convolutional3DLayer
|
||||||
|
|
||||||
|
class Convolutional3DToFeedForward(ConnectiveLayer, ThreeDLayer):
|
||||||
|
"""Feed Forward to Embedding Layer"""
|
||||||
|
input_class = Convolutional3DLayer
|
||||||
|
output_class = FeedForwardLayer
|
||||||
|
|
||||||
|
def __init__(self, input_layer: Convolutional3DLayer, output_layer: FeedForwardLayer,
|
||||||
|
passing_flash_color=ORANGE, **kwargs):
|
||||||
|
super().__init__(input_layer, output_layer, input_class=Convolutional3DLayer,
|
||||||
|
output_class=Convolutional3DLayer, **kwargs)
|
||||||
|
self.passing_flash_color = passing_flash_color
|
||||||
|
|
||||||
|
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
|
||||||
|
"""Forward pass animation from conv2d to conv2d"""
|
||||||
|
animations = []
|
||||||
|
# Get input layer final feature map
|
||||||
|
final_feature_map = self.input_layer.feature_maps[-1]
|
||||||
|
# Get output layer nodes
|
||||||
|
feed_forward_nodes = self.output_layer.node_group
|
||||||
|
# Go through each corner
|
||||||
|
corners = final_feature_map.get_corners_dict().values()
|
||||||
|
for corner in corners:
|
||||||
|
# Go through each node
|
||||||
|
for node in feed_forward_nodes:
|
||||||
|
line = Line(corner, node, stroke_width=1.0)
|
||||||
|
line.set_z_index(self.output_layer.node_group.get_z_index())
|
||||||
|
anim = ShowPassingFlash(
|
||||||
|
line.set_color(self.passing_flash_color),
|
||||||
|
time_width=0.2
|
||||||
|
)
|
||||||
|
animations.append(anim)
|
||||||
|
|
||||||
|
return AnimationGroup(*animations)
|
@ -48,6 +48,14 @@ class ImageToConvolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
target_feature_map = self.output_layer.feature_maps[0]
|
target_feature_map = self.output_layer.feature_maps[0]
|
||||||
# Map image mobject to feature map
|
# Map image mobject to feature map
|
||||||
# Make rotation of image
|
# Make rotation of image
|
||||||
|
rotation = ApplyMethod(
|
||||||
|
image_mobject.rotate,
|
||||||
|
ThreeDLayer.rotation_angle,
|
||||||
|
ThreeDLayer.rotation_axis,
|
||||||
|
image_mobject.get_center(),
|
||||||
|
run_time=0.5
|
||||||
|
)
|
||||||
|
"""
|
||||||
x_rotation = ApplyMethod(
|
x_rotation = ApplyMethod(
|
||||||
image_mobject.rotate,
|
image_mobject.rotate,
|
||||||
ThreeDLayer.three_d_x_rotation,
|
ThreeDLayer.three_d_x_rotation,
|
||||||
@ -62,6 +70,7 @@ class ImageToConvolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
image_mobject.get_center(),
|
image_mobject.get_center(),
|
||||||
run_time=0.5
|
run_time=0.5
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
# Set opacity
|
# Set opacity
|
||||||
set_opacity = ApplyMethod(
|
set_opacity = ApplyMethod(
|
||||||
image_mobject.set_opacity,
|
image_mobject.set_opacity,
|
||||||
@ -84,42 +93,12 @@ class ImageToConvolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
)
|
)
|
||||||
# Compose the animations
|
# Compose the animations
|
||||||
animation = Succession(
|
animation = Succession(
|
||||||
x_rotation,
|
rotation,
|
||||||
y_rotation,
|
|
||||||
scale_image,
|
scale_image,
|
||||||
set_opacity,
|
set_opacity,
|
||||||
move_image,
|
move_image,
|
||||||
)
|
)
|
||||||
return animation
|
return animation
|
||||||
"""
|
|
||||||
# Make the object 3D by adding it back into camera frame
|
|
||||||
def remove_fixed_func(image_mobject):
|
|
||||||
# self.camera.remove_fixed_orientation_mobjects(image_mobject)
|
|
||||||
# self.camera.remove_fixed_in_frame_mobjects(image_mobject)
|
|
||||||
return image_mobject
|
|
||||||
|
|
||||||
remove_fixed = ApplyFunction(
|
|
||||||
remove_fixed_func,
|
|
||||||
image_mobject
|
|
||||||
)
|
|
||||||
animations.append(remove_fixed)
|
|
||||||
# Make a transformation of the image_mobject to the first feature map
|
|
||||||
input_to_feature_map_transformation = Transform(image_mobject, target_feature_map)
|
|
||||||
animations.append(input_to_feature_map_transformation)
|
|
||||||
# Make the object fixed in 2D again
|
|
||||||
def make_fixed_func(image_mobject):
|
|
||||||
# self.camera.add_fixed_orientation_mobjects(image_mobject)
|
|
||||||
# self.camera.add_fixed_in_frame_mobjects(image_mobject)
|
|
||||||
return image_mobject
|
|
||||||
|
|
||||||
make_fixed = ApplyFunction(
|
|
||||||
make_fixed_func,
|
|
||||||
image_mobject
|
|
||||||
)
|
|
||||||
animations.append(make_fixed)
|
|
||||||
|
|
||||||
return AnimationGroup(*animations)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def scale(self, scale_factor, **kwargs):
|
def scale(self, scale_factor, **kwargs):
|
||||||
super().scale(scale_factor, **kwargs)
|
super().scale(scale_factor, **kwargs)
|
||||||
|
@ -42,8 +42,10 @@ class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
|
|||||||
class ThreeDLayer(ABC):
|
class ThreeDLayer(ABC):
|
||||||
"""Abstract class for 3D layers"""
|
"""Abstract class for 3D layers"""
|
||||||
# Angle of ThreeD layers is static context
|
# Angle of ThreeD layers is static context
|
||||||
three_d_x_rotation = 0 * DEGREES #-90 * DEGREES
|
three_d_x_rotation = 90 * DEGREES #-90 * DEGREES
|
||||||
three_d_y_rotation = 75 * DEGREES # -10 * DEGREES
|
three_d_y_rotation = 0 * DEGREES # -10 * DEGREES
|
||||||
|
rotation_angle = 60 * DEGREES
|
||||||
|
rotation_axis = [0.1, 0.9, 0]
|
||||||
|
|
||||||
class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
||||||
"""Forward pass animation for a given pair of layers"""
|
"""Forward pass animation for a given pair of layers"""
|
||||||
|
@ -17,9 +17,8 @@ class CombinedScene(ThreeDScene):
|
|||||||
image = Image.open('../assets/mnist/digit.jpeg')
|
image = Image.open('../assets/mnist/digit.jpeg')
|
||||||
numpy_image = np.asarray(image)
|
numpy_image = np.asarray(image)
|
||||||
# Make nn
|
# Make nn
|
||||||
nn = NeuralNetwork(
|
nn = NeuralNetwork([
|
||||||
[
|
ImageLayer(numpy_image, height=1.5),
|
||||||
ImageLayer(numpy_image, height=2.0),
|
|
||||||
Convolutional3DLayer(1, 7, 7, 3, 3, filter_spacing=0.32),
|
Convolutional3DLayer(1, 7, 7, 3, 3, filter_spacing=0.32),
|
||||||
Convolutional3DLayer(3, 5, 5, 3, 3, filter_spacing=0.32),
|
Convolutional3DLayer(3, 5, 5, 3, 3, filter_spacing=0.32),
|
||||||
Convolutional3DLayer(5, 3, 3, 1, 1, filter_spacing=0.18),
|
Convolutional3DLayer(5, 3, 3, 1, 1, filter_spacing=0.18),
|
||||||
@ -27,18 +26,19 @@ class CombinedScene(ThreeDScene):
|
|||||||
FeedForwardLayer(3),
|
FeedForwardLayer(3),
|
||||||
],
|
],
|
||||||
layer_spacing=0.25,
|
layer_spacing=0.25,
|
||||||
# camera=self.camera
|
|
||||||
)
|
)
|
||||||
# Center the nn
|
# Center the nn
|
||||||
# self.add(nn)
|
|
||||||
nn.move_to(ORIGIN)
|
nn.move_to(ORIGIN)
|
||||||
|
self.add(nn)
|
||||||
|
"""
|
||||||
self.play(
|
self.play(
|
||||||
FadeIn(nn)
|
FadeIn(nn)
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
# Play animation
|
# Play animation
|
||||||
forward_pass = nn.make_forward_pass_animation(
|
forward_pass = nn.make_forward_pass_animation(
|
||||||
corner_pulses=False,
|
corner_pulses=False,
|
||||||
all_filters_at_once=True
|
all_filters_at_once=False
|
||||||
)
|
)
|
||||||
self.wait(1)
|
self.wait(1)
|
||||||
self.play(
|
self.play(
|
||||||
|
Reference in New Issue
Block a user