diff --git a/manim_ml/neural_network/layers/convolutional.py b/manim_ml/neural_network/layers/convolutional.py index 7434099..b97c72a 100644 --- a/manim_ml/neural_network/layers/convolutional.py +++ b/manim_ml/neural_network/layers/convolutional.py @@ -16,7 +16,7 @@ class ConvolutionalLayer(VGroupNeuralNetworkLayer): """Creates the neural network layer""" pass - def make_forward_pass_animation(self): + def make_forward_pass_animation(self, layer_args={}, **kwargs): # make highlight animation return None diff --git a/manim_ml/neural_network/layers/embedding.py b/manim_ml/neural_network/layers/embedding.py index bcdcf1d..79c191d 100644 --- a/manim_ml/neural_network/layers/embedding.py +++ b/manim_ml/neural_network/layers/embedding.py @@ -6,16 +6,29 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer): """NeuralNetwork embedding object that can show probability distributions""" def __init__(self, point_radius=0.02, mean = np.array([0, 0]), - covariance=np.array([[1.5, 0], [0, 1.5]]), dist_theme="gaussian", **kwargs): + covariance=np.array([[1.0, 0], [0, 1.0]]), dist_theme="gaussian", + paired_query_mode=False, **kwargs): super(VGroupNeuralNetworkLayer, self).__init__(**kwargs) self.point_radius = point_radius self.dist_theme = dist_theme + self.paired_query_mode = paired_query_mode self.axes = Axes( tips=False, x_length=0.8, - y_length=0.8 + y_length=0.8, + x_range=(-2.0, 2.0), + y_range=(-2.0, 2.0), + x_axis_config={ + "include_ticks": False, + "stroke_width": 0.0 + }, + y_axis_config={ + "include_ticks": False, + "stroke_width": 0.0 + } ) self.add(self.axes) + self.axes.move_to(self.get_center()) # Make point cloud self.point_cloud = self.construct_gaussian_point_cloud(mean, covariance) self.add(self.point_cloud) @@ -51,24 +64,44 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer): return point_dots - def make_forward_pass_animation(self, **kwargs): - """Forward pass animation""" - # Make ellipse object corresponding to the latent distribution - self.latent_distribution = GaussianDistribution( - self.axes, - dist_theme=self.dist_theme, - cov=np.array([[0.8, 0], [0.0, 0.8]]) - ) # Use defaults - # Create animation + def make_paired_query_embedding_animation(self): + """Embed paired query""" animations = [] - #create_distribution = Create(self.latent_distribution.construct_gaussian_distribution(self.latent_distribution.mean, self.latent_distribution.cov)) #Create(self.latent_distribution) - create_distribution = Create(self.latent_distribution.ellipses) - animations.append(create_distribution) - - animation_group = AnimationGroup(*animations) - + # Make the animation + + # Animation group + animation_group = AnimationGroup( + *animations, + lag_ratio=1.0 + ) + return animation_group + def make_forward_pass_animation(self, layer_args={}, **kwargs): + """Forward pass animation""" + animations = [] + if not self.paired_query_mode: + # Normal embedding mode + # Make ellipse object corresponding to the latent distribution + self.latent_distribution = GaussianDistribution( + self.axes, + dist_theme=self.dist_theme, + cov=np.array([[0.8, 0], [0.0, 0.8]]) + ) # Use defaults + # Create animation + #create_distribution = Create(self.latent_distribution.construct_gaussian_distribution(self.latent_distribution.mean, self.latent_distribution.cov)) #Create(self.latent_distribution) + create_distribution = Create(self.latent_distribution.ellipses) + animations.append(create_distribution) + + animation_group = AnimationGroup(*animations) + + return animation_group + else: + # Paired Query Mode + # Handle logic for embedding a paired query into the embedding layer + paired_query_embedding_animation = self.make_paired_query_embedding_animation() + return paired_query_embedding_animation + @override_animation(Create) def _create_override(self, **kwargs): # Plot each point at once diff --git a/manim_ml/neural_network/layers/embedding_to_feed_forward.py b/manim_ml/neural_network/layers/embedding_to_feed_forward.py index 9057c3c..9053c78 100644 --- a/manim_ml/neural_network/layers/embedding_to_feed_forward.py +++ b/manim_ml/neural_network/layers/embedding_to_feed_forward.py @@ -17,7 +17,7 @@ class EmbeddingToFeedForward(ConnectiveLayer): self.animation_dot_color = animation_dot_color self.dot_radius = dot_radius - def make_forward_pass_animation(self, run_time=1.5, **kwargs): + def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs): """Makes dots diverge from the given location and move the decoder""" # Find point to converge on by sampling from gaussian distribution location = self.embedding_layer.sample_point_location_from_distribution() diff --git a/manim_ml/neural_network/layers/feed_forward.py b/manim_ml/neural_network/layers/feed_forward.py index 62af147..2e4ef53 100644 --- a/manim_ml/neural_network/layers/feed_forward.py +++ b/manim_ml/neural_network/layers/feed_forward.py @@ -44,7 +44,7 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer): # Add the objects to the class self.add(self.surrounding_rectangle, self.node_group) - def make_forward_pass_animation(self, **kwargs): + def make_forward_pass_animation(self, layer_args={}, **kwargs): # make highlight animation succession = Succession( ApplyMethod(self.node_group.set_color, self.animation_dot_color, run_time=0.25), diff --git a/manim_ml/neural_network/layers/feed_forward_to_embedding.py b/manim_ml/neural_network/layers/feed_forward_to_embedding.py index afb4a9f..8e3da57 100644 --- a/manim_ml/neural_network/layers/feed_forward_to_embedding.py +++ b/manim_ml/neural_network/layers/feed_forward_to_embedding.py @@ -17,7 +17,7 @@ class FeedForwardToEmbedding(ConnectiveLayer): self.animation_dot_color = animation_dot_color self.dot_radius = dot_radius - def make_forward_pass_animation(self, run_time=1.5): + def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs): """Makes dots converge on a specific location""" # Find point to converge on by sampling from gaussian distribution location = self.embedding_layer.sample_point_location_from_distribution() diff --git a/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py b/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py index 007aead..dbeed74 100644 --- a/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py +++ b/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py @@ -44,7 +44,7 @@ class FeedForwardToFeedForward(ConnectiveLayer): return animation_group - def make_forward_pass_animation(self, run_time=1, **kwargs): + def make_forward_pass_animation(self, layer_args={}, run_time=1, **kwargs): """Animation for passing information from one FeedForwardLayer to the next""" path_animations = [] dots = [] diff --git a/manim_ml/neural_network/layers/feed_forward_to_image.py b/manim_ml/neural_network/layers/feed_forward_to_image.py index ae3468c..ef6b940 100644 --- a/manim_ml/neural_network/layers/feed_forward_to_image.py +++ b/manim_ml/neural_network/layers/feed_forward_to_image.py @@ -18,7 +18,7 @@ class FeedForwardToImage(ConnectiveLayer): self.feed_forward_layer = input_layer self.image_layer = output_layer - def make_forward_pass_animation(self, **kwargs): + def make_forward_pass_animation(self, layer_args={}, **kwargs): """Makes dots diverge from the given location and move to the feed forward nodes decoder""" animations = [] image_mobject = self.image_layer.image_mobject diff --git a/manim_ml/neural_network/layers/feed_forward_to_vector.py b/manim_ml/neural_network/layers/feed_forward_to_vector.py index 9563ca0..a61fb13 100644 --- a/manim_ml/neural_network/layers/feed_forward_to_vector.py +++ b/manim_ml/neural_network/layers/feed_forward_to_vector.py @@ -18,7 +18,7 @@ class FeedForwardToVector(ConnectiveLayer): self.feed_forward_layer = input_layer self.vector_layer = output_layer - def make_forward_pass_animation(self): + def make_forward_pass_animation(self, layer_args={}, **kwargs): """Makes dots diverge from the given location and move to the feed forward nodes decoder""" animations = [] # Move the dots to the centers of each of the nodes in the FeedForwardLayer diff --git a/manim_ml/neural_network/layers/image.py b/manim_ml/neural_network/layers/image.py index b37d1a2..a983f80 100644 --- a/manim_ml/neural_network/layers/image.py +++ b/manim_ml/neural_network/layers/image.py @@ -27,7 +27,7 @@ class ImageLayer(NeuralNetworkLayer): else: return AnimationGroup() - def make_forward_pass_animation(self, **kwargs): + def make_forward_pass_animation(self, layer_args={}, **kwargs): return FadeIn(self.image_mobject) # def move_to(self, location): diff --git a/manim_ml/neural_network/layers/image_to_feed_forward.py b/manim_ml/neural_network/layers/image_to_feed_forward.py index 8e58a85..59d838f 100644 --- a/manim_ml/neural_network/layers/image_to_feed_forward.py +++ b/manim_ml/neural_network/layers/image_to_feed_forward.py @@ -18,7 +18,7 @@ class ImageToFeedForward(ConnectiveLayer): self.feed_forward_layer = output_layer self.image_layer = input_layer - def make_forward_pass_animation(self, **kwargs): + def make_forward_pass_animation(self, layer_args={}, **kwargs): """Makes dots diverge from the given location and move to the feed forward nodes decoder""" animations = [] dots = [] diff --git a/manim_ml/neural_network/layers/paired_query.py b/manim_ml/neural_network/layers/paired_query.py index f16b436..8867a28 100644 --- a/manim_ml/neural_network/layers/paired_query.py +++ b/manim_ml/neural_network/layers/paired_query.py @@ -60,6 +60,6 @@ class PairedQueryLayer(NeuralNetworkLayer): # TODO make Create animation that is custom return FadeIn(self.assets) - def make_forward_pass_animation(self, **kwargs): + def make_forward_pass_animation(self, layer_args={}, **kwargs): """Forward pass for query""" return AnimationGroup() \ No newline at end of file diff --git a/manim_ml/neural_network/layers/paired_query_to_feed_forward.py b/manim_ml/neural_network/layers/paired_query_to_feed_forward.py index df77c5c..f1d8775 100644 --- a/manim_ml/neural_network/layers/paired_query_to_feed_forward.py +++ b/manim_ml/neural_network/layers/paired_query_to_feed_forward.py @@ -17,7 +17,7 @@ class PairedQueryToFeedForward(ConnectiveLayer): self.paired_query_layer = input_layer self.feed_forward_layer = output_layer - def make_forward_pass_animation(self, **kwargs): + def make_forward_pass_animation(self, layer_args={}, **kwargs): """Makes dots diverge from the given location and move to the feed forward nodes decoder""" animations = [] # Loop through each image diff --git a/manim_ml/neural_network/layers/parent_layers.py b/manim_ml/neural_network/layers/parent_layers.py index 40feba7..3a5bb82 100644 --- a/manim_ml/neural_network/layers/parent_layers.py +++ b/manim_ml/neural_network/layers/parent_layers.py @@ -12,7 +12,7 @@ class NeuralNetworkLayer(ABC, Group): # self.add(self.title) @abstractmethod - def make_forward_pass_animation(self, **kwargs): + def make_forward_pass_animation(self, layer_args={}, **kwargs): pass @override_animation(Create) @@ -51,7 +51,7 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer): assert isinstance(output_layer, self.output_class) @abstractmethod - def make_forward_pass_animation(self, **kwargs): + def make_forward_pass_animation(self, layer_args={}, **kwargs): pass @override_animation(Create) diff --git a/manim_ml/neural_network/layers/triplet.py b/manim_ml/neural_network/layers/triplet.py index c3b7ab0..991c1e9 100644 --- a/manim_ml/neural_network/layers/triplet.py +++ b/manim_ml/neural_network/layers/triplet.py @@ -71,6 +71,6 @@ class TripletLayer(NeuralNetworkLayer): # TODO make Create animation that is custom return FadeIn(self.assets) - def make_forward_pass_animation(self, **kwargs): + def make_forward_pass_animation(self, layer_args={}, **kwargs): """Forward pass for triplet""" return AnimationGroup() diff --git a/manim_ml/neural_network/layers/triplet_to_feed_forward.py b/manim_ml/neural_network/layers/triplet_to_feed_forward.py index 9f16333..22939d8 100644 --- a/manim_ml/neural_network/layers/triplet_to_feed_forward.py +++ b/manim_ml/neural_network/layers/triplet_to_feed_forward.py @@ -18,7 +18,7 @@ class TripletToFeedForward(ConnectiveLayer): self.feed_forward_layer = output_layer self.triplet_layer = input_layer - def make_forward_pass_animation(self, **kwargs): + def make_forward_pass_animation(self, layer_args={}, **kwargs): """Makes dots diverge from the given location and move to the feed forward nodes decoder""" animations = [] # Loop through each image diff --git a/manim_ml/neural_network/layers/vector.py b/manim_ml/neural_network/layers/vector.py index 2dcc45c..c05a73f 100644 --- a/manim_ml/neural_network/layers/vector.py +++ b/manim_ml/neural_network/layers/vector.py @@ -28,7 +28,7 @@ class VectorLayer(VGroupNeuralNetworkLayer): return vector_label - def make_forward_pass_animation(self, **kwargs): + def make_forward_pass_animation(self, layer_args={}, **kwargs): return AnimationGroup() @override_animation(Create) diff --git a/manim_ml/neural_network/neural_network.py b/manim_ml/neural_network/neural_network.py index e98d482..7246892 100644 --- a/manim_ml/neural_network/neural_network.py +++ b/manim_ml/neural_network/neural_network.py @@ -9,254 +9,17 @@ Example: # Create the object with default style settings NeuralNetwork(layer_node_count) """ +from cv2 import AGAST_FEATURE_DETECTOR_NONMAX_SUPPRESSION from manim import * import warnings import textwrap from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer +from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer from manim_ml.neural_network.layers.util import get_connective_layer from manim_ml.list_group import ListGroup +from manim_ml.neural_network.neural_network_transformations import InsertLayer, RemoveLayer -class RemoveLayer(AnimationGroup): - """ - Animation for removing a layer from a neural network. - - Note: I needed to do something strange for creating the new connective layer. - The issue with creating it intially is that the positions of the sides of the - connective layer depend upon the location of the moved layers **after** the - move animations are performed. However, all of these animations are performed - after the animations have been created. This means that the animation depends upon - the state of the neural network layers after previous animations have been run. - To fix this issue I needed to use an UpdateFromFunc. - """ - - def __init__(self, layer, neural_network, layer_spacing=0.2): - self.layer = layer - self.neural_network = neural_network - self.layer_spacing = layer_spacing - # Get the before and after layers - layers_tuple = self.get_connective_layers() - self.before_layer = layers_tuple[0] - self.after_layer = layers_tuple[1] - self.before_connective = layers_tuple[2] - self.after_connective = layers_tuple[3] - # Make the animations - remove_animations = self.make_remove_animation() - move_animations = self.make_move_animation() - new_connective_animation = self.make_new_connective_animation() - # Add all of the animations to the group - animations_list = [ - remove_animations, - move_animations, - new_connective_animation - ] - - super().__init__(*animations_list, lag_ratio=1.0) - - def get_connective_layers(self): - """Gets the connective layers before and after self.layer""" - # Get layer index - layer_index = self.neural_network.all_layers.index_of(self.layer) - if layer_index == -1: - raise Exception("Layer object not found") - # Get the layers before and after - before_layer = None - after_layer = None - before_connective = None - after_connective = None - if layer_index - 2 >= 0: - before_layer = self.neural_network.all_layers[layer_index - 2] - before_connective = self.neural_network.all_layers[layer_index - 1] - if layer_index + 2 < len(self.neural_network.all_layers): - after_layer = self.neural_network.all_layers[layer_index + 2] - after_connective = self.neural_network.all_layers[layer_index + 1] - - return before_layer, after_layer, before_connective, after_connective - - def make_remove_animation(self): - """Removes layer and the surrounding connective layers""" - remove_layer_animation = self.make_remove_layer_animation() - remove_connective_animation = self.make_remove_connective_layers_animation() - # Remove animations - remove_animations = AnimationGroup( - remove_layer_animation, - remove_connective_animation - ) - - return remove_animations - - def make_remove_layer_animation(self): - """Removes the layer""" - # Remove the layer - self.neural_network.all_layers.remove(self.layer) - # Fade out the removed layer - fade_out_removed = FadeOut(self.layer) - return fade_out_removed - - def make_remove_connective_layers_animation(self): - """Removes the connective layers before and after layer if they exist""" - # Fade out the removed connective layers - fade_out_before_connective = AnimationGroup() - if not self.before_connective is None: - self.neural_network.all_layers.remove(self.before_connective) - fade_out_before_connective = FadeOut(self.before_connective) - fade_out_after_connective = AnimationGroup() - if not self.after_connective is None: - self.neural_network.all_layers.remove(self.after_connective) - fade_out_after_connective = FadeOut(self.after_connective) - # Group items - remove_connective_group = AnimationGroup( - fade_out_after_connective, - fade_out_before_connective - ) - - return remove_connective_group - - def make_move_animation(self): - """Collapses layers""" - # Animate the movements - move_before_layers = AnimationGroup() - shift_right_amount = None - if not self.before_layer is None: - # Compute shift amount - layer_dist = np.abs(self.layer.get_center() - self.before_layer.get_right())[0] - shift_right_amount = np.array([layer_dist - self.layer_spacing/2, 0, 0]) - # Shift all layers before forward - before_layer_index = self.neural_network.all_layers.index_of(self.before_layer) - layers_before = Group(*self.neural_network.all_layers[:before_layer_index + 1]) - move_before_layers = layers_before.animate.shift(shift_right_amount) - move_after_layers = AnimationGroup() - shift_left_amount = None - if not self.after_layer is None: - layer_dist = np.abs(self.after_layer.get_left() - self.layer.get_center())[0] - shift_left_amount = np.array([-layer_dist + self.layer_spacing / 2, 0, 0]) - # Shift all layers after backward - after_layer_index = self.neural_network.all_layers.index_of(self.after_layer) - layers_after = Group(*self.neural_network.all_layers[after_layer_index:]) - move_after_layers = layers_after.animate.shift(shift_left_amount) - # Group the move animations - move_group = AnimationGroup( - move_before_layers, - move_after_layers - ) - - return move_group - - def make_new_connective_animation(self): - """Makes new connective layer""" - self.anim_count = 0 - def create_new_connective(neural_network): - """ - Creates new connective layer - - This is a closure that creates a new connective layer and animates it. - """ - self.anim_count += 1 - if self.anim_count == 1: - if not self.before_layer is None and not self.after_layer is None: - print(neural_network) - new_connective = get_connective_layer(self.before_layer, self.after_layer) - before_layer_index = neural_network.all_layers.index_of(self.before_layer) + 1 - neural_network.all_layers.insert(before_layer_index, new_connective) - print(neural_network) - - update_func_anim = UpdateFromFunc(self.neural_network, create_new_connective) - - return update_func_anim - -class InsertLayer(AnimationGroup): - """Animation for inserting layer at given index""" - - def __init__(self, layer, index, neural_network): - self.layer = layer - self.index = index - self.neural_network = neural_network - # Layers before and after - self.layers_before = self.neural_network.all_layers[:self.index] - self.layers_after = self.neural_network.all_layers[self.index:] - - remove_connective_layer = self.remove_connective_layer() - move_layers = self.make_move_layers() - # create_layer = self.make_create_layer() - # create_connective_layers = self.make_create_connective_layers() - animations = [ - remove_connective_layer, - move_layers, - # create_layer, - # create_connective_layers - ] - - super().__init__(*animations, lag_ratio=1.0) - - def remove_connective_layer(self): - """Removes the connective layer before the insertion index""" - # Check if connective layer exists - if len(self.layers_before) > 0: - removed_connective = self.layers_before[-1] - self.neural_network.all_layers.remove(removed_connective) - # Make remove animation - remove_animation = FadeOut(removed_connective) - return remove_animation - - return AnimationGroup() - - def make_move_layers(self): - """Shifts layers before and after""" - # Before layer shift - before_shift_animation = AnimationGroup() - if len(self.layers_before) > 0: - before_shift = np.array([-self.layer.width/2, 0, 0]) - # Shift layers before - before_shift_animation = Group(*self.layers_before).animate.shift(before_shift) - # After layer shift - after_shift_animation = AnimationGroup() - if len(self.layers_after) > 0: - after_shift = np.array([self.layer.width/2, 0, 0]) - # Shift layers after - after_shift_animation = Group(*self.layers_after).animate.shift(after_shift) - # Make animation group - shift_animations = AnimationGroup( - before_shift_animation, - after_shift_animation - ) - - return shift_animations - - def make_create_layer(self): - """Animates the creation of the layer""" - pass - - def make_create_connective_layers(self): - pass - - - # Make connective layers and shift animations - # Before layer - if len(layers_before) > 0: - before_connective = get_connective_layer(layers_before[-1], layer) - before_shift = np.array([-layer.width/2, 0, 0]) - # Shift layers before - before_shift_animation = Group(*layers_before).animate.shift(before_shift) - else: - before_connective = AnimationGroup() - # After layer - if len(layers_after) > 0: - after_connective = get_connective_layer(layer, layers_after[0]) - after_shift = np.array([layer.width/2, 0, 0]) - # Shift layers after - after_shift_animation = Group(*layers_after).animate.shift(after_shift) - else: - after_connective = AnimationGroup - - insert_animation = Create(layer) - animation_group = AnimationGroup( - shift_animations, - insert_animation, - lag_ratio=1.0 - ) - - return animation_group - class NeuralNetwork(Group): def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.2, @@ -347,19 +110,32 @@ class NeuralNetwork(Group): return animation_group - def make_forward_pass_animation(self, run_time=10, passing_flash=True, + def make_forward_pass_animation(self, run_time=10, passing_flash=True, layer_args={}, **kwargs): """Generates an animation for feed forward propagation""" all_animations = [] - for layer_index, layer in enumerate(self.input_layers[:-1]): - layer_forward_pass = layer.make_forward_pass_animation(**kwargs) + for layer_index, layer in enumerate(self.all_layers): + # Get the layer args + if isinstance(layer, ConnectiveLayer): + """ + NOTE: By default a connective layer will get the combined + layer_args of the layers it is connecting. + """ + before_layer_args = {} + after_layer_args = {} + if layer.input_layer in layer_args: + before_layer_args = layer_args[layer.input_layer] + if layer.output_layer in layer_args: + after_layer_args = layer_args[layer.output_layer] + # Merge the two dicts + current_layer_args = {**before_layer_args, **after_layer_args} + else: + current_layer_args = {} + if layer in layer_args: + current_layer_args = layer_args[layer] + # Perform the forward pass of the current layer + layer_forward_pass = layer.make_forward_pass_animation(layer_args=current_layer_args, **kwargs) all_animations.append(layer_forward_pass) - connective_layer = self.connective_layers[layer_index] - connective_forward_pass = connective_layer.make_forward_pass_animation(**kwargs) - all_animations.append(connective_forward_pass) - # Do last layer animation - last_layer_forward_pass = self.input_layers[-1].make_forward_pass_animation(**kwargs) - all_animations.append(last_layer_forward_pass) # Make the animation group animation_group = AnimationGroup(*all_animations, run_time=run_time, lag_ratio=1.0) diff --git a/tests/test_embedding_layer.py b/tests/test_embedding_layer.py new file mode 100644 index 0000000..2f518b2 --- /dev/null +++ b/tests/test_embedding_layer.py @@ -0,0 +1,54 @@ +from manim import * + +from manim_ml.neural_network.layers.embedding import EmbeddingLayer +from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer +from manim_ml.neural_network.neural_network import NeuralNetwork + +config.pixel_height = 720 +config.pixel_width = 1280 +config.frame_height = 5.0 +config.frame_width = 5.0 + +class EmbeddingNNScene(Scene): + + def construct(self): + embedding_layer = EmbeddingLayer() + + neural_network = NeuralNetwork([ + FeedForwardLayer(5), + FeedForwardLayer(3), + embedding_layer, + FeedForwardLayer(3), + FeedForwardLayer(5) + ]) + + self.play(Create(neural_network)) + + self.play(neural_network.make_forward_pass_animation(run_time=5)) + +class QueryEmbeddingNNScene(Scene): + + def construct(self): + embedding_layer = EmbeddingLayer() + embedding_layer.paired_query_mode = True + + neural_network = NeuralNetwork([ + FeedForwardLayer(5), + FeedForwardLayer(3), + embedding_layer, + FeedForwardLayer(3), + FeedForwardLayer(5) + ]) + + self.play(Create(neural_network), run_time=2) + + self.play( + neural_network.make_forward_pass_animation( + run_time=5, + layer_args={ + embedding_layer: { + "query_locations": (np.array([2, 2]), np.array([1, 1])) + } + } + ) + ) \ No newline at end of file