Added the ability to make residual connections.

Note: still need to add the residual plus icon.
This commit is contained in:
Alec Helbling
2023-02-01 12:40:43 -05:00
parent 4b06ce1622
commit 27d235de25
8 changed files with 358 additions and 4 deletions

View File

@ -0,0 +1,66 @@
from manim import *
from PIL import Image
import numpy as np
from manim_ml.neural_network import Convolutional2DLayer, NeuralNetwork
# Make the specific scene
config.pixel_height = 1200
config.pixel_width = 1900
config.frame_height = 6.0
config.frame_width = 6.0
def make_code_snippet():
code_str = """
# Make the neural network
nn = NeuralNetwork({
"layer1": Convolutional2DLayer(1, 5, padding=1),
"layer2": Convolutional2DLayer(1, 5, 3, padding=1),
"layer3": Convolutional2DLayer(1, 5, 3, padding=1)
})
# Add the residual connection
nn.add_connection("layer1", "layer3")
# Make the animation
self.play(nn.make_forward_pass_animation())
"""
code = Code(
code=code_str,
tab_width=4,
background_stroke_width=1,
background_stroke_color=WHITE,
insert_line_no=False,
style="monokai",
# background="window",
language="py",
)
code.scale(0.38)
return code
class ConvScene(ThreeDScene):
def construct(self):
image = Image.open("../../assets/mnist/digit.jpeg")
numpy_image = np.asarray(image)
nn = NeuralNetwork({
"layer1": Convolutional2DLayer(1, 5, padding=1),
"layer2": Convolutional2DLayer(1, 5, 3, padding=1),
"layer3": Convolutional2DLayer(1, 5, 3, padding=1),
},
layer_spacing=0.25,
)
nn.add_connection("layer1", "layer3")
self.add(nn)
code = make_code_snippet()
code.next_to(nn, DOWN)
self.add(code)
Group(code, nn).move_to(ORIGIN)
self.play(
nn.make_forward_pass_animation(),
run_time=8
)

View File

@ -11,6 +11,7 @@ Example:
"""
import textwrap
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
from manim_ml.utils.mobjects.connections import NetworkConnection
import numpy as np
from manim import *
@ -38,7 +39,8 @@ class NeuralNetwork(Group):
layout_direction="left_to_right",
):
super(Group, self).__init__()
self.input_layers = ListGroup(*input_layers)
self.input_layers_dict = self.make_input_layers_dict(input_layers)
self.input_layers = ListGroup(*self.input_layers_dict.values())
self.edge_width = edge_width
self.edge_color = edge_color
self.layer_spacing = layer_spacing
@ -69,9 +71,46 @@ class NeuralNetwork(Group):
# Center the whole diagram by default
self.all_layers.move_to(ORIGIN)
self.add(self.all_layers)
# Make container for connections
self.connections = []
# Print neural network
print(repr(self))
def make_input_layers_dict(self, input_layers):
"""Make dictionary of input layers"""
if isinstance(input_layers, dict):
# If input layers is dictionary then return it
return input_layers
elif isinstance(input_layers, list):
# If input layers is a list then make a dictionary with default
return_dict = {}
for layer_index, input_layer in enumerate(input_layers):
return_dict[f"layer{layer_index}"] = input_layer
return return_dict
else:
raise Exception(f"Uncrecognized input layers type: {type(input_layers)}")
def add_connection(
self,
start_layer_name,
end_layer_name,
connection_style="default",
connection_position="bottom"
):
"""Add connection from start layer to end layer"""
assert connection_style in ["default"]
if connection_style == "default":
# Make arrow connection from start layer to end layer
# Add the connection
connection = NetworkConnection(
self.input_layers_dict[start_layer_name],
self.input_layers_dict[end_layer_name],
arc_direction="down" # TODO generalize this more
)
self.connections.append(connection)
self.add(connection)
def _construct_input_layers(self):
"""Constructs each of the input layers in context
of their adjacent layers"""
@ -220,7 +259,27 @@ class NeuralNetwork(Group):
current_layer_args = layer_args[layer]
# Perform the forward pass of the current layer
layer_forward_pass = layer.make_forward_pass_animation(
layer_args=current_layer_args, run_time=per_layer_runtime, **kwargs
layer_args=current_layer_args,
run_time=per_layer_runtime,
**kwargs
)
# Animate a forward pass for incoming connections
connection_input_pass = AnimationGroup()
for connection in self.connections:
if isinstance(layer, ConnectiveLayer):
output_layer = layer.output_layer
if connection.end_mobject == output_layer:
connection_input_pass = ShowPassingFlash(
connection,
run_time=layer_forward_pass.run_time,
time_width=0.2
)
break
layer_forward_pass = AnimationGroup(
layer_forward_pass,
connection_input_pass,
lag_ratio=0.0
)
all_animations.append(layer_forward_pass)
# Make the animation group

View File

View File

View File

@ -0,0 +1,158 @@
import numpy as np
from manim import *
class NetworkConnection(VGroup):
"""
This class allows for creating connections
between locations in a network
"""
direction_vector_map = {
"up": UP,
"down": DOWN,
"left": LEFT,
"right": RIGHT
}
def __init__(
self,
start_mobject,
end_mobject,
arc_direction="straight",
buffer=0.05,
arc_distance=0.3,
stroke_width=2.0,
color=WHITE,
active_color=ORANGE
):
"""Creates an arrow with right angles in it connecting
two mobjects.
Parameters
----------
start_mobject : Mobject
Mobject where the start of the connection is from
end_mobject : Mobject
Mobject where the end of the connection goes to
arc_direction : str, optional
direction that the connection arcs, by default "straight"
buffer : float, optional
amount of space between the connection and mobjects at the end
arc_distance : float, optional
Distance from start and end mobject that the arc bends
stroke_width : float, optional
Stroke width of the connection
color : [float], optional
Color of the connection
active_color : [float], optional
Color of active animations for this mobject
"""
super().__init__()
assert arc_direction in ["straight", "up", "down", "left", "right"]
self.start_mobject = start_mobject
self.end_mobject = end_mobject
self.arc_direction = arc_direction
self.buffer = buffer
self.arc_distance = arc_distance
self.stroke_width = stroke_width
self.color = color
self.active_color = active_color
self.make_mobjects()
def make_mobjects(self):
"""Makes the submobjects"""
if self.start_mobject.get_center()[0] < self.end_mobject.get_center()[0]:
left_mobject = self.start_mobject
right_mobject = self.end_mobject
else:
right_mobject = self.start_mobject
left_mobject = self.end_mobject
if self.arc_direction == "straight":
# Make an arrow
arrow_line = Line(
left_mobject.get_right() + np.array([self.buffer, 0.0, 0.0]),
right_mobject.get_left() + np.array([-1 * self.buffer, 0.0, 0.0])
)
arrow = Arrow(
arrow_line,
color=self.color,
stroke_width=self.stroke_width
)
self.straight_arrow = arrow
self.add(arrow)
else:
# Figure out the direction of the arc
direction_vector = NetworkConnection.direction_vector_map[self.arc_direction]
# Make the start arc piece
start_line_start = left_mobject.get_critical_point(
direction_vector
)
start_line_start += direction_vector * self.buffer
start_line_end = start_line_start + direction_vector * self.arc_distance
self.start_line = Line(
start_line_start,
start_line_end,
color=self.color,
stroke_width=self.stroke_width
)
# Make the end arc piece with an arrow
end_line_end = right_mobject.get_critical_point(
direction_vector
)
end_line_end += direction_vector * self.buffer
end_line_start = end_line_end + direction_vector * self.arc_distance
self.end_arrow = Arrow(
start=end_line_start,
end=end_line_end,
color=WHITE,
fill_color=WHITE,
stroke_opacity=1.0,
buff=0.0
)
# Make the middle arc piece
self.middle_line = Line(
start_line_end,
end_line_start,
color=self.color,
stroke_width=self.stroke_width
)
# Add the mobjects
self.add(
self.start_line,
self.middle_line,
self.end_arrow,
)
@override_animation(ShowPassingFlash)
def _override_passing_flash(self, run_time=1.0, time_width=0.2):
"""Passing flash animation"""
if self.arc_direction == "straight":
return ShowPassingFlash(
self.straight_arrow.copy().set_color(self.active_color),
time_width=time_width
)
else:
# Animate the start line
start_line_animation = ShowPassingFlash(
self.start_line.copy().set_color(self.active_color),
time_width=time_width
)
# Animate the middle line
middle_line_animation = ShowPassingFlash(
self.middle_line.copy().set_color(self.active_color),
time_width=time_width
)
# Animate the end line
end_line_animation = ShowPassingFlash(
self.end_arrow.copy().set_color(self.active_color),
time_width=time_width
)
return AnimationGroup(
start_line_animation,
middle_line_animation,
end_line_animation,
lag_ratio=1.0,
run_time=run_time
)

View File

@ -26,3 +26,5 @@ class FeedForwardScene(Scene):
])
self.add(nn)
self.play(nn.make_forward_pass_animation())

View File

@ -0,0 +1,69 @@
from manim import *
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
from manim_ml.utils.testing.frames_comparison import frames_comparison
from manim_ml.neural_network import NeuralNetwork, FeedForwardLayer, ImageLayer
from PIL import Image
import numpy as np
__module_test__ = "residual"
@frames_comparison
def test_ResidualConnectionScene(scene):
"""Tests the appearance of a residual connection"""
nn = NeuralNetwork({
"layer1": FeedForwardLayer(3),
"layer2": FeedForwardLayer(5),
"layer3": FeedForwardLayer(3)
})
scene.add(nn)
# Make the specific scene
config.pixel_height = 1200
config.pixel_width = 1900
config.frame_height = 6.0
config.frame_width = 6.0
class FeedForwardScene(Scene):
def construct(self):
nn = NeuralNetwork({
"layer1": FeedForwardLayer(4),
"layer2": FeedForwardLayer(4),
"layer3": FeedForwardLayer(4)
},
layer_spacing=0.45)
nn.add_connection("layer1", "layer3")
self.add(nn)
self.play(
nn.make_forward_pass_animation(),
run_time=8
)
class ConvScene(ThreeDScene):
def construct(self):
image = Image.open("../assets/mnist/digit.jpeg")
numpy_image = np.asarray(image)
nn = NeuralNetwork({
"layer1": Convolutional2DLayer(1, 5, padding=1),
"layer2": Convolutional2DLayer(1, 5, 3, padding=1),
"layer3": Convolutional2DLayer(1, 5, 3, padding=1),
},
layer_spacing=0.25,
)
nn.add_connection("layer1", "layer3")
self.add(nn)
self.play(
nn.make_forward_pass_animation(),
run_time=8
)