mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-17 10:45:54 +08:00
Added residual layer example. Fixed some bugs in the process.
This commit is contained in:
74
examples/basic_neural_network/residual_block.py
Normal file
74
examples/basic_neural_network/residual_block.py
Normal file
@ -0,0 +1,74 @@
|
||||
from manim import *
|
||||
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.math_operation_layer import MathOperationLayer
|
||||
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||
|
||||
# Make the specific scene
|
||||
config.pixel_height = 1200
|
||||
config.pixel_width = 1900
|
||||
config.frame_height = 6.0
|
||||
config.frame_width = 6.0
|
||||
|
||||
def make_code_snippet():
|
||||
code_str = """
|
||||
nn = NeuralNetwork({
|
||||
"feed_forward_1": FeedForwardLayer(3),
|
||||
"feed_forward_2": FeedForwardLayer(3, activation_function="ReLU"),
|
||||
"feed_forward_3": FeedForwardLayer(3),
|
||||
"sum_operation": MathOperationLayer("+", activation_function="ReLU"),
|
||||
})
|
||||
nn.add_connection("feed_forward_1", "sum_operation")
|
||||
self.play(nn.make_forward_pass_animation())
|
||||
"""
|
||||
|
||||
code = Code(
|
||||
code=code_str,
|
||||
tab_width=4,
|
||||
background_stroke_width=1,
|
||||
background_stroke_color=WHITE,
|
||||
insert_line_no=False,
|
||||
style="monokai",
|
||||
# background="window",
|
||||
language="py",
|
||||
)
|
||||
code.scale(0.38)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
class CombinedScene(ThreeDScene):
|
||||
def construct(self):
|
||||
# Add the network
|
||||
nn = NeuralNetwork({
|
||||
"feed_forward_1": FeedForwardLayer(3),
|
||||
"feed_forward_2": FeedForwardLayer(3, activation_function="ReLU"),
|
||||
"feed_forward_3": FeedForwardLayer(3),
|
||||
"sum_operation": MathOperationLayer("+", activation_function="ReLU"),
|
||||
},
|
||||
layer_spacing=0.38
|
||||
)
|
||||
# Make connections
|
||||
input_blank_dot = Dot(
|
||||
nn.input_layers_dict["feed_forward_1"].get_left() - np.array([0.65, 0.0, 0.0])
|
||||
)
|
||||
nn.add_connection(input_blank_dot, "feed_forward_1", arc_direction="straight")
|
||||
nn.add_connection("feed_forward_1", "sum_operation")
|
||||
output_blank_dot = Dot(
|
||||
nn.input_layers_dict["sum_operation"].get_right() + np.array([0.65, 0.0, 0.0])
|
||||
)
|
||||
nn.add_connection("sum_operation", output_blank_dot, arc_direction="straight")
|
||||
# Center the nn
|
||||
nn.move_to(ORIGIN)
|
||||
self.add(nn)
|
||||
# Make code snippet
|
||||
code = make_code_snippet()
|
||||
code.next_to(nn, DOWN)
|
||||
self.add(code)
|
||||
# Group it all
|
||||
group = Group(nn, code)
|
||||
group.move_to(ORIGIN)
|
||||
# Play animation
|
||||
forward_pass = nn.make_forward_pass_animation()
|
||||
self.wait(1)
|
||||
self.play(forward_pass)
|
@ -39,3 +39,4 @@ from manim_ml.neural_network.layers.paired_query import PairedQueryLayer
|
||||
from manim_ml.neural_network.layers.triplet_to_feed_forward import TripletToFeedForward
|
||||
from manim_ml.neural_network.layers.triplet import TripletLayer
|
||||
from manim_ml.neural_network.layers.vector import VectorLayer
|
||||
from manim_ml.neural_network.layers.math_operation_layer import MathOperationLayer
|
@ -31,6 +31,7 @@ from .triplet_to_feed_forward import TripletToFeedForward
|
||||
from .paired_query import PairedQueryLayer
|
||||
from .paired_query_to_feed_forward import PairedQueryToFeedForward
|
||||
from .max_pooling_2d import MaxPooling2DLayer
|
||||
from .feed_forward_to_math_operation import FeedForwardToMathOperation
|
||||
|
||||
connective_layers_list = (
|
||||
EmbeddingToFeedForward,
|
||||
@ -48,4 +49,5 @@ connective_layers_list = (
|
||||
Convolutional2DToMaxPooling2D,
|
||||
MaxPooling2DToConvolutional2D,
|
||||
MaxPooling2DToFeedForward,
|
||||
FeedForwardToMathOperation
|
||||
)
|
||||
|
@ -272,6 +272,16 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
"""Overrides get height function"""
|
||||
return self.feature_maps.length_over_dim(1)
|
||||
|
||||
def move_to(self, mobject_or_point):
|
||||
"""Moves the center of the layer to the given mobject or point"""
|
||||
layer_center = self.feature_maps.get_center()
|
||||
if isinstance(mobject_or_point, Mobject):
|
||||
target_center = mobject_or_point.get_center()
|
||||
else:
|
||||
target_center = mobject_or_point
|
||||
|
||||
self.shift(target_center - layer_center)
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
return FadeIn(self.feature_maps)
|
||||
|
@ -153,3 +153,25 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
|
||||
|
||||
animation_group = AnimationGroup(*animations, lag_ratio=0.0)
|
||||
return animation_group
|
||||
|
||||
def get_height(self):
|
||||
return self.surrounding_rectangle.get_height()
|
||||
|
||||
def get_center(self):
|
||||
return self.surrounding_rectangle.get_center()
|
||||
|
||||
def get_left(self):
|
||||
return self.surrounding_rectangle.get_left()
|
||||
|
||||
def get_right(self):
|
||||
return self.surrounding_rectangle.get_right()
|
||||
|
||||
def move_to(self, mobject_or_point):
|
||||
"""Moves the center of the layer to the given mobject or point"""
|
||||
layer_center = self.surrounding_rectangle.get_center()
|
||||
if isinstance(mobject_or_point, Mobject):
|
||||
target_center = mobject_or_point.get_center()
|
||||
else:
|
||||
target_center = mobject_or_point
|
||||
|
||||
self.shift(target_center - layer_center)
|
@ -90,7 +90,7 @@ class FeedForwardToFeedForward(ConnectiveLayer):
|
||||
if self.passing_flash:
|
||||
copy_edge = edge.copy()
|
||||
anim = ShowPassingFlash(
|
||||
copy_edge.set_color(self.animation_dot_color), time_width=0.2
|
||||
copy_edge.set_color(self.animation_dot_color), time_width=0.3
|
||||
)
|
||||
else:
|
||||
anim = MoveAlongPath(
|
||||
|
@ -0,0 +1,48 @@
|
||||
from manim import *
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
||||
from manim_ml.neural_network.layers.math_operation_layer import MathOperationLayer
|
||||
from manim_ml.utils.mobjects.connections import NetworkConnection
|
||||
|
||||
class FeedForwardToMathOperation(ConnectiveLayer):
|
||||
"""Image Layer to FeedForward layer"""
|
||||
|
||||
input_class = FeedForwardLayer
|
||||
output_class = MathOperationLayer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_layer,
|
||||
output_layer,
|
||||
active_color=ORANGE,
|
||||
**kwargs
|
||||
):
|
||||
self.active_color = active_color
|
||||
super().__init__(input_layer, output_layer, **kwargs)
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
# Draw an arrow from the output of the feed forward layer to the
|
||||
# input of the math operation layer
|
||||
self.connection = NetworkConnection(
|
||||
self.input_layer,
|
||||
self.output_layer,
|
||||
arc_direction="straight",
|
||||
buffer=0.05
|
||||
)
|
||||
self.add(self.connection)
|
||||
|
||||
return super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
|
||||
# Make flashing pass animation on arrow
|
||||
passing_flash = ShowPassingFlash(
|
||||
self.connection.copy().set_color(self.active_color)
|
||||
)
|
||||
|
||||
return passing_flash
|
129
manim_ml/neural_network/layers/math_operation_layer.py
Normal file
129
manim_ml/neural_network/layers/math_operation_layer.py
Normal file
@ -0,0 +1,129 @@
|
||||
from manim import *
|
||||
|
||||
from manim_ml.neural_network.activation_functions import get_activation_function_by_name
|
||||
from manim_ml.neural_network.activation_functions.activation_function import (
|
||||
ActivationFunction,
|
||||
)
|
||||
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
|
||||
|
||||
class MathOperationLayer(VGroupNeuralNetworkLayer):
|
||||
"""Handles rendering a layer for a neural network"""
|
||||
valid_operations = ["+", "-", "*", "/"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
operation_type: str,
|
||||
node_radius=0.5,
|
||||
node_color=BLUE,
|
||||
node_stroke_width=2.0,
|
||||
active_color=ORANGE,
|
||||
activation_function=None,
|
||||
font_size=20,
|
||||
**kwargs
|
||||
):
|
||||
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
|
||||
# Ensure operation type is valid
|
||||
assert operation_type in MathOperationLayer.valid_operations
|
||||
self.operation_type = operation_type
|
||||
self.node_radius = node_radius
|
||||
self.node_color = node_color
|
||||
self.node_stroke_width = node_stroke_width
|
||||
self.active_color = active_color
|
||||
self.font_size = font_size
|
||||
self.activation_function = activation_function
|
||||
|
||||
def construct_layer(
|
||||
self,
|
||||
input_layer: "NeuralNetworkLayer",
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs
|
||||
):
|
||||
"""Creates the neural network layer"""
|
||||
# Draw the operation
|
||||
self.operation_text = Text(
|
||||
self.operation_type,
|
||||
font_size=self.font_size
|
||||
)
|
||||
self.add(self.operation_text)
|
||||
# Make the surrounding circle
|
||||
self.surrounding_circle = Circle(
|
||||
color=self.node_color,
|
||||
stroke_width=self.node_stroke_width
|
||||
).surround(self.operation_text)
|
||||
self.add(self.surrounding_circle)
|
||||
# Make the activation function
|
||||
self.construct_activation_function()
|
||||
super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def construct_activation_function(self):
|
||||
"""Construct the activation function"""
|
||||
# Add the activation function
|
||||
if not self.activation_function is None:
|
||||
# Check if it is a string
|
||||
if isinstance(self.activation_function, str):
|
||||
activation_function = get_activation_function_by_name(
|
||||
self.activation_function
|
||||
)()
|
||||
else:
|
||||
assert isinstance(self.activation_function, ActivationFunction)
|
||||
activation_function = self.activation_function
|
||||
# Plot the function above the rest of the layer
|
||||
self.activation_function = activation_function
|
||||
self.add(self.activation_function)
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Makes the forward pass animation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
layer_args : dict, optional
|
||||
layer specific arguments, by default {}
|
||||
|
||||
Returns
|
||||
-------
|
||||
AnimationGroup
|
||||
Forward pass animation
|
||||
"""
|
||||
# Make highlight animation
|
||||
succession = Succession(
|
||||
ApplyMethod(
|
||||
self.surrounding_circle.set_color,
|
||||
self.active_color,
|
||||
run_time=0.25
|
||||
),
|
||||
Wait(1.0),
|
||||
ApplyMethod(
|
||||
self.surrounding_circle.set_color,
|
||||
self.node_color,
|
||||
run_time=0.25
|
||||
),
|
||||
)
|
||||
# Animate the activation function
|
||||
if not self.activation_function is None:
|
||||
animation_group = AnimationGroup(
|
||||
succession,
|
||||
self.activation_function.make_evaluate_animation(),
|
||||
lag_ratio=0.0,
|
||||
)
|
||||
return animation_group
|
||||
else:
|
||||
return succession
|
||||
|
||||
def get_center(self):
|
||||
return self.surrounding_circle.get_center()
|
||||
|
||||
def get_left(self):
|
||||
return self.surrounding_circle.get_left()
|
||||
|
||||
def get_right(self):
|
||||
return self.surrounding_circle.get_right()
|
||||
|
||||
def move_to(self, mobject_or_point):
|
||||
"""Moves the center of the layer to the given mobject or point"""
|
||||
layer_center = self.surrounding_circle.get_center()
|
||||
if isinstance(mobject_or_point, Mobject):
|
||||
target_center = mobject_or_point.get_center()
|
||||
else:
|
||||
target_center = mobject_or_point
|
||||
|
||||
self.shift(target_center - layer_center)
|
@ -93,20 +93,31 @@ class NeuralNetwork(Group):
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
start_layer_name,
|
||||
end_layer_name,
|
||||
start_mobject_or_name,
|
||||
end_mobject_or_name,
|
||||
connection_style="default",
|
||||
connection_position="bottom",
|
||||
arc_direction="down"
|
||||
):
|
||||
"""Add connection from start layer to end layer"""
|
||||
assert connection_style in ["default"]
|
||||
if connection_style == "default":
|
||||
# Make arrow connection from start layer to end layer
|
||||
# Add the connection
|
||||
if isinstance(start_mobject_or_name, Mobject):
|
||||
input_mobject = start_mobject_or_name
|
||||
else:
|
||||
input_mobject = self.input_layers_dict[start_mobject_or_name]
|
||||
if isinstance(end_mobject_or_name, Mobject):
|
||||
output_mobject = end_mobject_or_name
|
||||
else:
|
||||
output_mobject = self.input_layers_dict[end_mobject_or_name]
|
||||
|
||||
connection = NetworkConnection(
|
||||
self.input_layers_dict[start_layer_name],
|
||||
self.input_layers_dict[end_layer_name],
|
||||
arc_direction="down", # TODO generalize this more
|
||||
input_mobject,
|
||||
output_mobject,
|
||||
arc_direction=arc_direction,
|
||||
buffer=0.05
|
||||
)
|
||||
self.connections.append(connection)
|
||||
self.add(connection)
|
||||
@ -243,7 +254,7 @@ class NeuralNetwork(Group):
|
||||
):
|
||||
"""Generates an animation for feed forward propagation"""
|
||||
all_animations = []
|
||||
per_layer_animations = {}
|
||||
per_layer_animation_map = {}
|
||||
per_layer_runtime = (
|
||||
run_time / len(self.all_layers) if not run_time is None else None
|
||||
)
|
||||
@ -297,11 +308,11 @@ class NeuralNetwork(Group):
|
||||
)
|
||||
all_animations.append(layer_forward_pass)
|
||||
# Add the animation to per layer animation
|
||||
per_layer_animations[layer] = layer_forward_pass
|
||||
per_layer_animation_map[layer] = layer_forward_pass
|
||||
# Make the animation group
|
||||
animation_group = Succession(*all_animations, lag_ratio=1.0)
|
||||
if per_layer_animations:
|
||||
return per_layer_animations
|
||||
return per_layer_animation_map
|
||||
else:
|
||||
return animation_group
|
||||
|
||||
|
@ -15,8 +15,8 @@ class NetworkConnection(VGroup):
|
||||
start_mobject,
|
||||
end_mobject,
|
||||
arc_direction="straight",
|
||||
buffer=0.05,
|
||||
arc_distance=0.3,
|
||||
buffer=0.0,
|
||||
arc_distance=0.2,
|
||||
stroke_width=2.0,
|
||||
color=WHITE,
|
||||
active_color=ORANGE,
|
||||
@ -64,25 +64,51 @@ class NetworkConnection(VGroup):
|
||||
else:
|
||||
right_mobject = self.start_mobject
|
||||
left_mobject = self.end_mobject
|
||||
|
||||
if self.arc_direction == "straight":
|
||||
# Make an arrow
|
||||
arrow_line = Line(
|
||||
left_mobject.get_right() + np.array([self.buffer, 0.0, 0.0]),
|
||||
right_mobject.get_left() + np.array([-1 * self.buffer, 0.0, 0.0]),
|
||||
self.straight_arrow = Arrow(
|
||||
start=left_mobject.get_right() + np.array([self.buffer, 0.0, 0.0]),
|
||||
end=right_mobject.get_left() + np.array([-1 * self.buffer, 0.0, 0.0]),
|
||||
color=WHITE,
|
||||
fill_color=WHITE,
|
||||
stroke_opacity=1.0,
|
||||
buff=0.0,
|
||||
)
|
||||
arrow = Arrow(arrow_line, color=self.color, stroke_width=self.stroke_width)
|
||||
self.straight_arrow = arrow
|
||||
self.add(arrow)
|
||||
self.add(self.straight_arrow)
|
||||
else:
|
||||
# Figure out the direction of the arc
|
||||
direction_vector = NetworkConnection.direction_vector_map[
|
||||
self.arc_direction
|
||||
]
|
||||
# Based on the position of the start and end layer, and direction
|
||||
# figure out how large to make each line
|
||||
# Whichever mobject has a critical point the farthest
|
||||
# distance in the direction_vector direction we will use that end
|
||||
left_mobject_critical_point = left_mobject.get_critical_point(direction_vector)
|
||||
right_mobject_critical_point = right_mobject.get_critical_point(direction_vector)
|
||||
# Take the dot product of each
|
||||
# These dot products correspond to the orthogonal projection
|
||||
# onto the direction vectors
|
||||
left_dot_product = np.dot(
|
||||
left_mobject_critical_point,
|
||||
direction_vector
|
||||
)
|
||||
right_dot_product = np.dot(
|
||||
right_mobject_critical_point,
|
||||
direction_vector
|
||||
)
|
||||
extra_distance = abs(left_dot_product - right_dot_product)
|
||||
# The difference between the dot products
|
||||
if left_dot_product < right_dot_product:
|
||||
right_is_farthest = False
|
||||
else:
|
||||
right_is_farthest = True
|
||||
# Make the start arc piece
|
||||
start_line_start = left_mobject.get_critical_point(direction_vector)
|
||||
start_line_start += direction_vector * self.buffer
|
||||
start_line_end = start_line_start + direction_vector * self.arc_distance
|
||||
if not right_is_farthest:
|
||||
start_line_end = start_line_end + direction_vector * extra_distance
|
||||
self.start_line = Line(
|
||||
start_line_start,
|
||||
start_line_end,
|
||||
@ -93,6 +119,9 @@ class NetworkConnection(VGroup):
|
||||
end_line_end = right_mobject.get_critical_point(direction_vector)
|
||||
end_line_end += direction_vector * self.buffer
|
||||
end_line_start = end_line_end + direction_vector * self.arc_distance
|
||||
if right_is_farthest:
|
||||
end_line_start = end_line_start + direction_vector * extra_distance
|
||||
|
||||
self.end_arrow = Arrow(
|
||||
start=end_line_start,
|
||||
end=end_line_end,
|
||||
|
38
tests/test_ff_residual.py
Normal file
38
tests/test_ff_residual.py
Normal file
@ -0,0 +1,38 @@
|
||||
from manim import *
|
||||
from manim_ml.neural_network import NeuralNetwork, FeedForwardLayer, MathOperationLayer
|
||||
|
||||
# Make the specific scene
|
||||
config.pixel_height = 1200
|
||||
config.pixel_width = 1900
|
||||
config.frame_height = 6.0
|
||||
config.frame_width = 6.0
|
||||
|
||||
class TestFeedForwardResidualNetwork(Scene):
|
||||
|
||||
def construct(self):
|
||||
# Add the network
|
||||
nn = NeuralNetwork({
|
||||
"feed_forward_1": FeedForwardLayer(3),
|
||||
"feed_forward_2": FeedForwardLayer(3, activation_function="ReLU"),
|
||||
"feed_forward_3": FeedForwardLayer(3),
|
||||
"sum_operation": MathOperationLayer("+", activation_function="ReLU"),
|
||||
},
|
||||
layer_spacing=0.38
|
||||
)
|
||||
self.add(nn)
|
||||
# Make connections
|
||||
input_blank_dot = Dot(
|
||||
nn.input_layers_dict["feed_forward_1"].get_left() - np.array([0.65, 0.0, 0.0])
|
||||
)
|
||||
nn.add_connection(input_blank_dot, "feed_forward_1", arc_direction="straight")
|
||||
nn.add_connection("feed_forward_1", "sum_operation")
|
||||
output_blank_dot = Dot(
|
||||
nn.input_layers_dict["sum_operation"].get_right() + np.array([0.65, 0.0, 0.0])
|
||||
)
|
||||
nn.add_connection("sum_operation", output_blank_dot, arc_direction="straight")
|
||||
|
||||
# Make forward pass animation
|
||||
self.play(
|
||||
nn.make_forward_pass_animation()
|
||||
)
|
||||
|
Reference in New Issue
Block a user