Added support for Neural Network overhead title and per-layer title.

This commit is contained in:
Alec Helbling
2022-04-19 00:36:15 -04:00
parent 05f512f185
commit 229c27fa3f
12 changed files with 77 additions and 16 deletions

View File

@ -65,7 +65,7 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
return animation_group
@override_animation(Create)
def _create_embedding_layer(self, **kwargs):
def _create_override(self, **kwargs):
# Plot each point at once
point_animations = []
for point in self.point_cloud:

View File

@ -44,6 +44,6 @@ class EmbeddingToFeedForward(ConnectiveLayer):
return animation_group
@override_animation(Create)
def _create_embedding_layer(self, **kwargs):
def _create_override(self, **kwargs):
return AnimationGroup()

View File

@ -55,7 +55,7 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
return succession
@override_animation(Create)
def _create_animation(self, **kwargs):
def _create_override(self, **kwargs):
animations = []
animations.append(Create(self.surrounding_rectangle))

View File

@ -48,6 +48,6 @@ class FeedForwardToEmbedding(ConnectiveLayer):
return animation_group
@override_animation(Create)
def _create_embedding_layer(self, **kwargs):
def _create_override(self, **kwargs):
return AnimationGroup()

View File

@ -57,7 +57,7 @@ class FeedForwardToFeedForward(ConnectiveLayer):
return path_animations
@override_animation(Create)
def _create_animation(self, **kwargs):
def _create_override(self, **kwargs):
animations = []
for edge in self.edges:

View File

@ -17,7 +17,7 @@ class ImageLayer(NeuralNetworkLayer):
self.add(self.image_mobject)
@override_animation(Create)
def _create_animation(self, **kwargs):
def _create_override(self, **kwargs):
return FadeIn(self.image_mobject)
def make_forward_pass_animation(self):

View File

@ -56,7 +56,7 @@ class PairedQueryLayer(NeuralNetworkLayer):
return assets
@override_animation(Create)
def _create_layer(self):
def _create_override(self):
# TODO make Create animation that is custom
return FadeIn(self.assets)

View File

@ -6,11 +6,17 @@ class NeuralNetworkLayer(ABC, Group):
def __init__(self, text=None, **kwargs):
super(Group, self).__init__()
self.title_text = kwargs["title"] if "title" in kwargs else " "
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE/3)
@abstractmethod
def make_forward_pass_animation(self):
pass
@override_animation(Create)
def _create_override(self):
pass
def __repr__(self):
return f"{type(self).__name__}"
@ -23,6 +29,10 @@ class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
def make_forward_pass_animation(self):
pass
@override_animation(Create)
def _create_override(self):
return super()._create_override()
class ConnectiveLayer(VGroupNeuralNetworkLayer):
"""Forward pass animation for a given pair of layers"""
@ -40,4 +50,8 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
@abstractmethod
def make_forward_pass_animation(self):
pass
pass
@override_animation(Create)
def _create_override(self):
return super()._create_override()

View File

@ -67,7 +67,7 @@ class TripletLayer(NeuralNetworkLayer):
return assets
@override_animation(Create)
def _create_layer(self):
def _create_override(self):
# TODO make Create animation that is custom
return FadeIn(self.assets)

View File

@ -23,7 +23,8 @@ from manim_ml.list_group import ListGroup
class NeuralNetwork(Group):
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03):
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03,
title="Overhead Title"):
super(Group, self).__init__()
self.input_layers = ListGroup(*input_layers)
self.edge_width = edge_width
@ -31,11 +32,19 @@ class NeuralNetwork(Group):
self.layer_spacing = layer_spacing
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
self.title = title
self.created = False
# TODO take layer_node_count [0, (1, 2), 0]
# and make it have explicit distinct subspaces
self._place_layers()
self.connective_layers, self.all_layers = self._construct_connective_layers()
# Make layer titles
self.layer_titles = self._make_layer_titles()
self.add(self.layer_titles)
# Make overhead title
self.overhead_title = Text(self.title, font_size=DEFAULT_FONT_SIZE/2)
self.overhead_title.next_to(self, UP, 0.2)
self.add(self.overhead_title)
# Place layers at correct z index
self.connective_layers.set_z_index(2)
self.input_layers.set_z_index(3)
@ -56,6 +65,15 @@ class NeuralNetwork(Group):
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + 0.2, 0, 0])
current_layer.shift(shift_vector)
def _make_layer_titles(self):
"""Makes titles"""
titles = VGroup()
for layer in self.all_layers:
title = layer.title
title.next_to(layer, UP, 0.2)
titles.add(title)
return titles
def _construct_connective_layers(self):
"""Draws connecting lines between layers"""
connective_layers = ListGroup()
@ -283,21 +301,37 @@ class NeuralNetwork(Group):
if self.created:
return AnimationGroup()
self.created = True
# Create each layer one by one
animations = []
# Create the overhead title
animations.append(Write(self.overhead_title))
# Create each layer one by one
for layer in self.all_layers:
animation = Create(layer)
animations.append(animation)
layer_animation = Create(layer)
# Make titles
create_title = Create(layer.title)
# Create layer animation group
animation_group = AnimationGroup(
layer_animation,
create_title
)
animations.append(animation_group)
animation_group = AnimationGroup(*animations, lag_ratio=1.0)
print(animation_group)
return animation_group
def __repr__(self):
def __repr__(self, metadata=["z_index", "title"]):
"""Print string representation of layers"""
inner_string = ""
for layer in self.all_layers:
inner_string += f"{repr(layer)} {layer.z_index} ,\n"
inner_string += f"{repr(layer)} ("
for key in metadata:
value = getattr(layer, key)
if not value is "":
inner_string += f"{key}={value}, "
inner_string += "),\n"
inner_string = textwrap.indent(inner_string, " ")
string_repr = "NeuralNetwork([\n" + inner_string + "])"

12
tests/test_layers.py Normal file
View File

@ -0,0 +1,12 @@
from manim import *
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.feed_forward_to_feed_forward import FeedForwardToFeedForward
from manim_ml.neural_network.layers.util import get_connective_layer
def test_get_connective_layer():
"""Tests get connective layer"""
input_layer = FeedForwardLayer(3)
output_layer = FeedForwardLayer(5)
connective_layer = get_connective_layer(input_layer, output_layer)
assert isinstance(connective_layer, FeedForwardToFeedForward)

View File

@ -94,7 +94,7 @@ class NeuralNetworkScene(Scene):
def construct(self):
# Make the Layer object
layers = [
FeedForwardLayer(3),
FeedForwardLayer(3, title="Title Test"),
FeedForwardLayer(5),
FeedForwardLayer(3)
]
@ -102,6 +102,7 @@ class NeuralNetworkScene(Scene):
nn.move_to(ORIGIN)
# Make Animation
self.add(nn)
#self.play(Create(nn))
forward_propagation_animation = nn.make_forward_pass_animation(run_time=5, passing_flash=True)
self.play(forward_propagation_animation)