mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-26 13:04:07 +08:00
Added the ability to make residual connections.
Note: still need to add the residual plus icon.
This commit is contained in:
66
examples/cnn/resnet_block.py
Normal file
66
examples/cnn/resnet_block.py
Normal file
@ -0,0 +1,66 @@
|
||||
from manim import *
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from manim_ml.neural_network import Convolutional2DLayer, NeuralNetwork
|
||||
|
||||
# Make the specific scene
|
||||
config.pixel_height = 1200
|
||||
config.pixel_width = 1900
|
||||
config.frame_height = 6.0
|
||||
config.frame_width = 6.0
|
||||
|
||||
def make_code_snippet():
|
||||
code_str = """
|
||||
# Make the neural network
|
||||
nn = NeuralNetwork({
|
||||
"layer1": Convolutional2DLayer(1, 5, padding=1),
|
||||
"layer2": Convolutional2DLayer(1, 5, 3, padding=1),
|
||||
"layer3": Convolutional2DLayer(1, 5, 3, padding=1)
|
||||
})
|
||||
# Add the residual connection
|
||||
nn.add_connection("layer1", "layer3")
|
||||
# Make the animation
|
||||
self.play(nn.make_forward_pass_animation())
|
||||
"""
|
||||
|
||||
code = Code(
|
||||
code=code_str,
|
||||
tab_width=4,
|
||||
background_stroke_width=1,
|
||||
background_stroke_color=WHITE,
|
||||
insert_line_no=False,
|
||||
style="monokai",
|
||||
# background="window",
|
||||
language="py",
|
||||
)
|
||||
code.scale(0.38)
|
||||
|
||||
return code
|
||||
|
||||
class ConvScene(ThreeDScene):
|
||||
|
||||
def construct(self):
|
||||
image = Image.open("../../assets/mnist/digit.jpeg")
|
||||
numpy_image = np.asarray(image)
|
||||
|
||||
nn = NeuralNetwork({
|
||||
"layer1": Convolutional2DLayer(1, 5, padding=1),
|
||||
"layer2": Convolutional2DLayer(1, 5, 3, padding=1),
|
||||
"layer3": Convolutional2DLayer(1, 5, 3, padding=1),
|
||||
},
|
||||
layer_spacing=0.25,
|
||||
)
|
||||
|
||||
nn.add_connection("layer1", "layer3")
|
||||
|
||||
self.add(nn)
|
||||
|
||||
code = make_code_snippet()
|
||||
code.next_to(nn, DOWN)
|
||||
self.add(code)
|
||||
Group(code, nn).move_to(ORIGIN)
|
||||
|
||||
self.play(
|
||||
nn.make_forward_pass_animation(),
|
||||
run_time=8
|
||||
)
|
@ -11,6 +11,7 @@ Example:
|
||||
"""
|
||||
import textwrap
|
||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||
from manim_ml.utils.mobjects.connections import NetworkConnection
|
||||
import numpy as np
|
||||
from manim import *
|
||||
|
||||
@ -38,7 +39,8 @@ class NeuralNetwork(Group):
|
||||
layout_direction="left_to_right",
|
||||
):
|
||||
super(Group, self).__init__()
|
||||
self.input_layers = ListGroup(*input_layers)
|
||||
self.input_layers_dict = self.make_input_layers_dict(input_layers)
|
||||
self.input_layers = ListGroup(*self.input_layers_dict.values())
|
||||
self.edge_width = edge_width
|
||||
self.edge_color = edge_color
|
||||
self.layer_spacing = layer_spacing
|
||||
@ -69,9 +71,46 @@ class NeuralNetwork(Group):
|
||||
# Center the whole diagram by default
|
||||
self.all_layers.move_to(ORIGIN)
|
||||
self.add(self.all_layers)
|
||||
# Make container for connections
|
||||
self.connections = []
|
||||
# Print neural network
|
||||
print(repr(self))
|
||||
|
||||
def make_input_layers_dict(self, input_layers):
|
||||
"""Make dictionary of input layers"""
|
||||
if isinstance(input_layers, dict):
|
||||
# If input layers is dictionary then return it
|
||||
return input_layers
|
||||
elif isinstance(input_layers, list):
|
||||
# If input layers is a list then make a dictionary with default
|
||||
return_dict = {}
|
||||
for layer_index, input_layer in enumerate(input_layers):
|
||||
return_dict[f"layer{layer_index}"] = input_layer
|
||||
|
||||
return return_dict
|
||||
else:
|
||||
raise Exception(f"Uncrecognized input layers type: {type(input_layers)}")
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
start_layer_name,
|
||||
end_layer_name,
|
||||
connection_style="default",
|
||||
connection_position="bottom"
|
||||
):
|
||||
"""Add connection from start layer to end layer"""
|
||||
assert connection_style in ["default"]
|
||||
if connection_style == "default":
|
||||
# Make arrow connection from start layer to end layer
|
||||
# Add the connection
|
||||
connection = NetworkConnection(
|
||||
self.input_layers_dict[start_layer_name],
|
||||
self.input_layers_dict[end_layer_name],
|
||||
arc_direction="down" # TODO generalize this more
|
||||
)
|
||||
self.connections.append(connection)
|
||||
self.add(connection)
|
||||
|
||||
def _construct_input_layers(self):
|
||||
"""Constructs each of the input layers in context
|
||||
of their adjacent layers"""
|
||||
@ -220,7 +259,27 @@ class NeuralNetwork(Group):
|
||||
current_layer_args = layer_args[layer]
|
||||
# Perform the forward pass of the current layer
|
||||
layer_forward_pass = layer.make_forward_pass_animation(
|
||||
layer_args=current_layer_args, run_time=per_layer_runtime, **kwargs
|
||||
layer_args=current_layer_args,
|
||||
run_time=per_layer_runtime,
|
||||
**kwargs
|
||||
)
|
||||
# Animate a forward pass for incoming connections
|
||||
connection_input_pass = AnimationGroup()
|
||||
for connection in self.connections:
|
||||
if isinstance(layer, ConnectiveLayer):
|
||||
output_layer = layer.output_layer
|
||||
if connection.end_mobject == output_layer:
|
||||
connection_input_pass = ShowPassingFlash(
|
||||
connection,
|
||||
run_time=layer_forward_pass.run_time,
|
||||
time_width=0.2
|
||||
)
|
||||
break
|
||||
|
||||
layer_forward_pass = AnimationGroup(
|
||||
layer_forward_pass,
|
||||
connection_input_pass,
|
||||
lag_ratio=0.0
|
||||
)
|
||||
all_animations.append(layer_forward_pass)
|
||||
# Make the animation group
|
||||
|
0
manim_ml/utils/__init__.py
Normal file
0
manim_ml/utils/__init__.py
Normal file
0
manim_ml/utils/mobjects/__init__.py
Normal file
0
manim_ml/utils/mobjects/__init__.py
Normal file
158
manim_ml/utils/mobjects/connections.py
Normal file
158
manim_ml/utils/mobjects/connections.py
Normal file
@ -0,0 +1,158 @@
|
||||
import numpy as np
|
||||
from manim import *
|
||||
|
||||
class NetworkConnection(VGroup):
|
||||
"""
|
||||
This class allows for creating connections
|
||||
between locations in a network
|
||||
"""
|
||||
direction_vector_map = {
|
||||
"up": UP,
|
||||
"down": DOWN,
|
||||
"left": LEFT,
|
||||
"right": RIGHT
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_mobject,
|
||||
end_mobject,
|
||||
arc_direction="straight",
|
||||
buffer=0.05,
|
||||
arc_distance=0.3,
|
||||
stroke_width=2.0,
|
||||
color=WHITE,
|
||||
active_color=ORANGE
|
||||
):
|
||||
"""Creates an arrow with right angles in it connecting
|
||||
two mobjects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_mobject : Mobject
|
||||
Mobject where the start of the connection is from
|
||||
end_mobject : Mobject
|
||||
Mobject where the end of the connection goes to
|
||||
arc_direction : str, optional
|
||||
direction that the connection arcs, by default "straight"
|
||||
buffer : float, optional
|
||||
amount of space between the connection and mobjects at the end
|
||||
arc_distance : float, optional
|
||||
Distance from start and end mobject that the arc bends
|
||||
stroke_width : float, optional
|
||||
Stroke width of the connection
|
||||
color : [float], optional
|
||||
Color of the connection
|
||||
active_color : [float], optional
|
||||
Color of active animations for this mobject
|
||||
"""
|
||||
super().__init__()
|
||||
assert arc_direction in ["straight", "up", "down", "left", "right"]
|
||||
self.start_mobject = start_mobject
|
||||
self.end_mobject = end_mobject
|
||||
self.arc_direction = arc_direction
|
||||
self.buffer = buffer
|
||||
self.arc_distance = arc_distance
|
||||
self.stroke_width = stroke_width
|
||||
self.color = color
|
||||
self.active_color = active_color
|
||||
|
||||
self.make_mobjects()
|
||||
|
||||
def make_mobjects(self):
|
||||
"""Makes the submobjects"""
|
||||
if self.start_mobject.get_center()[0] < self.end_mobject.get_center()[0]:
|
||||
left_mobject = self.start_mobject
|
||||
right_mobject = self.end_mobject
|
||||
else:
|
||||
right_mobject = self.start_mobject
|
||||
left_mobject = self.end_mobject
|
||||
|
||||
if self.arc_direction == "straight":
|
||||
# Make an arrow
|
||||
arrow_line = Line(
|
||||
left_mobject.get_right() + np.array([self.buffer, 0.0, 0.0]),
|
||||
right_mobject.get_left() + np.array([-1 * self.buffer, 0.0, 0.0])
|
||||
)
|
||||
arrow = Arrow(
|
||||
arrow_line,
|
||||
color=self.color,
|
||||
stroke_width=self.stroke_width
|
||||
)
|
||||
self.straight_arrow = arrow
|
||||
self.add(arrow)
|
||||
else:
|
||||
# Figure out the direction of the arc
|
||||
direction_vector = NetworkConnection.direction_vector_map[self.arc_direction]
|
||||
# Make the start arc piece
|
||||
start_line_start = left_mobject.get_critical_point(
|
||||
direction_vector
|
||||
)
|
||||
start_line_start += direction_vector * self.buffer
|
||||
start_line_end = start_line_start + direction_vector * self.arc_distance
|
||||
self.start_line = Line(
|
||||
start_line_start,
|
||||
start_line_end,
|
||||
color=self.color,
|
||||
stroke_width=self.stroke_width
|
||||
)
|
||||
# Make the end arc piece with an arrow
|
||||
end_line_end = right_mobject.get_critical_point(
|
||||
direction_vector
|
||||
)
|
||||
end_line_end += direction_vector * self.buffer
|
||||
end_line_start = end_line_end + direction_vector * self.arc_distance
|
||||
self.end_arrow = Arrow(
|
||||
start=end_line_start,
|
||||
end=end_line_end,
|
||||
color=WHITE,
|
||||
fill_color=WHITE,
|
||||
stroke_opacity=1.0,
|
||||
buff=0.0
|
||||
)
|
||||
# Make the middle arc piece
|
||||
self.middle_line = Line(
|
||||
start_line_end,
|
||||
end_line_start,
|
||||
color=self.color,
|
||||
stroke_width=self.stroke_width
|
||||
)
|
||||
# Add the mobjects
|
||||
self.add(
|
||||
self.start_line,
|
||||
self.middle_line,
|
||||
self.end_arrow,
|
||||
)
|
||||
|
||||
@override_animation(ShowPassingFlash)
|
||||
def _override_passing_flash(self, run_time=1.0, time_width=0.2):
|
||||
"""Passing flash animation"""
|
||||
if self.arc_direction == "straight":
|
||||
return ShowPassingFlash(
|
||||
self.straight_arrow.copy().set_color(self.active_color),
|
||||
time_width=time_width
|
||||
)
|
||||
else:
|
||||
# Animate the start line
|
||||
start_line_animation = ShowPassingFlash(
|
||||
self.start_line.copy().set_color(self.active_color),
|
||||
time_width=time_width
|
||||
)
|
||||
# Animate the middle line
|
||||
middle_line_animation = ShowPassingFlash(
|
||||
self.middle_line.copy().set_color(self.active_color),
|
||||
time_width=time_width
|
||||
)
|
||||
# Animate the end line
|
||||
end_line_animation = ShowPassingFlash(
|
||||
self.end_arrow.copy().set_color(self.active_color),
|
||||
time_width=time_width
|
||||
)
|
||||
|
||||
return AnimationGroup(
|
||||
start_line_animation,
|
||||
middle_line_animation,
|
||||
end_line_animation,
|
||||
lag_ratio=1.0,
|
||||
run_time=run_time
|
||||
)
|
@ -26,3 +26,5 @@ class FeedForwardScene(Scene):
|
||||
])
|
||||
|
||||
self.add(nn)
|
||||
|
||||
self.play(nn.make_forward_pass_animation())
|
69
tests/test_residual_connection.py
Normal file
69
tests/test_residual_connection.py
Normal file
@ -0,0 +1,69 @@
|
||||
from manim import *
|
||||
from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer
|
||||
from manim_ml.utils.testing.frames_comparison import frames_comparison
|
||||
|
||||
from manim_ml.neural_network import NeuralNetwork, FeedForwardLayer, ImageLayer
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
__module_test__ = "residual"
|
||||
|
||||
@frames_comparison
|
||||
def test_ResidualConnectionScene(scene):
|
||||
"""Tests the appearance of a residual connection"""
|
||||
nn = NeuralNetwork({
|
||||
"layer1": FeedForwardLayer(3),
|
||||
"layer2": FeedForwardLayer(5),
|
||||
"layer3": FeedForwardLayer(3)
|
||||
})
|
||||
|
||||
scene.add(nn)
|
||||
|
||||
# Make the specific scene
|
||||
config.pixel_height = 1200
|
||||
config.pixel_width = 1900
|
||||
config.frame_height = 6.0
|
||||
config.frame_width = 6.0
|
||||
|
||||
class FeedForwardScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
nn = NeuralNetwork({
|
||||
"layer1": FeedForwardLayer(4),
|
||||
"layer2": FeedForwardLayer(4),
|
||||
"layer3": FeedForwardLayer(4)
|
||||
},
|
||||
layer_spacing=0.45)
|
||||
|
||||
nn.add_connection("layer1", "layer3")
|
||||
|
||||
self.add(nn)
|
||||
|
||||
self.play(
|
||||
nn.make_forward_pass_animation(),
|
||||
run_time=8
|
||||
)
|
||||
|
||||
class ConvScene(ThreeDScene):
|
||||
|
||||
def construct(self):
|
||||
image = Image.open("../assets/mnist/digit.jpeg")
|
||||
numpy_image = np.asarray(image)
|
||||
|
||||
nn = NeuralNetwork({
|
||||
"layer1": Convolutional2DLayer(1, 5, padding=1),
|
||||
"layer2": Convolutional2DLayer(1, 5, 3, padding=1),
|
||||
"layer3": Convolutional2DLayer(1, 5, 3, padding=1),
|
||||
},
|
||||
layer_spacing=0.25,
|
||||
)
|
||||
|
||||
nn.add_connection("layer1", "layer3")
|
||||
|
||||
self.add(nn)
|
||||
|
||||
self.play(
|
||||
nn.make_forward_pass_animation(),
|
||||
run_time=8
|
||||
)
|
Reference in New Issue
Block a user