Changed examples

This commit is contained in:
Alec Helbling
2022-02-10 00:45:40 -05:00
committed by Alec Helbling
parent fe7089abbf
commit 92d8a3d59a
6 changed files with 103 additions and 38 deletions

View File

@ -1,6 +1,6 @@
video: video:
manim -pqh src/vae.py VAEScene --media_dir media manim -pqh src/vae.py VAEScene --media_dir media
cp media/videos/vae/1080p60/VAEScene.mp4 final_videos cp media/videos/vae/720p60/VAEScene.mp4 examples
train: train:
cd src/autoencoder_models cd src/autoencoder_models
python vanilla_autoencoder.py python vanilla_autoencoder.py

Binary file not shown.

BIN
examples/VAEScene.mp4 Normal file

Binary file not shown.

View File

@ -37,15 +37,14 @@ save_object["interpolation_path"] = interpolation_path
for i in range(num_images): for i in range(num_images):
# Generate # Generate
z = torch.Tensor(interpolation_path[i]).unsqueeze(0) z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
gen_image = vae.decode(z) gen_image = vae.decode(z).detach().numpy()
gen_image = np.reshape(gen_image, (28, 28)) * 255
save_object["interpolation_images"].append(gen_image) save_object["interpolation_images"].append(gen_image)
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images)) fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))
image_pairs = [] image_pairs = []
for i in range(num_images): for i in range(num_images):
im = save_object["interpolation_images"][i] recon_image = save_object["interpolation_images"][i]
im = im.detach().numpy()
recon_image = np.reshape(im, (28, 28)) * 255
# Add to plot # Add to plot
axs[i].imshow(recon_image) axs[i].imshow(recon_image)

View File

@ -5,11 +5,13 @@ and Traditional Autoencoders.
""" """
from configparser import Interpolation from configparser import Interpolation
from random import sample
from typing_extensions import runtime from typing_extensions import runtime
from manim import * from manim import *
import pickle import pickle
import numpy as np import numpy as np
import neural_network import neural_network
from scipy.interpolate import make_interp_spline
class VariationalAutoencoder(Group): class VariationalAutoencoder(Group):
"""Variational Autoencoder Manim Visualization""" """Variational Autoencoder Manim Visualization"""
@ -76,24 +78,27 @@ class VariationalAutoencoder(Group):
embedding.add(self.point_dots) embedding.add(self.point_dots)
return embedding return embedding
def _construct_image_mobject(self, input_image, height=2):
"""Constructs an ImageMobject from a numpy grayscale image"""
# Convert image to rgb
input_image = np.repeat(input_image, 3, axis=0)
input_image = np.rollaxis(input_image, 0, start=3)
# Make the ImageMobject
image_mobject = ImageMobject(input_image, image_mode="RGB")
image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
image_mobject.height = height
return image_mobject
def _construct_input_output_images(self, image_pair): def _construct_input_output_images(self, image_pair):
"""Places the input and output images for the AE""" """Places the input and output images for the AE"""
# Takes the image pair # Takes the image pair
# image_pair is assumed to be [2, x, y] # image_pair is assumed to be [2, x, y]
input_image = image_pair[0][None, :, :] input_image = image_pair[0][None, :, :]
recon_image = image_pair[1][None, :, :] recon_image = image_pair[1][None, :, :]
# Convert images to rgb # Make the image mobjects
input_image = np.repeat(input_image, 3, axis=0) input_image_object = self._construct_image_mobject(input_image)
input_image = np.rollaxis(input_image, 0, start=3) recon_image_object = self._construct_image_mobject(recon_image)
recon_image = np.repeat(recon_image, 3, axis=0)
recon_image = np.rollaxis(recon_image, 0, start=3)
# Make an image objects
input_image_object = ImageMobject(input_image, image_mode="RGB")
input_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
input_image_object.height = 2
recon_image_object = ImageMobject(recon_image, image_mode="RGB")
recon_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
recon_image_object.height = 2
return input_image_object, recon_image_object return input_image_object, recon_image_object
@ -129,6 +134,17 @@ class VariationalAutoencoder(Group):
animation_group = AnimationGroup(*animations) animation_group = AnimationGroup(*animations)
return animation_group return animation_group
def make_reset_vae_animation(self):
"""Resets the VAE to just the neural network"""
animation_group = AnimationGroup(
FadeOut(self.input_image),
FadeOut(self.output_image),
FadeOut(self.distribution_objects),
run_time=1.0
)
return animation_group
def make_forward_pass_animation(self, image_pair, run_time=1.5): def make_forward_pass_animation(self, image_pair, run_time=1.5):
"""Overriden forward pass animation specific to a VAE""" """Overriden forward pass animation specific to a VAE"""
per_unit_runtime = run_time per_unit_runtime = run_time
@ -151,15 +167,22 @@ class VariationalAutoencoder(Group):
dot_convergence_animation dot_convergence_animation
) )
# Make an ellipse centered at mean_point witAnimationGraph std outline # Make an ellipse centered at mean_point witAnimationGraph std outline
center_dot = Dot(mean_point, radius=self.dot_radius, color=GREEN) center_dot = Dot(mean_point, radius=self.dot_radius, color=RED)
ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.5, stroke_width=self.ellipse_stroke_width) ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.3, stroke_width=self.ellipse_stroke_width)
ellipse.move_to(mean_point) ellipse.move_to(mean_point)
self.distribution_objects = VGroup(
center_dot,
ellipse
)
# Make ellipse animation
ellipse_animation = AnimationGroup( ellipse_animation = AnimationGroup(
GrowFromCenter(center_dot), GrowFromCenter(center_dot),
GrowFromCenter(ellipse), GrowFromCenter(ellipse),
) )
# Make the dot divergence animation # Make the dot divergence animation
dot_divergence_animation = self.make_dot_divergence_animation(mean_point, run_time=per_unit_runtime) sampled_point = [0.51, 1.0]
divergence_point = self.embedding.axes.coords_to_point(*sampled_point)
dot_divergence_animation = self.make_dot_divergence_animation(divergence_point, run_time=per_unit_runtime)
# Make decoder foward pass # Make decoder foward pass
decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime) decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime)
# Add the animations to the group # Add the animations to the group
@ -175,34 +198,77 @@ class VariationalAutoencoder(Group):
return animation_group return animation_group
def make_interpolation_animation(self, interpolation_images, frame_rate=5):
"""Makes an animation interpolation"""
num_images = len(interpolation_images)
# Make madeup path
interpolation_latent_path = np.linspace([-0.5, -1], [1, 1.3], num=num_images)
# Make the path animation
first_dot_location = self.embedding.axes.coords_to_point(*interpolation_latent_path[0])
moving_dot = Dot(first_dot_location, radius=self.dot_radius, color=RED)
animation_list = [GrowFromCenter(moving_dot)]
for image_index in range(num_images - 1):
next_index = image_index + 1
# Get path
next_point = interpolation_latent_path[next_index]
next_position = self.embedding.axes.coords_to_point(*next_point)
# Draw path from current point to next point
move_animation = moving_dot.animate.move_to(next_position)
animation_list.append(move_animation)
interpolation_animation = Succession(*animation_list, run_time=0.1*num_images)
# Make the images animation
animation_list = []
for numpy_image in interpolation_images:
numpy_image = interpolation_images[0]
numpy_image = numpy_image[None, :, :]
manim_image = self._construct_image_mobject(numpy_image)
# Move the image to the correct location
manim_image.move_to(self.output_image)
# Add the image
animation_list.append(FadeIn(manim_image, run_time=0.1))
# Wait
animation_list.append(Wait(1 / frame_rate))
# Remove the image
animation_list.append(FadeOut(manim_image, run_time=0.1))
images_animation = Succession(*animation_list)
# Combine the two into an AnimationGroup
animation_group = AnimationGroup(
interpolation_animation,
images_animation
)
return animation_group
class MNISTImageHandler(): class MNISTImageHandler():
"""Deals with loading serialized VAE mnist images from "autoencoder_models" """ """Deals with loading serialized VAE mnist images from "autoencoder_models" """
def __init__( def __init__(
self, self,
image_pairs_path="src/autoencoder_models/image_pairs.pkl", image_pairs_file_path="src/autoencoder_models/image_pairs.pkl",
interpolations_path="src/autoencoder_models/interpolations.pkl" interpolations_file_path="src/autoencoder_models/interpolations.pkl"
): ):
self.image_pairs_path = image_pairs_path self.image_pairs_file_path = image_pairs_file_path
self.interpolations_path = interpolations_path self.interpolations_file_path = interpolations_file_path
self.image_pairs = [] self.image_pairs = []
self.interpolations = [] self.interpolation_images = []
self.interpolation_latent_path = []
self.load_serialized_data() self.load_serialized_data()
def load_serialized_data(self): def load_serialized_data(self):
with open(self.image_pairs_path, "rb") as f: with open(self.image_pairs_file_path, "rb") as f:
self.image_pairs = pickle.load(f) self.image_pairs = pickle.load(f)
with open(self.interpolations_path, "rb") as f: with open(self.interpolations_file_path, "rb") as f:
self.interpolations_path = pickle.load(f) self.interpolation_dict = pickle.load(f)
self.interpolation_images = self.interpolation_dict["interpolation_images"]
self.interpolation_latent_path = self.interpolation_dict["interpolation_path"]
""" """
The VAE Scene for the twitter video. The VAE Scene for the twitter video.
""" """
config.pixel_height = 720 config.pixel_height = 720
config.pixel_width = 1280 config.pixel_width = 1280
config.frame_height = 10.0 config.frame_height = 10.0
@ -217,17 +283,17 @@ class VAEScene(Scene):
# Set Scene config # Set Scene config
vae = VariationalAutoencoder() vae = VariationalAutoencoder()
mnist_image_handler = MNISTImageHandler() mnist_image_handler = MNISTImageHandler()
image_pair = mnist_image_handler.image_pairs[2] image_pair = mnist_image_handler.image_pairs[3]
vae.move_to(ORIGIN) vae.move_to(ORIGIN)
vae.scale(1.2) vae.scale(1.2)
self.add(vae) self.add(vae)
# Make a forward pass animation
forward_pass_animation = vae.make_forward_pass_animation(image_pair) forward_pass_animation = vae.make_forward_pass_animation(image_pair)
self.play(forward_pass_animation) self.play(forward_pass_animation)
""" # Remove the input and output images
autoencoder = Autoencoder() reset_animation = vae.make_reset_vae_animation()
autoencoder.move_to(ORIGIN) self.play(reset_animation)
# Make a forward pass animation # Interpolation animation
self.add(autoencoder) interpolation_images = mnist_image_handler.interpolation_images
forward_pass_animation = autoencoder.make_forward_pass_animation(run_time=1.5) interpolation_animation = vae.make_interpolation_animation(interpolation_images)
self.play(forward_pass_animation) self.play(interpolation_animation)
"""