Files
ManimML/manim_ml/neural_network/neural_network.py

317 lines
13 KiB
Python

"""Neural Network Manim Visualization
This module is responsible for generating a neural network visualization with
manim, specifically a fully connected neural network diagram.
Example:
# Specify how many nodes are in each node layer
layer_node_count = [5, 3, 5]
# Create the object with default style settings
NeuralNetwork(layer_node_count)
"""
from socket import create_connection
from urllib.parse import non_hierarchical
from manim import *
import warnings
import textwrap
from manim_ml.neural_network.layers import connective_layers_list
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.util import get_connective_layer
from manim_ml.list_group import ListGroup
class NeuralNetwork(Group):
def __init__(self, input_layers, edge_color=WHITE, layer_spacing=0.8,
animation_dot_color=RED, edge_width=2.5, dot_radius=0.03):
super(Group, self).__init__()
self.input_layers = ListGroup(*input_layers)
self.edge_width = edge_width
self.edge_color = edge_color
self.layer_spacing = layer_spacing
self.animation_dot_color = animation_dot_color
self.dot_radius = dot_radius
self.created = False
# TODO take layer_node_count [0, (1, 2), 0]
# and make it have explicit distinct subspaces
self._place_layers()
self.connective_layers, self.all_layers = self._construct_connective_layers()
# Place layers at correct z index
self.connective_layers.set_z_index(2)
self.input_layers.set_z_index(3)
# Center the whole diagram by default
self.all_layers.move_to(ORIGIN)
self.add(self.all_layers)
# Print neural network
print(repr(self))
def _place_layers(self):
"""Creates the neural network"""
# TODO implement more sophisticated custom layouts
# Default: Linear layout
for layer_index in range(1, len(self.input_layers)):
previous_layer = self.input_layers[layer_index - 1]
current_layer = self.input_layers[layer_index]
current_layer.move_to(previous_layer)
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + 0.2, 0, 0])
current_layer.shift(shift_vector)
def _construct_connective_layers(self):
"""Draws connecting lines between layers"""
connective_layers = ListGroup()
all_layers = ListGroup()
for layer_index in range(len(self.input_layers) - 1):
current_layer = self.input_layers[layer_index]
all_layers.add(current_layer)
next_layer = self.input_layers[layer_index + 1]
# Check if layer is actually a nested NeuralNetwork
if isinstance(current_layer, NeuralNetwork):
# Last layer of the current layer
current_layer = current_layer.all_layers[-1]
if isinstance(next_layer, NeuralNetwork):
# First layer of the next layer
next_layer = next_layer.all_layers[0]
# Find connective layer with correct layer pair
connective_layer = get_connective_layer(current_layer, next_layer)
connective_layers.add(connective_layer)
all_layers.add(connective_layer)
# Add final layer
all_layers.add(self.input_layers[-1])
# Handle layering
return connective_layers, all_layers
def insert_layer(self, layer, insert_index):
"""Inserts a layer at the given index"""
layers_before = self.all_layers[:insert_index]
layers_after = self.all_layers[insert_index:]
# Make connective layers and shift animations
# Before layer
if len(layers_before) > 0:
before_connective = get_connective_layer(layers_before[-1], layer)
before_shift = np.array([-layer.width/2, 0, 0])
# Shift layers before
before_shift_animation = Group(*layers_before).animate.shift(before_shift)
else:
before_connective = AnimationGroup()
# After layer
if len(layers_after) > 0:
after_connective = get_connective_layer(layer, layers_after[0])
after_shift = np.array([layer.width/2, 0, 0])
# Shift layers after
after_shift_animation = Group(*layers_after).animate.shift(after_shift)
else:
after_connective = AnimationGroup
# Make animation group
shift_animations = AnimationGroup(
before_shift_animation,
after_shift_animation
)
insert_animation = Create(layer)
animation_group = AnimationGroup(
shift_animations,
insert_animation,
lag_ratio=1.0
)
return animation_group
def remove_layer(self, layer):
"""Removes layer object if it exists"""
# Get layer index
layer_index = self.all_layers.index_of(layer)
if layer_index == -1:
raise Exception("Layer object not found")
# Get the layers before and after
before_layer = None
after_layer = None
if layer_index - 2 >= 0:
before_layer = self.all_layers[layer_index - 2]
if layer_index + 2 < len(self.all_layers):
after_layer = self.all_layers[layer_index + 2]
# Remove the layer
self.all_layers.remove(layer)
# Remove surrounding connective layers from self.all_layers
before_connective = None
after_connective = None
if layer_index - 1 >= 0:
# There is a layer before
before_connective = self.all_layers.remove_at_index(layer_index - 1)
if layer_index + 1 < len(self.all_layers):
# There is a layer after
after_connective = self.all_layers.remove_at_index(layer_index + 1)
# Make animations
# Fade out the removed layer
fade_out_removed = FadeOut(layer)
# Fade out the removed connective layers
fade_out_before_connective = Animation()
if not before_connective is None:
fade_out_before_connective = FadeOut(before_connective)
fade_out_after_connective = Animation()
if not after_connective is None:
fade_out_after_connective = FadeOut(after_connective)
# Create new connective layer
new_connective = None
if not before_layer is None and not after_layer is None:
new_connective = get_connective_layer(before_layer, after_layer)
before_layer_index = self.all_layers.index_of(before_layer)
self.all_layers.insert(before_layer_index, new_connective)
# Place the new connective
new_connective.move_to(layer)
# Animate the creation of the new connective layer
create_new_connective = Animation()
if not new_connective is None:
create_new_connective = Create(new_connective)
# Collapse the neural network to fill the empty space
removed_width = layer.width + before_connective.width + after_connective.width - new_connective.width
shift_right_amount = np.array([removed_width / 2, 0, 0])
shift_left_amount = np.array([-removed_width / 2, 0, 0])
move_before_layer = Animation()
if not before_layer is None:
move_before_layer = before_layer.animate.shift(shift_right_amount)
move_after_layer = Animation()
if not after_layer is None:
move_after_layer = after_layer.animate.shift(shift_left_amount)
# Make the final AnimationGroup
fade_out_group = AnimationGroup(
fade_out_removed,
fade_out_before_connective,
fade_out_after_connective
)
move_group = AnimationGroup(
move_before_layer,
move_after_layer
)
animation_group = AnimationGroup(
fade_out_group,
move_group,
create_new_connective,
lag_ratio=1.0
)
return animation_group
"""
remove_layer = list(self.all_layers)[remove_index]
if remove_index > 0:
connective_before = list(self.all_layers)[remove_index - 1]
else:
connective_before = None
if remove_index < len(list(self.all_layers)) - 1:
connective_after = list(self.all_layers)[remove_index + 1]
else:
connective_after = None
# Collapse the surrounding layer
layers_before = list(self.all_layers)[:remove_index]
layers_after = list(self.all_layers)[remove_index+1:]
before_group = Group(*layers_before)
after_group = Group(*layers_after)
before_shift_amount = np.array([remove_layer.width/2, 0, 0])
after_shift_amount = np.array([-remove_layer.width/2, 0, 0])
# Remove the layers from the neural network representation
self.all_layers.remove(remove_layer)
if not connective_before is None:
self.all_layers.remove(connective_before)
if not connective_after is None:
self.all_layers.remove(connective_after)
# Connect the layers before and layers after
pre_index = remove_index - 1
pre_layer = None
if pre_index >= 0:
pre_layer = list(self.all_layers)[pre_index]
post_index = remove_index
post_layer = None
if post_index < len(list(self.all_layers)):
post_layer = list(self.all_layers)[post_index]
if not pre_layer is None and not post_layer is None:
connective_layer = get_connective_layer(pre_layer, post_layer)
self.all_layers = Group(
*self.all_layers[:remove_index],
connective_layer,
*self.all_layers[remove_index:]
)
# Make animations
fade_out_animation = FadeOut(remove_layer)
shift_animations = AnimationGroup(
before_group.animate.shift(before_shift_amount),
after_group.animate.shift(after_shift_amount)
)
animation_group = AnimationGroup(
fade_out_animation,
shift_animations,
lag_ratio=1.0
)
return animation_group
"""
def replace_layer(self, old_layer, new_layer):
"""Replaces given layer object"""
remove_animation = self.remove_layer(insert_index)
insert_animation = self.insert_layer(layer, insert_index)
# Make the animation
animation_group = AnimationGroup(
FadeOut(self.all_layers[insert_index]),
FadeIn(layer),
lag_ratio=1.0
)
return animation_group
def make_forward_pass_animation(self, run_time=10, passing_flash=True):
"""Generates an animation for feed forward propagation"""
all_animations = []
for layer_index, layer in enumerate(self.input_layers[:-1]):
layer_forward_pass = layer.make_forward_pass_animation()
all_animations.append(layer_forward_pass)
connective_layer = self.connective_layers[layer_index]
connective_forward_pass = connective_layer.make_forward_pass_animation()
all_animations.append(connective_forward_pass)
# Do last layer animation
last_layer_forward_pass = self.input_layers[-1].make_forward_pass_animation()
all_animations.append(last_layer_forward_pass)
# Make the animation group
animation_group = AnimationGroup(*all_animations, run_time=run_time, lag_ratio=1.0)
return animation_group
@override_animation(Create)
def _create_override(self, **kwargs):
"""Overrides Create animation"""
# Stop the neural network from being created twice
if self.created:
return AnimationGroup()
self.created = True
# Create each layer one by one
animations = []
for layer in self.all_layers:
animation = Create(layer)
animations.append(animation)
animation_group = AnimationGroup(*animations, lag_ratio=1.0)
return animation_group
def __repr__(self):
"""Print string representation of layers"""
inner_string = ""
for layer in self.all_layers:
inner_string += f"{repr(layer)} {layer.z_index} ,\n"
inner_string = textwrap.indent(inner_string, " ")
string_repr = "NeuralNetwork([\n" + inner_string + "])"
return string_repr
class FeedForwardNeuralNetwork(NeuralNetwork):
"""NeuralNetwork with just feed forward layers"""
def __init__(self, layer_node_count, node_radius=0.08,
node_color=BLUE, **kwargs):
# construct layers
layers = []
for num_nodes in layer_node_count:
layer = FeedForwardLayer(num_nodes, node_color=node_color, node_radius=node_radius)
layers.append(layer)
# call super class
super().__init__(layers, **kwargs)