diff --git a/manim_ml/list_group.py b/manim_ml/list_group.py new file mode 100644 index 0000000..2dff075 --- /dev/null +++ b/manim_ml/list_group.py @@ -0,0 +1,73 @@ +from manim import * + +class ListGroup(Mobject): + """Indexable Group with traditional list operations""" + + def __init__(self, *layers): + super().__init__() + self.items = [*layers] + + def __getitem__(self, indices): + """Traditional list indexing""" + return self.items[indices] + + def insert(self, index, item): + """Inserts item at index""" + self.items.insert(index, item) + self.submobjects = self.items + + def remove_at_index(self, index): + """Removes item at index""" + if index < 0 or index > len(self.items): + raise Exception(f"ListGroup index out of range: {index}") + item = self.items[index] + del self.items[index] + self.submobjects = self.items + + return item + + def remove_at_indices(self, indices): + """Removes items at indices""" + items = [] + for index in indices: + item = self.remove_at_index(index) + items.append(item) + + return items + + def remove(self, item): + """Removes first instance of item""" + self.items.remove(item) + self.submobjects = self.items + + return item + + def get(self, index): + """Gets item at index""" + return self.items[index] + + def add(self, item): + """Adds to end""" + self.items.append(item) + self.submobjects = self.items + + def replace(self, index, item): + """Replaces item at index""" + self.items[index] = item + self.submobjects = self.items + + def index_of(self, item): + """Returns index of item if it exists""" + for index, obj in enumerate(self.items): + if item is obj: + return index + return -1 + + def __len__(self): + """Length of items""" + return len(self.items) + + def set_z_index(self, z_index_value): + """Sets z index of all values in ListGroup""" + for item in self.items: + item.set_z_index(z_index_value) \ No newline at end of file diff --git a/manim_ml/neural_network/layers/__init__.py b/manim_ml/neural_network/layers/__init__.py index d4fd61f..3e2f6d5 100644 --- a/manim_ml/neural_network/layers/__init__.py +++ b/manim_ml/neural_network/layers/__init__.py @@ -23,4 +23,4 @@ connective_layers_list = ( PairedQueryToFeedForward, TripletToFeedForward, PairedQueryToFeedForward, -) \ No newline at end of file +) diff --git a/manim_ml/neural_network/layers/embedding.py b/manim_ml/neural_network/layers/embedding.py index f79fdc9..edbf277 100644 --- a/manim_ml/neural_network/layers/embedding.py +++ b/manim_ml/neural_network/layers/embedding.py @@ -5,8 +5,8 @@ from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLaye class EmbeddingLayer(VGroupNeuralNetworkLayer): """NeuralNetwork embedding object that can show probability distributions""" - def __init__(self, point_radius=0.02): - super(EmbeddingLayer, self).__init__() + def __init__(self, point_radius=0.02, **kwargs): + super(EmbeddingLayer, self).__init__(**kwargs) self.point_radius = point_radius self.axes = Axes( tips=False, diff --git a/manim_ml/neural_network/layers/embedding_to_feed_forward.py b/manim_ml/neural_network/layers/embedding_to_feed_forward.py index a01cf2d..d3fa461 100644 --- a/manim_ml/neural_network/layers/embedding_to_feed_forward.py +++ b/manim_ml/neural_network/layers/embedding_to_feed_forward.py @@ -8,8 +8,10 @@ class EmbeddingToFeedForward(ConnectiveLayer): input_class = EmbeddingLayer output_class = FeedForwardLayer - def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03): - super().__init__(input_layer, output_layer, input_class=EmbeddingLayer, output_class=FeedForwardLayer) + def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03, + **kwargs): + super().__init__(input_layer, output_layer, input_class=EmbeddingLayer, output_class=FeedForwardLayer, + **kwargs) self.feed_forward_layer = output_layer self.embedding_layer = input_layer self.animation_dot_color = animation_dot_color diff --git a/manim_ml/neural_network/layers/feed_forward.py b/manim_ml/neural_network/layers/feed_forward.py index 5a23d1b..8576f2c 100644 --- a/manim_ml/neural_network/layers/feed_forward.py +++ b/manim_ml/neural_network/layers/feed_forward.py @@ -7,8 +7,8 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer): def __init__(self, num_nodes, layer_buffer=SMALL_BUFF/2, node_radius=0.08, node_color=BLUE, node_outline_color=WHITE, rectangle_color=WHITE, node_spacing=0.3, rectangle_fill_color=BLACK, node_stroke_width=2.0, - rectangle_stroke_width=2.0, animation_dot_color=RED): - super(VGroupNeuralNetworkLayer, self).__init__() + rectangle_stroke_width=2.0, animation_dot_color=RED, **kwargs): + super(VGroupNeuralNetworkLayer, self).__init__(**kwargs) self.num_nodes = num_nodes self.layer_buffer = layer_buffer self.node_radius = node_radius diff --git a/manim_ml/neural_network/layers/feed_forward_to_embedding.py b/manim_ml/neural_network/layers/feed_forward_to_embedding.py index d05af4d..39fb6d1 100644 --- a/manim_ml/neural_network/layers/feed_forward_to_embedding.py +++ b/manim_ml/neural_network/layers/feed_forward_to_embedding.py @@ -8,8 +8,10 @@ class FeedForwardToEmbedding(ConnectiveLayer): input_class = FeedForwardLayer output_class = EmbeddingLayer - def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03): - super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=EmbeddingLayer) + def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03, + **kwargs): + super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=EmbeddingLayer, + **kwargs) self.feed_forward_layer = input_layer self.embedding_layer = output_layer self.animation_dot_color = animation_dot_color diff --git a/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py b/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py index 0a05ae6..0e72226 100644 --- a/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py +++ b/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py @@ -9,8 +9,9 @@ class FeedForwardToFeedForward(ConnectiveLayer): def __init__(self, input_layer, output_layer, passing_flash=True, dot_radius=0.05, animation_dot_color=RED, edge_color=WHITE, - edge_width=0.5): - super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=FeedForwardLayer) + edge_width=1.5, **kwargs): + super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=FeedForwardLayer, + **kwargs) self.passing_flash = passing_flash self.edge_color = edge_color self.dot_radius = dot_radius diff --git a/manim_ml/neural_network/layers/feed_forward_to_image.py b/manim_ml/neural_network/layers/feed_forward_to_image.py index 7dc24da..5dfe07c 100644 --- a/manim_ml/neural_network/layers/feed_forward_to_image.py +++ b/manim_ml/neural_network/layers/feed_forward_to_image.py @@ -9,8 +9,9 @@ class FeedForwardToImage(ConnectiveLayer): output_class = ImageLayer def __init__(self, input_layer, output_layer, animation_dot_color=RED, - dot_radius=0.05): - super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=ImageLayer) + dot_radius=0.05, **kwargs): + super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=ImageLayer + **kwargs) self.animation_dot_color = animation_dot_color self.dot_radius = dot_radius diff --git a/manim_ml/neural_network/layers/image.py b/manim_ml/neural_network/layers/image.py index 6a7d324..8c07aa3 100644 --- a/manim_ml/neural_network/layers/image.py +++ b/manim_ml/neural_network/layers/image.py @@ -5,9 +5,8 @@ from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer class ImageLayer(NeuralNetworkLayer): """Single Image Layer for Neural Network""" - def __init__(self, numpy_image, height=1.5): - super().__init__() - self.set_z_index(1) + def __init__(self, numpy_image, height=1.5, **kwargs): + super().__init__(**kwargs) self.numpy_image = numpy_image if len(np.shape(self.numpy_image)) == 2: # Assumed Grayscale diff --git a/manim_ml/neural_network/layers/image_to_feed_forward.py b/manim_ml/neural_network/layers/image_to_feed_forward.py index 4e5e096..86568f9 100644 --- a/manim_ml/neural_network/layers/image_to_feed_forward.py +++ b/manim_ml/neural_network/layers/image_to_feed_forward.py @@ -9,8 +9,9 @@ class ImageToFeedForward(ConnectiveLayer): output_class = FeedForwardLayer def __init__(self, input_layer, output_layer, animation_dot_color=RED, - dot_radius=0.05): - super().__init__(input_layer, output_layer, input_class=ImageLayer, output_class=FeedForwardLayer) + dot_radius=0.05, **kwargs): + super().__init__(input_layer, output_layer, input_class=ImageLayer, output_class=FeedForwardLayer + **kwargs) self.animation_dot_color = animation_dot_color self.dot_radius = dot_radius diff --git a/manim_ml/neural_network/layers/paired_query.py b/manim_ml/neural_network/layers/paired_query.py index add053e..5e3dcaa 100644 --- a/manim_ml/neural_network/layers/paired_query.py +++ b/manim_ml/neural_network/layers/paired_query.py @@ -6,8 +6,8 @@ import numpy as np class PairedQueryLayer(NeuralNetworkLayer): """Paired Query Layer""" - def __init__(self, positive, negative, stroke_width=5): - super().__init__() + def __init__(self, positive, negative, stroke_width=5, **kwargs): + super().__init__(**kwargs) self.positive = positive self.negative = negative diff --git a/manim_ml/neural_network/layers/paired_query_to_feed_forward.py b/manim_ml/neural_network/layers/paired_query_to_feed_forward.py index 41bfd3d..c1f6d97 100644 --- a/manim_ml/neural_network/layers/paired_query_to_feed_forward.py +++ b/manim_ml/neural_network/layers/paired_query_to_feed_forward.py @@ -8,8 +8,9 @@ class PairedQueryToFeedForward(ConnectiveLayer): input_class = PairedQueryLayer output_class = FeedForwardLayer - def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.02): - super().__init__(input_layer, output_layer, input_class=PairedQueryLayer, output_class=FeedForwardLayer) + def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.02, **kwargs): + super().__init__(input_layer, output_layer, input_class=PairedQueryLayer, output_class=FeedForwardLayer + **kwargs) self.animation_dot_color = animation_dot_color self.dot_radius = dot_radius diff --git a/manim_ml/neural_network/layers/parent_layers.py b/manim_ml/neural_network/layers/parent_layers.py index dd65a1e..790f3c7 100644 --- a/manim_ml/neural_network/layers/parent_layers.py +++ b/manim_ml/neural_network/layers/parent_layers.py @@ -4,9 +4,8 @@ from abc import ABC, abstractmethod class NeuralNetworkLayer(ABC, Group): """Abstract Neural Network Layer class""" - def __init__(self, **kwargs): + def __init__(self, text=None, **kwargs): super(Group, self).__init__() - self.set_z_index(1) @abstractmethod def make_forward_pass_animation(self): @@ -28,8 +27,9 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer): """Forward pass animation for a given pair of layers""" @abstractmethod - def __init__(self, input_layer, output_layer, input_class=None, output_class=None): - super(VGroupNeuralNetworkLayer, self).__init__() + def __init__(self, input_layer, output_layer, input_class=None, output_class=None, + **kwargs): + super(VGroupNeuralNetworkLayer, self).__init__(**kwargs) self.input_layer = input_layer self.output_layer = output_layer self.input_class = input_class @@ -38,8 +38,6 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer): assert isinstance(input_layer, self.input_class) assert isinstance(output_layer, self.output_class) - self.set_z_index(-1) - @abstractmethod def make_forward_pass_animation(self): pass \ No newline at end of file diff --git a/manim_ml/neural_network/layers/triplet.py b/manim_ml/neural_network/layers/triplet.py index 1415bad..936c42e 100644 --- a/manim_ml/neural_network/layers/triplet.py +++ b/manim_ml/neural_network/layers/triplet.py @@ -6,8 +6,9 @@ import numpy as np class TripletLayer(NeuralNetworkLayer): """Shows triplet images""" - def __init__(self, anchor, positive, negative, stroke_width=5): - super().__init__() + def __init__(self, anchor, positive, negative, stroke_width=5, + **kwargs): + super().__init__(**kwargs) self.anchor = anchor self.positive = positive self.negative = negative diff --git a/manim_ml/neural_network/layers/triplet_to_feed_forward.py b/manim_ml/neural_network/layers/triplet_to_feed_forward.py index 75e5f6e..c5206d1 100644 --- a/manim_ml/neural_network/layers/triplet_to_feed_forward.py +++ b/manim_ml/neural_network/layers/triplet_to_feed_forward.py @@ -9,8 +9,9 @@ class TripletToFeedForward(ConnectiveLayer): output_class = FeedForwardLayer def __init__(self, input_layer, output_layer, animation_dot_color=RED, - dot_radius=0.02): - super().__init__(input_layer, output_layer, input_class=TripletLayer, output_class=FeedForwardLayer) + dot_radius=0.02, **kwargs): + super().__init__(input_layer, output_layer, input_class=TripletLayer, output_class=FeedForwardLayer, + **kwargs) self.animation_dot_color = animation_dot_color self.dot_radius = dot_radius diff --git a/manim_ml/neural_network/layers/util.py b/manim_ml/neural_network/layers/util.py new file mode 100644 index 0000000..65654ba --- /dev/null +++ b/manim_ml/neural_network/layers/util.py @@ -0,0 +1,20 @@ +from manim import * +from ..layers import connective_layers_list + +def get_connective_layer(input_layer, output_layer): + """ + Deduces the relevant connective layer + """ + connective_layer = None + for connective_layer_class in connective_layers_list: + input_class = connective_layer_class.input_class + output_class = connective_layer_class.output_class + if isinstance(input_layer, input_class) \ + and isinstance(output_layer, output_class): + connective_layer = connective_layer_class(input_layer, output_layer) + + if connective_layer is None: + raise Exception(f"Unrecognized class pair {input_layer.__class__.__name__}" + \ + " and {output_layer.__class__.__name__}") + + return connective_layer diff --git a/manim_ml/neural_network/neural_network.py b/manim_ml/neural_network/neural_network.py index 510f8aa..463378d 100644 --- a/manim_ml/neural_network/neural_network.py +++ b/manim_ml/neural_network/neural_network.py @@ -9,24 +9,23 @@ Example: # Create the object with default style settings NeuralNetwork(layer_node_count) """ +from socket import create_connection +from urllib.parse import non_hierarchical from manim import * import warnings import textwrap -from manim_ml.neural_network.layers import \ - FeedForwardLayer, FeedForwardToFeedForward, ImageLayer, \ - ImageToFeedForward, FeedForwardToImage, EmbeddingLayer, \ - EmbeddingToFeedForward, FeedForwardToEmbedding, TripletLayer, \ - TripletToFeedForward - from manim_ml.neural_network.layers import connective_layers_list +from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer +from manim_ml.neural_network.layers.util import get_connective_layer +from manim_ml.list_group import ListGroup class NeuralNetwork(Group): def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8, - animation_dot_color=RED, edge_width=1.5, dot_radius=0.03): + animation_dot_color=RED, edge_width=2.5, dot_radius=0.03): super(Group, self).__init__() - self.input_layers = Group(*input_layers) + self.input_layers = ListGroup(*input_layers) self.edge_width = edge_width self.edge_color = edge_color self.layer_spacing = layer_spacing @@ -37,6 +36,9 @@ class NeuralNetwork(Group): # and make it have explicit distinct subspaces self._place_layers() self.connective_layers, self.all_layers = self._construct_connective_layers() + # Place layers at correct z index + self.connective_layers.set_z_index(2) + self.input_layers.set_z_index(3) # Center the whole diagram by default self.all_layers.move_to(ORIGIN) self.add(self.all_layers) @@ -50,17 +52,14 @@ class NeuralNetwork(Group): for layer_index in range(1, len(self.input_layers)): previous_layer = self.input_layers[layer_index - 1] current_layer = self.input_layers[layer_index] - current_layer.move_to(previous_layer) shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + 0.2, 0, 0]) current_layer.shift(shift_vector) - # Handle layering - self.input_layers.set_z_index(2) def _construct_connective_layers(self): """Draws connecting lines between layers""" - connective_layers = Group() - all_layers = Group() + connective_layers = ListGroup() + all_layers = ListGroup() for layer_index in range(len(self.input_layers) - 1): current_layer = self.input_layers[layer_index] all_layers.add(current_layer) @@ -72,28 +71,196 @@ class NeuralNetwork(Group): if isinstance(next_layer, NeuralNetwork): # First layer of the next layer next_layer = next_layer.all_layers[0] - # Find connective layer with correct layer pair - connective_layer = None - for connective_layer_class in connective_layers_list: - input_class = connective_layer_class.input_class - output_class = connective_layer_class.output_class - if isinstance(current_layer, input_class) \ - and isinstance(next_layer, output_class): - connective_layer = connective_layer_class(current_layer, next_layer) - - connective_layers.add(connective_layer) - all_layers.add(connective_layer) - - if connective_layer is None: - raise Exception(f"Unrecognized class pair {current_layer.__class__.__name__} and {next_layer.__class__.__name__}") + connective_layer = get_connective_layer(current_layer, next_layer) + connective_layers.add(connective_layer) + all_layers.add(connective_layer) # Add final layer all_layers.add(self.input_layers[-1]) # Handle layering return connective_layers, all_layers + def insert_layer(self, layer, insert_index): + """Inserts a layer at the given index""" + layers_before = self.all_layers[:insert_index] + layers_after = self.all_layers[insert_index:] + # Make connective layers and shift animations + # Before layer + if len(layers_before) > 0: + before_connective = get_connective_layer(layers_before[-1], layer) + before_shift = np.array([-layer.width/2, 0, 0]) + # Shift layers before + before_shift_animation = Group(*layers_before).animate.shift(before_shift) + else: + before_connective = AnimationGroup() + # After layer + if len(layers_after) > 0: + after_connective = get_connective_layer(layer, layers_after[0]) + after_shift = np.array([layer.width/2, 0, 0]) + # Shift layers after + after_shift_animation = Group(*layers_after).animate.shift(after_shift) + else: + after_connective = AnimationGroup + + # Make animation group + shift_animations = AnimationGroup( + before_shift_animation, + after_shift_animation + ) + + insert_animation = Create(layer) + animation_group = AnimationGroup( + shift_animations, + insert_animation, + lag_ratio=1.0 + ) + + return animation_group + + def remove_layer(self, layer): + """Removes layer object if it exists""" + # Get layer index + layer_index = self.all_layers.index_of(layer) + if layer_index == -1: + raise Exception("Layer object not found") + # Get the layers before and after + before_layer = None + after_layer = None + if layer_index - 2 >= 0: + before_layer = self.all_layers[layer_index - 2] + if layer_index + 2 < len(self.all_layers): + after_layer = self.all_layers[layer_index + 2] + # Remove the layer + self.all_layers.remove(layer) + # Remove surrounding connective layers from self.all_layers + before_connective = None + after_connective = None + if layer_index - 1 >= 0: + # There is a layer before + before_connective = self.all_layers.remove_at_index(layer_index - 1) + if layer_index + 1 < len(self.all_layers): + # There is a layer after + after_connective = self.all_layers.remove_at_index(layer_index + 1) + # Make animations + # Fade out the removed layer + fade_out_removed = FadeOut(layer) + # Fade out the removed connective layers + fade_out_before_connective = Animation() + if not before_connective is None: + fade_out_before_connective = FadeOut(before_connective) + fade_out_after_connective = Animation() + if not after_connective is None: + fade_out_after_connective = FadeOut(after_connective) + # Create new connective layer + new_connective = None + if not before_layer is None and not after_layer is None: + new_connective = get_connective_layer(before_layer, after_layer) + before_layer_index = self.all_layers.index_of(before_layer) + self.all_layers.insert(before_layer_index, new_connective) + # Place the new connective + new_connective.move_to(layer) + # Animate the creation of the new connective layer + create_new_connective = Animation() + if not new_connective is None: + create_new_connective = Create(new_connective) + # Collapse the neural network to fill the empty space + removed_width = layer.width + before_connective.width + after_connective.width - new_connective.width + shift_right_amount = np.array([removed_width / 2, 0, 0]) + shift_left_amount = np.array([-removed_width / 2, 0, 0]) + move_before_layer = Animation() + if not before_layer is None: + move_before_layer = before_layer.animate.shift(shift_right_amount) + move_after_layer = Animation() + if not after_layer is None: + move_after_layer = after_layer.animate.shift(shift_left_amount) + # Make the final AnimationGroup + fade_out_group = AnimationGroup( + fade_out_removed, + fade_out_before_connective, + fade_out_after_connective + ) + move_group = AnimationGroup( + move_before_layer, + move_after_layer + ) + animation_group = AnimationGroup( + fade_out_group, + move_group, + create_new_connective, + lag_ratio=1.0 + ) + + return animation_group + + """ + remove_layer = list(self.all_layers)[remove_index] + if remove_index > 0: + connective_before = list(self.all_layers)[remove_index - 1] + else: + connective_before = None + if remove_index < len(list(self.all_layers)) - 1: + connective_after = list(self.all_layers)[remove_index + 1] + else: + connective_after = None + # Collapse the surrounding layer + layers_before = list(self.all_layers)[:remove_index] + layers_after = list(self.all_layers)[remove_index+1:] + before_group = Group(*layers_before) + after_group = Group(*layers_after) + before_shift_amount = np.array([remove_layer.width/2, 0, 0]) + after_shift_amount = np.array([-remove_layer.width/2, 0, 0]) + # Remove the layers from the neural network representation + self.all_layers.remove(remove_layer) + if not connective_before is None: + self.all_layers.remove(connective_before) + if not connective_after is None: + self.all_layers.remove(connective_after) + # Connect the layers before and layers after + pre_index = remove_index - 1 + pre_layer = None + if pre_index >= 0: + pre_layer = list(self.all_layers)[pre_index] + post_index = remove_index + post_layer = None + if post_index < len(list(self.all_layers)): + post_layer = list(self.all_layers)[post_index] + if not pre_layer is None and not post_layer is None: + connective_layer = get_connective_layer(pre_layer, post_layer) + self.all_layers = Group( + *self.all_layers[:remove_index], + connective_layer, + *self.all_layers[remove_index:] + ) + # Make animations + fade_out_animation = FadeOut(remove_layer) + shift_animations = AnimationGroup( + before_group.animate.shift(before_shift_amount), + after_group.animate.shift(after_shift_amount) + ) + animation_group = AnimationGroup( + fade_out_animation, + shift_animations, + lag_ratio=1.0 + ) + + return animation_group + """ + + def replace_layer(self, old_layer, new_layer): + """Replaces given layer object""" + remove_animation = self.remove_layer(insert_index) + insert_animation = self.insert_layer(layer, insert_index) + # Make the animation + animation_group = AnimationGroup( + FadeOut(self.all_layers[insert_index]), + FadeIn(layer), + lag_ratio=1.0 + ) + + return animation_group + def make_forward_pass_animation(self, run_time=10, passing_flash=True): - """Generates an animation for feed forward propogation""" + """Generates an animation for feed forward propagation""" all_animations = [] for layer_index, layer in enumerate(self.input_layers[:-1]): layer_forward_pass = layer.make_forward_pass_animation() @@ -126,19 +293,11 @@ class NeuralNetwork(Group): return animation_group - def remove_layer(self, layer_index): - """Removes layer at given index and returns animation for removing the layer""" - raise NotImplementedError() - - def add_layer(self, layer): - """Adds layer and returns animation for adding action""" - raise NotImplementedError() - def __repr__(self): """Print string representation of layers""" inner_string = "" for layer in self.all_layers: - inner_string += f"{repr(layer)},\n" + inner_string += f"{repr(layer)} {layer.z_index} ,\n" inner_string = textwrap.indent(inner_string, " ") string_repr = "NeuralNetwork([\n" + inner_string + "])" diff --git a/manim_ml/probability.py b/manim_ml/probability.py index 2793ec5..c765853 100644 --- a/manim_ml/probability.py +++ b/manim_ml/probability.py @@ -16,8 +16,7 @@ class GaussianDistribution(VGroup): self.cov = np.array([[3, 0], [0, 3]]) # Make the Gaussian self.ellipses = self.construct_gaussian_distribution(self.mean, self.cov) - self.ellipses.set_z_index(2) - + @override_animation(Create) def _create_gaussian_distribution(self): return Create(self.ellipses) diff --git a/tests/test_neural_network.py b/tests/test_neural_network.py index d83721d..f47d83b 100644 --- a/tests/test_neural_network.py +++ b/tests/test_neural_network.py @@ -1,6 +1,10 @@ +from cv2 import exp from manim import * from manim_ml.neural_network.layers.embedding import EmbeddingLayer +from manim_ml.neural_network.layers.embedding_to_feed_forward import EmbeddingToFeedForward from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer +from manim_ml.neural_network.layers.feed_forward_to_embedding import FeedForwardToEmbedding +from manim_ml.neural_network.layers.feed_forward_to_feed_forward import FeedForwardToFeedForward from manim_ml.neural_network.layers.image import ImageLayer from manim_ml.neural_network.neural_network import NeuralNetwork, FeedForwardNeuralNetwork from PIL import Image @@ -11,6 +15,72 @@ config.pixel_width = 1280 config.frame_height = 6.0 config.frame_width = 6.0 +""" + Unit Tests +""" + +def assert_classes_match(all_layers, expected_classes): + assert len(list(all_layers)) == 5 + + for index, layer in enumerate(all_layers): + expected_class = expected_classes[index] + assert isinstance(layer, expected_class), f"Wrong layer class {layer.__class__} expected {expected_class}" + +def test_embedding_layer(): + embedding_layer = EmbeddingLayer() + + neural_network = NeuralNetwork([ + FeedForwardLayer(5), + FeedForwardLayer(3), + embedding_layer + ]) + + expected_classes = [ + FeedForwardLayer, + FeedForwardToFeedForward, + FeedForwardLayer, + FeedForwardToEmbedding, + EmbeddingLayer + ] + + assert_classes_match(neural_network.all_layers, expected_classes) + + +def test_remove_layer(): + embedding_layer = EmbeddingLayer() + + neural_network = NeuralNetwork([ + FeedForwardLayer(5), + FeedForwardLayer(3), + embedding_layer + ]) + + expected_classes = [ + FeedForwardLayer, + FeedForwardToFeedForward, + FeedForwardLayer, + FeedForwardToEmbedding, + EmbeddingLayer + ] + + assert_classes_match(neural_network.all_layers, expected_classes) + + print("before removal") + print(list(neural_network.all_layers)) + neural_network.remove_layer(embedding_layer) + print("after removal") + print(list(neural_network.all_layers)) + + expected_classes = [ + FeedForwardLayer, + FeedForwardToFeedForward, + FeedForwardLayer, + ] + + print(list(neural_network.all_layers)) + + assert_classes_match(neural_network.all_layers, expected_classes) + class FeedForwardNeuralNetworkScene(Scene): def construct(self): @@ -92,6 +162,31 @@ class RecursiveNNScene(Scene): self.play(Create(nn)) +class LayerInsertionScene(Scene): + + def construct(self): + pass + +class LayerRemovalScene(Scene): + + def construct(self): + image = Image.open('images/image.jpeg') + numpy_image = np.asarray(image) + + layer = FeedForwardLayer(5), + layers = [ + ImageLayer(numpy_image, height=1.4), + FeedForwardLayer(3), + layer, + FeedForwardLayer(3), + FeedForwardLayer(6) + ] + + nn = NeuralNetwork(layers) + + self.play(Create(nn)) + self.play(nn.remove_layer(layer)) + if __name__ == "__main__": """Render all scenes""" # Feed Forward Neural Network