mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-22 13:06:46 +08:00
Used Black to reformat the code in the repository.
This commit is contained in:
@ -10,12 +10,19 @@ from PIL import Image
|
||||
from manim_ml.neural_network.layers import FeedForwardLayer, EmbeddingLayer, ImageLayer
|
||||
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||
|
||||
|
||||
class VariationalAutoencoder(VGroup):
|
||||
"""Variational Autoencoder Manim Visualization"""
|
||||
|
||||
def __init__(self, encoder_nodes_per_layer=[5, 3], decoder_nodes_per_layer=[3, 5],
|
||||
point_color=BLUE, dot_radius=0.05, ellipse_stroke_width=1.0,
|
||||
layer_spacing=0.5):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_nodes_per_layer=[5, 3],
|
||||
decoder_nodes_per_layer=[3, 5],
|
||||
point_color=BLUE,
|
||||
dot_radius=0.05,
|
||||
ellipse_stroke_width=1.0,
|
||||
layer_spacing=0.5,
|
||||
):
|
||||
super(VGroup, self).__init__()
|
||||
self.encoder_nodes_per_layer = encoder_nodes_per_layer
|
||||
self.decoder_nodes_per_layer = decoder_nodes_per_layer
|
||||
@ -30,13 +37,15 @@ class VariationalAutoencoder(VGroup):
|
||||
"""Makes the VAE encoder, embedding layer, and decoder"""
|
||||
embedding_layer = EmbeddingLayer()
|
||||
|
||||
neural_network = NeuralNetwork([
|
||||
FeedForwardLayer(5),
|
||||
FeedForwardLayer(3),
|
||||
embedding_layer,
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(5)
|
||||
])
|
||||
neural_network = NeuralNetwork(
|
||||
[
|
||||
FeedForwardLayer(5),
|
||||
FeedForwardLayer(3),
|
||||
embedding_layer,
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(5),
|
||||
]
|
||||
)
|
||||
|
||||
return neural_network, embedding_layer
|
||||
|
||||
@ -46,20 +55,18 @@ class VariationalAutoencoder(VGroup):
|
||||
|
||||
def make_triplet_forward_pass(self, triplet):
|
||||
pass
|
||||
|
||||
|
||||
def make_image_forward_pass(self, input_image, output_image, run_time=1.5):
|
||||
"""Override forward pass animation specific to a VAE"""
|
||||
# Make a wrapper NN with images
|
||||
wrapper_neural_network = NeuralNetwork([
|
||||
ImageLayer(input_image),
|
||||
self.neural_network,
|
||||
ImageLayer(output_image)
|
||||
])
|
||||
# Make animation
|
||||
wrapper_neural_network = NeuralNetwork(
|
||||
[ImageLayer(input_image), self.neural_network, ImageLayer(output_image)]
|
||||
)
|
||||
# Make animation
|
||||
animation_group = AnimationGroup(
|
||||
Create(wrapper_neural_network),
|
||||
wrapper_neural_network.make_forward_pass_animation(),
|
||||
lag_ratio=1.0
|
||||
lag_ratio=1.0,
|
||||
)
|
||||
|
||||
return animation_group
|
||||
return animation_group
|
||||
|
Reference in New Issue
Block a user