mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-15 07:57:41 +08:00
Added ListGroup class for better management of a group of objects.
This commit is contained in:
73
manim_ml/list_group.py
Normal file
73
manim_ml/list_group.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from manim import *
|
||||||
|
|
||||||
|
class ListGroup(Mobject):
|
||||||
|
"""Indexable Group with traditional list operations"""
|
||||||
|
|
||||||
|
def __init__(self, *layers):
|
||||||
|
super().__init__()
|
||||||
|
self.items = [*layers]
|
||||||
|
|
||||||
|
def __getitem__(self, indices):
|
||||||
|
"""Traditional list indexing"""
|
||||||
|
return self.items[indices]
|
||||||
|
|
||||||
|
def insert(self, index, item):
|
||||||
|
"""Inserts item at index"""
|
||||||
|
self.items.insert(index, item)
|
||||||
|
self.submobjects = self.items
|
||||||
|
|
||||||
|
def remove_at_index(self, index):
|
||||||
|
"""Removes item at index"""
|
||||||
|
if index < 0 or index > len(self.items):
|
||||||
|
raise Exception(f"ListGroup index out of range: {index}")
|
||||||
|
item = self.items[index]
|
||||||
|
del self.items[index]
|
||||||
|
self.submobjects = self.items
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
def remove_at_indices(self, indices):
|
||||||
|
"""Removes items at indices"""
|
||||||
|
items = []
|
||||||
|
for index in indices:
|
||||||
|
item = self.remove_at_index(index)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
def remove(self, item):
|
||||||
|
"""Removes first instance of item"""
|
||||||
|
self.items.remove(item)
|
||||||
|
self.submobjects = self.items
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
def get(self, index):
|
||||||
|
"""Gets item at index"""
|
||||||
|
return self.items[index]
|
||||||
|
|
||||||
|
def add(self, item):
|
||||||
|
"""Adds to end"""
|
||||||
|
self.items.append(item)
|
||||||
|
self.submobjects = self.items
|
||||||
|
|
||||||
|
def replace(self, index, item):
|
||||||
|
"""Replaces item at index"""
|
||||||
|
self.items[index] = item
|
||||||
|
self.submobjects = self.items
|
||||||
|
|
||||||
|
def index_of(self, item):
|
||||||
|
"""Returns index of item if it exists"""
|
||||||
|
for index, obj in enumerate(self.items):
|
||||||
|
if item is obj:
|
||||||
|
return index
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Length of items"""
|
||||||
|
return len(self.items)
|
||||||
|
|
||||||
|
def set_z_index(self, z_index_value):
|
||||||
|
"""Sets z index of all values in ListGroup"""
|
||||||
|
for item in self.items:
|
||||||
|
item.set_z_index(z_index_value)
|
@ -5,8 +5,8 @@ from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLaye
|
|||||||
class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||||
"""NeuralNetwork embedding object that can show probability distributions"""
|
"""NeuralNetwork embedding object that can show probability distributions"""
|
||||||
|
|
||||||
def __init__(self, point_radius=0.02):
|
def __init__(self, point_radius=0.02, **kwargs):
|
||||||
super(EmbeddingLayer, self).__init__()
|
super(EmbeddingLayer, self).__init__(**kwargs)
|
||||||
self.point_radius = point_radius
|
self.point_radius = point_radius
|
||||||
self.axes = Axes(
|
self.axes = Axes(
|
||||||
tips=False,
|
tips=False,
|
||||||
|
@ -8,8 +8,10 @@ class EmbeddingToFeedForward(ConnectiveLayer):
|
|||||||
input_class = EmbeddingLayer
|
input_class = EmbeddingLayer
|
||||||
output_class = FeedForwardLayer
|
output_class = FeedForwardLayer
|
||||||
|
|
||||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03):
|
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03,
|
||||||
super().__init__(input_layer, output_layer, input_class=EmbeddingLayer, output_class=FeedForwardLayer)
|
**kwargs):
|
||||||
|
super().__init__(input_layer, output_layer, input_class=EmbeddingLayer, output_class=FeedForwardLayer,
|
||||||
|
**kwargs)
|
||||||
self.feed_forward_layer = output_layer
|
self.feed_forward_layer = output_layer
|
||||||
self.embedding_layer = input_layer
|
self.embedding_layer = input_layer
|
||||||
self.animation_dot_color = animation_dot_color
|
self.animation_dot_color = animation_dot_color
|
||||||
|
@ -7,8 +7,8 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
|
|||||||
def __init__(self, num_nodes, layer_buffer=SMALL_BUFF/2, node_radius=0.08,
|
def __init__(self, num_nodes, layer_buffer=SMALL_BUFF/2, node_radius=0.08,
|
||||||
node_color=BLUE, node_outline_color=WHITE, rectangle_color=WHITE,
|
node_color=BLUE, node_outline_color=WHITE, rectangle_color=WHITE,
|
||||||
node_spacing=0.3, rectangle_fill_color=BLACK, node_stroke_width=2.0,
|
node_spacing=0.3, rectangle_fill_color=BLACK, node_stroke_width=2.0,
|
||||||
rectangle_stroke_width=2.0, animation_dot_color=RED):
|
rectangle_stroke_width=2.0, animation_dot_color=RED, **kwargs):
|
||||||
super(VGroupNeuralNetworkLayer, self).__init__()
|
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
|
||||||
self.num_nodes = num_nodes
|
self.num_nodes = num_nodes
|
||||||
self.layer_buffer = layer_buffer
|
self.layer_buffer = layer_buffer
|
||||||
self.node_radius = node_radius
|
self.node_radius = node_radius
|
||||||
|
@ -8,8 +8,10 @@ class FeedForwardToEmbedding(ConnectiveLayer):
|
|||||||
input_class = FeedForwardLayer
|
input_class = FeedForwardLayer
|
||||||
output_class = EmbeddingLayer
|
output_class = EmbeddingLayer
|
||||||
|
|
||||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03):
|
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.03,
|
||||||
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=EmbeddingLayer)
|
**kwargs):
|
||||||
|
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=EmbeddingLayer,
|
||||||
|
**kwargs)
|
||||||
self.feed_forward_layer = input_layer
|
self.feed_forward_layer = input_layer
|
||||||
self.embedding_layer = output_layer
|
self.embedding_layer = output_layer
|
||||||
self.animation_dot_color = animation_dot_color
|
self.animation_dot_color = animation_dot_color
|
||||||
|
@ -9,8 +9,9 @@ class FeedForwardToFeedForward(ConnectiveLayer):
|
|||||||
|
|
||||||
def __init__(self, input_layer, output_layer, passing_flash=True,
|
def __init__(self, input_layer, output_layer, passing_flash=True,
|
||||||
dot_radius=0.05, animation_dot_color=RED, edge_color=WHITE,
|
dot_radius=0.05, animation_dot_color=RED, edge_color=WHITE,
|
||||||
edge_width=0.5):
|
edge_width=1.5, **kwargs):
|
||||||
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=FeedForwardLayer)
|
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=FeedForwardLayer,
|
||||||
|
**kwargs)
|
||||||
self.passing_flash = passing_flash
|
self.passing_flash = passing_flash
|
||||||
self.edge_color = edge_color
|
self.edge_color = edge_color
|
||||||
self.dot_radius = dot_radius
|
self.dot_radius = dot_radius
|
||||||
|
@ -9,8 +9,9 @@ class FeedForwardToImage(ConnectiveLayer):
|
|||||||
output_class = ImageLayer
|
output_class = ImageLayer
|
||||||
|
|
||||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
|
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
|
||||||
dot_radius=0.05):
|
dot_radius=0.05, **kwargs):
|
||||||
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=ImageLayer)
|
super().__init__(input_layer, output_layer, input_class=FeedForwardLayer, output_class=ImageLayer
|
||||||
|
**kwargs)
|
||||||
self.animation_dot_color = animation_dot_color
|
self.animation_dot_color = animation_dot_color
|
||||||
self.dot_radius = dot_radius
|
self.dot_radius = dot_radius
|
||||||
|
|
||||||
|
@ -5,9 +5,8 @@ from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer
|
|||||||
class ImageLayer(NeuralNetworkLayer):
|
class ImageLayer(NeuralNetworkLayer):
|
||||||
"""Single Image Layer for Neural Network"""
|
"""Single Image Layer for Neural Network"""
|
||||||
|
|
||||||
def __init__(self, numpy_image, height=1.5):
|
def __init__(self, numpy_image, height=1.5, **kwargs):
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self.set_z_index(1)
|
|
||||||
self.numpy_image = numpy_image
|
self.numpy_image = numpy_image
|
||||||
if len(np.shape(self.numpy_image)) == 2:
|
if len(np.shape(self.numpy_image)) == 2:
|
||||||
# Assumed Grayscale
|
# Assumed Grayscale
|
||||||
|
@ -9,8 +9,9 @@ class ImageToFeedForward(ConnectiveLayer):
|
|||||||
output_class = FeedForwardLayer
|
output_class = FeedForwardLayer
|
||||||
|
|
||||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
|
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
|
||||||
dot_radius=0.05):
|
dot_radius=0.05, **kwargs):
|
||||||
super().__init__(input_layer, output_layer, input_class=ImageLayer, output_class=FeedForwardLayer)
|
super().__init__(input_layer, output_layer, input_class=ImageLayer, output_class=FeedForwardLayer
|
||||||
|
**kwargs)
|
||||||
self.animation_dot_color = animation_dot_color
|
self.animation_dot_color = animation_dot_color
|
||||||
self.dot_radius = dot_radius
|
self.dot_radius = dot_radius
|
||||||
|
|
||||||
|
@ -6,8 +6,8 @@ import numpy as np
|
|||||||
class PairedQueryLayer(NeuralNetworkLayer):
|
class PairedQueryLayer(NeuralNetworkLayer):
|
||||||
"""Paired Query Layer"""
|
"""Paired Query Layer"""
|
||||||
|
|
||||||
def __init__(self, positive, negative, stroke_width=5):
|
def __init__(self, positive, negative, stroke_width=5, **kwargs):
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self.positive = positive
|
self.positive = positive
|
||||||
self.negative = negative
|
self.negative = negative
|
||||||
|
|
||||||
|
@ -8,8 +8,9 @@ class PairedQueryToFeedForward(ConnectiveLayer):
|
|||||||
input_class = PairedQueryLayer
|
input_class = PairedQueryLayer
|
||||||
output_class = FeedForwardLayer
|
output_class = FeedForwardLayer
|
||||||
|
|
||||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.02):
|
def __init__(self, input_layer, output_layer, animation_dot_color=RED, dot_radius=0.02, **kwargs):
|
||||||
super().__init__(input_layer, output_layer, input_class=PairedQueryLayer, output_class=FeedForwardLayer)
|
super().__init__(input_layer, output_layer, input_class=PairedQueryLayer, output_class=FeedForwardLayer
|
||||||
|
**kwargs)
|
||||||
self.animation_dot_color = animation_dot_color
|
self.animation_dot_color = animation_dot_color
|
||||||
self.dot_radius = dot_radius
|
self.dot_radius = dot_radius
|
||||||
|
|
||||||
|
@ -4,9 +4,8 @@ from abc import ABC, abstractmethod
|
|||||||
class NeuralNetworkLayer(ABC, Group):
|
class NeuralNetworkLayer(ABC, Group):
|
||||||
"""Abstract Neural Network Layer class"""
|
"""Abstract Neural Network Layer class"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, text=None, **kwargs):
|
||||||
super(Group, self).__init__()
|
super(Group, self).__init__()
|
||||||
self.set_z_index(1)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def make_forward_pass_animation(self):
|
def make_forward_pass_animation(self):
|
||||||
@ -28,8 +27,9 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
|||||||
"""Forward pass animation for a given pair of layers"""
|
"""Forward pass animation for a given pair of layers"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, input_layer, output_layer, input_class=None, output_class=None):
|
def __init__(self, input_layer, output_layer, input_class=None, output_class=None,
|
||||||
super(VGroupNeuralNetworkLayer, self).__init__()
|
**kwargs):
|
||||||
|
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
|
||||||
self.input_layer = input_layer
|
self.input_layer = input_layer
|
||||||
self.output_layer = output_layer
|
self.output_layer = output_layer
|
||||||
self.input_class = input_class
|
self.input_class = input_class
|
||||||
@ -38,8 +38,6 @@ class ConnectiveLayer(VGroupNeuralNetworkLayer):
|
|||||||
assert isinstance(input_layer, self.input_class)
|
assert isinstance(input_layer, self.input_class)
|
||||||
assert isinstance(output_layer, self.output_class)
|
assert isinstance(output_layer, self.output_class)
|
||||||
|
|
||||||
self.set_z_index(-1)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def make_forward_pass_animation(self):
|
def make_forward_pass_animation(self):
|
||||||
pass
|
pass
|
@ -6,8 +6,9 @@ import numpy as np
|
|||||||
class TripletLayer(NeuralNetworkLayer):
|
class TripletLayer(NeuralNetworkLayer):
|
||||||
"""Shows triplet images"""
|
"""Shows triplet images"""
|
||||||
|
|
||||||
def __init__(self, anchor, positive, negative, stroke_width=5):
|
def __init__(self, anchor, positive, negative, stroke_width=5,
|
||||||
super().__init__()
|
**kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
self.anchor = anchor
|
self.anchor = anchor
|
||||||
self.positive = positive
|
self.positive = positive
|
||||||
self.negative = negative
|
self.negative = negative
|
||||||
|
@ -9,8 +9,9 @@ class TripletToFeedForward(ConnectiveLayer):
|
|||||||
output_class = FeedForwardLayer
|
output_class = FeedForwardLayer
|
||||||
|
|
||||||
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
|
def __init__(self, input_layer, output_layer, animation_dot_color=RED,
|
||||||
dot_radius=0.02):
|
dot_radius=0.02, **kwargs):
|
||||||
super().__init__(input_layer, output_layer, input_class=TripletLayer, output_class=FeedForwardLayer)
|
super().__init__(input_layer, output_layer, input_class=TripletLayer, output_class=FeedForwardLayer,
|
||||||
|
**kwargs)
|
||||||
self.animation_dot_color = animation_dot_color
|
self.animation_dot_color = animation_dot_color
|
||||||
self.dot_radius = dot_radius
|
self.dot_radius = dot_radius
|
||||||
|
|
||||||
|
20
manim_ml/neural_network/layers/util.py
Normal file
20
manim_ml/neural_network/layers/util.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from manim import *
|
||||||
|
from ..layers import connective_layers_list
|
||||||
|
|
||||||
|
def get_connective_layer(input_layer, output_layer):
|
||||||
|
"""
|
||||||
|
Deduces the relevant connective layer
|
||||||
|
"""
|
||||||
|
connective_layer = None
|
||||||
|
for connective_layer_class in connective_layers_list:
|
||||||
|
input_class = connective_layer_class.input_class
|
||||||
|
output_class = connective_layer_class.output_class
|
||||||
|
if isinstance(input_layer, input_class) \
|
||||||
|
and isinstance(output_layer, output_class):
|
||||||
|
connective_layer = connective_layer_class(input_layer, output_layer)
|
||||||
|
|
||||||
|
if connective_layer is None:
|
||||||
|
raise Exception(f"Unrecognized class pair {input_layer.__class__.__name__}" + \
|
||||||
|
" and {output_layer.__class__.__name__}")
|
||||||
|
|
||||||
|
return connective_layer
|
@ -9,24 +9,23 @@ Example:
|
|||||||
# Create the object with default style settings
|
# Create the object with default style settings
|
||||||
NeuralNetwork(layer_node_count)
|
NeuralNetwork(layer_node_count)
|
||||||
"""
|
"""
|
||||||
|
from socket import create_connection
|
||||||
|
from urllib.parse import non_hierarchical
|
||||||
from manim import *
|
from manim import *
|
||||||
import warnings
|
import warnings
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from manim_ml.neural_network.layers import \
|
|
||||||
FeedForwardLayer, FeedForwardToFeedForward, ImageLayer, \
|
|
||||||
ImageToFeedForward, FeedForwardToImage, EmbeddingLayer, \
|
|
||||||
EmbeddingToFeedForward, FeedForwardToEmbedding, TripletLayer, \
|
|
||||||
TripletToFeedForward
|
|
||||||
|
|
||||||
from manim_ml.neural_network.layers import connective_layers_list
|
from manim_ml.neural_network.layers import connective_layers_list
|
||||||
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||||
|
from manim_ml.neural_network.layers.util import get_connective_layer
|
||||||
|
from manim_ml.list_group import ListGroup
|
||||||
|
|
||||||
class NeuralNetwork(Group):
|
class NeuralNetwork(Group):
|
||||||
|
|
||||||
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
|
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
|
||||||
animation_dot_color=RED, edge_width=1.5, dot_radius=0.03):
|
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03):
|
||||||
super(Group, self).__init__()
|
super(Group, self).__init__()
|
||||||
self.input_layers = Group(*input_layers)
|
self.input_layers = ListGroup(*input_layers)
|
||||||
self.edge_width = edge_width
|
self.edge_width = edge_width
|
||||||
self.edge_color = edge_color
|
self.edge_color = edge_color
|
||||||
self.layer_spacing = layer_spacing
|
self.layer_spacing = layer_spacing
|
||||||
@ -37,6 +36,9 @@ class NeuralNetwork(Group):
|
|||||||
# and make it have explicit distinct subspaces
|
# and make it have explicit distinct subspaces
|
||||||
self._place_layers()
|
self._place_layers()
|
||||||
self.connective_layers, self.all_layers = self._construct_connective_layers()
|
self.connective_layers, self.all_layers = self._construct_connective_layers()
|
||||||
|
# Place layers at correct z index
|
||||||
|
self.connective_layers.set_z_index(2)
|
||||||
|
self.input_layers.set_z_index(3)
|
||||||
# Center the whole diagram by default
|
# Center the whole diagram by default
|
||||||
self.all_layers.move_to(ORIGIN)
|
self.all_layers.move_to(ORIGIN)
|
||||||
self.add(self.all_layers)
|
self.add(self.all_layers)
|
||||||
@ -50,17 +52,14 @@ class NeuralNetwork(Group):
|
|||||||
for layer_index in range(1, len(self.input_layers)):
|
for layer_index in range(1, len(self.input_layers)):
|
||||||
previous_layer = self.input_layers[layer_index - 1]
|
previous_layer = self.input_layers[layer_index - 1]
|
||||||
current_layer = self.input_layers[layer_index]
|
current_layer = self.input_layers[layer_index]
|
||||||
|
|
||||||
current_layer.move_to(previous_layer)
|
current_layer.move_to(previous_layer)
|
||||||
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + 0.2, 0, 0])
|
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + 0.2, 0, 0])
|
||||||
current_layer.shift(shift_vector)
|
current_layer.shift(shift_vector)
|
||||||
# Handle layering
|
|
||||||
self.input_layers.set_z_index(2)
|
|
||||||
|
|
||||||
def _construct_connective_layers(self):
|
def _construct_connective_layers(self):
|
||||||
"""Draws connecting lines between layers"""
|
"""Draws connecting lines between layers"""
|
||||||
connective_layers = Group()
|
connective_layers = ListGroup()
|
||||||
all_layers = Group()
|
all_layers = ListGroup()
|
||||||
for layer_index in range(len(self.input_layers) - 1):
|
for layer_index in range(len(self.input_layers) - 1):
|
||||||
current_layer = self.input_layers[layer_index]
|
current_layer = self.input_layers[layer_index]
|
||||||
all_layers.add(current_layer)
|
all_layers.add(current_layer)
|
||||||
@ -72,28 +71,196 @@ class NeuralNetwork(Group):
|
|||||||
if isinstance(next_layer, NeuralNetwork):
|
if isinstance(next_layer, NeuralNetwork):
|
||||||
# First layer of the next layer
|
# First layer of the next layer
|
||||||
next_layer = next_layer.all_layers[0]
|
next_layer = next_layer.all_layers[0]
|
||||||
|
|
||||||
# Find connective layer with correct layer pair
|
# Find connective layer with correct layer pair
|
||||||
connective_layer = None
|
connective_layer = get_connective_layer(current_layer, next_layer)
|
||||||
for connective_layer_class in connective_layers_list:
|
connective_layers.add(connective_layer)
|
||||||
input_class = connective_layer_class.input_class
|
all_layers.add(connective_layer)
|
||||||
output_class = connective_layer_class.output_class
|
|
||||||
if isinstance(current_layer, input_class) \
|
|
||||||
and isinstance(next_layer, output_class):
|
|
||||||
connective_layer = connective_layer_class(current_layer, next_layer)
|
|
||||||
|
|
||||||
connective_layers.add(connective_layer)
|
|
||||||
all_layers.add(connective_layer)
|
|
||||||
|
|
||||||
if connective_layer is None:
|
|
||||||
raise Exception(f"Unrecognized class pair {current_layer.__class__.__name__} and {next_layer.__class__.__name__}")
|
|
||||||
# Add final layer
|
# Add final layer
|
||||||
all_layers.add(self.input_layers[-1])
|
all_layers.add(self.input_layers[-1])
|
||||||
# Handle layering
|
# Handle layering
|
||||||
return connective_layers, all_layers
|
return connective_layers, all_layers
|
||||||
|
|
||||||
|
def insert_layer(self, layer, insert_index):
|
||||||
|
"""Inserts a layer at the given index"""
|
||||||
|
layers_before = self.all_layers[:insert_index]
|
||||||
|
layers_after = self.all_layers[insert_index:]
|
||||||
|
# Make connective layers and shift animations
|
||||||
|
# Before layer
|
||||||
|
if len(layers_before) > 0:
|
||||||
|
before_connective = get_connective_layer(layers_before[-1], layer)
|
||||||
|
before_shift = np.array([-layer.width/2, 0, 0])
|
||||||
|
# Shift layers before
|
||||||
|
before_shift_animation = Group(*layers_before).animate.shift(before_shift)
|
||||||
|
else:
|
||||||
|
before_connective = AnimationGroup()
|
||||||
|
# After layer
|
||||||
|
if len(layers_after) > 0:
|
||||||
|
after_connective = get_connective_layer(layer, layers_after[0])
|
||||||
|
after_shift = np.array([layer.width/2, 0, 0])
|
||||||
|
# Shift layers after
|
||||||
|
after_shift_animation = Group(*layers_after).animate.shift(after_shift)
|
||||||
|
else:
|
||||||
|
after_connective = AnimationGroup
|
||||||
|
|
||||||
|
# Make animation group
|
||||||
|
shift_animations = AnimationGroup(
|
||||||
|
before_shift_animation,
|
||||||
|
after_shift_animation
|
||||||
|
)
|
||||||
|
|
||||||
|
insert_animation = Create(layer)
|
||||||
|
animation_group = AnimationGroup(
|
||||||
|
shift_animations,
|
||||||
|
insert_animation,
|
||||||
|
lag_ratio=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
return animation_group
|
||||||
|
|
||||||
|
def remove_layer(self, layer):
|
||||||
|
"""Removes layer object if it exists"""
|
||||||
|
# Get layer index
|
||||||
|
layer_index = self.all_layers.index_of(layer)
|
||||||
|
if layer_index == -1:
|
||||||
|
raise Exception("Layer object not found")
|
||||||
|
# Get the layers before and after
|
||||||
|
before_layer = None
|
||||||
|
after_layer = None
|
||||||
|
if layer_index - 2 >= 0:
|
||||||
|
before_layer = self.all_layers[layer_index - 2]
|
||||||
|
if layer_index + 2 < len(self.all_layers):
|
||||||
|
after_layer = self.all_layers[layer_index + 2]
|
||||||
|
# Remove the layer
|
||||||
|
self.all_layers.remove(layer)
|
||||||
|
# Remove surrounding connective layers from self.all_layers
|
||||||
|
before_connective = None
|
||||||
|
after_connective = None
|
||||||
|
if layer_index - 1 >= 0:
|
||||||
|
# There is a layer before
|
||||||
|
before_connective = self.all_layers.remove_at_index(layer_index - 1)
|
||||||
|
if layer_index + 1 < len(self.all_layers):
|
||||||
|
# There is a layer after
|
||||||
|
after_connective = self.all_layers.remove_at_index(layer_index + 1)
|
||||||
|
# Make animations
|
||||||
|
# Fade out the removed layer
|
||||||
|
fade_out_removed = FadeOut(layer)
|
||||||
|
# Fade out the removed connective layers
|
||||||
|
fade_out_before_connective = Animation()
|
||||||
|
if not before_connective is None:
|
||||||
|
fade_out_before_connective = FadeOut(before_connective)
|
||||||
|
fade_out_after_connective = Animation()
|
||||||
|
if not after_connective is None:
|
||||||
|
fade_out_after_connective = FadeOut(after_connective)
|
||||||
|
# Create new connective layer
|
||||||
|
new_connective = None
|
||||||
|
if not before_layer is None and not after_layer is None:
|
||||||
|
new_connective = get_connective_layer(before_layer, after_layer)
|
||||||
|
before_layer_index = self.all_layers.index_of(before_layer)
|
||||||
|
self.all_layers.insert(before_layer_index, new_connective)
|
||||||
|
# Place the new connective
|
||||||
|
new_connective.move_to(layer)
|
||||||
|
# Animate the creation of the new connective layer
|
||||||
|
create_new_connective = Animation()
|
||||||
|
if not new_connective is None:
|
||||||
|
create_new_connective = Create(new_connective)
|
||||||
|
# Collapse the neural network to fill the empty space
|
||||||
|
removed_width = layer.width + before_connective.width + after_connective.width - new_connective.width
|
||||||
|
shift_right_amount = np.array([removed_width / 2, 0, 0])
|
||||||
|
shift_left_amount = np.array([-removed_width / 2, 0, 0])
|
||||||
|
move_before_layer = Animation()
|
||||||
|
if not before_layer is None:
|
||||||
|
move_before_layer = before_layer.animate.shift(shift_right_amount)
|
||||||
|
move_after_layer = Animation()
|
||||||
|
if not after_layer is None:
|
||||||
|
move_after_layer = after_layer.animate.shift(shift_left_amount)
|
||||||
|
# Make the final AnimationGroup
|
||||||
|
fade_out_group = AnimationGroup(
|
||||||
|
fade_out_removed,
|
||||||
|
fade_out_before_connective,
|
||||||
|
fade_out_after_connective
|
||||||
|
)
|
||||||
|
move_group = AnimationGroup(
|
||||||
|
move_before_layer,
|
||||||
|
move_after_layer
|
||||||
|
)
|
||||||
|
animation_group = AnimationGroup(
|
||||||
|
fade_out_group,
|
||||||
|
move_group,
|
||||||
|
create_new_connective,
|
||||||
|
lag_ratio=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
return animation_group
|
||||||
|
|
||||||
|
"""
|
||||||
|
remove_layer = list(self.all_layers)[remove_index]
|
||||||
|
if remove_index > 0:
|
||||||
|
connective_before = list(self.all_layers)[remove_index - 1]
|
||||||
|
else:
|
||||||
|
connective_before = None
|
||||||
|
if remove_index < len(list(self.all_layers)) - 1:
|
||||||
|
connective_after = list(self.all_layers)[remove_index + 1]
|
||||||
|
else:
|
||||||
|
connective_after = None
|
||||||
|
# Collapse the surrounding layer
|
||||||
|
layers_before = list(self.all_layers)[:remove_index]
|
||||||
|
layers_after = list(self.all_layers)[remove_index+1:]
|
||||||
|
before_group = Group(*layers_before)
|
||||||
|
after_group = Group(*layers_after)
|
||||||
|
before_shift_amount = np.array([remove_layer.width/2, 0, 0])
|
||||||
|
after_shift_amount = np.array([-remove_layer.width/2, 0, 0])
|
||||||
|
# Remove the layers from the neural network representation
|
||||||
|
self.all_layers.remove(remove_layer)
|
||||||
|
if not connective_before is None:
|
||||||
|
self.all_layers.remove(connective_before)
|
||||||
|
if not connective_after is None:
|
||||||
|
self.all_layers.remove(connective_after)
|
||||||
|
# Connect the layers before and layers after
|
||||||
|
pre_index = remove_index - 1
|
||||||
|
pre_layer = None
|
||||||
|
if pre_index >= 0:
|
||||||
|
pre_layer = list(self.all_layers)[pre_index]
|
||||||
|
post_index = remove_index
|
||||||
|
post_layer = None
|
||||||
|
if post_index < len(list(self.all_layers)):
|
||||||
|
post_layer = list(self.all_layers)[post_index]
|
||||||
|
if not pre_layer is None and not post_layer is None:
|
||||||
|
connective_layer = get_connective_layer(pre_layer, post_layer)
|
||||||
|
self.all_layers = Group(
|
||||||
|
*self.all_layers[:remove_index],
|
||||||
|
connective_layer,
|
||||||
|
*self.all_layers[remove_index:]
|
||||||
|
)
|
||||||
|
# Make animations
|
||||||
|
fade_out_animation = FadeOut(remove_layer)
|
||||||
|
shift_animations = AnimationGroup(
|
||||||
|
before_group.animate.shift(before_shift_amount),
|
||||||
|
after_group.animate.shift(after_shift_amount)
|
||||||
|
)
|
||||||
|
animation_group = AnimationGroup(
|
||||||
|
fade_out_animation,
|
||||||
|
shift_animations,
|
||||||
|
lag_ratio=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
return animation_group
|
||||||
|
"""
|
||||||
|
|
||||||
|
def replace_layer(self, old_layer, new_layer):
|
||||||
|
"""Replaces given layer object"""
|
||||||
|
remove_animation = self.remove_layer(insert_index)
|
||||||
|
insert_animation = self.insert_layer(layer, insert_index)
|
||||||
|
# Make the animation
|
||||||
|
animation_group = AnimationGroup(
|
||||||
|
FadeOut(self.all_layers[insert_index]),
|
||||||
|
FadeIn(layer),
|
||||||
|
lag_ratio=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
return animation_group
|
||||||
|
|
||||||
def make_forward_pass_animation(self, run_time=10, passing_flash=True):
|
def make_forward_pass_animation(self, run_time=10, passing_flash=True):
|
||||||
"""Generates an animation for feed forward propogation"""
|
"""Generates an animation for feed forward propagation"""
|
||||||
all_animations = []
|
all_animations = []
|
||||||
for layer_index, layer in enumerate(self.input_layers[:-1]):
|
for layer_index, layer in enumerate(self.input_layers[:-1]):
|
||||||
layer_forward_pass = layer.make_forward_pass_animation()
|
layer_forward_pass = layer.make_forward_pass_animation()
|
||||||
@ -126,19 +293,11 @@ class NeuralNetwork(Group):
|
|||||||
|
|
||||||
return animation_group
|
return animation_group
|
||||||
|
|
||||||
def remove_layer(self, layer_index):
|
|
||||||
"""Removes layer at given index and returns animation for removing the layer"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def add_layer(self, layer):
|
|
||||||
"""Adds layer and returns animation for adding action"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""Print string representation of layers"""
|
"""Print string representation of layers"""
|
||||||
inner_string = ""
|
inner_string = ""
|
||||||
for layer in self.all_layers:
|
for layer in self.all_layers:
|
||||||
inner_string += f"{repr(layer)},\n"
|
inner_string += f"{repr(layer)} {layer.z_index} ,\n"
|
||||||
inner_string = textwrap.indent(inner_string, " ")
|
inner_string = textwrap.indent(inner_string, " ")
|
||||||
|
|
||||||
string_repr = "NeuralNetwork([\n" + inner_string + "])"
|
string_repr = "NeuralNetwork([\n" + inner_string + "])"
|
||||||
|
@ -16,7 +16,6 @@ class GaussianDistribution(VGroup):
|
|||||||
self.cov = np.array([[3, 0], [0, 3]])
|
self.cov = np.array([[3, 0], [0, 3]])
|
||||||
# Make the Gaussian
|
# Make the Gaussian
|
||||||
self.ellipses = self.construct_gaussian_distribution(self.mean, self.cov)
|
self.ellipses = self.construct_gaussian_distribution(self.mean, self.cov)
|
||||||
self.ellipses.set_z_index(2)
|
|
||||||
|
|
||||||
@override_animation(Create)
|
@override_animation(Create)
|
||||||
def _create_gaussian_distribution(self):
|
def _create_gaussian_distribution(self):
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
|
from cv2 import exp
|
||||||
from manim import *
|
from manim import *
|
||||||
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
||||||
|
from manim_ml.neural_network.layers.embedding_to_feed_forward import EmbeddingToFeedForward
|
||||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||||
|
from manim_ml.neural_network.layers.feed_forward_to_embedding import FeedForwardToEmbedding
|
||||||
|
from manim_ml.neural_network.layers.feed_forward_to_feed_forward import FeedForwardToFeedForward
|
||||||
from manim_ml.neural_network.layers.image import ImageLayer
|
from manim_ml.neural_network.layers.image import ImageLayer
|
||||||
from manim_ml.neural_network.neural_network import NeuralNetwork, FeedForwardNeuralNetwork
|
from manim_ml.neural_network.neural_network import NeuralNetwork, FeedForwardNeuralNetwork
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -11,6 +15,72 @@ config.pixel_width = 1280
|
|||||||
config.frame_height = 6.0
|
config.frame_height = 6.0
|
||||||
config.frame_width = 6.0
|
config.frame_width = 6.0
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit Tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
def assert_classes_match(all_layers, expected_classes):
|
||||||
|
assert len(list(all_layers)) == 5
|
||||||
|
|
||||||
|
for index, layer in enumerate(all_layers):
|
||||||
|
expected_class = expected_classes[index]
|
||||||
|
assert isinstance(layer, expected_class), f"Wrong layer class {layer.__class__} expected {expected_class}"
|
||||||
|
|
||||||
|
def test_embedding_layer():
|
||||||
|
embedding_layer = EmbeddingLayer()
|
||||||
|
|
||||||
|
neural_network = NeuralNetwork([
|
||||||
|
FeedForwardLayer(5),
|
||||||
|
FeedForwardLayer(3),
|
||||||
|
embedding_layer
|
||||||
|
])
|
||||||
|
|
||||||
|
expected_classes = [
|
||||||
|
FeedForwardLayer,
|
||||||
|
FeedForwardToFeedForward,
|
||||||
|
FeedForwardLayer,
|
||||||
|
FeedForwardToEmbedding,
|
||||||
|
EmbeddingLayer
|
||||||
|
]
|
||||||
|
|
||||||
|
assert_classes_match(neural_network.all_layers, expected_classes)
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_layer():
|
||||||
|
embedding_layer = EmbeddingLayer()
|
||||||
|
|
||||||
|
neural_network = NeuralNetwork([
|
||||||
|
FeedForwardLayer(5),
|
||||||
|
FeedForwardLayer(3),
|
||||||
|
embedding_layer
|
||||||
|
])
|
||||||
|
|
||||||
|
expected_classes = [
|
||||||
|
FeedForwardLayer,
|
||||||
|
FeedForwardToFeedForward,
|
||||||
|
FeedForwardLayer,
|
||||||
|
FeedForwardToEmbedding,
|
||||||
|
EmbeddingLayer
|
||||||
|
]
|
||||||
|
|
||||||
|
assert_classes_match(neural_network.all_layers, expected_classes)
|
||||||
|
|
||||||
|
print("before removal")
|
||||||
|
print(list(neural_network.all_layers))
|
||||||
|
neural_network.remove_layer(embedding_layer)
|
||||||
|
print("after removal")
|
||||||
|
print(list(neural_network.all_layers))
|
||||||
|
|
||||||
|
expected_classes = [
|
||||||
|
FeedForwardLayer,
|
||||||
|
FeedForwardToFeedForward,
|
||||||
|
FeedForwardLayer,
|
||||||
|
]
|
||||||
|
|
||||||
|
print(list(neural_network.all_layers))
|
||||||
|
|
||||||
|
assert_classes_match(neural_network.all_layers, expected_classes)
|
||||||
|
|
||||||
class FeedForwardNeuralNetworkScene(Scene):
|
class FeedForwardNeuralNetworkScene(Scene):
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
@ -92,6 +162,31 @@ class RecursiveNNScene(Scene):
|
|||||||
|
|
||||||
self.play(Create(nn))
|
self.play(Create(nn))
|
||||||
|
|
||||||
|
class LayerInsertionScene(Scene):
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class LayerRemovalScene(Scene):
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
image = Image.open('images/image.jpeg')
|
||||||
|
numpy_image = np.asarray(image)
|
||||||
|
|
||||||
|
layer = FeedForwardLayer(5),
|
||||||
|
layers = [
|
||||||
|
ImageLayer(numpy_image, height=1.4),
|
||||||
|
FeedForwardLayer(3),
|
||||||
|
layer,
|
||||||
|
FeedForwardLayer(3),
|
||||||
|
FeedForwardLayer(6)
|
||||||
|
]
|
||||||
|
|
||||||
|
nn = NeuralNetwork(layers)
|
||||||
|
|
||||||
|
self.play(Create(nn))
|
||||||
|
self.play(nn.remove_layer(layer))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""Render all scenes"""
|
"""Render all scenes"""
|
||||||
# Feed Forward Neural Network
|
# Feed Forward Neural Network
|
||||||
|
Reference in New Issue
Block a user