Used Black to reformat the code in the repository.

This commit is contained in:
Alec Helbling
2023-01-01 23:24:59 -05:00
parent 334662e8c8
commit 3d6e8072e1
71 changed files with 1701 additions and 1135 deletions

View File

@ -10,12 +10,19 @@ from PIL import Image
from manim_ml.neural_network.layers import FeedForwardLayer, EmbeddingLayer, ImageLayer
from manim_ml.neural_network.neural_network import NeuralNetwork
class VariationalAutoencoder(VGroup):
"""Variational Autoencoder Manim Visualization"""
def __init__(self, encoder_nodes_per_layer=[5, 3], decoder_nodes_per_layer=[3, 5],
point_color=BLUE, dot_radius=0.05, ellipse_stroke_width=1.0,
layer_spacing=0.5):
def __init__(
self,
encoder_nodes_per_layer=[5, 3],
decoder_nodes_per_layer=[3, 5],
point_color=BLUE,
dot_radius=0.05,
ellipse_stroke_width=1.0,
layer_spacing=0.5,
):
super(VGroup, self).__init__()
self.encoder_nodes_per_layer = encoder_nodes_per_layer
self.decoder_nodes_per_layer = decoder_nodes_per_layer
@ -30,13 +37,15 @@ class VariationalAutoencoder(VGroup):
"""Makes the VAE encoder, embedding layer, and decoder"""
embedding_layer = EmbeddingLayer()
neural_network = NeuralNetwork([
FeedForwardLayer(5),
FeedForwardLayer(3),
embedding_layer,
FeedForwardLayer(3),
FeedForwardLayer(5)
])
neural_network = NeuralNetwork(
[
FeedForwardLayer(5),
FeedForwardLayer(3),
embedding_layer,
FeedForwardLayer(3),
FeedForwardLayer(5),
]
)
return neural_network, embedding_layer
@ -46,20 +55,18 @@ class VariationalAutoencoder(VGroup):
def make_triplet_forward_pass(self, triplet):
pass
def make_image_forward_pass(self, input_image, output_image, run_time=1.5):
"""Override forward pass animation specific to a VAE"""
# Make a wrapper NN with images
wrapper_neural_network = NeuralNetwork([
ImageLayer(input_image),
self.neural_network,
ImageLayer(output_image)
])
# Make animation
wrapper_neural_network = NeuralNetwork(
[ImageLayer(input_image), self.neural_network, ImageLayer(output_image)]
)
# Make animation
animation_group = AnimationGroup(
Create(wrapper_neural_network),
wrapper_neural_network.make_forward_pass_animation(),
lag_ratio=1.0
lag_ratio=1.0,
)
return animation_group
return animation_group