Added support for Neural Network overhead title and per-layer title.

This commit is contained in:
Alec Helbling
2022-04-19 00:36:15 -04:00
parent 05f512f185
commit 229c27fa3f
12 changed files with 77 additions and 16 deletions

View File

@ -23,7 +23,8 @@ from manim_ml.list_group import ListGroup
class NeuralNetwork(Group):
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03):
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03,
title="Overhead Title"):
super(Group, self).__init__()
self.input_layers = ListGroup(*input_layers)
self.edge_width = edge_width
@ -31,11 +32,19 @@ class NeuralNetwork(Group):
self.layer_spacing = layer_spacing
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
self.title = title
self.created = False
# TODO take layer_node_count [0, (1, 2), 0]
# and make it have explicit distinct subspaces
self._place_layers()
self.connective_layers, self.all_layers = self._construct_connective_layers()
# Make layer titles
self.layer_titles = self._make_layer_titles()
self.add(self.layer_titles)
# Make overhead title
self.overhead_title = Text(self.title, font_size=DEFAULT_FONT_SIZE/2)
self.overhead_title.next_to(self, UP, 0.2)
self.add(self.overhead_title)
# Place layers at correct z index
self.connective_layers.set_z_index(2)
self.input_layers.set_z_index(3)
@ -56,6 +65,15 @@ class NeuralNetwork(Group):
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + 0.2, 0, 0])
current_layer.shift(shift_vector)
def _make_layer_titles(self):
"""Makes titles"""
titles = VGroup()
for layer in self.all_layers:
title = layer.title
title.next_to(layer, UP, 0.2)
titles.add(title)
return titles
def _construct_connective_layers(self):
"""Draws connecting lines between layers"""
connective_layers = ListGroup()
@ -283,21 +301,37 @@ class NeuralNetwork(Group):
if self.created:
return AnimationGroup()
self.created = True
# Create each layer one by one
animations = []
# Create the overhead title
animations.append(Write(self.overhead_title))
# Create each layer one by one
for layer in self.all_layers:
animation = Create(layer)
animations.append(animation)
layer_animation = Create(layer)
# Make titles
create_title = Create(layer.title)
# Create layer animation group
animation_group = AnimationGroup(
layer_animation,
create_title
)
animations.append(animation_group)
animation_group = AnimationGroup(*animations, lag_ratio=1.0)
print(animation_group)
return animation_group
def __repr__(self):
def __repr__(self, metadata=["z_index", "title"]):
"""Print string representation of layers"""
inner_string = ""
for layer in self.all_layers:
inner_string += f"{repr(layer)} {layer.z_index} ,\n"
inner_string += f"{repr(layer)} ("
for key in metadata:
value = getattr(layer, key)
if not value is "":
inner_string += f"{key}={value}, "
inner_string += "),\n"
inner_string = textwrap.indent(inner_string, " ")
string_repr = "NeuralNetwork([\n" + inner_string + "])"