Working remove neural network animation

This commit is contained in:
Alec Helbling
2022-04-21 23:18:58 -04:00
parent 11bbd59bb6
commit ffd31701bf
13 changed files with 230 additions and 158 deletions

View File

@ -67,7 +67,7 @@ class ListGroup(Mobject):
"""Length of items"""
return len(self.items)
def set_z_index(self, z_index_value):
def set_z_index(self, z_index_value, family=True):
"""Sets z index of all values in ListGroup"""
for item in self.items:
item.set_z_index(z_index_value)
item.set_z_index(z_index_value, family=True)

View File

@ -0,0 +1,25 @@
from manim import *
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
class ConvolutionalLayer(VGroupNeuralNetworkLayer):
"""Handles rendering a convolutional layer for a nn"""
def __init__(self, num_filters, filter_width, **kwargs):
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
self.num_filters = num_filters
self.filter_width = filter_width
self._construct_neural_network_layer()
def _construct_neural_network_layer(self):
"""Creates the neural network layer"""
pass
def make_forward_pass_animation(self):
# make highlight animation
return None
@override_animation(Create)
def _create_override(self, **kwargs):
pass

View File

@ -6,7 +6,7 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
"""NeuralNetwork embedding object that can show probability distributions"""
def __init__(self, point_radius=0.02, **kwargs):
super(EmbeddingLayer, self).__init__(**kwargs)
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
self.point_radius = point_radius
self.axes = Axes(
tips=False,

View File

@ -10,7 +10,7 @@ class FeedForwardToImage(ConnectiveLayer):
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
dot_radius=0.05, **kwargs):
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=ImageLayer
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=ImageLayer,
**kwargs)
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius

View File

@ -13,19 +13,22 @@ class ImageLayer(NeuralNetworkLayer):
self.image_mobject = GrayscaleImageMobject(self.numpy_image, height=height)
elif len(np.shape(self.numpy_image)) == 3:
# Assumed RGB
self.image_mobject = ImageMobject(self.numpy_image)
self.image_mobject = ImageMobject(self.numpy_image).scale_to_fit_height(height)
self.add(self.image_mobject)
@override_animation(Create)
def _create_override(self, **kwargs):
debug_mode = False
if debug_mode:
return FadeIn(SurroundingRectangle(self.image_mobject))
return FadeIn(self.image_mobject)
def make_forward_pass_animation(self):
return Create(self.image_mobject)
return FadeIn(self.image_mobject)
def move_to(self, location):
"""Override of move to"""
self.image_mobject.move_to(location)
# def move_to(self, location):
# """Override of move to"""
# self.image_mobject.move_to(location)
def get_right(self):
"""Override get right"""

View File

@ -10,7 +10,7 @@ class ImageToFeedForward(ConnectiveLayer):
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
dot_radius=0.05, **kwargs):
super().__init__(input_layer, output_layer, input_class=ImageLayer, output_class=FeedForwardLayer
super().__init__(input_layer, output_layer, input_class=ImageLayer, output_class=FeedForwardLayer,
**kwargs)
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius

View File

@ -9,7 +9,7 @@ class PairedQueryToFeedForward(ConnectiveLayer):
output_class = FeedForwardLayer
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.02, **kwargs):
super().__init__(input_layer, output_layer, input_class=PairedQueryLayer, output_class=FeedForwardLayer
super().__init__(input_layer, output_layer, input_class=PairedQueryLayer, output_class=FeedForwardLayer,
**kwargs)
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius

View File

@ -23,7 +23,7 @@ class NeuralNetworkLayer(ABC, Group):
class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
def __init__(self, **kwargs):
super(NeuralNetworkLayer, self).__init__(**kwargs)
super().__init__(**kwargs)
@abstractmethod
def make_forward_pass_animation(self):

View File

@ -15,6 +15,6 @@ def get_connective_layer(input_layer, output_layer):
if connective_layer is None:
raise Exception(f"Unrecognized class pair {input_layer.__class__.__name__}" + \
" and {output_layer.__class__.__name__}")
f" and {output_layer.__class__.__name__}")
return connective_layer

View File

@ -8,13 +8,12 @@ class VectorLayer(VGroupNeuralNetworkLayer):
def __init__(self, num_values, value_func=lambda: random.uniform(0, 1),
**kwargs):
print("vector layer")
super().__init__(**kwargs)
print("after init")
self.num_values = num_values
self.value_func = value_func
# Make the vector
self.vector_label = self.make_vector()
self.add(self.vector_label)
def make_vector(self):
"""Makes the vector"""
@ -24,7 +23,8 @@ class VectorLayer(VGroupNeuralNetworkLayer):
values = values[None, :].T
vector = Matrix(values)
vector_label = Text(f"[{self.value_func()}]")
vector_label = Text(f"[{self.value_func():.2}]")
vector_label.scale(0.5)
return vector_label
@ -34,4 +34,4 @@ class VectorLayer(VGroupNeuralNetworkLayer):
@override_animation(Create)
def _create_override(self):
"""Create animation"""
return Create(self.vector_label)
return Write(self.vector_label)

View File

@ -10,7 +10,6 @@ Example:
NeuralNetwork(layer_node_count)
"""
from socket import create_connection
from urllib.parse import non_hierarchical
from manim import *
import warnings
import textwrap
@ -20,11 +19,176 @@ from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.util import get_connective_layer
from manim_ml.list_group import ListGroup
class LazyAnimation(Animation):
"""
Lazily creates animation when the animation is called.
This is helpful when creating the animation depends upon the internal
state of some set of objects.
"""
def __init__(self, animation_func):
self.animation_func = animation_func
super().__init__(None)
def begin(self):
"""Begins animation"""
self.mobject, animation = self.animation_func()
animation = Create(self.mobject)
animation.begin()
class RemoveLayer(Succession):
"""
Animation for removing a layer from a neural network.
Note: I needed to do something strange for creating the new connective layer.
The issue with creating it intially is that the positions of the sides of the
connective layer depend upon the location of the moved layers **after** the
move animations are performed. However, all of these animations are performed
after the animations have been created. This means that the animation depends upon
the state of the neural network layers after previous animations have been run.
To fix this issue I needed to use an UpdateFromFunc.
"""
def __init__(self, layer, neural_network, layer_spacing=0.2):
self.layer = layer
self.neural_network = neural_network
self.layer_spacing = layer_spacing
# Get the before and after layers
layers_tuple = self.get_connective_layers()
self.before_layer = layers_tuple[0]
self.after_layer = layers_tuple[1]
self.before_connective = layers_tuple[2]
self.after_connective = layers_tuple[3]
# Make the animations
remove_animations = self.make_remove_animation()
move_animations = self.make_move_animation()
new_connective_animation = self.make_new_connective_animation()
# Add all of the animations to the group
animations_list = [
remove_animations,
move_animations,
new_connective_animation
]
super().__init__(*animations_list, lag_ratio=1.0)
def get_connective_layers(self):
"""Gets the connective layers before and after self.layer"""
# Get layer index
layer_index = self.neural_network.all_layers.index_of(self.layer)
if layer_index == -1:
raise Exception("Layer object not found")
# Get the layers before and after
before_layer = None
after_layer = None
before_connective = None
after_connective = None
if layer_index - 2 >= 0:
before_layer = self.neural_network.all_layers[layer_index - 2]
before_connective = self.neural_network.all_layers[layer_index - 1]
if layer_index + 2 < len(self.neural_network.all_layers):
after_layer = self.neural_network.all_layers[layer_index + 2]
after_connective = self.neural_network.all_layers[layer_index + 1]
return before_layer, after_layer, before_connective, after_connective
def make_remove_animation(self):
"""Removes layer and the surrounding connective layers"""
remove_layer_animation = self.make_remove_layer_animation()
remove_connective_animation = self.make_remove_connective_layers_animation()
# Remove animations
remove_animations = AnimationGroup(
remove_layer_animation,
remove_connective_animation
)
return remove_animations
def make_remove_layer_animation(self):
"""Removes the layer"""
# Remove the layer
self.neural_network.all_layers.remove(self.layer)
# Fade out the removed layer
fade_out_removed = FadeOut(self.layer)
return fade_out_removed
def make_remove_connective_layers_animation(self):
"""Removes the connective layers before and after layer if they exist"""
# Fade out the removed connective layers
fade_out_before_connective = AnimationGroup()
if not self.before_connective is None:
self.neural_network.all_layers.remove(self.before_connective)
fade_out_before_connective = FadeOut(self.before_connective)
fade_out_after_connective = AnimationGroup()
if not self.after_connective is None:
self.neural_network.all_layers.remove(self.after_connective)
fade_out_after_connective = FadeOut(self.after_connective)
# Group items
remove_connective_group = AnimationGroup(
fade_out_after_connective,
fade_out_before_connective
)
return remove_connective_group
def make_move_animation(self):
"""Collapses layers"""
# Animate the movements
move_before_layers = AnimationGroup()
shift_right_amount = None
if not self.before_layer is None:
# Compute shift amount
layer_dist = np.abs(self.layer.get_center() - self.before_layer.get_right())[0]
shift_right_amount = np.array([layer_dist - self.layer_spacing/2, 0, 0])
# Shift all layers before forward
before_layer_index = self.neural_network.all_layers.index_of(self.before_layer)
layers_before = Group(*self.neural_network.all_layers[:before_layer_index + 1])
move_before_layers = layers_before.animate.shift(shift_right_amount)
move_after_layers = AnimationGroup()
shift_left_amount = None
if not self.after_layer is None:
layer_dist = np.abs(self.after_layer.get_left() - self.layer.get_center())[0]
shift_left_amount = np.array([-layer_dist + self.layer_spacing / 2, 0, 0])
# Shift all layers after backward
after_layer_index = self.neural_network.all_layers.index_of(self.after_layer)
layers_after = Group(*self.neural_network.all_layers[after_layer_index:])
move_after_layers = layers_after.animate.shift(shift_left_amount)
# Group the move animations
move_group = AnimationGroup(
move_before_layers,
move_after_layers
)
return move_group
def make_new_connective_animation(self):
"""Makes new connective layer"""
self.anim_count = 0
def create_new_connective(neural_network):
"""
Creates new connective layer
This is a closure that creates a new connective layer and animates it.
"""
self.anim_count += 1
if self.anim_count == 1:
if not self.before_layer is None and not self.after_layer is None:
print(neural_network)
new_connective = get_connective_layer(self.before_layer, self.after_layer)
before_layer_index = neural_network.all_layers.index_of(self.before_layer) + 1
neural_network.all_layers.insert(before_layer_index, new_connective)
print(neural_network)
update_func_anim = UpdateFromFunc(self.neural_network, create_new_connective)
return update_func_anim
class NeuralNetwork(Group):
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.2,
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03,
title="Overhead Title"):
title=" "):
super(Group, self).__init__()
self.input_layers = ListGroup(*input_layers)
self.edge_width = edge_width
@ -32,7 +196,7 @@ class NeuralNetwork(Group):
self.layer_spacing = layer_spacing
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
self.title = title
self.title_text = title
self.created = False
# TODO take layer_node_count [0, (1, 2), 0]
# and make it have explicit distinct subspaces
@ -42,9 +206,9 @@ class NeuralNetwork(Group):
self.layer_titles = self._make_layer_titles()
self.add(self.layer_titles)
# Make overhead title
self.overhead_title = Text(self.title, font_size=DEFAULT_FONT_SIZE/2)
self.overhead_title.next_to(self, UP, 0.2)
self.add(self.overhead_title)
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE/2)
self.title.next_to(self, UP, 1.0)
self.add(self.title)
# Place layers at correct z index
self.connective_layers.set_z_index(2)
self.input_layers.set_z_index(3)
@ -62,7 +226,7 @@ class NeuralNetwork(Group):
previous_layer = self.input_layers[layer_index - 1]
current_layer = self.input_layers[layer_index]
current_layer.move_to(previous_layer)
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + 0.2, 0, 0])
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + self.layer_spacing, 0, 0])
current_layer.shift(shift_vector)
def _make_layer_titles(self):
@ -137,132 +301,8 @@ class NeuralNetwork(Group):
def remove_layer(self, layer):
"""Removes layer object if it exists"""
# Get layer index
layer_index = self.all_layers.index_of(layer)
if layer_index == -1:
raise Exception("Layer object not found")
# Get the layers before and after
before_layer = None
after_layer = None
if layer_index - 2 >= 0:
before_layer = self.all_layers[layer_index - 2]
if layer_index + 2 < len(self.all_layers):
after_layer = self.all_layers[layer_index + 2]
# Remove the layer
self.all_layers.remove(layer)
# Remove surrounding connective layers from self.all_layers
before_connective = None
after_connective = None
if layer_index - 1 >= 0:
# There is a layer before
before_connective = self.all_layers.remove_at_index(layer_index - 1)
if layer_index + 1 < len(self.all_layers):
# There is a layer after
after_connective = self.all_layers.remove_at_index(layer_index + 1)
# Make animations
# Fade out the removed layer
fade_out_removed = FadeOut(layer)
# Fade out the removed connective layers
fade_out_before_connective = Animation()
if not before_connective is None:
fade_out_before_connective = FadeOut(before_connective)
fade_out_after_connective = Animation()
if not after_connective is None:
fade_out_after_connective = FadeOut(after_connective)
# Create new connective layer
new_connective = None
if not before_layer is None and not after_layer is None:
new_connective = get_connective_layer(before_layer, after_layer)
before_layer_index = self.all_layers.index_of(before_layer)
self.all_layers.insert(before_layer_index, new_connective)
# Place the new connective
new_connective.move_to(layer)
# Animate the creation of the new connective layer
create_new_connective = Animation()
if not new_connective is None:
create_new_connective = Create(new_connective)
# Collapse the neural network to fill the empty space
removed_width = layer.width + before_connective.width + after_connective.width - new_connective.width
shift_right_amount = np.array([removed_width / 2, 0, 0])
shift_left_amount = np.array([-removed_width / 2, 0, 0])
move_before_layer = Animation()
if not before_layer is None:
move_before_layer = before_layer.animate.shift(shift_right_amount)
move_after_layer = Animation()
if not after_layer is None:
move_after_layer = after_layer.animate.shift(shift_left_amount)
# Make the final AnimationGroup
fade_out_group = AnimationGroup(
fade_out_removed,
fade_out_before_connective,
fade_out_after_connective
)
move_group = AnimationGroup(
move_before_layer,
move_after_layer
)
animation_group = AnimationGroup(
fade_out_group,
move_group,
create_new_connective,
lag_ratio=1.0
)
return animation_group
"""
remove_layer = list(self.all_layers)[remove_index]
if remove_index > 0:
connective_before = list(self.all_layers)[remove_index - 1]
else:
connective_before = None
if remove_index < len(list(self.all_layers)) - 1:
connective_after = list(self.all_layers)[remove_index + 1]
else:
connective_after = None
# Collapse the surrounding layer
layers_before = list(self.all_layers)[:remove_index]
layers_after = list(self.all_layers)[remove_index+1:]
before_group = Group(*layers_before)
after_group = Group(*layers_after)
before_shift_amount = np.array([remove_layer.width/2, 0, 0])
after_shift_amount = np.array([-remove_layer.width/2, 0, 0])
# Remove the layers from the neural network representation
self.all_layers.remove(remove_layer)
if not connective_before is None:
self.all_layers.remove(connective_before)
if not connective_after is None:
self.all_layers.remove(connective_after)
# Connect the layers before and layers after
pre_index = remove_index - 1
pre_layer = None
if pre_index >= 0:
pre_layer = list(self.all_layers)[pre_index]
post_index = remove_index
post_layer = None
if post_index < len(list(self.all_layers)):
post_layer = list(self.all_layers)[post_index]
if not pre_layer is None and not post_layer is None:
connective_layer = get_connective_layer(pre_layer, post_layer)
self.all_layers = Group(
*self.all_layers[:remove_index],
connective_layer,
*self.all_layers[remove_index:]
)
# Make animations
fade_out_animation = FadeOut(remove_layer)
shift_animations = AnimationGroup(
before_group.animate.shift(before_shift_amount),
after_group.animate.shift(after_shift_amount)
)
animation_group = AnimationGroup(
fade_out_animation,
shift_animations,
lag_ratio=1.0
)
return animation_group
"""
neural_network = self
return RemoveLayer(layer, neural_network, layer_spacing=self.layer_spacing)
def replace_layer(self, old_layer, new_layer):
"""Replaces given layer object"""
@ -304,7 +344,7 @@ class NeuralNetwork(Group):
animations = []
# Create the overhead title
animations.append(Write(self.overhead_title))
animations.append(Create(self.title))
# Create each layer one by one
for layer in self.all_layers:
layer_animation = Create(layer)
@ -318,11 +358,15 @@ class NeuralNetwork(Group):
animations.append(animation_group)
animation_group = AnimationGroup(*animations, lag_ratio=1.0)
print(animation_group)
return animation_group
def __repr__(self, metadata=["z_index", "title"]):
def set_z_index(self, z_index_value: float, family=False):
"""Overriden set_z_index"""
# Setting family=False stops sub-neural networks from inheriting parent z_index
return super().set_z_index(z_index_value, family=False)
def __repr__(self, metadata=["z_index", "title_text"]):
"""Print string representation of layers"""
inner_string = ""
for layer in self.all_layers: