mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-06 17:29:45 +08:00
Embedding Neural Network Layer.
This commit is contained in:
@ -10,23 +10,27 @@ Example:
|
||||
NeuralNetwork(layer_node_count)
|
||||
"""
|
||||
from manim import *
|
||||
from matplotlib import animation
|
||||
from numpy import isin
|
||||
import warnings
|
||||
from manim_ml.neural_network.layers import FeedForwardLayer, ImageLayer
|
||||
from manim_ml.neural_network.connective_layers import FeedForwardToFeedForward, ImageToFeedForward
|
||||
import textwrap
|
||||
|
||||
class NeuralNetwork(VGroup):
|
||||
from numpy import string_
|
||||
|
||||
from manim_ml.neural_network.embedding import EmbeddingLayer, EmbeddingToFeedForward, FeedForwardToEmbedding
|
||||
from manim_ml.neural_network.feed_forward import FeedForwardLayer, FeedForwardToFeedForward
|
||||
from manim_ml.neural_network.image import ImageLayer, ImageToFeedForward, FeedForwardToImage
|
||||
|
||||
class NeuralNetwork(Group):
|
||||
|
||||
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
|
||||
animation_dot_color=RED, edge_width=1.5, dot_radius=0.03):
|
||||
super().__init__()
|
||||
self.input_layers = VGroup(*input_layers)
|
||||
super(Group, self).__init__()
|
||||
self.input_layers = Group(*input_layers)
|
||||
self.edge_width = edge_width
|
||||
self.edge_color = edge_color
|
||||
self.layer_spacing = layer_spacing
|
||||
self.animation_dot_color = animation_dot_color
|
||||
self.dot_radius = dot_radius
|
||||
self.created = False
|
||||
# TODO take layer_node_count [0, (1, 2), 0]
|
||||
# and make it have explicit distinct subspaces
|
||||
self._place_layers()
|
||||
@ -34,62 +38,87 @@ class NeuralNetwork(VGroup):
|
||||
# Center the whole diagram by default
|
||||
self.all_layers.move_to(ORIGIN)
|
||||
self.add(self.all_layers)
|
||||
# print nn
|
||||
print(repr(self))
|
||||
|
||||
def _place_layers(self):
|
||||
"""Creates the neural network"""
|
||||
# TODO implement more sophisticated custom layouts
|
||||
# Default: Linear layout
|
||||
for layer_index in range(1, len(self.input_layers)):
|
||||
previous_layer = self.input_layers[layer_index - 1]
|
||||
current_layer = self.input_layers[layer_index]
|
||||
# Manage spacing
|
||||
# Default: half each width times 2
|
||||
spacing = config.frame_width * 0.05 + (previous_layer.width / 2 + current_layer.width / 2)
|
||||
current_layer.move_to(previous_layer.get_center())
|
||||
current_layer.shift(np.array([spacing, 0, 0]))
|
||||
# Add layer to VGroup
|
||||
|
||||
current_layer.move_to(previous_layer)
|
||||
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + 0.2, 0, 0])
|
||||
current_layer.shift(shift_vector)
|
||||
# Handle layering
|
||||
self.input_layers.set_z_index(2)
|
||||
|
||||
def _construct_connective_layers(self):
|
||||
"""Draws connecting lines between layers"""
|
||||
connective_layers = VGroup()
|
||||
all_layers = VGroup()
|
||||
connective_layers = Group()
|
||||
all_layers = Group()
|
||||
for layer_index in range(len(self.input_layers) - 1):
|
||||
current_layer = self.input_layers[layer_index]
|
||||
all_layers.add(current_layer)
|
||||
next_layer = self.input_layers[layer_index + 1]
|
||||
|
||||
# Check if layer is actually a nested NeuralNetwork
|
||||
if isinstance(current_layer, NeuralNetwork):
|
||||
# Last layer of the current layer
|
||||
current_layer = current_layer.all_layers[-1]
|
||||
if isinstance(next_layer, NeuralNetwork):
|
||||
# First layer of the next layer
|
||||
next_layer = next_layer.all_layers[0]
|
||||
if isinstance(current_layer, FeedForwardLayer) \
|
||||
and isinstance(next_layer, FeedForwardLayer):
|
||||
# FeedForward to Image
|
||||
edge_layer = FeedForwardToFeedForward(current_layer, next_layer,
|
||||
edge_width=self.edge_width)
|
||||
connective_layers.add(edge_layer)
|
||||
all_layers.add(edge_layer)
|
||||
elif isinstance(current_layer, ImageLayer) \
|
||||
and isinstance(next_layer, FeedForwardLayer):
|
||||
# Image to FeedForward
|
||||
image_to_feedforward = ImageToFeedForward(current_layer, next_layer, dot_radius=self.dot_radius)
|
||||
connective_layers.add(image_to_feedforward)
|
||||
all_layers.add(image_to_feedforward)
|
||||
elif isinstance(current_layer, FeedForwardLayer) \
|
||||
and isinstance(next_layer, ImageLayer):
|
||||
# Image to FeedForward
|
||||
feed_forward_to_image = FeedForwardToImage(current_layer, next_layer, dot_radius=self.dot_radius)
|
||||
connective_layers.add(feed_forward_to_image)
|
||||
all_layers.add(feed_forward_to_image)
|
||||
elif isinstance(current_layer, FeedForwardLayer) \
|
||||
and isinstance(next_layer, EmbeddingLayer):
|
||||
# FeedForward to Embedding
|
||||
layer = FeedForwardToEmbedding(current_layer, next_layer,
|
||||
animation_dot_color=self.animation_dot_color, dot_radius=self.dot_radius)
|
||||
connective_layers.add(layer)
|
||||
all_layers.add(layer)
|
||||
elif isinstance(current_layer, EmbeddingLayer) \
|
||||
and isinstance(next_layer, FeedForwardLayer):
|
||||
# Embedding to FeedForward
|
||||
layer = EmbeddingToFeedForward(current_layer, next_layer,
|
||||
animation_dot_color=self.animation_dot_color, dot_radius=self.dot_radius)
|
||||
connective_layers.add(layer)
|
||||
all_layers.add(layer)
|
||||
else:
|
||||
warnings.warn(f"Warning: unimplemented connection for layer types: {type(current_layer)} and {type(next_layer)}")
|
||||
# Add final layer
|
||||
all_layers.add(self.input_layers[-1])
|
||||
# Handle layering
|
||||
connective_layers.set_z_index(0)
|
||||
return connective_layers, all_layers
|
||||
|
||||
def make_forward_pass_animation(self, run_time=10, passing_flash=True):
|
||||
"""Generates an animation for feed forward propogation"""
|
||||
all_animations = []
|
||||
|
||||
for layer_index, layer in enumerate(self.input_layers[:-1]):
|
||||
layer_forward_pass = layer.make_forward_pass_animation()
|
||||
all_animations.append(layer_forward_pass)
|
||||
|
||||
connective_layer = self.connective_layers[layer_index]
|
||||
connective_forward_pass = connective_layer.make_forward_pass_animation()
|
||||
all_animations.append(connective_forward_pass)
|
||||
|
||||
# Do last layer animation
|
||||
last_layer_forward_pass = self.input_layers[-1].make_forward_pass_animation()
|
||||
all_animations.append(last_layer_forward_pass)
|
||||
@ -101,17 +130,39 @@ class NeuralNetwork(VGroup):
|
||||
@override_animation(Create)
|
||||
def _create_override(self, **kwargs):
|
||||
"""Overrides Create animation"""
|
||||
# Stop the neural network from being created twice
|
||||
if self.created:
|
||||
return AnimationGroup()
|
||||
self.created = True
|
||||
# Create each layer one by one
|
||||
animations = []
|
||||
|
||||
for layer in self.all_layers:
|
||||
print(layer)
|
||||
animation = Create(layer)
|
||||
animations.append(animation)
|
||||
|
||||
animation_group = AnimationGroup(*animations, lag_ratio=1.0)
|
||||
|
||||
|
||||
return animation_group
|
||||
|
||||
def remove_layer(self, layer_index):
|
||||
"""Removes layer at given index and returns animation for removing the layer"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_layer(self, layer):
|
||||
"""Adds layer and returns animation for adding action"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __repr__(self):
|
||||
"""Print string representation of layers"""
|
||||
inner_string = ""
|
||||
for layer in self.all_layers:
|
||||
inner_string += f"{repr(layer)},\n"
|
||||
inner_string = textwrap.indent(inner_string, " ")
|
||||
|
||||
string_repr = "NeuralNetwork([\n" + inner_string + "])"
|
||||
return string_repr
|
||||
|
||||
class FeedForwardNeuralNetwork(NeuralNetwork):
|
||||
"""NeuralNetwork with just feed forward layers"""
|
||||
|
||||
|
Reference in New Issue
Block a user