From 61b47798f38e6099182554b11b31913878a6f12d Mon Sep 17 00:00:00 2001 From: Alec Helbling Date: Sat, 2 Apr 2022 19:20:30 -0400 Subject: [PATCH] Working neural network test with refactor --- manim_ml/neural_network/feed_forward.py | 3 +- manim_ml/neural_network/layers.py | 22 +++++++++---- manim_ml/neural_network/neural_network.py | 11 ++++--- tests/test_neural_network.py | 39 +++++++++++++++++++++++ 4 files changed, 63 insertions(+), 12 deletions(-) create mode 100644 tests/test_neural_network.py diff --git a/manim_ml/neural_network/feed_forward.py b/manim_ml/neural_network/feed_forward.py index a24ce46..72fe4a5 100644 --- a/manim_ml/neural_network/feed_forward.py +++ b/manim_ml/neural_network/feed_forward.py @@ -5,8 +5,9 @@ from manim_ml.neural_network.neural_network import NeuralNetwork class FeedForwardNeuralNetwork(NeuralNetwork): """NeuralNetwork with just feed forward layers""" - def __init__(self, layer_node_count, node_radius=1.0, + def __init__(self, layer_node_count, node_radius=0.08, node_color=BLUE, **kwargs): + # construct layers layers = [] for num_nodes in layer_node_count: diff --git a/manim_ml/neural_network/layers.py b/manim_ml/neural_network/layers.py index 33d36cc..7a1352b 100644 --- a/manim_ml/neural_network/layers.py +++ b/manim_ml/neural_network/layers.py @@ -24,16 +24,17 @@ class ConnectiveLayer(NeuralNetworkLayer): class FeedForwardToFeedForward(ConnectiveLayer): def __init__(self, input_layer, output_layer, passing_flash=True, - dot_radius=0.05, animation_dot_count=RED, edge_color=WHITE, + dot_radius=0.05, animation_dot_color=RED, edge_color=WHITE, edge_width=0.5): - super(FeedForwardToFeedForward, self).__init__(input_layer, output_layer) + super().__init__(input_layer, output_layer) self.passing_flash = passing_flash self.edge_color = edge_color self.dot_radius = dot_radius - self.animation_dot_count = animation_dot_count + self.animation_dot_color = animation_dot_color self.edge_width = edge_width - self.construct_edges() + self.edges = self.construct_edges() + self.add(self.edges) def construct_edges(self): # Go through each node in the two layers and make a connecting line @@ -42,26 +43,33 @@ class FeedForwardToFeedForward(ConnectiveLayer): for node_j in self.output_layer.node_group: line = Line(node_i.get_center(), node_j.get_center(), color=self.edge_color, stroke_width=self.edge_width) - self.add(line) + edges.append(line) - self.edges = VGroup(*edges) + edges = VGroup(*edges) + return edges def make_forward_pass_animation(self, run_time=1): """Animation for passing information from one FeedForwardLayer to the next""" path_animations = [] + dots = [] for edge in self.edges: dot = Dot(color=self.animation_dot_color, fill_opacity=1.0, radius=self.dot_radius) # Handle layering dot.set_z_index(1) # Add to dots group - self.dots.add(dot) + dots.append(dot) # Make the animation if self.passing_flash: + print("passing flash") anim = ShowPassingFlash(edge.copy().set_color(self.animation_dot_color), time_width=0.2, run_time=3) else: anim = MoveAlongPath(dot, edge, run_time=run_time, rate_function=sigmoid) path_animations.append(anim) + if not self.passing_flash: + dots = VGroup(*dots) + self.add(dots) + path_animations = AnimationGroup(*path_animations) return path_animations diff --git a/manim_ml/neural_network/neural_network.py b/manim_ml/neural_network/neural_network.py index 8110b20..6ce1c77 100644 --- a/manim_ml/neural_network/neural_network.py +++ b/manim_ml/neural_network/neural_network.py @@ -16,7 +16,7 @@ from manim_ml.neural_network.layers import FeedForwardToFeedForward, FeedForward class NeuralNetwork(VGroup): def __init__(self, layers, edge_color=WHITE, layer_spacing=0.8, - animation_dot_color=RED, edge_width=2.0, dot_radius=0.05): + animation_dot_color=RED, edge_width=1.5, dot_radius=0.05): super().__init__() self.layers = layers self.edge_width = edge_width @@ -55,7 +55,9 @@ class NeuralNetwork(VGroup): if isinstance(current_layer, FeedForwardLayer) \ and isinstance(next_layer, FeedForwardLayer): - edge_layer = FeedForwardToFeedForward(current_layer, next_layer) + edge_layer = FeedForwardToFeedForward(current_layer, next_layer, + edge_width=self.edge_width) + connective_layers.add(edge_layer) else: raise Exception(f"Unimplemented connection for layer types: {type(current_layer)} and {type(next_layer)}") @@ -64,14 +66,15 @@ class NeuralNetwork(VGroup): connective_layers.set_z_index(0) return connective_layers - def make_forward_propagation_animation(self, run_time=2, passing_flash=True): + def make_forward_pass_animation(self, run_time=2, passing_flash=True): """Generates an animation for feed forward propogation""" all_animations = [] for layer_index, layer in enumerate(self.layers[:-1]): - connective_layer = self.connective_layers[layer_index] layer_forward_pass = layer.make_forward_pass_animation() all_animations.append(layer_forward_pass) + + connective_layer = self.connective_layers[layer_index] connective_forward_pass = connective_layer.make_forward_pass_animation() all_animations.append(connective_forward_pass) diff --git a/tests/test_neural_network.py b/tests/test_neural_network.py new file mode 100644 index 0000000..1ddd18c --- /dev/null +++ b/tests/test_neural_network.py @@ -0,0 +1,39 @@ +from manim import * +from manim_ml.neural_network.layers import FeedForwardLayer +from manim_ml.neural_network.neural_network import NeuralNetwork +from manim_ml.neural_network.feed_forward import FeedForwardNeuralNetwork + +config.pixel_height = 720 +config.pixel_width = 1280 +config.frame_height = 6.0 +config.frame_width = 6.0 + +class FeedForwardNeuralNetworkScene(Scene): + + def construct(self): + nn = FeedForwardNeuralNetwork([3, 5, 3]) + self.play(Create(nn)) + self.play(Wait(3)) + +class NeuralNetworkScene(Scene): + """Test Scene for the Neural Network""" + + def construct(self): + # Make the Layer object + layers = [FeedForwardLayer(3), FeedForwardLayer(5), FeedForwardLayer(3)] + nn = NeuralNetwork(layers) + nn.move_to(ORIGIN) + # Make Animation + self.add(nn) + forward_propagation_animation = nn.make_forward_pass_animation(run_time=5, passing_flash=True) + + self.play(forward_propagation_animation) + +if __name__ == "__main__": + """Render all scenes""" + # Feed Forward Neural Network + ffnn_scene = FeedForwardNeuralNetworkScene() + ffnn_scene.render() + # Neural Network + nn_scene = NeuralNetworkScene() + nn_scene.render()