Added ability to pass layer_args dictionary to each forward pass, which allows

for arguments to be passed through to each neural network layer when running a neural
network forward pass.
This commit is contained in:
Alec Helbling
2022-04-25 16:28:11 -04:00
parent 7d04bf55ec
commit 63427be139
18 changed files with 145 additions and 282 deletions

View File

@ -16,7 +16,7 @@ class ConvolutionalLayer(VGroupNeuralNetworkLayer):
"""Creates the neural network layer"""
pass
def make_forward_pass_animation(self):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
# make highlight animation
return None

View File

@ -6,16 +6,29 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
"""NeuralNetwork embedding object that can show probability distributions"""
def __init__(self, point_radius=0.02, mean = np.array([0, 0]),
covariance=np.array([[1.5, 0], [0, 1.5]]), dist_theme="gaussian", **kwargs):
covariance=np.array([[1.0, 0], [0, 1.0]]), dist_theme="gaussian",
paired_query_mode=False, **kwargs):
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
self.point_radius = point_radius
self.dist_theme = dist_theme
self.paired_query_mode = paired_query_mode
self.axes = Axes(
tips=False,
x_length=0.8,
y_length=0.8
y_length=0.8,
x_range=(-2.0, 2.0),
y_range=(-2.0, 2.0),
x_axis_config={
"include_ticks": False,
"stroke_width": 0.0
},
y_axis_config={
"include_ticks": False,
"stroke_width": 0.0
}
)
self.add(self.axes)
self.axes.move_to(self.get_center())
# Make point cloud
self.point_cloud = self.construct_gaussian_point_cloud(mean, covariance)
self.add(self.point_cloud)
@ -51,24 +64,44 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
return point_dots
def make_forward_pass_animation(self, **kwargs):
"""Forward pass animation"""
# Make ellipse object corresponding to the latent distribution
self.latent_distribution = GaussianDistribution(
self.axes,
dist_theme=self.dist_theme,
cov=np.array([[0.8, 0], [0.0, 0.8]])
) # Use defaults
# Create animation
def make_paired_query_embedding_animation(self):
"""Embed paired query"""
animations = []
#create_distribution = Create(self.latent_distribution.construct_gaussian_distribution(self.latent_distribution.mean, self.latent_distribution.cov)) #Create(self.latent_distribution)
create_distribution = Create(self.latent_distribution.ellipses)
animations.append(create_distribution)
animation_group = AnimationGroup(*animations)
# Make the animation
# Animation group
animation_group = AnimationGroup(
*animations,
lag_ratio=1.0
)
return animation_group
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Forward pass animation"""
animations = []
if not self.paired_query_mode:
# Normal embedding mode
# Make ellipse object corresponding to the latent distribution
self.latent_distribution = GaussianDistribution(
self.axes,
dist_theme=self.dist_theme,
cov=np.array([[0.8, 0], [0.0, 0.8]])
) # Use defaults
# Create animation
#create_distribution = Create(self.latent_distribution.construct_gaussian_distribution(self.latent_distribution.mean, self.latent_distribution.cov)) #Create(self.latent_distribution)
create_distribution = Create(self.latent_distribution.ellipses)
animations.append(create_distribution)
animation_group = AnimationGroup(*animations)
return animation_group
else:
# Paired Query Mode
# Handle logic for embedding a paired query into the embedding layer
paired_query_embedding_animation = self.make_paired_query_embedding_animation()
return paired_query_embedding_animation
@override_animation(Create)
def _create_override(self, **kwargs):
# Plot each point at once

View File

@ -17,7 +17,7 @@ class EmbeddingToFeedForward(ConnectiveLayer):
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
def make_forward_pass_animation(self, run_time=1.5, **kwargs):
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
"""Makes dots diverge from the given location and move the decoder"""
# Find point to converge on by sampling from gaussian distribution
location = self.embedding_layer.sample_point_location_from_distribution()

View File

@ -44,7 +44,7 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
# Add the objects to the class
self.add(self.surrounding_rectangle, self.node_group)
def make_forward_pass_animation(self, **kwargs):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
# make highlight animation
succession = Succession(
ApplyMethod(self.node_group.set_color, self.animation_dot_color, run_time=0.25),

View File

@ -17,7 +17,7 @@ class FeedForwardToEmbedding(ConnectiveLayer):
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
def make_forward_pass_animation(self, run_time=1.5):
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
"""Makes dots converge on a specific location"""
# Find point to converge on by sampling from gaussian distribution
location = self.embedding_layer.sample_point_location_from_distribution()

View File

@ -44,7 +44,7 @@ class FeedForwardToFeedForward(ConnectiveLayer):
return animation_group
def make_forward_pass_animation(self, run_time=1, **kwargs):
def make_forward_pass_animation(self, layer_args={}, run_time=1, **kwargs):
"""Animation for passing information from one FeedForwardLayer to the next"""
path_animations = []
dots = []

View File

@ -18,7 +18,7 @@ class FeedForwardToImage(ConnectiveLayer):
self.feed_forward_layer = input_layer
self.image_layer = output_layer
def make_forward_pass_animation(self, **kwargs):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
animations = []
image_mobject = self.image_layer.image_mobject

View File

@ -18,7 +18,7 @@ class FeedForwardToVector(ConnectiveLayer):
self.feed_forward_layer = input_layer
self.vector_layer = output_layer
def make_forward_pass_animation(self):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
animations = []
# Move the dots to the centers of each of the nodes in the FeedForwardLayer

View File

@ -27,7 +27,7 @@ class ImageLayer(NeuralNetworkLayer):
else:
return AnimationGroup()
def make_forward_pass_animation(self, **kwargs):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
return FadeIn(self.image_mobject)
# def move_to(self, location):

View File

@ -18,7 +18,7 @@ class ImageToFeedForward(ConnectiveLayer):
self.feed_forward_layer = output_layer
self.image_layer = input_layer
def make_forward_pass_animation(self, **kwargs):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
animations = []
dots = []

View File

@ -60,6 +60,6 @@ class PairedQueryLayer(NeuralNetworkLayer):
# TODO make Create animation that is custom
return FadeIn(self.assets)
def make_forward_pass_animation(self, **kwargs):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Forward pass for query"""
return AnimationGroup()

View File

@ -17,7 +17,7 @@ class PairedQueryToFeedForward(ConnectiveLayer):
self.paired_query_layer = input_layer
self.feed_forward_layer = output_layer
def make_forward_pass_animation(self, **kwargs):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
animations = []
# Loop through each image

View File

@ -12,7 +12,7 @@ class NeuralNetworkLayer(ABC, Group):
# self.add(self.title)
@abstractmethod
def make_forward_pass_animation(self, **kwargs):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
pass
@override_animation(Create)
@ -51,7 +51,7 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
assert isinstance(output_layer, self.output_class)
@abstractmethod
def make_forward_pass_animation(self, **kwargs):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
pass
@override_animation(Create)

View File

@ -71,6 +71,6 @@ class TripletLayer(NeuralNetworkLayer):
# TODO make Create animation that is custom
return FadeIn(self.assets)
def make_forward_pass_animation(self, **kwargs):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Forward pass for triplet"""
return AnimationGroup()

View File

@ -18,7 +18,7 @@ class TripletToFeedForward(ConnectiveLayer):
self.feed_forward_layer = output_layer
self.triplet_layer = input_layer
def make_forward_pass_animation(self, **kwargs):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
animations = []
# Loop through each image

View File

@ -28,7 +28,7 @@ class VectorLayer(VGroupNeuralNetworkLayer):
return vector_label
def make_forward_pass_animation(self, **kwargs):
def make_forward_pass_animation(self, layer_args={}, **kwargs):
return AnimationGroup()
@override_animation(Create)

View File

@ -9,254 +9,17 @@ Example:
# Create the object with default style settings
NeuralNetwork(layer_node_count)
"""
from cv2 import AGAST_FEATURE_DETECTOR_NONMAX_SUPPRESSION
from manim import *
import warnings
import textwrap
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
from manim_ml.neural_network.layers.util import get_connective_layer
from manim_ml.list_group import ListGroup
from manim_ml.neural_network.neural_network_transformations import InsertLayer, RemoveLayer
class RemoveLayer(AnimationGroup):
"""
Animation for removing a layer from a neural network.
Note: I needed to do something strange for creating the new connective layer.
The issue with creating it intially is that the positions of the sides of the
connective layer depend upon the location of the moved layers **after** the
move animations are performed. However, all of these animations are performed
after the animations have been created. This means that the animation depends upon
the state of the neural network layers after previous animations have been run.
To fix this issue I needed to use an UpdateFromFunc.
"""
def __init__(self, layer, neural_network, layer_spacing=0.2):
self.layer = layer
self.neural_network = neural_network
self.layer_spacing = layer_spacing
# Get the before and after layers
layers_tuple = self.get_connective_layers()
self.before_layer = layers_tuple[0]
self.after_layer = layers_tuple[1]
self.before_connective = layers_tuple[2]
self.after_connective = layers_tuple[3]
# Make the animations
remove_animations = self.make_remove_animation()
move_animations = self.make_move_animation()
new_connective_animation = self.make_new_connective_animation()
# Add all of the animations to the group
animations_list = [
remove_animations,
move_animations,
new_connective_animation
]
super().__init__(*animations_list, lag_ratio=1.0)
def get_connective_layers(self):
"""Gets the connective layers before and after self.layer"""
# Get layer index
layer_index = self.neural_network.all_layers.index_of(self.layer)
if layer_index == -1:
raise Exception("Layer object not found")
# Get the layers before and after
before_layer = None
after_layer = None
before_connective = None
after_connective = None
if layer_index - 2 >= 0:
before_layer = self.neural_network.all_layers[layer_index - 2]
before_connective = self.neural_network.all_layers[layer_index - 1]
if layer_index + 2 < len(self.neural_network.all_layers):
after_layer = self.neural_network.all_layers[layer_index + 2]
after_connective = self.neural_network.all_layers[layer_index + 1]
return before_layer, after_layer, before_connective, after_connective
def make_remove_animation(self):
"""Removes layer and the surrounding connective layers"""
remove_layer_animation = self.make_remove_layer_animation()
remove_connective_animation = self.make_remove_connective_layers_animation()
# Remove animations
remove_animations = AnimationGroup(
remove_layer_animation,
remove_connective_animation
)
return remove_animations
def make_remove_layer_animation(self):
"""Removes the layer"""
# Remove the layer
self.neural_network.all_layers.remove(self.layer)
# Fade out the removed layer
fade_out_removed = FadeOut(self.layer)
return fade_out_removed
def make_remove_connective_layers_animation(self):
"""Removes the connective layers before and after layer if they exist"""
# Fade out the removed connective layers
fade_out_before_connective = AnimationGroup()
if not self.before_connective is None:
self.neural_network.all_layers.remove(self.before_connective)
fade_out_before_connective = FadeOut(self.before_connective)
fade_out_after_connective = AnimationGroup()
if not self.after_connective is None:
self.neural_network.all_layers.remove(self.after_connective)
fade_out_after_connective = FadeOut(self.after_connective)
# Group items
remove_connective_group = AnimationGroup(
fade_out_after_connective,
fade_out_before_connective
)
return remove_connective_group
def make_move_animation(self):
"""Collapses layers"""
# Animate the movements
move_before_layers = AnimationGroup()
shift_right_amount = None
if not self.before_layer is None:
# Compute shift amount
layer_dist = np.abs(self.layer.get_center() - self.before_layer.get_right())[0]
shift_right_amount = np.array([layer_dist - self.layer_spacing/2, 0, 0])
# Shift all layers before forward
before_layer_index = self.neural_network.all_layers.index_of(self.before_layer)
layers_before = Group(*self.neural_network.all_layers[:before_layer_index + 1])
move_before_layers = layers_before.animate.shift(shift_right_amount)
move_after_layers = AnimationGroup()
shift_left_amount = None
if not self.after_layer is None:
layer_dist = np.abs(self.after_layer.get_left() - self.layer.get_center())[0]
shift_left_amount = np.array([-layer_dist + self.layer_spacing / 2, 0, 0])
# Shift all layers after backward
after_layer_index = self.neural_network.all_layers.index_of(self.after_layer)
layers_after = Group(*self.neural_network.all_layers[after_layer_index:])
move_after_layers = layers_after.animate.shift(shift_left_amount)
# Group the move animations
move_group = AnimationGroup(
move_before_layers,
move_after_layers
)
return move_group
def make_new_connective_animation(self):
"""Makes new connective layer"""
self.anim_count = 0
def create_new_connective(neural_network):
"""
Creates new connective layer
This is a closure that creates a new connective layer and animates it.
"""
self.anim_count += 1
if self.anim_count == 1:
if not self.before_layer is None and not self.after_layer is None:
print(neural_network)
new_connective = get_connective_layer(self.before_layer, self.after_layer)
before_layer_index = neural_network.all_layers.index_of(self.before_layer) + 1
neural_network.all_layers.insert(before_layer_index, new_connective)
print(neural_network)
update_func_anim = UpdateFromFunc(self.neural_network, create_new_connective)
return update_func_anim
class InsertLayer(AnimationGroup):
"""Animation for inserting layer at given index"""
def __init__(self, layer, index, neural_network):
self.layer = layer
self.index = index
self.neural_network = neural_network
# Layers before and after
self.layers_before = self.neural_network.all_layers[:self.index]
self.layers_after = self.neural_network.all_layers[self.index:]
remove_connective_layer = self.remove_connective_layer()
move_layers = self.make_move_layers()
# create_layer = self.make_create_layer()
# create_connective_layers = self.make_create_connective_layers()
animations = [
remove_connective_layer,
move_layers,
# create_layer,
# create_connective_layers
]
super().__init__(*animations, lag_ratio=1.0)
def remove_connective_layer(self):
"""Removes the connective layer before the insertion index"""
# Check if connective layer exists
if len(self.layers_before) > 0:
removed_connective = self.layers_before[-1]
self.neural_network.all_layers.remove(removed_connective)
# Make remove animation
remove_animation = FadeOut(removed_connective)
return remove_animation
return AnimationGroup()
def make_move_layers(self):
"""Shifts layers before and after"""
# Before layer shift
before_shift_animation = AnimationGroup()
if len(self.layers_before) > 0:
before_shift = np.array([-self.layer.width/2, 0, 0])
# Shift layers before
before_shift_animation = Group(*self.layers_before).animate.shift(before_shift)
# After layer shift
after_shift_animation = AnimationGroup()
if len(self.layers_after) > 0:
after_shift = np.array([self.layer.width/2, 0, 0])
# Shift layers after
after_shift_animation = Group(*self.layers_after).animate.shift(after_shift)
# Make animation group
shift_animations = AnimationGroup(
before_shift_animation,
after_shift_animation
)
return shift_animations
def make_create_layer(self):
"""Animates the creation of the layer"""
pass
def make_create_connective_layers(self):
pass
# Make connective layers and shift animations
# Before layer
if len(layers_before) > 0:
before_connective = get_connective_layer(layers_before[-1], layer)
before_shift = np.array([-layer.width/2, 0, 0])
# Shift layers before
before_shift_animation = Group(*layers_before).animate.shift(before_shift)
else:
before_connective = AnimationGroup()
# After layer
if len(layers_after) > 0:
after_connective = get_connective_layer(layer, layers_after[0])
after_shift = np.array([layer.width/2, 0, 0])
# Shift layers after
after_shift_animation = Group(*layers_after).animate.shift(after_shift)
else:
after_connective = AnimationGroup
insert_animation = Create(layer)
animation_group = AnimationGroup(
shift_animations,
insert_animation,
lag_ratio=1.0
)
return animation_group
class NeuralNetwork(Group):
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.2,
@ -347,19 +110,32 @@ class NeuralNetwork(Group):
return animation_group
def make_forward_pass_animation(self, run_time=10, passing_flash=True,
def make_forward_pass_animation(self, run_time=10, passing_flash=True, layer_args={},
**kwargs):
"""Generates an animation for feed forward propagation"""
all_animations = []
for layer_index, layer in enumerate(self.input_layers[:-1]):
layer_forward_pass = layer.make_forward_pass_animation(**kwargs)
for layer_index, layer in enumerate(self.all_layers):
# Get the layer args
if isinstance(layer, ConnectiveLayer):
"""
NOTE: By default a connective layer will get the combined
layer_args of the layers it is connecting.
"""
before_layer_args = {}
after_layer_args = {}
if layer.input_layer in layer_args:
before_layer_args = layer_args[layer.input_layer]
if layer.output_layer in layer_args:
after_layer_args = layer_args[layer.output_layer]
# Merge the two dicts
current_layer_args = {**before_layer_args, **after_layer_args}
else:
current_layer_args = {}
if layer in layer_args:
current_layer_args = layer_args[layer]
# Perform the forward pass of the current layer
layer_forward_pass = layer.make_forward_pass_animation(layer_args=current_layer_args, **kwargs)
all_animations.append(layer_forward_pass)
connective_layer = self.connective_layers[layer_index]
connective_forward_pass = connective_layer.make_forward_pass_animation(**kwargs)
all_animations.append(connective_forward_pass)
# Do last layer animation
last_layer_forward_pass = self.input_layers[-1].make_forward_pass_animation(**kwargs)
all_animations.append(last_layer_forward_pass)
# Make the animation group
animation_group = AnimationGroup(*all_animations, run_time=run_time, lag_ratio=1.0)

View File

@ -0,0 +1,54 @@
from manim import *
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.neural_network import NeuralNetwork
config.pixel_height = 720
config.pixel_width = 1280
config.frame_height = 5.0
config.frame_width = 5.0
class EmbeddingNNScene(Scene):
def construct(self):
embedding_layer = EmbeddingLayer()
neural_network = NeuralNetwork([
FeedForwardLayer(5),
FeedForwardLayer(3),
embedding_layer,
FeedForwardLayer(3),
FeedForwardLayer(5)
])
self.play(Create(neural_network))
self.play(neural_network.make_forward_pass_animation(run_time=5))
class QueryEmbeddingNNScene(Scene):
def construct(self):
embedding_layer = EmbeddingLayer()
embedding_layer.paired_query_mode = True
neural_network = NeuralNetwork([
FeedForwardLayer(5),
FeedForwardLayer(3),
embedding_layer,
FeedForwardLayer(3),
FeedForwardLayer(5)
])
self.play(Create(neural_network), run_time=2)
self.play(
neural_network.make_forward_pass_animation(
run_time=5,
layer_args={
embedding_layer: {
"query_locations": (np.array([2, 2]), np.array([1, 1]))
}
}
)
)