mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-19 04:41:57 +08:00
Made a workaround to make sure the filters in the CNN 2 CNN layers
don't appear at the beggining of a forward pass animation. I think this has to do with a bug in Succession in the core ManimCommunity library.
This commit is contained in:
@ -8,7 +8,7 @@ from manim_ml.one_to_one_sync import OneToOneSync
|
||||
class LeafNode(VGroup):
|
||||
pass
|
||||
|
||||
class SplitNode(VGroup):
|
||||
class NonLeafNode(VGroup):
|
||||
pass
|
||||
|
||||
class DecisionTreeDiagram(Graph):
|
||||
@ -22,5 +22,17 @@ class DecisionTreeEmbedding():
|
||||
class DecisionTreeContainer(OneToOneSync):
|
||||
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, sklearn_tree, points, classes):
|
||||
self.sklearn_tree = sklearn_tree
|
||||
self.points = points
|
||||
self.classes = classes
|
||||
|
||||
def make_unfold_tree_animation(self):
|
||||
"""Unfolds the tree through an in order traversal
|
||||
|
||||
This animations unfolds the tree diagram as well as showing the splitting
|
||||
of a shaded region in the Decision Tree embedding.
|
||||
"""
|
||||
# Draw points in the embedding
|
||||
# Start the tree splitting animation
|
||||
pass
|
||||
|
0
manim_ml/flow/__init__.py
Normal file
0
manim_ml/flow/__init__.py
Normal file
17
manim_ml/flow/flow.py
Normal file
17
manim_ml/flow/flow.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""
|
||||
Animated flow charts.
|
||||
"""
|
||||
from manim import *
|
||||
|
||||
class FlowGraph(VGroup):
|
||||
"""Graph container"""
|
||||
pass
|
||||
|
||||
class FlowNode(VGroup):
|
||||
"""Node in the FlowGraph"""
|
||||
pass
|
||||
|
||||
class DataNode(FlowNode):
|
||||
"""Node that outputs data"""
|
||||
pass
|
||||
|
@ -40,7 +40,7 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
def construct_feature_maps(self):
|
||||
"""Creates the neural network layer"""
|
||||
# Draw rectangles that are filled in with opacity
|
||||
feature_maps = VGroup()
|
||||
feature_maps = []
|
||||
for filter_index in range(self.num_feature_maps):
|
||||
rectangle = GriddedRectangle(
|
||||
color=self.color,
|
||||
@ -54,19 +54,13 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
# grid_ystep=self.cell_width,
|
||||
# grid_stroke_width=DEFAULT_STROKE_WIDTH/2
|
||||
)
|
||||
# Move the feature map
|
||||
rectangle.move_to(
|
||||
[0, 0, filter_index * self.filter_spacing]
|
||||
)
|
||||
# Rotate about z axis
|
||||
"""
|
||||
rectangle.rotate_about_origin(
|
||||
90 * DEGREES,
|
||||
np.array([0, 1, 0])
|
||||
)
|
||||
"""
|
||||
feature_maps.add(rectangle)
|
||||
feature_maps.append(rectangle)
|
||||
|
||||
return feature_maps
|
||||
return VGroup(*feature_maps)
|
||||
|
||||
def make_forward_pass_animation(
|
||||
self,
|
||||
@ -79,6 +73,7 @@ class Convolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
# Note: most of this animation is done in the Convolution3DToConvolution3D layer
|
||||
print(f"Corner pulses: {corner_pulses}")
|
||||
if corner_pulses:
|
||||
raise NotImplementedError()
|
||||
passing_flashes = []
|
||||
for line in self.corner_lines:
|
||||
pulse = ShowPassingFlash(
|
||||
|
@ -22,11 +22,11 @@ class Filters(VGroup):
|
||||
self.stroke_width = stroke_width
|
||||
# Make the filter
|
||||
self.input_rectangles = self.make_input_feature_map_rectangles()
|
||||
self.add(self.input_rectangles)
|
||||
# self.add(self.input_rectangles)
|
||||
self.output_rectangles = self.make_output_feature_map_rectangles()
|
||||
self.add(self.output_rectangles)
|
||||
# self.add(self.output_rectangles)
|
||||
self.connective_lines = self.make_connective_lines()
|
||||
self.add(self.connective_lines)
|
||||
# self.add(self.connective_lines)
|
||||
|
||||
def make_input_feature_map_rectangles(self):
|
||||
rectangles = []
|
||||
@ -185,6 +185,34 @@ class Filters(VGroup):
|
||||
|
||||
return connective_lines
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
"""
|
||||
NOTE This create override animation
|
||||
is a workaround to make sure that the filter
|
||||
does not show up in the scene before the create animation.
|
||||
|
||||
Without this override the filters were shown at the beginning
|
||||
of the neural network forward pass animimation
|
||||
instead of just when the filters were supposed to appear.
|
||||
I think this is a bug with Succession in the core
|
||||
Manim Community Library.
|
||||
|
||||
TODO Fix this
|
||||
"""
|
||||
def add_content(object):
|
||||
object.add(self.input_rectangles)
|
||||
object.add(self.connective_lines)
|
||||
object.add(self.output_rectangles)
|
||||
|
||||
return object
|
||||
|
||||
return ApplyFunction(
|
||||
add_content,
|
||||
self
|
||||
)
|
||||
|
||||
|
||||
class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
"""Feed Forward to Embedding Layer"""
|
||||
input_class = Convolutional3DLayer
|
||||
@ -208,9 +236,6 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
self.filter_opacity = filter_opacity
|
||||
self.line_color = line_color
|
||||
self.pulse_color = pulse_color
|
||||
# Make filters
|
||||
self.filters = Filters(self.input_layer, self.output_layer)
|
||||
self.add(self.filters)
|
||||
|
||||
def make_filter_propagation_animation(self):
|
||||
"""Make filter propagation animation"""
|
||||
@ -251,26 +276,17 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
|
||||
return right_shift, down_shift
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=10.5, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={},
|
||||
all_filters_at_once=False, run_time=10.5, **kwargs):
|
||||
"""Forward pass animation from conv2d to conv2d"""
|
||||
|
||||
animations = []
|
||||
# Rotate given three_d_phi and three_d_theta
|
||||
# Rotate about center
|
||||
# filters.rotate(110 * DEGREES, about_point=filters.get_center(), axis=[0, 0, 1])
|
||||
"""
|
||||
self.filters.rotate(
|
||||
ThreeDLayer.three_d_x_rotation,
|
||||
about_point=self.filters.get_center(),
|
||||
axis=[1, 0, 0]
|
||||
# Make filters
|
||||
filters = Filters(self.input_layer, self.output_layer)
|
||||
filters.set_z_index(self.input_layer.feature_maps[0].get_z_index() + 1)
|
||||
# self.add(filters)
|
||||
animations.append(
|
||||
Create(filters)
|
||||
)
|
||||
|
||||
self.filters.rotate(
|
||||
ThreeDLayer.three_d_y_rotation,
|
||||
about_point=self.filters.get_center(),
|
||||
axis=[0, 1, 0]
|
||||
)
|
||||
"""
|
||||
# Get shift vectors
|
||||
right_shift, down_shift = self.get_rotated_shift_vectors()
|
||||
left_shift = -1 * right_shift
|
||||
@ -279,9 +295,6 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
# Make animations for creating the filters, output_nodes, and filter_lines
|
||||
# TODO decide if I want to create the filters at the start of a conv
|
||||
# animation or have them there by default
|
||||
# animations.append(
|
||||
# Create(filters)
|
||||
# )
|
||||
# Rotate the base shift vectors
|
||||
# Make filter shifting animations
|
||||
num_y_moves = int((self.feature_map_height - self.filter_height) / self.stride)
|
||||
@ -289,10 +302,9 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
for y_move in range(num_y_moves):
|
||||
# Go right num_x_moves
|
||||
for x_move in range(num_x_moves):
|
||||
print(right_shift)
|
||||
# Shift right
|
||||
shift_animation = ApplyMethod(
|
||||
self.filters.shift,
|
||||
filters.shift,
|
||||
self.stride * right_shift
|
||||
)
|
||||
# shift_animation = self.animate.shift(right_shift)
|
||||
@ -302,7 +314,7 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
shift_amount = self.stride * num_x_moves * left_shift + self.stride * down_shift
|
||||
# Make the animation
|
||||
shift_animation = ApplyMethod(
|
||||
self.filters.shift,
|
||||
filters.shift,
|
||||
shift_amount
|
||||
)
|
||||
animations.append(shift_animation)
|
||||
@ -310,12 +322,15 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
for x_move in range(num_x_moves):
|
||||
# Shift right
|
||||
shift_animation = ApplyMethod(
|
||||
self.filters.shift,
|
||||
filters.shift,
|
||||
self.stride * right_shift
|
||||
)
|
||||
# shift_animation = self.animate.shift(right_shift)
|
||||
animations.append(shift_animation)
|
||||
|
||||
# Remove the filters
|
||||
animations.append(
|
||||
FadeOut(filters)
|
||||
)
|
||||
# Remove filters
|
||||
return Succession(
|
||||
*animations,
|
||||
@ -332,4 +347,4 @@ class Convolutional3DToConvolutional3D(ConnectiveLayer, ThreeDLayer):
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
return AnimationGroup()
|
||||
return Succession()
|
||||
|
@ -43,7 +43,7 @@ class ImageLayer(NeuralNetworkLayer):
|
||||
return AnimationGroup()
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
return FadeIn(self.image_mobject)
|
||||
return AnimationGroup()
|
||||
|
||||
# def move_to(self, location):
|
||||
# """Override of move to"""
|
||||
|
@ -19,26 +19,28 @@ class CombinedScene(ThreeDScene):
|
||||
# Make nn
|
||||
nn = NeuralNetwork(
|
||||
[
|
||||
ImageLayer(numpy_image, height=1.4),
|
||||
Convolutional3DLayer(1, 7, 7, 3, 3, filter_spacing=0.2),
|
||||
Convolutional3DLayer(3, 5, 5, 3, 3, filter_spacing=0.2),
|
||||
Convolutional3DLayer(5, 3, 3, 1, 1, filter_spacing=0.2),
|
||||
FeedForwardLayer(3, rectangle_stroke_width=4, node_stroke_width=4),
|
||||
FeedForwardLayer(3, rectangle_stroke_width=4, node_stroke_width=4),
|
||||
ImageLayer(numpy_image, height=2.0),
|
||||
Convolutional3DLayer(1, 7, 7, 3, 3, filter_spacing=0.32),
|
||||
Convolutional3DLayer(3, 5, 5, 3, 3, filter_spacing=0.32),
|
||||
Convolutional3DLayer(5, 3, 3, 1, 1, filter_spacing=0.18),
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(3),
|
||||
],
|
||||
layer_spacing=0.5,
|
||||
layer_spacing=0.25,
|
||||
# camera=self.camera
|
||||
)
|
||||
# Center the nn
|
||||
self.add(nn)
|
||||
# self.add(nn)
|
||||
nn.move_to(ORIGIN)
|
||||
self.play(
|
||||
FadeIn(nn)
|
||||
)
|
||||
# Play animation
|
||||
forward_pass = nn.make_forward_pass_animation(
|
||||
corner_pulses=False,
|
||||
layer_args={
|
||||
"all_filters_at_once": True
|
||||
}
|
||||
all_filters_at_once=True
|
||||
)
|
||||
self.wait(1)
|
||||
self.play(
|
||||
forward_pass
|
||||
)
|
||||
|
Reference in New Issue
Block a user