From 11bbd59bb62e6892495002f8051df3970f450c33 Mon Sep 17 00:00:00 2001 From: Alec Helbling Date: Tue, 19 Apr 2022 02:10:53 -0400 Subject: [PATCH] Vector layer (Still work in progress) --- manim_ml/neural_network/layers/__init__.py | 4 +- .../layers/feed_forward_to_vector.py | 39 +++++++++++++++++++ manim_ml/neural_network/layers/vector.py | 37 ++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 manim_ml/neural_network/layers/feed_forward_to_vector.py create mode 100644 manim_ml/neural_network/layers/vector.py diff --git a/manim_ml/neural_network/layers/__init__.py b/manim_ml/neural_network/layers/__init__.py index 3e2f6d5..cd69c15 100644 --- a/manim_ml/neural_network/layers/__init__.py +++ b/manim_ml/neural_network/layers/__init__.py @@ -1,5 +1,6 @@ from tempfile import _TemporaryFileWrapper -from manim_ml.neural_network.layers.paired_query_to_feed_forward import PairedQueryToFeedForward +from .feed_forward_to_vector import FeedForwardToVector +from .paired_query_to_feed_forward import PairedQueryToFeedForward from .embedding_to_feed_forward import EmbeddingToFeedForward from .embedding import EmbeddingLayer from .feed_forward_to_embedding import FeedForwardToEmbedding @@ -23,4 +24,5 @@ connective_layers_list = ( PairedQueryToFeedForward, TripletToFeedForward, PairedQueryToFeedForward, + FeedForwardToVector, ) diff --git a/manim_ml/neural_network/layers/feed_forward_to_vector.py b/manim_ml/neural_network/layers/feed_forward_to_vector.py new file mode 100644 index 0000000..9563ca0 --- /dev/null +++ b/manim_ml/neural_network/layers/feed_forward_to_vector.py @@ -0,0 +1,39 @@ +from manim import * +from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer +from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer +from manim_ml.neural_network.layers.vector import VectorLayer + +class FeedForwardToVector(ConnectiveLayer): + """Image Layer to FeedForward layer""" + input_class = FeedForwardLayer + output_class = VectorLayer + + def __init__(self, input_layer, output_layer, animation_dot_color=RED, + dot_radius=0.05, **kwargs): + super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=VectorLayer, + **kwargs) + self.animation_dot_color = animation_dot_color + self.dot_radius = dot_radius + + self.feed_forward_layer = input_layer + self.vector_layer = output_layer + + def make_forward_pass_animation(self): + """Makes dots diverge from the given location and move to the feed forward nodes decoder""" + animations = [] + # Move the dots to the centers of each of the nodes in the FeedForwardLayer + destination = self.vector_layer.get_center() + for node in self.feed_forward_layer.node_group: + new_dot = Dot(node.get_center(), radius=self.dot_radius, color=self.animation_dot_color) + per_node_succession = Succession( + Create(new_dot), + new_dot.animate.move_to(destination), + ) + animations.append(per_node_succession) + + animation_group = AnimationGroup(*animations) + return animation_group + + @override_animation(Create) + def _create_override(self): + return AnimationGroup() \ No newline at end of file diff --git a/manim_ml/neural_network/layers/vector.py b/manim_ml/neural_network/layers/vector.py new file mode 100644 index 0000000..585e946 --- /dev/null +++ b/manim_ml/neural_network/layers/vector.py @@ -0,0 +1,37 @@ +from manim import * +import random + +from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer + +class VectorLayer(VGroupNeuralNetworkLayer): + """Shows a vector""" + + def __init__(self, num_values, value_func=lambda: random.uniform(0, 1), + **kwargs): + print("vector layer") + super().__init__(**kwargs) + print("after init") + self.num_values = num_values + self.value_func = value_func + # Make the vector + self.vector_label = self.make_vector() + + def make_vector(self): + """Makes the vector""" + if False: + # TODO install Latex + values = np.array([self.value_func() for i in range(self.num_values)]) + values = values[None, :].T + vector = Matrix(values) + + vector_label = Text(f"[{self.value_func()}]") + + return vector_label + + def make_forward_pass_animation(self): + return AnimationGroup() + + @override_animation(Create) + def _create_override(self): + """Create animation""" + return Create(self.vector_label) \ No newline at end of file