diff --git a/Makefile b/Makefile index 4715613..9220a8e 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ video: manim -pqh src/vae.py VAEScene --media_dir media - cp media/videos/vae/1080p60/VAEScene.mp4 final_videos + cp media/videos/vae/720p60/VAEScene.mp4 examples train: cd src/autoencoder_models python vanilla_autoencoder.py diff --git a/examples/TestNeuralNetworkScene.mp4 b/examples/TestNeuralNetworkScene.mp4 new file mode 100644 index 0000000..31bc870 Binary files /dev/null and b/examples/TestNeuralNetworkScene.mp4 differ diff --git a/examples/VAEScene.mp4 b/examples/VAEScene.mp4 new file mode 100644 index 0000000..b4e0d74 Binary files /dev/null and b/examples/VAEScene.mp4 differ diff --git a/src/autoencoder_models/generate_interpolation.py b/src/autoencoder_models/generate_interpolation.py index 1752f9d..7fa4de1 100644 --- a/src/autoencoder_models/generate_interpolation.py +++ b/src/autoencoder_models/generate_interpolation.py @@ -37,15 +37,14 @@ save_object["interpolation_path"] = interpolation_path for i in range(num_images): # Generate z = torch.Tensor(interpolation_path[i]).unsqueeze(0) - gen_image = vae.decode(z) + gen_image = vae.decode(z).detach().numpy() + gen_image = np.reshape(gen_image, (28, 28)) * 255 save_object["interpolation_images"].append(gen_image) fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images)) image_pairs = [] for i in range(num_images): - im = save_object["interpolation_images"][i] - im = im.detach().numpy() - recon_image = np.reshape(im, (28, 28)) * 255 + recon_image = save_object["interpolation_images"][i] # Add to plot axs[i].imshow(recon_image) diff --git a/src/autoencoder_models/interpolations.pkl b/src/autoencoder_models/interpolations.pkl index 105d856..a8a5d2b 100644 Binary files a/src/autoencoder_models/interpolations.pkl and b/src/autoencoder_models/interpolations.pkl differ diff --git a/src/vae.py b/src/vae.py index 279e643..e669463 100644 --- a/src/vae.py +++ b/src/vae.py @@ -5,11 +5,13 @@ and Traditional Autoencoders. """ from configparser import Interpolation +from random import sample from typing_extensions import runtime from manim import * import pickle import numpy as np import neural_network +from scipy.interpolate import make_interp_spline class VariationalAutoencoder(Group): """Variational Autoencoder Manim Visualization""" @@ -76,24 +78,27 @@ class VariationalAutoencoder(Group): embedding.add(self.point_dots) return embedding + def _construct_image_mobject(self, input_image, height=2): + """Constructs an ImageMobject from a numpy grayscale image""" + # Convert image to rgb + input_image = np.repeat(input_image, 3, axis=0) + input_image = np.rollaxis(input_image, 0, start=3) + # Make the ImageMobject + image_mobject = ImageMobject(input_image, image_mode="RGB") + image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"]) + image_mobject.height = height + + return image_mobject + def _construct_input_output_images(self, image_pair): """Places the input and output images for the AE""" # Takes the image pair # image_pair is assumed to be [2, x, y] input_image = image_pair[0][None, :, :] recon_image = image_pair[1][None, :, :] - # Convert images to rgb - input_image = np.repeat(input_image, 3, axis=0) - input_image = np.rollaxis(input_image, 0, start=3) - recon_image = np.repeat(recon_image, 3, axis=0) - recon_image = np.rollaxis(recon_image, 0, start=3) - # Make an image objects - input_image_object = ImageMobject(input_image, image_mode="RGB") - input_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"]) - input_image_object.height = 2 - recon_image_object = ImageMobject(recon_image, image_mode="RGB") - recon_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"]) - recon_image_object.height = 2 + # Make the image mobjects + input_image_object = self._construct_image_mobject(input_image) + recon_image_object = self._construct_image_mobject(recon_image) return input_image_object, recon_image_object @@ -129,6 +134,17 @@ class VariationalAutoencoder(Group): animation_group = AnimationGroup(*animations) return animation_group + def make_reset_vae_animation(self): + """Resets the VAE to just the neural network""" + animation_group = AnimationGroup( + FadeOut(self.input_image), + FadeOut(self.output_image), + FadeOut(self.distribution_objects), + run_time=1.0 + ) + + return animation_group + def make_forward_pass_animation(self, image_pair, run_time=1.5): """Overriden forward pass animation specific to a VAE""" per_unit_runtime = run_time @@ -151,15 +167,22 @@ class VariationalAutoencoder(Group): dot_convergence_animation ) # Make an ellipse centered at mean_point witAnimationGraph std outline - center_dot = Dot(mean_point, radius=self.dot_radius, color=GREEN) - ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.5, stroke_width=self.ellipse_stroke_width) + center_dot = Dot(mean_point, radius=self.dot_radius, color=RED) + ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.3, stroke_width=self.ellipse_stroke_width) ellipse.move_to(mean_point) + self.distribution_objects = VGroup( + center_dot, + ellipse + ) + # Make ellipse animation ellipse_animation = AnimationGroup( GrowFromCenter(center_dot), GrowFromCenter(ellipse), ) # Make the dot divergence animation - dot_divergence_animation = self.make_dot_divergence_animation(mean_point, run_time=per_unit_runtime) + sampled_point = [0.51, 1.0] + divergence_point = self.embedding.axes.coords_to_point(*sampled_point) + dot_divergence_animation = self.make_dot_divergence_animation(divergence_point, run_time=per_unit_runtime) # Make decoder foward pass decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime) # Add the animations to the group @@ -175,34 +198,77 @@ class VariationalAutoencoder(Group): return animation_group + def make_interpolation_animation(self, interpolation_images, frame_rate=5): + """Makes an animation interpolation""" + num_images = len(interpolation_images) + # Make madeup path + interpolation_latent_path = np.linspace([-0.5, -1], [1, 1.3], num=num_images) + # Make the path animation + first_dot_location = self.embedding.axes.coords_to_point(*interpolation_latent_path[0]) + moving_dot = Dot(first_dot_location, radius=self.dot_radius, color=RED) + animation_list = [GrowFromCenter(moving_dot)] + for image_index in range(num_images - 1): + next_index = image_index + 1 + # Get path + next_point = interpolation_latent_path[next_index] + next_position = self.embedding.axes.coords_to_point(*next_point) + # Draw path from current point to next point + move_animation = moving_dot.animate.move_to(next_position) + animation_list.append(move_animation) + + interpolation_animation = Succession(*animation_list, run_time=0.1*num_images) + # Make the images animation + animation_list = [] + for numpy_image in interpolation_images: + numpy_image = interpolation_images[0] + numpy_image = numpy_image[None, :, :] + manim_image = self._construct_image_mobject(numpy_image) + # Move the image to the correct location + manim_image.move_to(self.output_image) + # Add the image + animation_list.append(FadeIn(manim_image, run_time=0.1)) + # Wait + animation_list.append(Wait(1 / frame_rate)) + # Remove the image + animation_list.append(FadeOut(manim_image, run_time=0.1)) + images_animation = Succession(*animation_list) + # Combine the two into an AnimationGroup + animation_group = AnimationGroup( + interpolation_animation, + images_animation + ) + + return animation_group class MNISTImageHandler(): """Deals with loading serialized VAE mnist images from "autoencoder_models" """ def __init__( self, - image_pairs_path="src/autoencoder_models/image_pairs.pkl", - interpolations_path="src/autoencoder_models/interpolations.pkl" + image_pairs_file_path="src/autoencoder_models/image_pairs.pkl", + interpolations_file_path="src/autoencoder_models/interpolations.pkl" ): - self.image_pairs_path = image_pairs_path - self.interpolations_path = interpolations_path + self.image_pairs_file_path = image_pairs_file_path + self.interpolations_file_path = interpolations_file_path self.image_pairs = [] - self.interpolations = [] + self.interpolation_images = [] + self.interpolation_latent_path = [] self.load_serialized_data() def load_serialized_data(self): - with open(self.image_pairs_path, "rb") as f: + with open(self.image_pairs_file_path, "rb") as f: self.image_pairs = pickle.load(f) - with open(self.interpolations_path, "rb") as f: - self.interpolations_path = pickle.load(f) + with open(self.interpolations_file_path, "rb") as f: + self.interpolation_dict = pickle.load(f) + self.interpolation_images = self.interpolation_dict["interpolation_images"] + self.interpolation_latent_path = self.interpolation_dict["interpolation_path"] """ The VAE Scene for the twitter video. """ - config.pixel_height = 720 config.pixel_width = 1280 config.frame_height = 10.0 @@ -217,17 +283,17 @@ class VAEScene(Scene): # Set Scene config vae = VariationalAutoencoder() mnist_image_handler = MNISTImageHandler() - image_pair = mnist_image_handler.image_pairs[2] + image_pair = mnist_image_handler.image_pairs[3] vae.move_to(ORIGIN) vae.scale(1.2) self.add(vae) + # Make a forward pass animation forward_pass_animation = vae.make_forward_pass_animation(image_pair) self.play(forward_pass_animation) - """ - autoencoder = Autoencoder() - autoencoder.move_to(ORIGIN) - # Make a forward pass animation - self.add(autoencoder) - forward_pass_animation = autoencoder.make_forward_pass_animation(run_time=1.5) - self.play(forward_pass_animation) - """ \ No newline at end of file + # Remove the input and output images + reset_animation = vae.make_reset_vae_animation() + self.play(reset_animation) + # Interpolation animation + interpolation_images = mnist_image_handler.interpolation_images + interpolation_animation = vae.make_interpolation_animation(interpolation_images) + self.play(interpolation_animation)