mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-21 13:57:37 +08:00
Added support for Neural Network overhead title and per-layer title.
This commit is contained in:
@ -23,7 +23,8 @@ from manim_ml.list_group import ListGroup
|
||||
class NeuralNetwork(Group):
|
||||
|
||||
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
|
||||
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03):
|
||||
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03,
|
||||
title="Overhead Title"):
|
||||
super(Group, self).__init__()
|
||||
self.input_layers = ListGroup(*input_layers)
|
||||
self.edge_width = edge_width
|
||||
@ -31,11 +32,19 @@ class NeuralNetwork(Group):
|
||||
self.layer_spacing = layer_spacing
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
self.title = title
|
||||
self.created = False
|
||||
# TODO take layer_node_count [0, (1, 2), 0]
|
||||
# and make it have explicit distinct subspaces
|
||||
self._place_layers()
|
||||
self.connective_layers, self.all_layers = self._construct_connective_layers()
|
||||
# Make layer titles
|
||||
self.layer_titles = self._make_layer_titles()
|
||||
self.add(self.layer_titles)
|
||||
# Make overhead title
|
||||
self.overhead_title = Text(self.title, font_size=DEFAULT_FONT_SIZE/2)
|
||||
self.overhead_title.next_to(self, UP, 0.2)
|
||||
self.add(self.overhead_title)
|
||||
# Place layers at correct z index
|
||||
self.connective_layers.set_z_index(2)
|
||||
self.input_layers.set_z_index(3)
|
||||
@ -56,6 +65,15 @@ class NeuralNetwork(Group):
|
||||
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + 0.2, 0, 0])
|
||||
current_layer.shift(shift_vector)
|
||||
|
||||
def _make_layer_titles(self):
|
||||
"""Makes titles"""
|
||||
titles = VGroup()
|
||||
for layer in self.all_layers:
|
||||
title = layer.title
|
||||
title.next_to(layer, UP, 0.2)
|
||||
titles.add(title)
|
||||
return titles
|
||||
|
||||
def _construct_connective_layers(self):
|
||||
"""Draws connecting lines between layers"""
|
||||
connective_layers = ListGroup()
|
||||
@ -283,21 +301,37 @@ class NeuralNetwork(Group):
|
||||
if self.created:
|
||||
return AnimationGroup()
|
||||
self.created = True
|
||||
# Create each layer one by one
|
||||
|
||||
animations = []
|
||||
# Create the overhead title
|
||||
animations.append(Write(self.overhead_title))
|
||||
# Create each layer one by one
|
||||
for layer in self.all_layers:
|
||||
animation = Create(layer)
|
||||
animations.append(animation)
|
||||
layer_animation = Create(layer)
|
||||
# Make titles
|
||||
create_title = Create(layer.title)
|
||||
# Create layer animation group
|
||||
animation_group = AnimationGroup(
|
||||
layer_animation,
|
||||
create_title
|
||||
)
|
||||
animations.append(animation_group)
|
||||
|
||||
animation_group = AnimationGroup(*animations, lag_ratio=1.0)
|
||||
print(animation_group)
|
||||
|
||||
return animation_group
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self, metadata=["z_index", "title"]):
|
||||
"""Print string representation of layers"""
|
||||
inner_string = ""
|
||||
for layer in self.all_layers:
|
||||
inner_string += f"{repr(layer)} {layer.z_index} ,\n"
|
||||
inner_string += f"{repr(layer)} ("
|
||||
for key in metadata:
|
||||
value = getattr(layer, key)
|
||||
if not value is "":
|
||||
inner_string += f"{key}={value}, "
|
||||
inner_string += "),\n"
|
||||
inner_string = textwrap.indent(inner_string, " ")
|
||||
|
||||
string_repr = "NeuralNetwork([\n" + inner_string + "])"
|
||||
|
Reference in New Issue
Block a user