Added residual layer example. Fixed some bugs in the process.

This commit is contained in:
Alec Helbling
2023-02-05 12:26:36 -05:00
parent 2b21261db7
commit b1c838a45f
11 changed files with 382 additions and 18 deletions

View File

@ -0,0 +1,74 @@
from manim import *
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.math_operation_layer import MathOperationLayer
from manim_ml.neural_network.neural_network import NeuralNetwork
# Make the specific scene
config.pixel_height = 1200
config.pixel_width = 1900
config.frame_height = 6.0
config.frame_width = 6.0
def make_code_snippet():
code_str = """
nn = NeuralNetwork({
"feed_forward_1": FeedForwardLayer(3),
"feed_forward_2": FeedForwardLayer(3, activation_function="ReLU"),
"feed_forward_3": FeedForwardLayer(3),
"sum_operation": MathOperationLayer("+", activation_function="ReLU"),
})
nn.add_connection("feed_forward_1", "sum_operation")
self.play(nn.make_forward_pass_animation())
"""
code = Code(
code=code_str,
tab_width=4,
background_stroke_width=1,
background_stroke_color=WHITE,
insert_line_no=False,
style="monokai",
# background="window",
language="py",
)
code.scale(0.38)
return code
class CombinedScene(ThreeDScene):
def construct(self):
# Add the network
nn = NeuralNetwork({
"feed_forward_1": FeedForwardLayer(3),
"feed_forward_2": FeedForwardLayer(3, activation_function="ReLU"),
"feed_forward_3": FeedForwardLayer(3),
"sum_operation": MathOperationLayer("+", activation_function="ReLU"),
},
layer_spacing=0.38
)
# Make connections
input_blank_dot = Dot(
nn.input_layers_dict["feed_forward_1"].get_left() - np.array([0.65, 0.0, 0.0])
)
nn.add_connection(input_blank_dot, "feed_forward_1", arc_direction="straight")
nn.add_connection("feed_forward_1", "sum_operation")
output_blank_dot = Dot(
nn.input_layers_dict["sum_operation"].get_right() + np.array([0.65, 0.0, 0.0])
)
nn.add_connection("sum_operation", output_blank_dot, arc_direction="straight")
# Center the nn
nn.move_to(ORIGIN)
self.add(nn)
# Make code snippet
code = make_code_snippet()
code.next_to(nn, DOWN)
self.add(code)
# Group it all
group = Group(nn, code)
group.move_to(ORIGIN)
# Play animation
forward_pass = nn.make_forward_pass_animation()
self.wait(1)
self.play(forward_pass)

View File

@ -39,3 +39,4 @@ from manim_ml.neural_network.layers.paired_query import PairedQueryLayer
from manim_ml.neural_network.layers.triplet_to_feed_forward import TripletToFeedForward
from manim_ml.neural_network.layers.triplet import TripletLayer
from manim_ml.neural_network.layers.vector import VectorLayer
from manim_ml.neural_network.layers.math_operation_layer import MathOperationLayer

View File

@ -31,6 +31,7 @@ from .triplet_to_feed_forward import TripletToFeedForward
from .paired_query import PairedQueryLayer
from .paired_query_to_feed_forward import PairedQueryToFeedForward
from .max_pooling_2d import MaxPooling2DLayer
from .feed_forward_to_math_operation import FeedForwardToMathOperation
connective_layers_list = (
EmbeddingToFeedForward,
@ -48,4 +49,5 @@ connective_layers_list = (
Convolutional2DToMaxPooling2D,
MaxPooling2DToConvolutional2D,
MaxPooling2DToFeedForward,
FeedForwardToMathOperation
)

View File

@ -272,6 +272,16 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
"""Overrides get height function"""
return self.feature_maps.length_over_dim(1)
def move_to(self, mobject_or_point):
"""Moves the center of the layer to the given mobject or point"""
layer_center = self.feature_maps.get_center()
if isinstance(mobject_or_point, Mobject):
target_center = mobject_or_point.get_center()
else:
target_center = mobject_or_point
self.shift(target_center - layer_center)
@override_animation(Create)
def _create_override(self, **kwargs):
return FadeIn(self.feature_maps)

View File

@ -153,3 +153,25 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
animation_group = AnimationGroup(*animations, lag_ratio=0.0)
return animation_group
def get_height(self):
return self.surrounding_rectangle.get_height()
def get_center(self):
return self.surrounding_rectangle.get_center()
def get_left(self):
return self.surrounding_rectangle.get_left()
def get_right(self):
return self.surrounding_rectangle.get_right()
def move_to(self, mobject_or_point):
"""Moves the center of the layer to the given mobject or point"""
layer_center = self.surrounding_rectangle.get_center()
if isinstance(mobject_or_point, Mobject):
target_center = mobject_or_point.get_center()
else:
target_center = mobject_or_point
self.shift(target_center - layer_center)

View File

@ -90,7 +90,7 @@ class FeedForwardToFeedForward(ConnectiveLayer):
if self.passing_flash:
copy_edge = edge.copy()
anim = ShowPassingFlash(
copy_edge.set_color(self.animation_dot_color), time_width=0.2
copy_edge.set_color(self.animation_dot_color), time_width=0.3
)
else:
anim = MoveAlongPath(

View File

@ -0,0 +1,48 @@
from manim import *
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
from manim_ml.neural_network.layers.math_operation_layer import MathOperationLayer
from manim_ml.utils.mobjects.connections import NetworkConnection
class FeedForwardToMathOperation(ConnectiveLayer):
"""Image Layer to FeedForward layer"""
input_class = FeedForwardLayer
output_class = MathOperationLayer
def __init__(
self,
input_layer,
output_layer,
active_color=ORANGE,
**kwargs
):
self.active_color = active_color
super().__init__(input_layer, output_layer, **kwargs)
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
# Draw an arrow from the output of the feed forward layer to the
# input of the math operation layer
self.connection = NetworkConnection(
self.input_layer,
self.output_layer,
arc_direction="straight",
buffer=0.05
)
self.add(self.connection)
return super().construct_layer(input_layer, output_layer, **kwargs)
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
# Make flashing pass animation on arrow
passing_flash = ShowPassingFlash(
self.connection.copy().set_color(self.active_color)
)
return passing_flash

View File

@ -0,0 +1,129 @@
from manim import *
from manim_ml.neural_network.activation_functions import get_activation_function_by_name
from manim_ml.neural_network.activation_functions.activation_function import (
ActivationFunction,
)
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
class MathOperationLayer(VGroupNeuralNetworkLayer):
"""Handles rendering a layer for a neural network"""
valid_operations = ["+", "-", "*", "/"]
def __init__(
self,
operation_type: str,
node_radius=0.5,
node_color=BLUE,
node_stroke_width=2.0,
active_color=ORANGE,
activation_function=None,
font_size=20,
**kwargs
):
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
# Ensure operation type is valid
assert operation_type in MathOperationLayer.valid_operations
self.operation_type = operation_type
self.node_radius = node_radius
self.node_color = node_color
self.node_stroke_width = node_stroke_width
self.active_color = active_color
self.font_size = font_size
self.activation_function = activation_function
def construct_layer(
self,
input_layer: "NeuralNetworkLayer",
output_layer: "NeuralNetworkLayer",
**kwargs
):
"""Creates the neural network layer"""
# Draw the operation
self.operation_text = Text(
self.operation_type,
font_size=self.font_size
)
self.add(self.operation_text)
# Make the surrounding circle
self.surrounding_circle = Circle(
color=self.node_color,
stroke_width=self.node_stroke_width
).surround(self.operation_text)
self.add(self.surrounding_circle)
# Make the activation function
self.construct_activation_function()
super().construct_layer(input_layer, output_layer, **kwargs)
def construct_activation_function(self):
"""Construct the activation function"""
# Add the activation function
if not self.activation_function is None:
# Check if it is a string
if isinstance(self.activation_function, str):
activation_function = get_activation_function_by_name(
self.activation_function
)()
else:
assert isinstance(self.activation_function, ActivationFunction)
activation_function = self.activation_function
# Plot the function above the rest of the layer
self.activation_function = activation_function
self.add(self.activation_function)
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Makes the forward pass animation
Parameters
----------
layer_args : dict, optional
layer specific arguments, by default {}
Returns
-------
AnimationGroup
Forward pass animation
"""
# Make highlight animation
succession = Succession(
ApplyMethod(
self.surrounding_circle.set_color,
self.active_color,
run_time=0.25
),
Wait(1.0),
ApplyMethod(
self.surrounding_circle.set_color,
self.node_color,
run_time=0.25
),
)
# Animate the activation function
if not self.activation_function is None:
animation_group = AnimationGroup(
succession,
self.activation_function.make_evaluate_animation(),
lag_ratio=0.0,
)
return animation_group
else:
return succession
def get_center(self):
return self.surrounding_circle.get_center()
def get_left(self):
return self.surrounding_circle.get_left()
def get_right(self):
return self.surrounding_circle.get_right()
def move_to(self, mobject_or_point):
"""Moves the center of the layer to the given mobject or point"""
layer_center = self.surrounding_circle.get_center()
if isinstance(mobject_or_point, Mobject):
target_center = mobject_or_point.get_center()
else:
target_center = mobject_or_point
self.shift(target_center - layer_center)

View File

@ -93,20 +93,31 @@ class NeuralNetwork(Group):
def add_connection(
self,
start_layer_name,
end_layer_name,
start_mobject_or_name,
end_mobject_or_name,
connection_style="default",
connection_position="bottom",
arc_direction="down"
):
"""Add connection from start layer to end layer"""
assert connection_style in ["default"]
if connection_style == "default":
# Make arrow connection from start layer to end layer
# Add the connection
if isinstance(start_mobject_or_name, Mobject):
input_mobject = start_mobject_or_name
else:
input_mobject = self.input_layers_dict[start_mobject_or_name]
if isinstance(end_mobject_or_name, Mobject):
output_mobject = end_mobject_or_name
else:
output_mobject = self.input_layers_dict[end_mobject_or_name]
connection = NetworkConnection(
self.input_layers_dict[start_layer_name],
self.input_layers_dict[end_layer_name],
arc_direction="down", # TODO generalize this more
input_mobject,
output_mobject,
arc_direction=arc_direction,
buffer=0.05
)
self.connections.append(connection)
self.add(connection)
@ -243,7 +254,7 @@ class NeuralNetwork(Group):
):
"""Generates an animation for feed forward propagation"""
all_animations = []
per_layer_animations = {}
per_layer_animation_map = {}
per_layer_runtime = (
run_time / len(self.all_layers) if not run_time is None else None
)
@ -297,11 +308,11 @@ class NeuralNetwork(Group):
)
all_animations.append(layer_forward_pass)
# Add the animation to per layer animation
per_layer_animations[layer] = layer_forward_pass
per_layer_animation_map[layer] = layer_forward_pass
# Make the animation group
animation_group = Succession(*all_animations, lag_ratio=1.0)
if per_layer_animations:
return per_layer_animations
return per_layer_animation_map
else:
return animation_group

View File

@ -15,8 +15,8 @@ class NetworkConnection(VGroup):
start_mobject,
end_mobject,
arc_direction="straight",
buffer=0.05,
arc_distance=0.3,
buffer=0.0,
arc_distance=0.2,
stroke_width=2.0,
color=WHITE,
active_color=ORANGE,
@ -64,25 +64,51 @@ class NetworkConnection(VGroup):
else:
right_mobject = self.start_mobject
left_mobject = self.end_mobject
if self.arc_direction == "straight":
# Make an arrow
arrow_line = Line(
left_mobject.get_right() + np.array([self.buffer, 0.0, 0.0]),
right_mobject.get_left() + np.array([-1 * self.buffer, 0.0, 0.0]),
self.straight_arrow = Arrow(
start=left_mobject.get_right() + np.array([self.buffer, 0.0, 0.0]),
end=right_mobject.get_left() + np.array([-1 * self.buffer, 0.0, 0.0]),
color=WHITE,
fill_color=WHITE,
stroke_opacity=1.0,
buff=0.0,
)
arrow = Arrow(arrow_line, color=self.color, stroke_width=self.stroke_width)
self.straight_arrow = arrow
self.add(arrow)
self.add(self.straight_arrow)
else:
# Figure out the direction of the arc
direction_vector = NetworkConnection.direction_vector_map[
self.arc_direction
]
# Based on the position of the start and end layer, and direction
# figure out how large to make each line
# Whichever mobject has a critical point the farthest
# distance in the direction_vector direction we will use that end
left_mobject_critical_point = left_mobject.get_critical_point(direction_vector)
right_mobject_critical_point = right_mobject.get_critical_point(direction_vector)
# Take the dot product of each
# These dot products correspond to the orthogonal projection
# onto the direction vectors
left_dot_product = np.dot(
left_mobject_critical_point,
direction_vector
)
right_dot_product = np.dot(
right_mobject_critical_point,
direction_vector
)
extra_distance = abs(left_dot_product - right_dot_product)
# The difference between the dot products
if left_dot_product < right_dot_product:
right_is_farthest = False
else:
right_is_farthest = True
# Make the start arc piece
start_line_start = left_mobject.get_critical_point(direction_vector)
start_line_start += direction_vector * self.buffer
start_line_end = start_line_start + direction_vector * self.arc_distance
if not right_is_farthest:
start_line_end = start_line_end + direction_vector * extra_distance
self.start_line = Line(
start_line_start,
start_line_end,
@ -93,6 +119,9 @@ class NetworkConnection(VGroup):
end_line_end = right_mobject.get_critical_point(direction_vector)
end_line_end += direction_vector * self.buffer
end_line_start = end_line_end + direction_vector * self.arc_distance
if right_is_farthest:
end_line_start = end_line_start + direction_vector * extra_distance
self.end_arrow = Arrow(
start=end_line_start,
end=end_line_end,

38
tests/test_ff_residual.py Normal file
View File

@ -0,0 +1,38 @@
from manim import *
from manim_ml.neural_network import NeuralNetwork, FeedForwardLayer, MathOperationLayer
# Make the specific scene
config.pixel_height = 1200
config.pixel_width = 1900
config.frame_height = 6.0
config.frame_width = 6.0
class TestFeedForwardResidualNetwork(Scene):
def construct(self):
# Add the network
nn = NeuralNetwork({
"feed_forward_1": FeedForwardLayer(3),
"feed_forward_2": FeedForwardLayer(3, activation_function="ReLU"),
"feed_forward_3": FeedForwardLayer(3),
"sum_operation": MathOperationLayer("+", activation_function="ReLU"),
},
layer_spacing=0.38
)
self.add(nn)
# Make connections
input_blank_dot = Dot(
nn.input_layers_dict["feed_forward_1"].get_left() - np.array([0.65, 0.0, 0.0])
)
nn.add_connection(input_blank_dot, "feed_forward_1", arc_direction="straight")
nn.add_connection("feed_forward_1", "sum_operation")
output_blank_dot = Dot(
nn.input_layers_dict["sum_operation"].get_right() + np.array([0.65, 0.0, 0.0])
)
nn.add_connection("sum_operation", output_blank_dot, arc_direction="straight")
# Make forward pass animation
self.play(
nn.make_forward_pass_animation()
)