mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-26 13:22:08 +08:00
Added the ability to make residual connections.
Note: still need to add the residual plus icon.
This commit is contained in:
66
examples/cnn/resnet_block.py
Normal file
66
examples/cnn/resnet_block.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from manim import *
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from manim_ml.neural_network import Convolutional2DLayer, NeuralNetwork
|
||||||
|
|
||||||
|
# Make the specific scene
|
||||||
|
config.pixel_height = 1200
|
||||||
|
config.pixel_width = 1900
|
||||||
|
config.frame_height = 6.0
|
||||||
|
config.frame_width = 6.0
|
||||||
|
|
||||||
|
def make_code_snippet():
|
||||||
|
code_str = """
|
||||||
|
# Make the neural network
|
||||||
|
nn = NeuralNetwork({
|
||||||
|
"layer1": Convolutional2DLayer(1, 5, padding=1),
|
||||||
|
"layer2": Convolutional2DLayer(1, 5, 3, padding=1),
|
||||||
|
"layer3": Convolutional2DLayer(1, 5, 3, padding=1)
|
||||||
|
})
|
||||||
|
# Add the residual connection
|
||||||
|
nn.add_connection("layer1", "layer3")
|
||||||
|
# Make the animation
|
||||||
|
self.play(nn.make_forward_pass_animation())
|
||||||
|
"""
|
||||||
|
|
||||||
|
code = Code(
|
||||||
|
code=code_str,
|
||||||
|
tab_width=4,
|
||||||
|
background_stroke_width=1,
|
||||||
|
background_stroke_color=WHITE,
|
||||||
|
insert_line_no=False,
|
||||||
|
style="monokai",
|
||||||
|
# background="window",
|
||||||
|
language="py",
|
||||||
|
)
|
||||||
|
code.scale(0.38)
|
||||||
|
|
||||||
|
return code
|
||||||
|
|
||||||
|
class ConvScene(ThreeDScene):
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
image = Image.open("../../assets/mnist/digit.jpeg")
|
||||||
|
numpy_image = np.asarray(image)
|
||||||
|
|
||||||
|
nn = NeuralNetwork({
|
||||||
|
"layer1": Convolutional2DLayer(1, 5, padding=1),
|
||||||
|
"layer2": Convolutional2DLayer(1, 5, 3, padding=1),
|
||||||
|
"layer3": Convolutional2DLayer(1, 5, 3, padding=1),
|
||||||
|
},
|
||||||
|
layer_spacing=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
nn.add_connection("layer1", "layer3")
|
||||||
|
|
||||||
|
self.add(nn)
|
||||||
|
|
||||||
|
code = make_code_snippet()
|
||||||
|
code.next_to(nn, DOWN)
|
||||||
|
self.add(code)
|
||||||
|
Group(code, nn).move_to(ORIGIN)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
nn.make_forward_pass_animation(),
|
||||||
|
run_time=8
|
||||||
|
)
|
@ -11,6 +11,7 @@ Example:
|
|||||||
"""
|
"""
|
||||||
import textwrap
|
import textwrap
|
||||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||||
|
from manim_ml.utils.mobjects.connections import NetworkConnection
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from manim import *
|
from manim import *
|
||||||
|
|
||||||
@ -38,7 +39,8 @@ class NeuralNetwork(Group):
|
|||||||
layout_direction="left_to_right",
|
layout_direction="left_to_right",
|
||||||
):
|
):
|
||||||
super(Group, self).__init__()
|
super(Group, self).__init__()
|
||||||
self.input_layers = ListGroup(*input_layers)
|
self.input_layers_dict = self.make_input_layers_dict(input_layers)
|
||||||
|
self.input_layers = ListGroup(*self.input_layers_dict.values())
|
||||||
self.edge_width = edge_width
|
self.edge_width = edge_width
|
||||||
self.edge_color = edge_color
|
self.edge_color = edge_color
|
||||||
self.layer_spacing = layer_spacing
|
self.layer_spacing = layer_spacing
|
||||||
@ -69,9 +71,46 @@ class NeuralNetwork(Group):
|
|||||||
# Center the whole diagram by default
|
# Center the whole diagram by default
|
||||||
self.all_layers.move_to(ORIGIN)
|
self.all_layers.move_to(ORIGIN)
|
||||||
self.add(self.all_layers)
|
self.add(self.all_layers)
|
||||||
|
# Make container for connections
|
||||||
|
self.connections = []
|
||||||
# Print neural network
|
# Print neural network
|
||||||
print(repr(self))
|
print(repr(self))
|
||||||
|
|
||||||
|
def make_input_layers_dict(self, input_layers):
|
||||||
|
"""Make dictionary of input layers"""
|
||||||
|
if isinstance(input_layers, dict):
|
||||||
|
# If input layers is dictionary then return it
|
||||||
|
return input_layers
|
||||||
|
elif isinstance(input_layers, list):
|
||||||
|
# If input layers is a list then make a dictionary with default
|
||||||
|
return_dict = {}
|
||||||
|
for layer_index, input_layer in enumerate(input_layers):
|
||||||
|
return_dict[f"layer{layer_index}"] = input_layer
|
||||||
|
|
||||||
|
return return_dict
|
||||||
|
else:
|
||||||
|
raise Exception(f"Uncrecognized input layers type: {type(input_layers)}")
|
||||||
|
|
||||||
|
def add_connection(
|
||||||
|
self,
|
||||||
|
start_layer_name,
|
||||||
|
end_layer_name,
|
||||||
|
connection_style="default",
|
||||||
|
connection_position="bottom"
|
||||||
|
):
|
||||||
|
"""Add connection from start layer to end layer"""
|
||||||
|
assert connection_style in ["default"]
|
||||||
|
if connection_style == "default":
|
||||||
|
# Make arrow connection from start layer to end layer
|
||||||
|
# Add the connection
|
||||||
|
connection = NetworkConnection(
|
||||||
|
self.input_layers_dict[start_layer_name],
|
||||||
|
self.input_layers_dict[end_layer_name],
|
||||||
|
arc_direction="down" # TODO generalize this more
|
||||||
|
)
|
||||||
|
self.connections.append(connection)
|
||||||
|
self.add(connection)
|
||||||
|
|
||||||
def _construct_input_layers(self):
|
def _construct_input_layers(self):
|
||||||
"""Constructs each of the input layers in context
|
"""Constructs each of the input layers in context
|
||||||
of their adjacent layers"""
|
of their adjacent layers"""
|
||||||
@ -220,7 +259,27 @@ class NeuralNetwork(Group):
|
|||||||
current_layer_args = layer_args[layer]
|
current_layer_args = layer_args[layer]
|
||||||
# Perform the forward pass of the current layer
|
# Perform the forward pass of the current layer
|
||||||
layer_forward_pass = layer.make_forward_pass_animation(
|
layer_forward_pass = layer.make_forward_pass_animation(
|
||||||
layer_args=current_layer_args, run_time=per_layer_runtime, **kwargs
|
layer_args=current_layer_args,
|
||||||
|
run_time=per_layer_runtime,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
# Animate a forward pass for incoming connections
|
||||||
|
connection_input_pass = AnimationGroup()
|
||||||
|
for connection in self.connections:
|
||||||
|
if isinstance(layer, ConnectiveLayer):
|
||||||
|
output_layer = layer.output_layer
|
||||||
|
if connection.end_mobject == output_layer:
|
||||||
|
connection_input_pass = ShowPassingFlash(
|
||||||
|
connection,
|
||||||
|
run_time=layer_forward_pass.run_time,
|
||||||
|
time_width=0.2
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
layer_forward_pass = AnimationGroup(
|
||||||
|
layer_forward_pass,
|
||||||
|
connection_input_pass,
|
||||||
|
lag_ratio=0.0
|
||||||
)
|
)
|
||||||
all_animations.append(layer_forward_pass)
|
all_animations.append(layer_forward_pass)
|
||||||
# Make the animation group
|
# Make the animation group
|
||||||
|
0
manim_ml/utils/__init__.py
Normal file
0
manim_ml/utils/__init__.py
Normal file
0
manim_ml/utils/mobjects/__init__.py
Normal file
0
manim_ml/utils/mobjects/__init__.py
Normal file
158
manim_ml/utils/mobjects/connections.py
Normal file
158
manim_ml/utils/mobjects/connections.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import numpy as np
|
||||||
|
from manim import *
|
||||||
|
|
||||||
|
class NetworkConnection(VGroup):
|
||||||
|
"""
|
||||||
|
This class allows for creating connections
|
||||||
|
between locations in a network
|
||||||
|
"""
|
||||||
|
direction_vector_map = {
|
||||||
|
"up": UP,
|
||||||
|
"down": DOWN,
|
||||||
|
"left": LEFT,
|
||||||
|
"right": RIGHT
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
start_mobject,
|
||||||
|
end_mobject,
|
||||||
|
arc_direction="straight",
|
||||||
|
buffer=0.05,
|
||||||
|
arc_distance=0.3,
|
||||||
|
stroke_width=2.0,
|
||||||
|
color=WHITE,
|
||||||
|
active_color=ORANGE
|
||||||
|
):
|
||||||
|
"""Creates an arrow with right angles in it connecting
|
||||||
|
two mobjects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
start_mobject : Mobject
|
||||||
|
Mobject where the start of the connection is from
|
||||||
|
end_mobject : Mobject
|
||||||
|
Mobject where the end of the connection goes to
|
||||||
|
arc_direction : str, optional
|
||||||
|
direction that the connection arcs, by default "straight"
|
||||||
|
buffer : float, optional
|
||||||
|
amount of space between the connection and mobjects at the end
|
||||||
|
arc_distance : float, optional
|
||||||
|
Distance from start and end mobject that the arc bends
|
||||||
|
stroke_width : float, optional
|
||||||
|
Stroke width of the connection
|
||||||
|
color : [float], optional
|
||||||
|
Color of the connection
|
||||||
|
active_color : [float], optional
|
||||||
|
Color of active animations for this mobject
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert arc_direction in ["straight", "up", "down", "left", "right"]
|
||||||
|
self.start_mobject = start_mobject
|
||||||
|
self.end_mobject = end_mobject
|
||||||
|
self.arc_direction = arc_direction
|
||||||
|
self.buffer = buffer
|
||||||
|
self.arc_distance = arc_distance
|
||||||
|
self.stroke_width = stroke_width
|
||||||
|
self.color = color
|
||||||
|
self.active_color = active_color
|
||||||
|
|
||||||
|
self.make_mobjects()
|
||||||
|
|
||||||
|
def make_mobjects(self):
|
||||||
|
"""Makes the submobjects"""
|
||||||
|
if self.start_mobject.get_center()[0] < self.end_mobject.get_center()[0]:
|
||||||
|
left_mobject = self.start_mobject
|
||||||
|
right_mobject = self.end_mobject
|
||||||
|
else:
|
||||||
|
right_mobject = self.start_mobject
|
||||||
|
left_mobject = self.end_mobject
|
||||||
|
|
||||||
|
if self.arc_direction == "straight":
|
||||||
|
# Make an arrow
|
||||||
|
arrow_line = Line(
|
||||||
|
left_mobject.get_right() + np.array([self.buffer, 0.0, 0.0]),
|
||||||
|
right_mobject.get_left() + np.array([-1 * self.buffer, 0.0, 0.0])
|
||||||
|
)
|
||||||
|
arrow = Arrow(
|
||||||
|
arrow_line,
|
||||||
|
color=self.color,
|
||||||
|
stroke_width=self.stroke_width
|
||||||
|
)
|
||||||
|
self.straight_arrow = arrow
|
||||||
|
self.add(arrow)
|
||||||
|
else:
|
||||||
|
# Figure out the direction of the arc
|
||||||
|
direction_vector = NetworkConnection.direction_vector_map[self.arc_direction]
|
||||||
|
# Make the start arc piece
|
||||||
|
start_line_start = left_mobject.get_critical_point(
|
||||||
|
direction_vector
|
||||||
|
)
|
||||||
|
start_line_start += direction_vector * self.buffer
|
||||||
|
start_line_end = start_line_start + direction_vector * self.arc_distance
|
||||||
|
self.start_line = Line(
|
||||||
|
start_line_start,
|
||||||
|
start_line_end,
|
||||||
|
color=self.color,
|
||||||
|
stroke_width=self.stroke_width
|
||||||
|
)
|
||||||
|
# Make the end arc piece with an arrow
|
||||||
|
end_line_end = right_mobject.get_critical_point(
|
||||||
|
direction_vector
|
||||||
|
)
|
||||||
|
end_line_end += direction_vector * self.buffer
|
||||||
|
end_line_start = end_line_end + direction_vector * self.arc_distance
|
||||||
|
self.end_arrow = Arrow(
|
||||||
|
start=end_line_start,
|
||||||
|
end=end_line_end,
|
||||||
|
color=WHITE,
|
||||||
|
fill_color=WHITE,
|
||||||
|
stroke_opacity=1.0,
|
||||||
|
buff=0.0
|
||||||
|
)
|
||||||
|
# Make the middle arc piece
|
||||||
|
self.middle_line = Line(
|
||||||
|
start_line_end,
|
||||||
|
end_line_start,
|
||||||
|
color=self.color,
|
||||||
|
stroke_width=self.stroke_width
|
||||||
|
)
|
||||||
|
# Add the mobjects
|
||||||
|
self.add(
|
||||||
|
self.start_line,
|
||||||
|
self.middle_line,
|
||||||
|
self.end_arrow,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_animation(ShowPassingFlash)
|
||||||
|
def _override_passing_flash(self, run_time=1.0, time_width=0.2):
|
||||||
|
"""Passing flash animation"""
|
||||||
|
if self.arc_direction == "straight":
|
||||||
|
return ShowPassingFlash(
|
||||||
|
self.straight_arrow.copy().set_color(self.active_color),
|
||||||
|
time_width=time_width
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Animate the start line
|
||||||
|
start_line_animation = ShowPassingFlash(
|
||||||
|
self.start_line.copy().set_color(self.active_color),
|
||||||
|
time_width=time_width
|
||||||
|
)
|
||||||
|
# Animate the middle line
|
||||||
|
middle_line_animation = ShowPassingFlash(
|
||||||
|
self.middle_line.copy().set_color(self.active_color),
|
||||||
|
time_width=time_width
|
||||||
|
)
|
||||||
|
# Animate the end line
|
||||||
|
end_line_animation = ShowPassingFlash(
|
||||||
|
self.end_arrow.copy().set_color(self.active_color),
|
||||||
|
time_width=time_width
|
||||||
|
)
|
||||||
|
|
||||||
|
return AnimationGroup(
|
||||||
|
start_line_animation,
|
||||||
|
middle_line_animation,
|
||||||
|
end_line_animation,
|
||||||
|
lag_ratio=1.0,
|
||||||
|
run_time=run_time
|
||||||
|
)
|
@ -26,3 +26,5 @@ class FeedForwardScene(Scene):
|
|||||||
])
|
])
|
||||||
|
|
||||||
self.add(nn)
|
self.add(nn)
|
||||||
|
|
||||||
|
self.play(nn.make_forward_pass_animation())
|
69
tests/test_residual_connection.py
Normal file
69
tests/test_residual_connection.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from manim import *
|
||||||
|
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
|
||||||
|
from manim_ml.utils.testing.frames_comparison import frames_comparison
|
||||||
|
|
||||||
|
from manim_ml.neural_network import NeuralNetwork, FeedForwardLayer, ImageLayer
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
__module_test__ = "residual"
|
||||||
|
|
||||||
|
@frames_comparison
|
||||||
|
def test_ResidualConnectionScene(scene):
|
||||||
|
"""Tests the appearance of a residual connection"""
|
||||||
|
nn = NeuralNetwork({
|
||||||
|
"layer1": FeedForwardLayer(3),
|
||||||
|
"layer2": FeedForwardLayer(5),
|
||||||
|
"layer3": FeedForwardLayer(3)
|
||||||
|
})
|
||||||
|
|
||||||
|
scene.add(nn)
|
||||||
|
|
||||||
|
# Make the specific scene
|
||||||
|
config.pixel_height = 1200
|
||||||
|
config.pixel_width = 1900
|
||||||
|
config.frame_height = 6.0
|
||||||
|
config.frame_width = 6.0
|
||||||
|
|
||||||
|
class FeedForwardScene(Scene):
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
nn = NeuralNetwork({
|
||||||
|
"layer1": FeedForwardLayer(4),
|
||||||
|
"layer2": FeedForwardLayer(4),
|
||||||
|
"layer3": FeedForwardLayer(4)
|
||||||
|
},
|
||||||
|
layer_spacing=0.45)
|
||||||
|
|
||||||
|
nn.add_connection("layer1", "layer3")
|
||||||
|
|
||||||
|
self.add(nn)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
nn.make_forward_pass_animation(),
|
||||||
|
run_time=8
|
||||||
|
)
|
||||||
|
|
||||||
|
class ConvScene(ThreeDScene):
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
image = Image.open("../assets/mnist/digit.jpeg")
|
||||||
|
numpy_image = np.asarray(image)
|
||||||
|
|
||||||
|
nn = NeuralNetwork({
|
||||||
|
"layer1": Convolutional2DLayer(1, 5, padding=1),
|
||||||
|
"layer2": Convolutional2DLayer(1, 5, 3, padding=1),
|
||||||
|
"layer3": Convolutional2DLayer(1, 5, 3, padding=1),
|
||||||
|
},
|
||||||
|
layer_spacing=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
nn.add_connection("layer1", "layer3")
|
||||||
|
|
||||||
|
self.add(nn)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
nn.make_forward_pass_animation(),
|
||||||
|
run_time=8
|
||||||
|
)
|
Reference in New Issue
Block a user