From 27d235de257a439c5a24f3c9fd410eba0abd8f1e Mon Sep 17 00:00:00 2001 From: Alec Helbling Date: Wed, 1 Feb 2023 12:40:43 -0500 Subject: [PATCH] Added the ability to make residual connections. Note: still need to add the residual plus icon. --- examples/cnn/padding_example.py | 2 +- examples/cnn/resnet_block.py | 66 +++++++++ manim_ml/neural_network/neural_network.py | 63 ++++++++- manim_ml/utils/__init__.py | 0 manim_ml/utils/mobjects/__init__.py | 0 manim_ml/utils/mobjects/connections.py | 158 ++++++++++++++++++++++ tests/test_feed_forward.py | 4 +- tests/test_residual_connection.py | 69 ++++++++++ 8 files changed, 358 insertions(+), 4 deletions(-) create mode 100644 examples/cnn/resnet_block.py create mode 100644 manim_ml/utils/__init__.py create mode 100644 manim_ml/utils/mobjects/__init__.py create mode 100644 manim_ml/utils/mobjects/connections.py create mode 100644 tests/test_residual_connection.py diff --git a/examples/cnn/padding_example.py b/examples/cnn/padding_example.py index 6b65cbf..d8cb0a4 100644 --- a/examples/cnn/padding_example.py +++ b/examples/cnn/padding_example.py @@ -60,7 +60,7 @@ class CombinedScene(ThreeDScene): ), Convolutional2DLayer( num_feature_maps=3, - feature_map_size=6, + feature_map_size=6, filter_size=3, padding=0, padding_dashed=False diff --git a/examples/cnn/resnet_block.py b/examples/cnn/resnet_block.py new file mode 100644 index 0000000..b784256 --- /dev/null +++ b/examples/cnn/resnet_block.py @@ -0,0 +1,66 @@ +from manim import * +from PIL import Image +import numpy as np +from manim_ml.neural_network import Convolutional2DLayer, NeuralNetwork + +# Make the specific scene +config.pixel_height = 1200 +config.pixel_width = 1900 +config.frame_height = 6.0 +config.frame_width = 6.0 + +def make_code_snippet(): + code_str = """ + # Make the neural network + nn = NeuralNetwork({ + "layer1": Convolutional2DLayer(1, 5, padding=1), + "layer2": Convolutional2DLayer(1, 5, 3, padding=1), + "layer3": Convolutional2DLayer(1, 5, 3, padding=1) + }) + # Add the residual connection + nn.add_connection("layer1", "layer3") + # Make the animation + self.play(nn.make_forward_pass_animation()) + """ + + code = Code( + code=code_str, + tab_width=4, + background_stroke_width=1, + background_stroke_color=WHITE, + insert_line_no=False, + style="monokai", + # background="window", + language="py", + ) + code.scale(0.38) + + return code + +class ConvScene(ThreeDScene): + + def construct(self): + image = Image.open("../../assets/mnist/digit.jpeg") + numpy_image = np.asarray(image) + + nn = NeuralNetwork({ + "layer1": Convolutional2DLayer(1, 5, padding=1), + "layer2": Convolutional2DLayer(1, 5, 3, padding=1), + "layer3": Convolutional2DLayer(1, 5, 3, padding=1), + }, + layer_spacing=0.25, + ) + + nn.add_connection("layer1", "layer3") + + self.add(nn) + + code = make_code_snippet() + code.next_to(nn, DOWN) + self.add(code) + Group(code, nn).move_to(ORIGIN) + + self.play( + nn.make_forward_pass_animation(), + run_time=8 + ) \ No newline at end of file diff --git a/manim_ml/neural_network/neural_network.py b/manim_ml/neural_network/neural_network.py index 5b55ad8..2843c26 100644 --- a/manim_ml/neural_network/neural_network.py +++ b/manim_ml/neural_network/neural_network.py @@ -11,6 +11,7 @@ Example: """ import textwrap from manim_ml.neural_network.layers.embedding import EmbeddingLayer +from manim_ml.utils.mobjects.connections import NetworkConnection import numpy as np from manim import * @@ -38,7 +39,8 @@ class NeuralNetwork(Group): layout_direction="left_to_right", ): super(Group, self).__init__() - self.input_layers = ListGroup(*input_layers) + self.input_layers_dict = self.make_input_layers_dict(input_layers) + self.input_layers = ListGroup(*self.input_layers_dict.values()) self.edge_width = edge_width self.edge_color = edge_color self.layer_spacing = layer_spacing @@ -69,9 +71,46 @@ class NeuralNetwork(Group): # Center the whole diagram by default self.all_layers.move_to(ORIGIN) self.add(self.all_layers) + # Make container for connections + self.connections = [] # Print neural network print(repr(self)) + def make_input_layers_dict(self, input_layers): + """Make dictionary of input layers""" + if isinstance(input_layers, dict): + # If input layers is dictionary then return it + return input_layers + elif isinstance(input_layers, list): + # If input layers is a list then make a dictionary with default + return_dict = {} + for layer_index, input_layer in enumerate(input_layers): + return_dict[f"layer{layer_index}"] = input_layer + + return return_dict + else: + raise Exception(f"Uncrecognized input layers type: {type(input_layers)}") + + def add_connection( + self, + start_layer_name, + end_layer_name, + connection_style="default", + connection_position="bottom" + ): + """Add connection from start layer to end layer""" + assert connection_style in ["default"] + if connection_style == "default": + # Make arrow connection from start layer to end layer + # Add the connection + connection = NetworkConnection( + self.input_layers_dict[start_layer_name], + self.input_layers_dict[end_layer_name], + arc_direction="down" # TODO generalize this more + ) + self.connections.append(connection) + self.add(connection) + def _construct_input_layers(self): """Constructs each of the input layers in context of their adjacent layers""" @@ -220,7 +259,27 @@ class NeuralNetwork(Group): current_layer_args = layer_args[layer] # Perform the forward pass of the current layer layer_forward_pass = layer.make_forward_pass_animation( - layer_args=current_layer_args, run_time=per_layer_runtime, **kwargs + layer_args=current_layer_args, + run_time=per_layer_runtime, + **kwargs + ) + # Animate a forward pass for incoming connections + connection_input_pass = AnimationGroup() + for connection in self.connections: + if isinstance(layer, ConnectiveLayer): + output_layer = layer.output_layer + if connection.end_mobject == output_layer: + connection_input_pass = ShowPassingFlash( + connection, + run_time=layer_forward_pass.run_time, + time_width=0.2 + ) + break + + layer_forward_pass = AnimationGroup( + layer_forward_pass, + connection_input_pass, + lag_ratio=0.0 ) all_animations.append(layer_forward_pass) # Make the animation group diff --git a/manim_ml/utils/__init__.py b/manim_ml/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/manim_ml/utils/mobjects/__init__.py b/manim_ml/utils/mobjects/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/manim_ml/utils/mobjects/connections.py b/manim_ml/utils/mobjects/connections.py new file mode 100644 index 0000000..89b15f9 --- /dev/null +++ b/manim_ml/utils/mobjects/connections.py @@ -0,0 +1,158 @@ +import numpy as np +from manim import * + +class NetworkConnection(VGroup): + """ + This class allows for creating connections + between locations in a network + """ + direction_vector_map = { + "up": UP, + "down": DOWN, + "left": LEFT, + "right": RIGHT + } + + def __init__( + self, + start_mobject, + end_mobject, + arc_direction="straight", + buffer=0.05, + arc_distance=0.3, + stroke_width=2.0, + color=WHITE, + active_color=ORANGE + ): + """Creates an arrow with right angles in it connecting + two mobjects. + + Parameters + ---------- + start_mobject : Mobject + Mobject where the start of the connection is from + end_mobject : Mobject + Mobject where the end of the connection goes to + arc_direction : str, optional + direction that the connection arcs, by default "straight" + buffer : float, optional + amount of space between the connection and mobjects at the end + arc_distance : float, optional + Distance from start and end mobject that the arc bends + stroke_width : float, optional + Stroke width of the connection + color : [float], optional + Color of the connection + active_color : [float], optional + Color of active animations for this mobject + """ + super().__init__() + assert arc_direction in ["straight", "up", "down", "left", "right"] + self.start_mobject = start_mobject + self.end_mobject = end_mobject + self.arc_direction = arc_direction + self.buffer = buffer + self.arc_distance = arc_distance + self.stroke_width = stroke_width + self.color = color + self.active_color = active_color + + self.make_mobjects() + + def make_mobjects(self): + """Makes the submobjects""" + if self.start_mobject.get_center()[0] < self.end_mobject.get_center()[0]: + left_mobject = self.start_mobject + right_mobject = self.end_mobject + else: + right_mobject = self.start_mobject + left_mobject = self.end_mobject + + if self.arc_direction == "straight": + # Make an arrow + arrow_line = Line( + left_mobject.get_right() + np.array([self.buffer, 0.0, 0.0]), + right_mobject.get_left() + np.array([-1 * self.buffer, 0.0, 0.0]) + ) + arrow = Arrow( + arrow_line, + color=self.color, + stroke_width=self.stroke_width + ) + self.straight_arrow = arrow + self.add(arrow) + else: + # Figure out the direction of the arc + direction_vector = NetworkConnection.direction_vector_map[self.arc_direction] + # Make the start arc piece + start_line_start = left_mobject.get_critical_point( + direction_vector + ) + start_line_start += direction_vector * self.buffer + start_line_end = start_line_start + direction_vector * self.arc_distance + self.start_line = Line( + start_line_start, + start_line_end, + color=self.color, + stroke_width=self.stroke_width + ) + # Make the end arc piece with an arrow + end_line_end = right_mobject.get_critical_point( + direction_vector + ) + end_line_end += direction_vector * self.buffer + end_line_start = end_line_end + direction_vector * self.arc_distance + self.end_arrow = Arrow( + start=end_line_start, + end=end_line_end, + color=WHITE, + fill_color=WHITE, + stroke_opacity=1.0, + buff=0.0 + ) + # Make the middle arc piece + self.middle_line = Line( + start_line_end, + end_line_start, + color=self.color, + stroke_width=self.stroke_width + ) + # Add the mobjects + self.add( + self.start_line, + self.middle_line, + self.end_arrow, + ) + + @override_animation(ShowPassingFlash) + def _override_passing_flash(self, run_time=1.0, time_width=0.2): + """Passing flash animation""" + if self.arc_direction == "straight": + return ShowPassingFlash( + self.straight_arrow.copy().set_color(self.active_color), + time_width=time_width + ) + else: + # Animate the start line + start_line_animation = ShowPassingFlash( + self.start_line.copy().set_color(self.active_color), + time_width=time_width + ) + # Animate the middle line + middle_line_animation = ShowPassingFlash( + self.middle_line.copy().set_color(self.active_color), + time_width=time_width + ) + # Animate the end line + end_line_animation = ShowPassingFlash( + self.end_arrow.copy().set_color(self.active_color), + time_width=time_width + ) + + return AnimationGroup( + start_line_animation, + middle_line_animation, + end_line_animation, + lag_ratio=1.0, + run_time=run_time + ) \ No newline at end of file diff --git a/tests/test_feed_forward.py b/tests/test_feed_forward.py index da09bd7..29dd24c 100644 --- a/tests/test_feed_forward.py +++ b/tests/test_feed_forward.py @@ -25,4 +25,6 @@ class FeedForwardScene(Scene): FeedForwardLayer(3) ]) - self.add(nn) \ No newline at end of file + self.add(nn) + + self.play(nn.make_forward_pass_animation()) \ No newline at end of file diff --git a/tests/test_residual_connection.py b/tests/test_residual_connection.py new file mode 100644 index 0000000..b64100e --- /dev/null +++ b/tests/test_residual_connection.py @@ -0,0 +1,69 @@ +from manim import * +from manim_ml.neural_network.layers.convolutional_2d import Convolutional2DLayer +from manim_ml.utils.testing.frames_comparison import frames_comparison + +from manim_ml.neural_network import NeuralNetwork, FeedForwardLayer, ImageLayer + +from PIL import Image +import numpy as np + +__module_test__ = "residual" + +@frames_comparison +def test_ResidualConnectionScene(scene): + """Tests the appearance of a residual connection""" + nn = NeuralNetwork({ + "layer1": FeedForwardLayer(3), + "layer2": FeedForwardLayer(5), + "layer3": FeedForwardLayer(3) + }) + + scene.add(nn) + +# Make the specific scene +config.pixel_height = 1200 +config.pixel_width = 1900 +config.frame_height = 6.0 +config.frame_width = 6.0 + +class FeedForwardScene(Scene): + + def construct(self): + nn = NeuralNetwork({ + "layer1": FeedForwardLayer(4), + "layer2": FeedForwardLayer(4), + "layer3": FeedForwardLayer(4) + }, + layer_spacing=0.45) + + nn.add_connection("layer1", "layer3") + + self.add(nn) + + self.play( + nn.make_forward_pass_animation(), + run_time=8 + ) + +class ConvScene(ThreeDScene): + + def construct(self): + image = Image.open("../assets/mnist/digit.jpeg") + numpy_image = np.asarray(image) + + nn = NeuralNetwork({ + "layer1": Convolutional2DLayer(1, 5, padding=1), + "layer2": Convolutional2DLayer(1, 5, 3, padding=1), + "layer3": Convolutional2DLayer(1, 5, 3, padding=1), + }, + layer_spacing=0.25, + ) + + nn.add_connection("layer1", "layer3") + + self.add(nn) + + self.play( + nn.make_forward_pass_animation(), + run_time=8 + ) \ No newline at end of file