Added GAN visualization.

This commit is contained in:
Alec Helbling
2022-04-22 19:08:28 -04:00
parent ffd31701bf
commit 0152be64b0
24 changed files with 530 additions and 154 deletions

View File

@ -8,7 +8,12 @@ from manim import *
import pickle
import numpy as np
import os
from PIL import Image
import manim_ml.neural_network as neural_network
from manim_ml.neural_network.embedding import EmbeddingLayer
from manim_ml.neural_network.feed_forward import FeedForwardLayer
from manim_ml.neural_network.image import ImageLayer
from manim_ml.neural_network.neural_network import NeuralNetwork
class VariationalAutoencoder(VGroup):
"""Variational Autoencoder Manim Visualization"""
@ -239,6 +244,29 @@ class VariationalAutoencoder(VGroup):
return animation_group
class VariationalAutoencoder(VGroup):
def __init__(self):
embedding_layer = EmbeddingLayer()
image = Image.open('images/image.jpeg')
numpy_image = np.asarray(image)
# Make nn
neural_network = NeuralNetwork([
ImageLayer(numpy_image, height=1.4),
FeedForwardLayer(5),
FeedForwardLayer(3),
embedding_layer,
FeedForwardLayer(3),
FeedForwardLayer(5),
ImageLayer(numpy_image, height=1.4),
])
neural_network.scale(1.3)
self.play(Create(neural_network))
self.play(neural_network.make_forward_pass_animation(run_time=15))
class MNISTImageHandler():
"""Deals with loading serialized VAE mnist images from "autoencoder_models" """
@ -295,19 +323,4 @@ class VAEScene(Scene):
# Interpolation animation
interpolation_images = mnist_image_handler.interpolation_images
interpolation_animation = vae.make_interpolation_animation(interpolation_images)
self.play(interpolation_animation)
class VAEImage(Scene):
def construct(self):
# Set Scene config
vae = VariationalAutoencoder()
mnist_image_handler = MNISTImageHandler()
image_pair = mnist_image_handler.image_pairs[3]
vae.move_to(ORIGIN)
vae.scale(1.3)
self.play(Create(vae), run_time=3)
# Make a forward pass animation
forward_pass_animation = vae.make_forward_pass_animation(image_pair)
self.play(forward_pass_animation)
self.play(interpolation_animation)