mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-21 12:37:01 +08:00
Added ListGroup class for better management of a group of objects.
This commit is contained in:
73
manim_ml/list_group.py
Normal file
73
manim_ml/list_group.py
Normal file
@ -0,0 +1,73 @@
|
||||
from manim import *
|
||||
|
||||
class ListGroup(Mobject):
|
||||
"""Indexable Group with traditional list operations"""
|
||||
|
||||
def __init__(self, *layers):
|
||||
super().__init__()
|
||||
self.items = [*layers]
|
||||
|
||||
def __getitem__(self, indices):
|
||||
"""Traditional list indexing"""
|
||||
return self.items[indices]
|
||||
|
||||
def insert(self, index, item):
|
||||
"""Inserts item at index"""
|
||||
self.items.insert(index, item)
|
||||
self.submobjects = self.items
|
||||
|
||||
def remove_at_index(self, index):
|
||||
"""Removes item at index"""
|
||||
if index < 0 or index > len(self.items):
|
||||
raise Exception(f"ListGroup index out of range: {index}")
|
||||
item = self.items[index]
|
||||
del self.items[index]
|
||||
self.submobjects = self.items
|
||||
|
||||
return item
|
||||
|
||||
def remove_at_indices(self, indices):
|
||||
"""Removes items at indices"""
|
||||
items = []
|
||||
for index in indices:
|
||||
item = self.remove_at_index(index)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
def remove(self, item):
|
||||
"""Removes first instance of item"""
|
||||
self.items.remove(item)
|
||||
self.submobjects = self.items
|
||||
|
||||
return item
|
||||
|
||||
def get(self, index):
|
||||
"""Gets item at index"""
|
||||
return self.items[index]
|
||||
|
||||
def add(self, item):
|
||||
"""Adds to end"""
|
||||
self.items.append(item)
|
||||
self.submobjects = self.items
|
||||
|
||||
def replace(self, index, item):
|
||||
"""Replaces item at index"""
|
||||
self.items[index] = item
|
||||
self.submobjects = self.items
|
||||
|
||||
def index_of(self, item):
|
||||
"""Returns index of item if it exists"""
|
||||
for index, obj in enumerate(self.items):
|
||||
if item is obj:
|
||||
return index
|
||||
return -1
|
||||
|
||||
def __len__(self):
|
||||
"""Length of items"""
|
||||
return len(self.items)
|
||||
|
||||
def set_z_index(self, z_index_value):
|
||||
"""Sets z index of all values in ListGroup"""
|
||||
for item in self.items:
|
||||
item.set_z_index(z_index_value)
|
@ -5,8 +5,8 @@ from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLaye
|
||||
class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
"""NeuralNetwork embedding object that can show probability distributions"""
|
||||
|
||||
def __init__(self, point_radius=0.02):
|
||||
super(EmbeddingLayer, self).__init__()
|
||||
def __init__(self, point_radius=0.02, **kwargs):
|
||||
super(EmbeddingLayer, self).__init__(**kwargs)
|
||||
self.point_radius = point_radius
|
||||
self.axes = Axes(
|
||||
tips=False,
|
||||
|
@ -8,8 +8,10 @@ class EmbeddingToFeedForward(ConnectiveLayer):
|
||||
input_class = EmbeddingLayer
|
||||
output_class = FeedForwardLayer
|
||||
|
||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03):
|
||||
super().__init__(input_layer, output_layer, input_class=EmbeddingLayer, output_class=FeedForwardLayer)
|
||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03,
|
||||
**kwargs):
|
||||
super().__init__(input_layer, output_layer, input_class=EmbeddingLayer, output_class=FeedForwardLayer,
|
||||
**kwargs)
|
||||
self.feed_forward_layer = output_layer
|
||||
self.embedding_layer = input_layer
|
||||
self.animation_dot_color = animation_dot_color
|
||||
|
@ -7,8 +7,8 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
|
||||
def __init__(self, num_nodes, layer_buffer=SMALL_BUFF/2, node_radius=0.08,
|
||||
node_color=BLUE, node_outline_color=WHITE, rectangle_color=WHITE,
|
||||
node_spacing=0.3, rectangle_fill_color=BLACK, node_stroke_width=2.0,
|
||||
rectangle_stroke_width=2.0, animation_dot_color=RED):
|
||||
super(VGroupNeuralNetworkLayer, self).__init__()
|
||||
rectangle_stroke_width=2.0, animation_dot_color=RED, **kwargs):
|
||||
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
|
||||
self.num_nodes = num_nodes
|
||||
self.layer_buffer = layer_buffer
|
||||
self.node_radius = node_radius
|
||||
|
@ -8,8 +8,10 @@ class FeedForwardToEmbedding(ConnectiveLayer):
|
||||
input_class = FeedForwardLayer
|
||||
output_class = EmbeddingLayer
|
||||
|
||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03):
|
||||
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=EmbeddingLayer)
|
||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03,
|
||||
**kwargs):
|
||||
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=EmbeddingLayer,
|
||||
**kwargs)
|
||||
self.feed_forward_layer = input_layer
|
||||
self.embedding_layer = output_layer
|
||||
self.animation_dot_color = animation_dot_color
|
||||
|
@ -9,8 +9,9 @@ class FeedForwardToFeedForward(ConnectiveLayer):
|
||||
|
||||
def __init__(self, input_layer, output_layer, passing_flash=True,
|
||||
dot_radius=0.05, animation_dot_color=RED, edge_color=WHITE,
|
||||
edge_width=0.5):
|
||||
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=FeedForwardLayer)
|
||||
edge_width=1.5, **kwargs):
|
||||
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=FeedForwardLayer,
|
||||
**kwargs)
|
||||
self.passing_flash = passing_flash
|
||||
self.edge_color = edge_color
|
||||
self.dot_radius = dot_radius
|
||||
|
@ -9,8 +9,9 @@ class FeedForwardToImage(ConnectiveLayer):
|
||||
output_class = ImageLayer
|
||||
|
||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
|
||||
dot_radius=0.05):
|
||||
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=ImageLayer)
|
||||
dot_radius=0.05, **kwargs):
|
||||
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=ImageLayer
|
||||
**kwargs)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
|
@ -5,9 +5,8 @@ from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer
|
||||
class ImageLayer(NeuralNetworkLayer):
|
||||
"""Single Image Layer for Neural Network"""
|
||||
|
||||
def __init__(self, numpy_image, height=1.5):
|
||||
super().__init__()
|
||||
self.set_z_index(1)
|
||||
def __init__(self, numpy_image, height=1.5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.numpy_image = numpy_image
|
||||
if len(np.shape(self.numpy_image)) == 2:
|
||||
# Assumed Grayscale
|
||||
|
@ -9,8 +9,9 @@ class ImageToFeedForward(ConnectiveLayer):
|
||||
output_class = FeedForwardLayer
|
||||
|
||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
|
||||
dot_radius=0.05):
|
||||
super().__init__(input_layer, output_layer, input_class=ImageLayer, output_class=FeedForwardLayer)
|
||||
dot_radius=0.05, **kwargs):
|
||||
super().__init__(input_layer, output_layer, input_class=ImageLayer, output_class=FeedForwardLayer
|
||||
**kwargs)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
|
@ -6,8 +6,8 @@ import numpy as np
|
||||
class PairedQueryLayer(NeuralNetworkLayer):
|
||||
"""Paired Query Layer"""
|
||||
|
||||
def __init__(self, positive, negative, stroke_width=5):
|
||||
super().__init__()
|
||||
def __init__(self, positive, negative, stroke_width=5, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.positive = positive
|
||||
self.negative = negative
|
||||
|
||||
|
@ -8,8 +8,9 @@ class PairedQueryToFeedForward(ConnectiveLayer):
|
||||
input_class = PairedQueryLayer
|
||||
output_class = FeedForwardLayer
|
||||
|
||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.02):
|
||||
super().__init__(input_layer, output_layer, input_class=PairedQueryLayer, output_class=FeedForwardLayer)
|
||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.02, **kwargs):
|
||||
super().__init__(input_layer, output_layer, input_class=PairedQueryLayer, output_class=FeedForwardLayer
|
||||
**kwargs)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
|
@ -4,9 +4,8 @@ from abc import ABC, abstractmethod
|
||||
class NeuralNetworkLayer(ABC, Group):
|
||||
"""Abstract Neural Network Layer class"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, text=None, **kwargs):
|
||||
super(Group, self).__init__()
|
||||
self.set_z_index(1)
|
||||
|
||||
@abstractmethod
|
||||
def make_forward_pass_animation(self):
|
||||
@ -28,8 +27,9 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
||||
"""Forward pass animation for a given pair of layers"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, input_layer, output_layer, input_class=None, output_class=None):
|
||||
super(VGroupNeuralNetworkLayer, self).__init__()
|
||||
def __init__(self, input_layer, output_layer, input_class=None, output_class=None,
|
||||
**kwargs):
|
||||
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
|
||||
self.input_layer = input_layer
|
||||
self.output_layer = output_layer
|
||||
self.input_class = input_class
|
||||
@ -38,8 +38,6 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
||||
assert isinstance(input_layer, self.input_class)
|
||||
assert isinstance(output_layer, self.output_class)
|
||||
|
||||
self.set_z_index(-1)
|
||||
|
||||
@abstractmethod
|
||||
def make_forward_pass_animation(self):
|
||||
pass
|
@ -6,8 +6,9 @@ import numpy as np
|
||||
class TripletLayer(NeuralNetworkLayer):
|
||||
"""Shows triplet images"""
|
||||
|
||||
def __init__(self, anchor, positive, negative, stroke_width=5):
|
||||
super().__init__()
|
||||
def __init__(self, anchor, positive, negative, stroke_width=5,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.anchor = anchor
|
||||
self.positive = positive
|
||||
self.negative = negative
|
||||
|
@ -9,8 +9,9 @@ class TripletToFeedForward(ConnectiveLayer):
|
||||
output_class = FeedForwardLayer
|
||||
|
||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
|
||||
dot_radius=0.02):
|
||||
super().__init__(input_layer, output_layer, input_class=TripletLayer, output_class=FeedForwardLayer)
|
||||
dot_radius=0.02, **kwargs):
|
||||
super().__init__(input_layer, output_layer, input_class=TripletLayer, output_class=FeedForwardLayer,
|
||||
**kwargs)
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
|
||||
|
20
manim_ml/neural_network/layers/util.py
Normal file
20
manim_ml/neural_network/layers/util.py
Normal file
@ -0,0 +1,20 @@
|
||||
from manim import *
|
||||
from ..layers import connective_layers_list
|
||||
|
||||
def get_connective_layer(input_layer, output_layer):
|
||||
"""
|
||||
Deduces the relevant connective layer
|
||||
"""
|
||||
connective_layer = None
|
||||
for connective_layer_class in connective_layers_list:
|
||||
input_class = connective_layer_class.input_class
|
||||
output_class = connective_layer_class.output_class
|
||||
if isinstance(input_layer, input_class) \
|
||||
and isinstance(output_layer, output_class):
|
||||
connective_layer = connective_layer_class(input_layer, output_layer)
|
||||
|
||||
if connective_layer is None:
|
||||
raise Exception(f"Unrecognized class pair {input_layer.__class__.__name__}" + \
|
||||
" and {output_layer.__class__.__name__}")
|
||||
|
||||
return connective_layer
|
@ -9,24 +9,23 @@ Example:
|
||||
# Create the object with default style settings
|
||||
NeuralNetwork(layer_node_count)
|
||||
"""
|
||||
from socket import create_connection
|
||||
from urllib.parse import non_hierarchical
|
||||
from manim import *
|
||||
import warnings
|
||||
import textwrap
|
||||
|
||||
from manim_ml.neural_network.layers import \
|
||||
FeedForwardLayer, FeedForwardToFeedForward, ImageLayer, \
|
||||
ImageToFeedForward, FeedForwardToImage, EmbeddingLayer, \
|
||||
EmbeddingToFeedForward, FeedForwardToEmbedding, TripletLayer, \
|
||||
TripletToFeedForward
|
||||
|
||||
from manim_ml.neural_network.layers import connective_layers_list
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.util import get_connective_layer
|
||||
from manim_ml.list_group import ListGroup
|
||||
|
||||
class NeuralNetwork(Group):
|
||||
|
||||
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
|
||||
animation_dot_color=RED, edge_width=1.5, dot_radius=0.03):
|
||||
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03):
|
||||
super(Group, self).__init__()
|
||||
self.input_layers = Group(*input_layers)
|
||||
self.input_layers = ListGroup(*input_layers)
|
||||
self.edge_width = edge_width
|
||||
self.edge_color = edge_color
|
||||
self.layer_spacing = layer_spacing
|
||||
@ -37,6 +36,9 @@ class NeuralNetwork(Group):
|
||||
# and make it have explicit distinct subspaces
|
||||
self._place_layers()
|
||||
self.connective_layers, self.all_layers = self._construct_connective_layers()
|
||||
# Place layers at correct z index
|
||||
self.connective_layers.set_z_index(2)
|
||||
self.input_layers.set_z_index(3)
|
||||
# Center the whole diagram by default
|
||||
self.all_layers.move_to(ORIGIN)
|
||||
self.add(self.all_layers)
|
||||
@ -50,17 +52,14 @@ class NeuralNetwork(Group):
|
||||
for layer_index in range(1, len(self.input_layers)):
|
||||
previous_layer = self.input_layers[layer_index - 1]
|
||||
current_layer = self.input_layers[layer_index]
|
||||
|
||||
current_layer.move_to(previous_layer)
|
||||
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + 0.2, 0, 0])
|
||||
current_layer.shift(shift_vector)
|
||||
# Handle layering
|
||||
self.input_layers.set_z_index(2)
|
||||
|
||||
def _construct_connective_layers(self):
|
||||
"""Draws connecting lines between layers"""
|
||||
connective_layers = Group()
|
||||
all_layers = Group()
|
||||
connective_layers = ListGroup()
|
||||
all_layers = ListGroup()
|
||||
for layer_index in range(len(self.input_layers) - 1):
|
||||
current_layer = self.input_layers[layer_index]
|
||||
all_layers.add(current_layer)
|
||||
@ -72,28 +71,196 @@ class NeuralNetwork(Group):
|
||||
if isinstance(next_layer, NeuralNetwork):
|
||||
# First layer of the next layer
|
||||
next_layer = next_layer.all_layers[0]
|
||||
|
||||
# Find connective layer with correct layer pair
|
||||
connective_layer = None
|
||||
for connective_layer_class in connective_layers_list:
|
||||
input_class = connective_layer_class.input_class
|
||||
output_class = connective_layer_class.output_class
|
||||
if isinstance(current_layer, input_class) \
|
||||
and isinstance(next_layer, output_class):
|
||||
connective_layer = connective_layer_class(current_layer, next_layer)
|
||||
|
||||
connective_layers.add(connective_layer)
|
||||
all_layers.add(connective_layer)
|
||||
|
||||
if connective_layer is None:
|
||||
raise Exception(f"Unrecognized class pair {current_layer.__class__.__name__} and {next_layer.__class__.__name__}")
|
||||
connective_layer = get_connective_layer(current_layer, next_layer)
|
||||
connective_layers.add(connective_layer)
|
||||
all_layers.add(connective_layer)
|
||||
# Add final layer
|
||||
all_layers.add(self.input_layers[-1])
|
||||
# Handle layering
|
||||
return connective_layers, all_layers
|
||||
|
||||
def insert_layer(self, layer, insert_index):
|
||||
"""Inserts a layer at the given index"""
|
||||
layers_before = self.all_layers[:insert_index]
|
||||
layers_after = self.all_layers[insert_index:]
|
||||
# Make connective layers and shift animations
|
||||
# Before layer
|
||||
if len(layers_before) > 0:
|
||||
before_connective = get_connective_layer(layers_before[-1], layer)
|
||||
before_shift = np.array([-layer.width/2, 0, 0])
|
||||
# Shift layers before
|
||||
before_shift_animation = Group(*layers_before).animate.shift(before_shift)
|
||||
else:
|
||||
before_connective = AnimationGroup()
|
||||
# After layer
|
||||
if len(layers_after) > 0:
|
||||
after_connective = get_connective_layer(layer, layers_after[0])
|
||||
after_shift = np.array([layer.width/2, 0, 0])
|
||||
# Shift layers after
|
||||
after_shift_animation = Group(*layers_after).animate.shift(after_shift)
|
||||
else:
|
||||
after_connective = AnimationGroup
|
||||
|
||||
# Make animation group
|
||||
shift_animations = AnimationGroup(
|
||||
before_shift_animation,
|
||||
after_shift_animation
|
||||
)
|
||||
|
||||
insert_animation = Create(layer)
|
||||
animation_group = AnimationGroup(
|
||||
shift_animations,
|
||||
insert_animation,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
def remove_layer(self, layer):
|
||||
"""Removes layer object if it exists"""
|
||||
# Get layer index
|
||||
layer_index = self.all_layers.index_of(layer)
|
||||
if layer_index == -1:
|
||||
raise Exception("Layer object not found")
|
||||
# Get the layers before and after
|
||||
before_layer = None
|
||||
after_layer = None
|
||||
if layer_index - 2 >= 0:
|
||||
before_layer = self.all_layers[layer_index - 2]
|
||||
if layer_index + 2 < len(self.all_layers):
|
||||
after_layer = self.all_layers[layer_index + 2]
|
||||
# Remove the layer
|
||||
self.all_layers.remove(layer)
|
||||
# Remove surrounding connective layers from self.all_layers
|
||||
before_connective = None
|
||||
after_connective = None
|
||||
if layer_index - 1 >= 0:
|
||||
# There is a layer before
|
||||
before_connective = self.all_layers.remove_at_index(layer_index - 1)
|
||||
if layer_index + 1 < len(self.all_layers):
|
||||
# There is a layer after
|
||||
after_connective = self.all_layers.remove_at_index(layer_index + 1)
|
||||
# Make animations
|
||||
# Fade out the removed layer
|
||||
fade_out_removed = FadeOut(layer)
|
||||
# Fade out the removed connective layers
|
||||
fade_out_before_connective = Animation()
|
||||
if not before_connective is None:
|
||||
fade_out_before_connective = FadeOut(before_connective)
|
||||
fade_out_after_connective = Animation()
|
||||
if not after_connective is None:
|
||||
fade_out_after_connective = FadeOut(after_connective)
|
||||
# Create new connective layer
|
||||
new_connective = None
|
||||
if not before_layer is None and not after_layer is None:
|
||||
new_connective = get_connective_layer(before_layer, after_layer)
|
||||
before_layer_index = self.all_layers.index_of(before_layer)
|
||||
self.all_layers.insert(before_layer_index, new_connective)
|
||||
# Place the new connective
|
||||
new_connective.move_to(layer)
|
||||
# Animate the creation of the new connective layer
|
||||
create_new_connective = Animation()
|
||||
if not new_connective is None:
|
||||
create_new_connective = Create(new_connective)
|
||||
# Collapse the neural network to fill the empty space
|
||||
removed_width = layer.width + before_connective.width + after_connective.width - new_connective.width
|
||||
shift_right_amount = np.array([removed_width / 2, 0, 0])
|
||||
shift_left_amount = np.array([-removed_width / 2, 0, 0])
|
||||
move_before_layer = Animation()
|
||||
if not before_layer is None:
|
||||
move_before_layer = before_layer.animate.shift(shift_right_amount)
|
||||
move_after_layer = Animation()
|
||||
if not after_layer is None:
|
||||
move_after_layer = after_layer.animate.shift(shift_left_amount)
|
||||
# Make the final AnimationGroup
|
||||
fade_out_group = AnimationGroup(
|
||||
fade_out_removed,
|
||||
fade_out_before_connective,
|
||||
fade_out_after_connective
|
||||
)
|
||||
move_group = AnimationGroup(
|
||||
move_before_layer,
|
||||
move_after_layer
|
||||
)
|
||||
animation_group = AnimationGroup(
|
||||
fade_out_group,
|
||||
move_group,
|
||||
create_new_connective,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
"""
|
||||
remove_layer = list(self.all_layers)[remove_index]
|
||||
if remove_index > 0:
|
||||
connective_before = list(self.all_layers)[remove_index - 1]
|
||||
else:
|
||||
connective_before = None
|
||||
if remove_index < len(list(self.all_layers)) - 1:
|
||||
connective_after = list(self.all_layers)[remove_index + 1]
|
||||
else:
|
||||
connective_after = None
|
||||
# Collapse the surrounding layer
|
||||
layers_before = list(self.all_layers)[:remove_index]
|
||||
layers_after = list(self.all_layers)[remove_index+1:]
|
||||
before_group = Group(*layers_before)
|
||||
after_group = Group(*layers_after)
|
||||
before_shift_amount = np.array([remove_layer.width/2, 0, 0])
|
||||
after_shift_amount = np.array([-remove_layer.width/2, 0, 0])
|
||||
# Remove the layers from the neural network representation
|
||||
self.all_layers.remove(remove_layer)
|
||||
if not connective_before is None:
|
||||
self.all_layers.remove(connective_before)
|
||||
if not connective_after is None:
|
||||
self.all_layers.remove(connective_after)
|
||||
# Connect the layers before and layers after
|
||||
pre_index = remove_index - 1
|
||||
pre_layer = None
|
||||
if pre_index >= 0:
|
||||
pre_layer = list(self.all_layers)[pre_index]
|
||||
post_index = remove_index
|
||||
post_layer = None
|
||||
if post_index < len(list(self.all_layers)):
|
||||
post_layer = list(self.all_layers)[post_index]
|
||||
if not pre_layer is None and not post_layer is None:
|
||||
connective_layer = get_connective_layer(pre_layer, post_layer)
|
||||
self.all_layers = Group(
|
||||
*self.all_layers[:remove_index],
|
||||
connective_layer,
|
||||
*self.all_layers[remove_index:]
|
||||
)
|
||||
# Make animations
|
||||
fade_out_animation = FadeOut(remove_layer)
|
||||
shift_animations = AnimationGroup(
|
||||
before_group.animate.shift(before_shift_amount),
|
||||
after_group.animate.shift(after_shift_amount)
|
||||
)
|
||||
animation_group = AnimationGroup(
|
||||
fade_out_animation,
|
||||
shift_animations,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
"""
|
||||
|
||||
def replace_layer(self, old_layer, new_layer):
|
||||
"""Replaces given layer object"""
|
||||
remove_animation = self.remove_layer(insert_index)
|
||||
insert_animation = self.insert_layer(layer, insert_index)
|
||||
# Make the animation
|
||||
animation_group = AnimationGroup(
|
||||
FadeOut(self.all_layers[insert_index]),
|
||||
FadeIn(layer),
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_forward_pass_animation(self, run_time=10, passing_flash=True):
|
||||
"""Generates an animation for feed forward propogation"""
|
||||
"""Generates an animation for feed forward propagation"""
|
||||
all_animations = []
|
||||
for layer_index, layer in enumerate(self.input_layers[:-1]):
|
||||
layer_forward_pass = layer.make_forward_pass_animation()
|
||||
@ -126,19 +293,11 @@ class NeuralNetwork(Group):
|
||||
|
||||
return animation_group
|
||||
|
||||
def remove_layer(self, layer_index):
|
||||
"""Removes layer at given index and returns animation for removing the layer"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_layer(self, layer):
|
||||
"""Adds layer and returns animation for adding action"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __repr__(self):
|
||||
"""Print string representation of layers"""
|
||||
inner_string = ""
|
||||
for layer in self.all_layers:
|
||||
inner_string += f"{repr(layer)},\n"
|
||||
inner_string += f"{repr(layer)} {layer.z_index} ,\n"
|
||||
inner_string = textwrap.indent(inner_string, " ")
|
||||
|
||||
string_repr = "NeuralNetwork([\n" + inner_string + "])"
|
||||
|
@ -16,7 +16,6 @@ class GaussianDistribution(VGroup):
|
||||
self.cov = np.array([[3, 0], [0, 3]])
|
||||
# Make the Gaussian
|
||||
self.ellipses = self.construct_gaussian_distribution(self.mean, self.cov)
|
||||
self.ellipses.set_z_index(2)
|
||||
|
||||
@override_animation(Create)
|
||||
def _create_gaussian_distribution(self):
|
||||
|
@ -1,6 +1,10 @@
|
||||
from cv2 import exp
|
||||
from manim import *
|
||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||
from manim_ml.neural_network.layers.embedding_to_feed_forward import EmbeddingToFeedForward
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.feed_forward_to_embedding import FeedForwardToEmbedding
|
||||
from manim_ml.neural_network.layers.feed_forward_to_feed_forward import FeedForwardToFeedForward
|
||||
from manim_ml.neural_network.layers.image import ImageLayer
|
||||
from manim_ml.neural_network.neural_network import NeuralNetwork, FeedForwardNeuralNetwork
|
||||
from PIL import Image
|
||||
@ -11,6 +15,72 @@ config.pixel_width = 1280
|
||||
config.frame_height = 6.0
|
||||
config.frame_width = 6.0
|
||||
|
||||
"""
|
||||
Unit Tests
|
||||
"""
|
||||
|
||||
def assert_classes_match(all_layers, expected_classes):
|
||||
assert len(list(all_layers)) == 5
|
||||
|
||||
for index, layer in enumerate(all_layers):
|
||||
expected_class = expected_classes[index]
|
||||
assert isinstance(layer, expected_class), f"Wrong layer class {layer.__class__} expected {expected_class}"
|
||||
|
||||
def test_embedding_layer():
|
||||
embedding_layer = EmbeddingLayer()
|
||||
|
||||
neural_network = NeuralNetwork([
|
||||
FeedForwardLayer(5),
|
||||
FeedForwardLayer(3),
|
||||
embedding_layer
|
||||
])
|
||||
|
||||
expected_classes = [
|
||||
FeedForwardLayer,
|
||||
FeedForwardToFeedForward,
|
||||
FeedForwardLayer,
|
||||
FeedForwardToEmbedding,
|
||||
EmbeddingLayer
|
||||
]
|
||||
|
||||
assert_classes_match(neural_network.all_layers, expected_classes)
|
||||
|
||||
|
||||
def test_remove_layer():
|
||||
embedding_layer = EmbeddingLayer()
|
||||
|
||||
neural_network = NeuralNetwork([
|
||||
FeedForwardLayer(5),
|
||||
FeedForwardLayer(3),
|
||||
embedding_layer
|
||||
])
|
||||
|
||||
expected_classes = [
|
||||
FeedForwardLayer,
|
||||
FeedForwardToFeedForward,
|
||||
FeedForwardLayer,
|
||||
FeedForwardToEmbedding,
|
||||
EmbeddingLayer
|
||||
]
|
||||
|
||||
assert_classes_match(neural_network.all_layers, expected_classes)
|
||||
|
||||
print("before removal")
|
||||
print(list(neural_network.all_layers))
|
||||
neural_network.remove_layer(embedding_layer)
|
||||
print("after removal")
|
||||
print(list(neural_network.all_layers))
|
||||
|
||||
expected_classes = [
|
||||
FeedForwardLayer,
|
||||
FeedForwardToFeedForward,
|
||||
FeedForwardLayer,
|
||||
]
|
||||
|
||||
print(list(neural_network.all_layers))
|
||||
|
||||
assert_classes_match(neural_network.all_layers, expected_classes)
|
||||
|
||||
class FeedForwardNeuralNetworkScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
@ -92,6 +162,31 @@ class RecursiveNNScene(Scene):
|
||||
|
||||
self.play(Create(nn))
|
||||
|
||||
class LayerInsertionScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
pass
|
||||
|
||||
class LayerRemovalScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
image = Image.open('images/image.jpeg')
|
||||
numpy_image = np.asarray(image)
|
||||
|
||||
layer = FeedForwardLayer(5),
|
||||
layers = [
|
||||
ImageLayer(numpy_image, height=1.4),
|
||||
FeedForwardLayer(3),
|
||||
layer,
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(6)
|
||||
]
|
||||
|
||||
nn = NeuralNetwork(layers)
|
||||
|
||||
self.play(Create(nn))
|
||||
self.play(nn.remove_layer(layer))
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Render all scenes"""
|
||||
# Feed Forward Neural Network
|
||||
|
Reference in New Issue
Block a user