mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-01 20:44:56 +08:00
Added ability to pass layer_args dictionary to each forward pass, which allows
for arguments to be passed through to each neural network layer when running a neural network forward pass.
This commit is contained in:
@ -16,7 +16,7 @@ class ConvolutionalLayer(VGroupNeuralNetworkLayer):
|
||||
"""Creates the neural network layer"""
|
||||
pass
|
||||
|
||||
def make_forward_pass_animation(self):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
# make highlight animation
|
||||
return None
|
||||
|
||||
|
@ -6,16 +6,29 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
"""NeuralNetwork embedding object that can show probability distributions"""
|
||||
|
||||
def __init__(self, point_radius=0.02, mean = np.array([0, 0]),
|
||||
covariance=np.array([[1.5, 0], [0, 1.5]]), dist_theme="gaussian", **kwargs):
|
||||
covariance=np.array([[1.0, 0], [0, 1.0]]), dist_theme="gaussian",
|
||||
paired_query_mode=False, **kwargs):
|
||||
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
|
||||
self.point_radius = point_radius
|
||||
self.dist_theme = dist_theme
|
||||
self.paired_query_mode = paired_query_mode
|
||||
self.axes = Axes(
|
||||
tips=False,
|
||||
x_length=0.8,
|
||||
y_length=0.8
|
||||
y_length=0.8,
|
||||
x_range=(-2.0, 2.0),
|
||||
y_range=(-2.0, 2.0),
|
||||
x_axis_config={
|
||||
"include_ticks": False,
|
||||
"stroke_width": 0.0
|
||||
},
|
||||
y_axis_config={
|
||||
"include_ticks": False,
|
||||
"stroke_width": 0.0
|
||||
}
|
||||
)
|
||||
self.add(self.axes)
|
||||
self.axes.move_to(self.get_center())
|
||||
# Make point cloud
|
||||
self.point_cloud = self.construct_gaussian_point_cloud(mean, covariance)
|
||||
self.add(self.point_cloud)
|
||||
@ -51,24 +64,44 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
|
||||
return point_dots
|
||||
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
"""Forward pass animation"""
|
||||
# Make ellipse object corresponding to the latent distribution
|
||||
self.latent_distribution = GaussianDistribution(
|
||||
self.axes,
|
||||
dist_theme=self.dist_theme,
|
||||
cov=np.array([[0.8, 0], [0.0, 0.8]])
|
||||
) # Use defaults
|
||||
# Create animation
|
||||
def make_paired_query_embedding_animation(self):
|
||||
"""Embed paired query"""
|
||||
animations = []
|
||||
#create_distribution = Create(self.latent_distribution.construct_gaussian_distribution(self.latent_distribution.mean, self.latent_distribution.cov)) #Create(self.latent_distribution)
|
||||
create_distribution = Create(self.latent_distribution.ellipses)
|
||||
animations.append(create_distribution)
|
||||
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
# Make the animation
|
||||
|
||||
# Animation group
|
||||
animation_group = AnimationGroup(
|
||||
*animations,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Forward pass animation"""
|
||||
animations = []
|
||||
if not self.paired_query_mode:
|
||||
# Normal embedding mode
|
||||
# Make ellipse object corresponding to the latent distribution
|
||||
self.latent_distribution = GaussianDistribution(
|
||||
self.axes,
|
||||
dist_theme=self.dist_theme,
|
||||
cov=np.array([[0.8, 0], [0.0, 0.8]])
|
||||
) # Use defaults
|
||||
# Create animation
|
||||
#create_distribution = Create(self.latent_distribution.construct_gaussian_distribution(self.latent_distribution.mean, self.latent_distribution.cov)) #Create(self.latent_distribution)
|
||||
create_distribution = Create(self.latent_distribution.ellipses)
|
||||
animations.append(create_distribution)
|
||||
|
||||
animation_group = AnimationGroup(*animations)
|
||||
|
||||
return animation_group
|
||||
else:
|
||||
# Paired Query Mode
|
||||
# Handle logic for embedding a paired query into the embedding layer
|
||||
paired_query_embedding_animation = self.make_paired_query_embedding_animation()
|
||||
return paired_query_embedding_animation
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
# Plot each point at once
|
||||
|
@ -17,7 +17,7 @@ class EmbeddingToFeedForward(ConnectiveLayer):
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
def make_forward_pass_animation(self, run_time=1.5, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
|
||||
"""Makes dots diverge from the given location and move the decoder"""
|
||||
# Find point to converge on by sampling from gaussian distribution
|
||||
location = self.embedding_layer.sample_point_location_from_distribution()
|
||||
|
@ -44,7 +44,7 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
|
||||
# Add the objects to the class
|
||||
self.add(self.surrounding_rectangle, self.node_group)
|
||||
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
# make highlight animation
|
||||
succession = Succession(
|
||||
ApplyMethod(self.node_group.set_color, self.animation_dot_color, run_time=0.25),
|
||||
|
@ -17,7 +17,7 @@ class FeedForwardToEmbedding(ConnectiveLayer):
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
def make_forward_pass_animation(self, run_time=1.5):
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=1.5, **kwargs):
|
||||
"""Makes dots converge on a specific location"""
|
||||
# Find point to converge on by sampling from gaussian distribution
|
||||
location = self.embedding_layer.sample_point_location_from_distribution()
|
||||
|
@ -44,7 +44,7 @@ class FeedForwardToFeedForward(ConnectiveLayer):
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_forward_pass_animation(self, run_time=1, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, run_time=1, **kwargs):
|
||||
"""Animation for passing information from one FeedForwardLayer to the next"""
|
||||
path_animations = []
|
||||
dots = []
|
||||
|
@ -18,7 +18,7 @@ class FeedForwardToImage(ConnectiveLayer):
|
||||
self.feed_forward_layer = input_layer
|
||||
self.image_layer = output_layer
|
||||
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
|
||||
animations = []
|
||||
image_mobject = self.image_layer.image_mobject
|
||||
|
@ -18,7 +18,7 @@ class FeedForwardToVector(ConnectiveLayer):
|
||||
self.feed_forward_layer = input_layer
|
||||
self.vector_layer = output_layer
|
||||
|
||||
def make_forward_pass_animation(self):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
|
||||
animations = []
|
||||
# Move the dots to the centers of each of the nodes in the FeedForwardLayer
|
||||
|
@ -27,7 +27,7 @@ class ImageLayer(NeuralNetworkLayer):
|
||||
else:
|
||||
return AnimationGroup()
|
||||
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
return FadeIn(self.image_mobject)
|
||||
|
||||
# def move_to(self, location):
|
||||
|
@ -18,7 +18,7 @@ class ImageToFeedForward(ConnectiveLayer):
|
||||
self.feed_forward_layer = output_layer
|
||||
self.image_layer = input_layer
|
||||
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
|
||||
animations = []
|
||||
dots = []
|
||||
|
@ -60,6 +60,6 @@ class PairedQueryLayer(NeuralNetworkLayer):
|
||||
# TODO make Create animation that is custom
|
||||
return FadeIn(self.assets)
|
||||
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Forward pass for query"""
|
||||
return AnimationGroup()
|
@ -17,7 +17,7 @@ class PairedQueryToFeedForward(ConnectiveLayer):
|
||||
self.paired_query_layer = input_layer
|
||||
self.feed_forward_layer = output_layer
|
||||
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
|
||||
animations = []
|
||||
# Loop through each image
|
||||
|
@ -12,7 +12,7 @@ class NeuralNetworkLayer(ABC, Group):
|
||||
# self.add(self.title)
|
||||
|
||||
@abstractmethod
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
pass
|
||||
|
||||
@override_animation(Create)
|
||||
@ -51,7 +51,7 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
||||
assert isinstance(output_layer, self.output_class)
|
||||
|
||||
@abstractmethod
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
pass
|
||||
|
||||
@override_animation(Create)
|
||||
|
@ -71,6 +71,6 @@ class TripletLayer(NeuralNetworkLayer):
|
||||
# TODO make Create animation that is custom
|
||||
return FadeIn(self.assets)
|
||||
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Forward pass for triplet"""
|
||||
return AnimationGroup()
|
||||
|
@ -18,7 +18,7 @@ class TripletToFeedForward(ConnectiveLayer):
|
||||
self.feed_forward_layer = output_layer
|
||||
self.triplet_layer = input_layer
|
||||
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
|
||||
animations = []
|
||||
# Loop through each image
|
||||
|
@ -28,7 +28,7 @@ class VectorLayer(VGroupNeuralNetworkLayer):
|
||||
|
||||
return vector_label
|
||||
|
||||
def make_forward_pass_animation(self, **kwargs):
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
return AnimationGroup()
|
||||
|
||||
@override_animation(Create)
|
||||
|
@ -9,254 +9,17 @@ Example:
|
||||
# Create the object with default style settings
|
||||
NeuralNetwork(layer_node_count)
|
||||
"""
|
||||
from cv2 import AGAST_FEATURE_DETECTOR_NONMAX_SUPPRESSION
|
||||
from manim import *
|
||||
import warnings
|
||||
import textwrap
|
||||
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
|
||||
from manim_ml.neural_network.layers.util import get_connective_layer
|
||||
from manim_ml.list_group import ListGroup
|
||||
from manim_ml.neural_network.neural_network_transformations import InsertLayer, RemoveLayer
|
||||
|
||||
class RemoveLayer(AnimationGroup):
|
||||
"""
|
||||
Animation for removing a layer from a neural network.
|
||||
|
||||
Note: I needed to do something strange for creating the new connective layer.
|
||||
The issue with creating it intially is that the positions of the sides of the
|
||||
connective layer depend upon the location of the moved layers **after** the
|
||||
move animations are performed. However, all of these animations are performed
|
||||
after the animations have been created. This means that the animation depends upon
|
||||
the state of the neural network layers after previous animations have been run.
|
||||
To fix this issue I needed to use an UpdateFromFunc.
|
||||
"""
|
||||
|
||||
def __init__(self, layer, neural_network, layer_spacing=0.2):
|
||||
self.layer = layer
|
||||
self.neural_network = neural_network
|
||||
self.layer_spacing = layer_spacing
|
||||
# Get the before and after layers
|
||||
layers_tuple = self.get_connective_layers()
|
||||
self.before_layer = layers_tuple[0]
|
||||
self.after_layer = layers_tuple[1]
|
||||
self.before_connective = layers_tuple[2]
|
||||
self.after_connective = layers_tuple[3]
|
||||
# Make the animations
|
||||
remove_animations = self.make_remove_animation()
|
||||
move_animations = self.make_move_animation()
|
||||
new_connective_animation = self.make_new_connective_animation()
|
||||
# Add all of the animations to the group
|
||||
animations_list = [
|
||||
remove_animations,
|
||||
move_animations,
|
||||
new_connective_animation
|
||||
]
|
||||
|
||||
super().__init__(*animations_list, lag_ratio=1.0)
|
||||
|
||||
def get_connective_layers(self):
|
||||
"""Gets the connective layers before and after self.layer"""
|
||||
# Get layer index
|
||||
layer_index = self.neural_network.all_layers.index_of(self.layer)
|
||||
if layer_index == -1:
|
||||
raise Exception("Layer object not found")
|
||||
# Get the layers before and after
|
||||
before_layer = None
|
||||
after_layer = None
|
||||
before_connective = None
|
||||
after_connective = None
|
||||
if layer_index - 2 >= 0:
|
||||
before_layer = self.neural_network.all_layers[layer_index - 2]
|
||||
before_connective = self.neural_network.all_layers[layer_index - 1]
|
||||
if layer_index + 2 < len(self.neural_network.all_layers):
|
||||
after_layer = self.neural_network.all_layers[layer_index + 2]
|
||||
after_connective = self.neural_network.all_layers[layer_index + 1]
|
||||
|
||||
return before_layer, after_layer, before_connective, after_connective
|
||||
|
||||
def make_remove_animation(self):
|
||||
"""Removes layer and the surrounding connective layers"""
|
||||
remove_layer_animation = self.make_remove_layer_animation()
|
||||
remove_connective_animation = self.make_remove_connective_layers_animation()
|
||||
# Remove animations
|
||||
remove_animations = AnimationGroup(
|
||||
remove_layer_animation,
|
||||
remove_connective_animation
|
||||
)
|
||||
|
||||
return remove_animations
|
||||
|
||||
def make_remove_layer_animation(self):
|
||||
"""Removes the layer"""
|
||||
# Remove the layer
|
||||
self.neural_network.all_layers.remove(self.layer)
|
||||
# Fade out the removed layer
|
||||
fade_out_removed = FadeOut(self.layer)
|
||||
return fade_out_removed
|
||||
|
||||
def make_remove_connective_layers_animation(self):
|
||||
"""Removes the connective layers before and after layer if they exist"""
|
||||
# Fade out the removed connective layers
|
||||
fade_out_before_connective = AnimationGroup()
|
||||
if not self.before_connective is None:
|
||||
self.neural_network.all_layers.remove(self.before_connective)
|
||||
fade_out_before_connective = FadeOut(self.before_connective)
|
||||
fade_out_after_connective = AnimationGroup()
|
||||
if not self.after_connective is None:
|
||||
self.neural_network.all_layers.remove(self.after_connective)
|
||||
fade_out_after_connective = FadeOut(self.after_connective)
|
||||
# Group items
|
||||
remove_connective_group = AnimationGroup(
|
||||
fade_out_after_connective,
|
||||
fade_out_before_connective
|
||||
)
|
||||
|
||||
return remove_connective_group
|
||||
|
||||
def make_move_animation(self):
|
||||
"""Collapses layers"""
|
||||
# Animate the movements
|
||||
move_before_layers = AnimationGroup()
|
||||
shift_right_amount = None
|
||||
if not self.before_layer is None:
|
||||
# Compute shift amount
|
||||
layer_dist = np.abs(self.layer.get_center() - self.before_layer.get_right())[0]
|
||||
shift_right_amount = np.array([layer_dist - self.layer_spacing/2, 0, 0])
|
||||
# Shift all layers before forward
|
||||
before_layer_index = self.neural_network.all_layers.index_of(self.before_layer)
|
||||
layers_before = Group(*self.neural_network.all_layers[:before_layer_index + 1])
|
||||
move_before_layers = layers_before.animate.shift(shift_right_amount)
|
||||
move_after_layers = AnimationGroup()
|
||||
shift_left_amount = None
|
||||
if not self.after_layer is None:
|
||||
layer_dist = np.abs(self.after_layer.get_left() - self.layer.get_center())[0]
|
||||
shift_left_amount = np.array([-layer_dist + self.layer_spacing / 2, 0, 0])
|
||||
# Shift all layers after backward
|
||||
after_layer_index = self.neural_network.all_layers.index_of(self.after_layer)
|
||||
layers_after = Group(*self.neural_network.all_layers[after_layer_index:])
|
||||
move_after_layers = layers_after.animate.shift(shift_left_amount)
|
||||
# Group the move animations
|
||||
move_group = AnimationGroup(
|
||||
move_before_layers,
|
||||
move_after_layers
|
||||
)
|
||||
|
||||
return move_group
|
||||
|
||||
def make_new_connective_animation(self):
|
||||
"""Makes new connective layer"""
|
||||
self.anim_count = 0
|
||||
def create_new_connective(neural_network):
|
||||
"""
|
||||
Creates new connective layer
|
||||
|
||||
This is a closure that creates a new connective layer and animates it.
|
||||
"""
|
||||
self.anim_count += 1
|
||||
if self.anim_count == 1:
|
||||
if not self.before_layer is None and not self.after_layer is None:
|
||||
print(neural_network)
|
||||
new_connective = get_connective_layer(self.before_layer, self.after_layer)
|
||||
before_layer_index = neural_network.all_layers.index_of(self.before_layer) + 1
|
||||
neural_network.all_layers.insert(before_layer_index, new_connective)
|
||||
print(neural_network)
|
||||
|
||||
update_func_anim = UpdateFromFunc(self.neural_network, create_new_connective)
|
||||
|
||||
return update_func_anim
|
||||
|
||||
class InsertLayer(AnimationGroup):
|
||||
"""Animation for inserting layer at given index"""
|
||||
|
||||
def __init__(self, layer, index, neural_network):
|
||||
self.layer = layer
|
||||
self.index = index
|
||||
self.neural_network = neural_network
|
||||
# Layers before and after
|
||||
self.layers_before = self.neural_network.all_layers[:self.index]
|
||||
self.layers_after = self.neural_network.all_layers[self.index:]
|
||||
|
||||
remove_connective_layer = self.remove_connective_layer()
|
||||
move_layers = self.make_move_layers()
|
||||
# create_layer = self.make_create_layer()
|
||||
# create_connective_layers = self.make_create_connective_layers()
|
||||
animations = [
|
||||
remove_connective_layer,
|
||||
move_layers,
|
||||
# create_layer,
|
||||
# create_connective_layers
|
||||
]
|
||||
|
||||
super().__init__(*animations, lag_ratio=1.0)
|
||||
|
||||
def remove_connective_layer(self):
|
||||
"""Removes the connective layer before the insertion index"""
|
||||
# Check if connective layer exists
|
||||
if len(self.layers_before) > 0:
|
||||
removed_connective = self.layers_before[-1]
|
||||
self.neural_network.all_layers.remove(removed_connective)
|
||||
# Make remove animation
|
||||
remove_animation = FadeOut(removed_connective)
|
||||
return remove_animation
|
||||
|
||||
return AnimationGroup()
|
||||
|
||||
def make_move_layers(self):
|
||||
"""Shifts layers before and after"""
|
||||
# Before layer shift
|
||||
before_shift_animation = AnimationGroup()
|
||||
if len(self.layers_before) > 0:
|
||||
before_shift = np.array([-self.layer.width/2, 0, 0])
|
||||
# Shift layers before
|
||||
before_shift_animation = Group(*self.layers_before).animate.shift(before_shift)
|
||||
# After layer shift
|
||||
after_shift_animation = AnimationGroup()
|
||||
if len(self.layers_after) > 0:
|
||||
after_shift = np.array([self.layer.width/2, 0, 0])
|
||||
# Shift layers after
|
||||
after_shift_animation = Group(*self.layers_after).animate.shift(after_shift)
|
||||
# Make animation group
|
||||
shift_animations = AnimationGroup(
|
||||
before_shift_animation,
|
||||
after_shift_animation
|
||||
)
|
||||
|
||||
return shift_animations
|
||||
|
||||
def make_create_layer(self):
|
||||
"""Animates the creation of the layer"""
|
||||
pass
|
||||
|
||||
def make_create_connective_layers(self):
|
||||
pass
|
||||
|
||||
|
||||
# Make connective layers and shift animations
|
||||
# Before layer
|
||||
if len(layers_before) > 0:
|
||||
before_connective = get_connective_layer(layers_before[-1], layer)
|
||||
before_shift = np.array([-layer.width/2, 0, 0])
|
||||
# Shift layers before
|
||||
before_shift_animation = Group(*layers_before).animate.shift(before_shift)
|
||||
else:
|
||||
before_connective = AnimationGroup()
|
||||
# After layer
|
||||
if len(layers_after) > 0:
|
||||
after_connective = get_connective_layer(layer, layers_after[0])
|
||||
after_shift = np.array([layer.width/2, 0, 0])
|
||||
# Shift layers after
|
||||
after_shift_animation = Group(*layers_after).animate.shift(after_shift)
|
||||
else:
|
||||
after_connective = AnimationGroup
|
||||
|
||||
insert_animation = Create(layer)
|
||||
animation_group = AnimationGroup(
|
||||
shift_animations,
|
||||
insert_animation,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
class NeuralNetwork(Group):
|
||||
|
||||
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.2,
|
||||
@ -347,19 +110,32 @@ class NeuralNetwork(Group):
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_forward_pass_animation(self, run_time=10, passing_flash=True,
|
||||
def make_forward_pass_animation(self, run_time=10, passing_flash=True, layer_args={},
|
||||
**kwargs):
|
||||
"""Generates an animation for feed forward propagation"""
|
||||
all_animations = []
|
||||
for layer_index, layer in enumerate(self.input_layers[:-1]):
|
||||
layer_forward_pass = layer.make_forward_pass_animation(**kwargs)
|
||||
for layer_index, layer in enumerate(self.all_layers):
|
||||
# Get the layer args
|
||||
if isinstance(layer, ConnectiveLayer):
|
||||
"""
|
||||
NOTE: By default a connective layer will get the combined
|
||||
layer_args of the layers it is connecting.
|
||||
"""
|
||||
before_layer_args = {}
|
||||
after_layer_args = {}
|
||||
if layer.input_layer in layer_args:
|
||||
before_layer_args = layer_args[layer.input_layer]
|
||||
if layer.output_layer in layer_args:
|
||||
after_layer_args = layer_args[layer.output_layer]
|
||||
# Merge the two dicts
|
||||
current_layer_args = {**before_layer_args, **after_layer_args}
|
||||
else:
|
||||
current_layer_args = {}
|
||||
if layer in layer_args:
|
||||
current_layer_args = layer_args[layer]
|
||||
# Perform the forward pass of the current layer
|
||||
layer_forward_pass = layer.make_forward_pass_animation(layer_args=current_layer_args, **kwargs)
|
||||
all_animations.append(layer_forward_pass)
|
||||
connective_layer = self.connective_layers[layer_index]
|
||||
connective_forward_pass = connective_layer.make_forward_pass_animation(**kwargs)
|
||||
all_animations.append(connective_forward_pass)
|
||||
# Do last layer animation
|
||||
last_layer_forward_pass = self.input_layers[-1].make_forward_pass_animation(**kwargs)
|
||||
all_animations.append(last_layer_forward_pass)
|
||||
# Make the animation group
|
||||
animation_group = AnimationGroup(*all_animations, run_time=run_time, lag_ratio=1.0)
|
||||
|
||||
|
54
tests/test_embedding_layer.py
Normal file
54
tests/test_embedding_layer.py
Normal file
@ -0,0 +1,54 @@
|
||||
from manim import *
|
||||
|
||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||
|
||||
config.pixel_height = 720
|
||||
config.pixel_width = 1280
|
||||
config.frame_height = 5.0
|
||||
config.frame_width = 5.0
|
||||
|
||||
class EmbeddingNNScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
embedding_layer = EmbeddingLayer()
|
||||
|
||||
neural_network = NeuralNetwork([
|
||||
FeedForwardLayer(5),
|
||||
FeedForwardLayer(3),
|
||||
embedding_layer,
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(5)
|
||||
])
|
||||
|
||||
self.play(Create(neural_network))
|
||||
|
||||
self.play(neural_network.make_forward_pass_animation(run_time=5))
|
||||
|
||||
class QueryEmbeddingNNScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
embedding_layer = EmbeddingLayer()
|
||||
embedding_layer.paired_query_mode = True
|
||||
|
||||
neural_network = NeuralNetwork([
|
||||
FeedForwardLayer(5),
|
||||
FeedForwardLayer(3),
|
||||
embedding_layer,
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(5)
|
||||
])
|
||||
|
||||
self.play(Create(neural_network), run_time=2)
|
||||
|
||||
self.play(
|
||||
neural_network.make_forward_pass_animation(
|
||||
run_time=5,
|
||||
layer_args={
|
||||
embedding_layer: {
|
||||
"query_locations": (np.array([2, 2]), np.array([1, 1]))
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
Reference in New Issue
Block a user