mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-01 23:44:52 +08:00
Added GAN visualization.
This commit is contained in:
@ -9,35 +9,15 @@ Example:
|
||||
# Create the object with default style settings
|
||||
NeuralNetwork(layer_node_count)
|
||||
"""
|
||||
from socket import create_connection
|
||||
from manim import *
|
||||
import warnings
|
||||
import textwrap
|
||||
|
||||
from manim_ml.neural_network.layers import connective_layers_list
|
||||
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.layers.util import get_connective_layer
|
||||
from manim_ml.list_group import ListGroup
|
||||
|
||||
class LazyAnimation(Animation):
|
||||
"""
|
||||
Lazily creates animation when the animation is called.
|
||||
|
||||
This is helpful when creating the animation depends upon the internal
|
||||
state of some set of objects.
|
||||
"""
|
||||
|
||||
def __init__(self, animation_func):
|
||||
self.animation_func = animation_func
|
||||
super().__init__(None)
|
||||
|
||||
def begin(self):
|
||||
"""Begins animation"""
|
||||
self.mobject, animation = self.animation_func()
|
||||
animation = Create(self.mobject)
|
||||
animation.begin()
|
||||
|
||||
class RemoveLayer(Succession):
|
||||
class RemoveLayer(AnimationGroup):
|
||||
"""
|
||||
Animation for removing a layer from a neural network.
|
||||
|
||||
@ -183,6 +163,99 @@ class RemoveLayer(Succession):
|
||||
update_func_anim = UpdateFromFunc(self.neural_network, create_new_connective)
|
||||
|
||||
return update_func_anim
|
||||
|
||||
class InsertLayer(AnimationGroup):
|
||||
"""Animation for inserting layer at given index"""
|
||||
|
||||
def __init__(self, layer, index, neural_network):
|
||||
self.layer = layer
|
||||
self.index = index
|
||||
self.neural_network = neural_network
|
||||
# Layers before and after
|
||||
self.layers_before = self.neural_network.all_layers[:self.index]
|
||||
self.layers_after = self.neural_network.all_layers[self.index:]
|
||||
|
||||
remove_connective_layer = self.remove_connective_layer()
|
||||
move_layers = self.make_move_layers()
|
||||
# create_layer = self.make_create_layer()
|
||||
# create_connective_layers = self.make_create_connective_layers()
|
||||
animations = [
|
||||
remove_connective_layer,
|
||||
move_layers,
|
||||
# create_layer,
|
||||
# create_connective_layers
|
||||
]
|
||||
|
||||
super().__init__(*animations, lag_ratio=1.0)
|
||||
|
||||
def remove_connective_layer(self):
|
||||
"""Removes the connective layer before the insertion index"""
|
||||
# Check if connective layer exists
|
||||
if len(self.layers_before) > 0:
|
||||
removed_connective = self.layers_before[-1]
|
||||
self.neural_network.all_layers.remove(removed_connective)
|
||||
# Make remove animation
|
||||
remove_animation = FadeOut(removed_connective)
|
||||
return remove_animation
|
||||
|
||||
return AnimationGroup()
|
||||
|
||||
def make_move_layers(self):
|
||||
"""Shifts layers before and after"""
|
||||
# Before layer shift
|
||||
before_shift_animation = AnimationGroup()
|
||||
if len(self.layers_before) > 0:
|
||||
before_shift = np.array([-self.layer.width/2, 0, 0])
|
||||
# Shift layers before
|
||||
before_shift_animation = Group(*self.layers_before).animate.shift(before_shift)
|
||||
# After layer shift
|
||||
after_shift_animation = AnimationGroup()
|
||||
if len(self.layers_after) > 0:
|
||||
after_shift = np.array([self.layer.width/2, 0, 0])
|
||||
# Shift layers after
|
||||
after_shift_animation = Group(*self.layers_after).animate.shift(after_shift)
|
||||
# Make animation group
|
||||
shift_animations = AnimationGroup(
|
||||
before_shift_animation,
|
||||
after_shift_animation
|
||||
)
|
||||
|
||||
return shift_animations
|
||||
|
||||
def make_create_layer(self):
|
||||
"""Animates the creation of the layer"""
|
||||
pass
|
||||
|
||||
def make_create_connective_layers(self):
|
||||
pass
|
||||
|
||||
|
||||
# Make connective layers and shift animations
|
||||
# Before layer
|
||||
if len(layers_before) > 0:
|
||||
before_connective = get_connective_layer(layers_before[-1], layer)
|
||||
before_shift = np.array([-layer.width/2, 0, 0])
|
||||
# Shift layers before
|
||||
before_shift_animation = Group(*layers_before).animate.shift(before_shift)
|
||||
else:
|
||||
before_connective = AnimationGroup()
|
||||
# After layer
|
||||
if len(layers_after) > 0:
|
||||
after_connective = get_connective_layer(layer, layers_after[0])
|
||||
after_shift = np.array([layer.width/2, 0, 0])
|
||||
# Shift layers after
|
||||
after_shift_animation = Group(*layers_after).animate.shift(after_shift)
|
||||
else:
|
||||
after_connective = AnimationGroup
|
||||
|
||||
insert_animation = Create(layer)
|
||||
animation_group = AnimationGroup(
|
||||
shift_animations,
|
||||
insert_animation,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
class NeuralNetwork(Group):
|
||||
|
||||
@ -202,9 +275,6 @@ class NeuralNetwork(Group):
|
||||
# and make it have explicit distinct subspaces
|
||||
self._place_layers()
|
||||
self.connective_layers, self.all_layers = self._construct_connective_layers()
|
||||
# Make layer titles
|
||||
self.layer_titles = self._make_layer_titles()
|
||||
self.add(self.layer_titles)
|
||||
# Make overhead title
|
||||
self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE/2)
|
||||
self.title.next_to(self, UP, 1.0)
|
||||
@ -229,15 +299,6 @@ class NeuralNetwork(Group):
|
||||
shift_vector = np.array([(previous_layer.get_width()/2 + current_layer.get_width()/2) + self.layer_spacing, 0, 0])
|
||||
current_layer.shift(shift_vector)
|
||||
|
||||
def _make_layer_titles(self):
|
||||
"""Makes titles"""
|
||||
titles = VGroup()
|
||||
for layer in self.all_layers:
|
||||
title = layer.title
|
||||
title.next_to(layer, UP, 0.2)
|
||||
titles.add(title)
|
||||
return titles
|
||||
|
||||
def _construct_connective_layers(self):
|
||||
"""Draws connecting lines between layers"""
|
||||
connective_layers = ListGroup()
|
||||
@ -264,40 +325,9 @@ class NeuralNetwork(Group):
|
||||
|
||||
def insert_layer(self, layer, insert_index):
|
||||
"""Inserts a layer at the given index"""
|
||||
layers_before = self.all_layers[:insert_index]
|
||||
layers_after = self.all_layers[insert_index:]
|
||||
# Make connective layers and shift animations
|
||||
# Before layer
|
||||
if len(layers_before) > 0:
|
||||
before_connective = get_connective_layer(layers_before[-1], layer)
|
||||
before_shift = np.array([-layer.width/2, 0, 0])
|
||||
# Shift layers before
|
||||
before_shift_animation = Group(*layers_before).animate.shift(before_shift)
|
||||
else:
|
||||
before_connective = AnimationGroup()
|
||||
# After layer
|
||||
if len(layers_after) > 0:
|
||||
after_connective = get_connective_layer(layer, layers_after[0])
|
||||
after_shift = np.array([layer.width/2, 0, 0])
|
||||
# Shift layers after
|
||||
after_shift_animation = Group(*layers_after).animate.shift(after_shift)
|
||||
else:
|
||||
after_connective = AnimationGroup
|
||||
|
||||
# Make animation group
|
||||
shift_animations = AnimationGroup(
|
||||
before_shift_animation,
|
||||
after_shift_animation
|
||||
)
|
||||
|
||||
insert_animation = Create(layer)
|
||||
animation_group = AnimationGroup(
|
||||
shift_animations,
|
||||
insert_animation,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
neural_network = self
|
||||
insert_animation = InsertLayer(layer, insert_index, neural_network)
|
||||
return insert_animation
|
||||
|
||||
def remove_layer(self, layer):
|
||||
"""Removes layer object if it exists"""
|
||||
@ -317,17 +347,18 @@ class NeuralNetwork(Group):
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_forward_pass_animation(self, run_time=10, passing_flash=True):
|
||||
def make_forward_pass_animation(self, run_time=10, passing_flash=True,
|
||||
**kwargs):
|
||||
"""Generates an animation for feed forward propagation"""
|
||||
all_animations = []
|
||||
for layer_index, layer in enumerate(self.input_layers[:-1]):
|
||||
layer_forward_pass = layer.make_forward_pass_animation()
|
||||
layer_forward_pass = layer.make_forward_pass_animation(**kwargs)
|
||||
all_animations.append(layer_forward_pass)
|
||||
connective_layer = self.connective_layers[layer_index]
|
||||
connective_forward_pass = connective_layer.make_forward_pass_animation()
|
||||
connective_forward_pass = connective_layer.make_forward_pass_animation(**kwargs)
|
||||
all_animations.append(connective_forward_pass)
|
||||
# Do last layer animation
|
||||
last_layer_forward_pass = self.input_layers[-1].make_forward_pass_animation()
|
||||
last_layer_forward_pass = self.input_layers[-1].make_forward_pass_animation(**kwargs)
|
||||
all_animations.append(last_layer_forward_pass)
|
||||
# Make the animation group
|
||||
animation_group = AnimationGroup(*all_animations, run_time=run_time, lag_ratio=1.0)
|
||||
|
Reference in New Issue
Block a user