mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-19 04:41:57 +08:00
Made a workaround to make sure the filters in the CNN 2 CNN layers
don't appear at the beggining of a forward pass animation. I think this has to do with a bug in Succession in the core ManimCommunity library.
This commit is contained in:
@ -8,7 +8,7 @@ from manim_ml.one_to_one_sync import OneToOneSync
|
|||||||
class LeafNode(VGroup):
|
class LeafNode(VGroup):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class SplitNode(VGroup):
|
class NonLeafNode(VGroup):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class DecisionTreeDiagram(Graph):
|
class DecisionTreeDiagram(Graph):
|
||||||
@ -22,5 +22,17 @@ class DecisionTreeEmbedding():
|
|||||||
class DecisionTreeContainer(OneToOneSync):
|
class DecisionTreeContainer(OneToOneSync):
|
||||||
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, sklearn_tree, points, classes):
|
||||||
|
self.sklearn_tree = sklearn_tree
|
||||||
|
self.points = points
|
||||||
|
self.classes = classes
|
||||||
|
|
||||||
|
def make_unfold_tree_animation(self):
|
||||||
|
"""Unfolds the tree through an in order traversal
|
||||||
|
|
||||||
|
This animations unfolds the tree diagram as well as showing the splitting
|
||||||
|
of a shaded region in the Decision Tree embedding.
|
||||||
|
"""
|
||||||
|
# Draw points in the embedding
|
||||||
|
# Start the tree splitting animation
|
||||||
pass
|
pass
|
||||||
|
0
manim_ml/flow/__init__.py
Normal file
0
manim_ml/flow/__init__.py
Normal file
17
manim_ml/flow/flow.py
Normal file
17
manim_ml/flow/flow.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
Animated flow charts.
|
||||||
|
"""
|
||||||
|
from manim import *
|
||||||
|
|
||||||
|
class FlowGraph(VGroup):
|
||||||
|
"""Graph container"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class FlowNode(VGroup):
|
||||||
|
"""Node in the FlowGraph"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DataNode(FlowNode):
|
||||||
|
"""Node that outputs data"""
|
||||||
|
pass
|
||||||
|
|
@ -40,7 +40,7 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
def construct_feature_maps(self):
|
def construct_feature_maps(self):
|
||||||
"""Creates the neural network layer"""
|
"""Creates the neural network layer"""
|
||||||
# Draw rectangles that are filled in with opacity
|
# Draw rectangles that are filled in with opacity
|
||||||
feature_maps = VGroup()
|
feature_maps = []
|
||||||
for filter_index in range(self.num_feature_maps):
|
for filter_index in range(self.num_feature_maps):
|
||||||
rectangle = GriddedRectangle(
|
rectangle = GriddedRectangle(
|
||||||
color=self.color,
|
color=self.color,
|
||||||
@ -54,19 +54,13 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
# grid_ystep=self.cell_width,
|
# grid_ystep=self.cell_width,
|
||||||
# grid_stroke_width=DEFAULT_STROKE_WIDTH/2
|
# grid_stroke_width=DEFAULT_STROKE_WIDTH/2
|
||||||
)
|
)
|
||||||
|
# Move the feature map
|
||||||
rectangle.move_to(
|
rectangle.move_to(
|
||||||
[0, 0, filter_index * self.filter_spacing]
|
[0, 0, filter_index * self.filter_spacing]
|
||||||
)
|
)
|
||||||
# Rotate about z axis
|
feature_maps.append(rectangle)
|
||||||
"""
|
|
||||||
rectangle.rotate_about_origin(
|
return VGroup(*feature_maps)
|
||||||
90 * DEGREES,
|
|
||||||
np.array([0, 1, 0])
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
feature_maps.add(rectangle)
|
|
||||||
|
|
||||||
return feature_maps
|
|
||||||
|
|
||||||
def make_forward_pass_animation(
|
def make_forward_pass_animation(
|
||||||
self,
|
self,
|
||||||
@ -79,6 +73,7 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|||||||
# Note: most of this animation is done in the Convolution3DToConvolution3D layer
|
# Note: most of this animation is done in the Convolution3DToConvolution3D layer
|
||||||
print(f"Corner pulses: {corner_pulses}")
|
print(f"Corner pulses: {corner_pulses}")
|
||||||
if corner_pulses:
|
if corner_pulses:
|
||||||
|
raise NotImplementedError()
|
||||||
passing_flashes = []
|
passing_flashes = []
|
||||||
for line in self.corner_lines:
|
for line in self.corner_lines:
|
||||||
pulse = ShowPassingFlash(
|
pulse = ShowPassingFlash(
|
||||||
|
@ -22,11 +22,11 @@ class Filters(VGroup):
|
|||||||
self.stroke_width = stroke_width
|
self.stroke_width = stroke_width
|
||||||
# Make the filter
|
# Make the filter
|
||||||
self.input_rectangles = self.make_input_feature_map_rectangles()
|
self.input_rectangles = self.make_input_feature_map_rectangles()
|
||||||
self.add(self.input_rectangles)
|
# self.add(self.input_rectangles)
|
||||||
self.output_rectangles = self.make_output_feature_map_rectangles()
|
self.output_rectangles = self.make_output_feature_map_rectangles()
|
||||||
self.add(self.output_rectangles)
|
# self.add(self.output_rectangles)
|
||||||
self.connective_lines = self.make_connective_lines()
|
self.connective_lines = self.make_connective_lines()
|
||||||
self.add(self.connective_lines)
|
# self.add(self.connective_lines)
|
||||||
|
|
||||||
def make_input_feature_map_rectangles(self):
|
def make_input_feature_map_rectangles(self):
|
||||||
rectangles = []
|
rectangles = []
|
||||||
@ -185,6 +185,34 @@ class Filters(VGroup):
|
|||||||
|
|
||||||
return connective_lines
|
return connective_lines
|
||||||
|
|
||||||
|
@override_animation(Create)
|
||||||
|
def _create_override(self, **kwargs):
|
||||||
|
"""
|
||||||
|
NOTE This create override animation
|
||||||
|
is a workaround to make sure that the filter
|
||||||
|
does not show up in the scene before the create animation.
|
||||||
|
|
||||||
|
Without this override the filters were shown at the beginning
|
||||||
|
of the neural network forward pass animimation
|
||||||
|
instead of just when the filters were supposed to appear.
|
||||||
|
I think this is a bug with Succession in the core
|
||||||
|
Manim Community Library.
|
||||||
|
|
||||||
|
TODO Fix this
|
||||||
|
"""
|
||||||
|
def add_content(object):
|
||||||
|
object.add(self.input_rectangles)
|
||||||
|
object.add(self.connective_lines)
|
||||||
|
object.add(self.output_rectangles)
|
||||||
|
|
||||||
|
return object
|
||||||
|
|
||||||
|
return ApplyFunction(
|
||||||
|
add_content,
|
||||||
|
self
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||||
"""Feed Forward to Embedding Layer"""
|
"""Feed Forward to Embedding Layer"""
|
||||||
input_class = Convolutional3DLayer
|
input_class = Convolutional3DLayer
|
||||||
@ -208,9 +236,6 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
self.filter_opacity = filter_opacity
|
self.filter_opacity = filter_opacity
|
||||||
self.line_color = line_color
|
self.line_color = line_color
|
||||||
self.pulse_color = pulse_color
|
self.pulse_color = pulse_color
|
||||||
# Make filters
|
|
||||||
self.filters = Filters(self.input_layer, self.output_layer)
|
|
||||||
self.add(self.filters)
|
|
||||||
|
|
||||||
def make_filter_propagation_animation(self):
|
def make_filter_propagation_animation(self):
|
||||||
"""Make filter propagation animation"""
|
"""Make filter propagation animation"""
|
||||||
@ -251,26 +276,17 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
|
|
||||||
return right_shift, down_shift
|
return right_shift, down_shift
|
||||||
|
|
||||||
def make_forward_pass_animation(self, layer_args={}, run_time=10.5, **kwargs):
|
def make_forward_pass_animation(self, layer_args={},
|
||||||
|
all_filters_at_once=False, run_time=10.5, **kwargs):
|
||||||
"""Forward pass animation from conv2d to conv2d"""
|
"""Forward pass animation from conv2d to conv2d"""
|
||||||
|
|
||||||
animations = []
|
animations = []
|
||||||
# Rotate given three_d_phi and three_d_theta
|
# Make filters
|
||||||
# Rotate about center
|
filters = Filters(self.input_layer, self.output_layer)
|
||||||
# filters.rotate(110 * DEGREES, about_point=filters.get_center(), axis=[0, 0, 1])
|
filters.set_z_index(self.input_layer.feature_maps[0].get_z_index() + 1)
|
||||||
"""
|
# self.add(filters)
|
||||||
self.filters.rotate(
|
animations.append(
|
||||||
ThreeDLayer.three_d_x_rotation,
|
Create(filters)
|
||||||
about_point=self.filters.get_center(),
|
|
||||||
axis=[1, 0, 0]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.filters.rotate(
|
|
||||||
ThreeDLayer.three_d_y_rotation,
|
|
||||||
about_point=self.filters.get_center(),
|
|
||||||
axis=[0, 1, 0]
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
# Get shift vectors
|
# Get shift vectors
|
||||||
right_shift, down_shift = self.get_rotated_shift_vectors()
|
right_shift, down_shift = self.get_rotated_shift_vectors()
|
||||||
left_shift = -1 * right_shift
|
left_shift = -1 * right_shift
|
||||||
@ -279,9 +295,6 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
# Make animations for creating the filters, output_nodes, and filter_lines
|
# Make animations for creating the filters, output_nodes, and filter_lines
|
||||||
# TODO decide if I want to create the filters at the start of a conv
|
# TODO decide if I want to create the filters at the start of a conv
|
||||||
# animation or have them there by default
|
# animation or have them there by default
|
||||||
# animations.append(
|
|
||||||
# Create(filters)
|
|
||||||
# )
|
|
||||||
# Rotate the base shift vectors
|
# Rotate the base shift vectors
|
||||||
# Make filter shifting animations
|
# Make filter shifting animations
|
||||||
num_y_moves = int((self.feature_map_height - self.filter_height) / self.stride)
|
num_y_moves = int((self.feature_map_height - self.filter_height) / self.stride)
|
||||||
@ -289,10 +302,9 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
for y_move in range(num_y_moves):
|
for y_move in range(num_y_moves):
|
||||||
# Go right num_x_moves
|
# Go right num_x_moves
|
||||||
for x_move in range(num_x_moves):
|
for x_move in range(num_x_moves):
|
||||||
print(right_shift)
|
|
||||||
# Shift right
|
# Shift right
|
||||||
shift_animation = ApplyMethod(
|
shift_animation = ApplyMethod(
|
||||||
self.filters.shift,
|
filters.shift,
|
||||||
self.stride * right_shift
|
self.stride * right_shift
|
||||||
)
|
)
|
||||||
# shift_animation = self.animate.shift(right_shift)
|
# shift_animation = self.animate.shift(right_shift)
|
||||||
@ -302,7 +314,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
shift_amount = self.stride * num_x_moves * left_shift + self.stride * down_shift
|
shift_amount = self.stride * num_x_moves * left_shift + self.stride * down_shift
|
||||||
# Make the animation
|
# Make the animation
|
||||||
shift_animation = ApplyMethod(
|
shift_animation = ApplyMethod(
|
||||||
self.filters.shift,
|
filters.shift,
|
||||||
shift_amount
|
shift_amount
|
||||||
)
|
)
|
||||||
animations.append(shift_animation)
|
animations.append(shift_animation)
|
||||||
@ -310,12 +322,15 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
for x_move in range(num_x_moves):
|
for x_move in range(num_x_moves):
|
||||||
# Shift right
|
# Shift right
|
||||||
shift_animation = ApplyMethod(
|
shift_animation = ApplyMethod(
|
||||||
self.filters.shift,
|
filters.shift,
|
||||||
self.stride * right_shift
|
self.stride * right_shift
|
||||||
)
|
)
|
||||||
# shift_animation = self.animate.shift(right_shift)
|
# shift_animation = self.animate.shift(right_shift)
|
||||||
animations.append(shift_animation)
|
animations.append(shift_animation)
|
||||||
|
# Remove the filters
|
||||||
|
animations.append(
|
||||||
|
FadeOut(filters)
|
||||||
|
)
|
||||||
# Remove filters
|
# Remove filters
|
||||||
return Succession(
|
return Succession(
|
||||||
*animations,
|
*animations,
|
||||||
@ -332,4 +347,4 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
|||||||
|
|
||||||
@override_animation(Create)
|
@override_animation(Create)
|
||||||
def _create_override(self, **kwargs):
|
def _create_override(self, **kwargs):
|
||||||
return AnimationGroup()
|
return Succession()
|
||||||
|
@ -43,7 +43,7 @@ class ImageLayer(NeuralNetworkLayer):
|
|||||||
return AnimationGroup()
|
return AnimationGroup()
|
||||||
|
|
||||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||||
return FadeIn(self.image_mobject)
|
return AnimationGroup()
|
||||||
|
|
||||||
# def move_to(self, location):
|
# def move_to(self, location):
|
||||||
# """Override of move to"""
|
# """Override of move to"""
|
||||||
|
@ -19,26 +19,28 @@ class CombinedScene(ThreeDScene):
|
|||||||
# Make nn
|
# Make nn
|
||||||
nn = NeuralNetwork(
|
nn = NeuralNetwork(
|
||||||
[
|
[
|
||||||
ImageLayer(numpy_image, height=1.4),
|
ImageLayer(numpy_image, height=2.0),
|
||||||
Convolutional3DLayer(1, 7, 7, 3, 3, filter_spacing=0.2),
|
Convolutional3DLayer(1, 7, 7, 3, 3, filter_spacing=0.32),
|
||||||
Convolutional3DLayer(3, 5, 5, 3, 3, filter_spacing=0.2),
|
Convolutional3DLayer(3, 5, 5, 3, 3, filter_spacing=0.32),
|
||||||
Convolutional3DLayer(5, 3, 3, 1, 1, filter_spacing=0.2),
|
Convolutional3DLayer(5, 3, 3, 1, 1, filter_spacing=0.18),
|
||||||
FeedForwardLayer(3, rectangle_stroke_width=4, node_stroke_width=4),
|
FeedForwardLayer(3),
|
||||||
FeedForwardLayer(3, rectangle_stroke_width=4, node_stroke_width=4),
|
FeedForwardLayer(3),
|
||||||
],
|
],
|
||||||
layer_spacing=0.5,
|
layer_spacing=0.25,
|
||||||
# camera=self.camera
|
# camera=self.camera
|
||||||
)
|
)
|
||||||
# Center the nn
|
# Center the nn
|
||||||
self.add(nn)
|
# self.add(nn)
|
||||||
nn.move_to(ORIGIN)
|
nn.move_to(ORIGIN)
|
||||||
|
self.play(
|
||||||
|
FadeIn(nn)
|
||||||
|
)
|
||||||
# Play animation
|
# Play animation
|
||||||
forward_pass = nn.make_forward_pass_animation(
|
forward_pass = nn.make_forward_pass_animation(
|
||||||
corner_pulses=False,
|
corner_pulses=False,
|
||||||
layer_args={
|
all_filters_at_once=True
|
||||||
"all_filters_at_once": True
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
self.wait(1)
|
||||||
self.play(
|
self.play(
|
||||||
forward_pass
|
forward_pass
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user