mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-03 23:02:00 +08:00
Added support for Neural Network overhead title and per-layer title.
This commit is contained in:
@ -65,7 +65,7 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
return animation_group
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_embedding_layer(self, **kwargs):
|
||||
def _create_override(self, **kwargs):
|
||||
# Plot each point at once
|
||||
point_animations = []
|
||||
for point in self.point_cloud:
|
||||
|
@ -44,6 +44,6 @@ class EmbeddingToFeedForward(ConnectiveLayer):
|
||||
return animation_group
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_embedding_layer(self, **kwargs):
|
||||
def _create_override(self, **kwargs):
|
||||
return AnimationGroup()
|
||||
|
@ -55,7 +55,7 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
|
||||
return succession
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_animation(self, **kwargs):
|
||||
def _create_override(self, **kwargs):
|
||||
animations = []
|
||||
|
||||
animations.append(Create(self.surrounding_rectangle))
|
||||
|
@ -48,6 +48,6 @@ class FeedForwardToEmbedding(ConnectiveLayer):
|
||||
return animation_group
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_embedding_layer(self, **kwargs):
|
||||
def _create_override(self, **kwargs):
|
||||
return AnimationGroup()
|
||||
|
||||
|
@ -57,7 +57,7 @@ class FeedForwardToFeedForward(ConnectiveLayer):
|
||||
return path_animations
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_animation(self, **kwargs):
|
||||
def _create_override(self, **kwargs):
|
||||
animations = []
|
||||
|
||||
for edge in self.edges:
|
||||
|
@ -17,7 +17,7 @@ class ImageLayer(NeuralNetworkLayer):
|
||||
self.add(self.image_mobject)
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_animation(self, **kwargs):
|
||||
def _create_override(self, **kwargs):
|
||||
return FadeIn(self.image_mobject)
|
||||
|
||||
def make_forward_pass_animation(self):
|
||||
|
@ -56,7 +56,7 @@ class PairedQueryLayer(NeuralNetworkLayer):
|
||||
return assets
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_layer(self):
|
||||
def _create_override(self):
|
||||
# TODO make Create animation that is custom
|
||||
return FadeIn(self.assets)
|
||||
|
||||
|
@ -6,11 +6,17 @@ class NeuralNetworkLayer(ABC, Group):
|
||||
|
||||
def __init__(self, text=None, **kwargs):
|
||||
super(Group, self).__init__()
|
||||
self.title_text = kwargs["title"] if "title" in kwargs else " "
|
||||
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE/3)
|
||||
|
||||
@abstractmethod
|
||||
def make_forward_pass_animation(self):
|
||||
pass
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return f"{type(self).__name__}"
|
||||
|
||||
@ -23,6 +29,10 @@ class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
|
||||
def make_forward_pass_animation(self):
|
||||
pass
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self):
|
||||
return super()._create_override()
|
||||
|
||||
class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
||||
"""Forward pass animation for a given pair of layers"""
|
||||
|
||||
@ -40,4 +50,8 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
||||
|
||||
@abstractmethod
|
||||
def make_forward_pass_animation(self):
|
||||
pass
|
||||
pass
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self):
|
||||
return super()._create_override()
|
@ -67,7 +67,7 @@ class TripletLayer(NeuralNetworkLayer):
|
||||
return assets
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_layer(self):
|
||||
def _create_override(self):
|
||||
# TODO make Create animation that is custom
|
||||
return FadeIn(self.assets)
|
||||
|
||||
|
@ -23,7 +23,8 @@ from manim_ml.list_group import ListGroup
|
||||
class NeuralNetwork(Group):
|
||||
|
||||
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
|
||||
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03):
|
||||
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03,
|
||||
title="Overhead Title"):
|
||||
super(Group, self).__init__()
|
||||
self.input_layers = ListGroup(*input_layers)
|
||||
self.edge_width = edge_width
|
||||
@ -31,11 +32,19 @@ class NeuralNetwork(Group):
|
||||
self.layer_spacing = layer_spacing
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
self.title = title
|
||||
self.created = False
|
||||
# TODO take layer_node_count [0, (1, 2), 0]
|
||||
# and make it have explicit distinct subspaces
|
||||
self._place_layers()
|
||||
self.connective_layers, self.all_layers = self._construct_connective_layers()
|
||||
# Make layer titles
|
||||
self.layer_titles = self._make_layer_titles()
|
||||
self.add(self.layer_titles)
|
||||
# Make overhead title
|
||||
self.overhead_title = Text(self.title, font_size=DEFAULT_FONT_SIZE/2)
|
||||
self.overhead_title.next_to(self, UP, 0.2)
|
||||
self.add(self.overhead_title)
|
||||
# Place layers at correct z index
|
||||
self.connective_layers.set_z_index(2)
|
||||
self.input_layers.set_z_index(3)
|
||||
@ -56,6 +65,15 @@ class NeuralNetwork(Group):
|
||||
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + 0.2, 0, 0])
|
||||
current_layer.shift(shift_vector)
|
||||
|
||||
def _make_layer_titles(self):
|
||||
"""Makes titles"""
|
||||
titles = VGroup()
|
||||
for layer in self.all_layers:
|
||||
title = layer.title
|
||||
title.next_to(layer, UP, 0.2)
|
||||
titles.add(title)
|
||||
return titles
|
||||
|
||||
def _construct_connective_layers(self):
|
||||
"""Draws connecting lines between layers"""
|
||||
connective_layers = ListGroup()
|
||||
@ -283,21 +301,37 @@ class NeuralNetwork(Group):
|
||||
if self.created:
|
||||
return AnimationGroup()
|
||||
self.created = True
|
||||
# Create each layer one by one
|
||||
|
||||
animations = []
|
||||
# Create the overhead title
|
||||
animations.append(Write(self.overhead_title))
|
||||
# Create each layer one by one
|
||||
for layer in self.all_layers:
|
||||
animation = Create(layer)
|
||||
animations.append(animation)
|
||||
layer_animation = Create(layer)
|
||||
# Make titles
|
||||
create_title = Create(layer.title)
|
||||
# Create layer animation group
|
||||
animation_group = AnimationGroup(
|
||||
layer_animation,
|
||||
create_title
|
||||
)
|
||||
animations.append(animation_group)
|
||||
|
||||
animation_group = AnimationGroup(*animations, lag_ratio=1.0)
|
||||
print(animation_group)
|
||||
|
||||
return animation_group
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self, metadata=["z_index", "title"]):
|
||||
"""Print string representation of layers"""
|
||||
inner_string = ""
|
||||
for layer in self.all_layers:
|
||||
inner_string += f"{repr(layer)} {layer.z_index} ,\n"
|
||||
inner_string += f"{repr(layer)} ("
|
||||
for key in metadata:
|
||||
value = getattr(layer, key)
|
||||
if not value is "":
|
||||
inner_string += f"{key}={value}, "
|
||||
inner_string += "),\n"
|
||||
inner_string = textwrap.indent(inner_string, " ")
|
||||
|
||||
string_repr = "NeuralNetwork([\n" + inner_string + "])"
|
||||
|
12
tests/test_layers.py
Normal file
12
tests/test_layers.py
Normal file
@ -0,0 +1,12 @@
|
||||
from manim import *
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.feed_forward_to_feed_forward import FeedForwardToFeedForward
|
||||
from manim_ml.neural_network.layers.util import get_connective_layer
|
||||
|
||||
def test_get_connective_layer():
|
||||
"""Tests get connective layer"""
|
||||
input_layer = FeedForwardLayer(3)
|
||||
output_layer = FeedForwardLayer(5)
|
||||
connective_layer = get_connective_layer(input_layer, output_layer)
|
||||
|
||||
assert isinstance(connective_layer, FeedForwardToFeedForward)
|
@ -94,7 +94,7 @@ class NeuralNetworkScene(Scene):
|
||||
def construct(self):
|
||||
# Make the Layer object
|
||||
layers = [
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(3, title="Title Test"),
|
||||
FeedForwardLayer(5),
|
||||
FeedForwardLayer(3)
|
||||
]
|
||||
@ -102,6 +102,7 @@ class NeuralNetworkScene(Scene):
|
||||
nn.move_to(ORIGIN)
|
||||
# Make Animation
|
||||
self.add(nn)
|
||||
#self.play(Create(nn))
|
||||
forward_propagation_animation = nn.make_forward_pass_animation(run_time=5, passing_flash=True)
|
||||
|
||||
self.play(forward_propagation_animation)
|
||||
|
Reference in New Issue
Block a user