mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-06-08 21:08:05 +08:00
Added GAN visualization.
This commit is contained in:
@ -8,7 +8,12 @@ from manim import *
|
||||
import pickle
|
||||
import numpy as np
|
||||
import os
|
||||
from PIL import Image
|
||||
import manim_ml.neural_network as neural_network
|
||||
from manim_ml.neural_network.embedding import EmbeddingLayer
|
||||
from manim_ml.neural_network.feed_forward import FeedForwardLayer
|
||||
from manim_ml.neural_network.image import ImageLayer
|
||||
from manim_ml.neural_network.neural_network import NeuralNetwork
|
||||
|
||||
class VariationalAutoencoder(VGroup):
|
||||
"""Variational Autoencoder Manim Visualization"""
|
||||
@ -239,6 +244,29 @@ class VariationalAutoencoder(VGroup):
|
||||
|
||||
return animation_group
|
||||
|
||||
class VariationalAutoencoder(VGroup):
|
||||
|
||||
def __init__(self):
|
||||
embedding_layer = EmbeddingLayer()
|
||||
|
||||
image = Image.open('images/image.jpeg')
|
||||
numpy_image = np.asarray(image)
|
||||
# Make nn
|
||||
neural_network = NeuralNetwork([
|
||||
ImageLayer(numpy_image, height=1.4),
|
||||
FeedForwardLayer(5),
|
||||
FeedForwardLayer(3),
|
||||
embedding_layer,
|
||||
FeedForwardLayer(3),
|
||||
FeedForwardLayer(5),
|
||||
ImageLayer(numpy_image, height=1.4),
|
||||
])
|
||||
|
||||
neural_network.scale(1.3)
|
||||
|
||||
self.play(Create(neural_network))
|
||||
self.play(neural_network.make_forward_pass_animation(run_time=15))
|
||||
|
||||
class MNISTImageHandler():
|
||||
"""Deals with loading serialized VAE mnist images from "autoencoder_models" """
|
||||
|
||||
@ -295,19 +323,4 @@ class VAEScene(Scene):
|
||||
# Interpolation animation
|
||||
interpolation_images = mnist_image_handler.interpolation_images
|
||||
interpolation_animation = vae.make_interpolation_animation(interpolation_images)
|
||||
self.play(interpolation_animation)
|
||||
|
||||
class VAEImage(Scene):
|
||||
|
||||
def construct(self):
|
||||
# Set Scene config
|
||||
vae = VariationalAutoencoder()
|
||||
mnist_image_handler = MNISTImageHandler()
|
||||
image_pair = mnist_image_handler.image_pairs[3]
|
||||
vae.move_to(ORIGIN)
|
||||
vae.scale(1.3)
|
||||
self.play(Create(vae), run_time=3)
|
||||
# Make a forward pass animation
|
||||
forward_pass_animation = vae.make_forward_pass_animation(image_pair)
|
||||
self.play(forward_pass_animation)
|
||||
|
||||
self.play(interpolation_animation)
|
Reference in New Issue
Block a user