diff --git a/Makefile b/Makefile index e70d25d..4715613 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ video: manim -pqh src/vae.py VAEScene --media_dir media - cp media/videos/vae/720p60/VAEScene.mp4 final_videos + cp media/videos/vae/1080p60/VAEScene.mp4 final_videos train: cd src/autoencoder_models python vanilla_autoencoder.py diff --git a/final_videos/VAEScene.mp4 b/final_videos/VAEScene.mp4 index 0916659..6c77a4c 100644 Binary files a/final_videos/VAEScene.mp4 and b/final_videos/VAEScene.mp4 differ diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/autoencoder_models/__init__.py b/src/autoencoder_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/autoencoder_models/generate_images.py b/src/autoencoder_models/generate_images.py new file mode 100644 index 0000000..876eec7 --- /dev/null +++ b/src/autoencoder_models/generate_images.py @@ -0,0 +1,41 @@ +import torch +from variational_autoencoder import VAE +import matplotlib.pyplot as plt +from torchvision import datasets +from torchvision import transforms +from tqdm import tqdm +import numpy as np +import pickle + +# Load model +vae = VAE(latent_dim=16) +vae.load_state_dict(torch.load("saved_models/model.pth")) +# Transforms images to a PyTorch Tensor +tensor_transform = transforms.ToTensor() +# Download the MNIST Dataset +dataset = datasets.MNIST(root = "./data", + train = True, + download = True, + transform = tensor_transform) +# Generate reconstructions +num_recons = 10 +fig, axs = plt.subplots(num_recons, 2, figsize=(2, num_recons)) +image_pairs = [] +for i in range(num_recons): + base_image, _ = dataset[i] + base_image = base_image.reshape(-1, 28*28) + _, _, recon_image, _ = vae.forward(base_image) + base_image = base_image.detach().numpy() + base_image = np.reshape(base_image, (28, 28)) * 255 + recon_image = recon_image.detach().numpy() + recon_image = np.reshape(recon_image, (28, 28)) * 255 + # Add to plot + axs[i][0].imshow(base_image) + axs[i][1].imshow(recon_image) + # image pairs + image_pairs.append((base_image, recon_image)) + +with open("image_pairs.pkl", "wb") as f: + pickle.dump(image_pairs, f) + +plt.show() diff --git a/src/autoencoder_models/generate_interpolation.py b/src/autoencoder_models/generate_interpolation.py new file mode 100644 index 0000000..1752f9d --- /dev/null +++ b/src/autoencoder_models/generate_interpolation.py @@ -0,0 +1,56 @@ +import torch +from variational_autoencoder import VAE +import matplotlib.pyplot as plt +from torchvision import datasets +from torchvision import transforms +from tqdm import tqdm +import numpy as np +import pickle + +# Load model +vae = VAE(latent_dim=16) +vae.load_state_dict(torch.load("saved_models/model.pth")) +# Transforms images to a PyTorch Tensor +tensor_transform = transforms.ToTensor() +# Download the MNIST Dataset +dataset = datasets.MNIST(root = "./data", + train = True, + download = True, + transform = tensor_transform) +# Generate reconstructions +num_images = 50 +image_pairs = [] +save_object = {"interpolation_path":[], "interpolation_images":[]} + +# Make interpolation path +image_a, image_b = dataset[0][0], dataset[1][0] +image_a = image_a.view(28*28) +image_b = image_b.view(28*28) +z_a, _, _, _ = vae.forward(image_a) +z_a = z_a.detach().cpu().numpy() +z_b, _, _, _ = vae.forward(image_b) +z_b = z_b.detach().cpu().numpy() +interpolation_path = np.linspace(z_a, z_b, num=num_images) +# interpolation_path[:, 4] = np.linspace(-3, 3, num=num_images) +save_object["interpolation_path"] = interpolation_path + +for i in range(num_images): + # Generate + z = torch.Tensor(interpolation_path[i]).unsqueeze(0) + gen_image = vae.decode(z) + save_object["interpolation_images"].append(gen_image) + +fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images)) +image_pairs = [] +for i in range(num_images): + im = save_object["interpolation_images"][i] + im = im.detach().numpy() + recon_image = np.reshape(im, (28, 28)) * 255 + # Add to plot + axs[i].imshow(recon_image) + +# Perform intrpolations +with open("interpolations.pkl", "wb") as f: + pickle.dump(save_object, f) + +plt.show() \ No newline at end of file diff --git a/src/autoencoder_models/image_pairs.pkl b/src/autoencoder_models/image_pairs.pkl new file mode 100644 index 0000000..7a9325c Binary files /dev/null and b/src/autoencoder_models/image_pairs.pkl differ diff --git a/src/autoencoder_models/interpolations.pkl b/src/autoencoder_models/interpolations.pkl new file mode 100644 index 0000000..105d856 Binary files /dev/null and b/src/autoencoder_models/interpolations.pkl differ diff --git a/src/autoencoder_models/saved_models/model.pth b/src/autoencoder_models/saved_models/model.pth new file mode 100644 index 0000000..a5d215c Binary files /dev/null and b/src/autoencoder_models/saved_models/model.pth differ diff --git a/src/autoencoder_models/vanilla_autoencoder.py b/src/autoencoder_models/vanilla_autoencoder.py deleted file mode 100644 index 892f823..0000000 --- a/src/autoencoder_models/vanilla_autoencoder.py +++ /dev/null @@ -1,112 +0,0 @@ -import torch -from torchvision import datasets -from torchvision import transforms -import matplotlib.pyplot as plt -from tqdm import tqdm - -# Transforms images to a PyTorch Tensor -tensor_transform = transforms.ToTensor() - -# Download the MNIST Dataset -dataset = datasets.MNIST(root = "./data", - train = True, - download = True, - transform = tensor_transform) - -# DataLoader is used to load the dataset -# for training -loader = torch.utils.data.DataLoader(dataset = dataset, - batch_size = 32, - shuffle = True) - - # Creating a PyTorch class -# 28*28 ==> 9 ==> 28*28 -class VAE(torch.nn.Module): - def __init__(self): - super().__init__() - # Building an linear encoder with Linear - # layer followed by Relu activation function - # 784 ==> 9 - self.encoder = torch.nn.Sequential( - torch.nn.Linear(28 * 28, 128), - torch.nn.ReLU(), - torch.nn.Linear(128, 64), - torch.nn.ReLU(), - torch.nn.Linear(64, 36), - torch.nn.ReLU(), - torch.nn.Linear(36, 18), - torch.nn.ReLU(), - ) - self.mean_embedding = torch.nn.Linear(18, 9) - self.logvar_embedding = torch.nn.Linear(18, 9) - - # Building an linear decoder with Linear - # layer followed by Relu activation function - # The Sigmoid activation function - # outputs the value between 0 and 1 - # 9 ==> 784 - self.decoder = torch.nn.Sequential( - torch.nn.Linear(9, 18), - torch.nn.ReLU(), - torch.nn.Linear(18, 36), - torch.nn.ReLU(), - torch.nn.Linear(36, 64), - torch.nn.ReLU(), - torch.nn.Linear(64, 128), - torch.nn.ReLU(), - torch.nn.Linear(128, 28 * 28), - torch.nn.Sigmoid() - ) - - def forward(self, x): - encoded = self.encoder(x) - mean = self.mean_embedding(encoded) - logvar = self.logvar_embedding(encoded) - combined = torch.cat((mean, logvar), dim=1) - reconstructed = self.decoder(combined) - return mean, logvar, reconstructed, x - -# Model Initialization -model = VAE() -# Validation using MSE Loss function -def loss_function(mean, log_var, reconstructed, original): - kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0) - recon = torch.nn.functional.mse_loss(reconstructed, original) - - return kl + recon - -# Using an Adam Optimizer with lr = 0.1 -optimizer = torch.optim.Adam(model.parameters(), - lr = 1e-1, - weight_decay = 1e-8) - -epochs = 10 -outputs = [] -losses = [] -for epoch in tqdm(range(epochs)): - for (image, _) in loader: - # Reshaping the image to (-1, 784) - image = image.reshape(-1, 28*28) - # Output of Autoencoder - mean, log_var, reconstructed, image = model(image) - # Calculating the loss function - loss = loss_function(mean, log_var, reconstructed, image) - # The gradients are set to zero, - # the the gradient is computed and stored. - # .step() performs parameter update - optimizer.zero_grad() - loss.backward() - optimizer.step() - # Storing the losses in a list for plotting - losses.append(loss.detach().cpu()) - outputs.append((epochs, image, reconstructed)) - -# Defining the Plot Style -plt.style.use('fivethirtyeight') -plt.xlabel('Iterations') -plt.ylabel('Loss') - -# Plotting the last 100 values -print(losses) -plt.plot(losses[-100:]) -plt.show() \ No newline at end of file diff --git a/src/autoencoder_models/variational_autoencoder.py b/src/autoencoder_models/variational_autoencoder.py index 338a64e..cf9d46a 100644 --- a/src/autoencoder_models/variational_autoencoder.py +++ b/src/autoencoder_models/variational_autoencoder.py @@ -18,13 +18,12 @@ dataset = datasets.MNIST(root = "./data", loader = torch.utils.data.DataLoader(dataset = dataset, batch_size = 32, shuffle = True) - - # Creating a PyTorch class +# Creating a PyTorch class # 28*28 ==> 9 ==> 28*28 -class AE(torch.nn.Module): - def __init__(self): +class VAE(torch.nn.Module): + def __init__(self, latent_dim=5): super().__init__() - + self.latent_dim = latent_dim # Building an linear encoder with Linear # layer followed by Relu activation function # 784 ==> 9 @@ -37,16 +36,17 @@ class AE(torch.nn.Module): torch.nn.ReLU(), torch.nn.Linear(36, 18), torch.nn.ReLU(), - torch.nn.Linear(18, 9) ) - + self.mean_embedding = torch.nn.Linear(18, self.latent_dim) + self.logvar_embedding = torch.nn.Linear(18, self.latent_dim) + # Building an linear decoder with Linear # layer followed by Relu activation function # The Sigmoid activation function # outputs the value between 0 and 1 # 9 ==> 784 self.decoder = torch.nn.Sequential( - torch.nn.Linear(9, 18), + torch.nn.Linear(self.latent_dim, 18), torch.nn.ReLU(), torch.nn.Linear(18, 36), torch.nn.ReLU(), @@ -57,46 +57,68 @@ class AE(torch.nn.Module): torch.nn.Linear(128, 28 * 28), torch.nn.Sigmoid() ) + + def decode(self, z): + return self.decoder(z) def forward(self, x): encoded = self.encoder(x) - decoded = self.decoder(encoded) - return decoded + mean = self.mean_embedding(encoded) + logvar = self.logvar_embedding(encoded) + batch_size = x.shape[0] + eps = torch.randn(batch_size, self.latent_dim) + z = mean + torch.exp(logvar / 2) * eps + reconstructed = self.decoder(z) + return mean, logvar, reconstructed, x -# Model Initialization -model = AE() -# Validation using MSE Loss function -loss_function = torch.nn.MSELoss() -# Using an Adam Optimizer with lr = 0.1 -optimizer = torch.optim.Adam(model.parameters(), - lr = 1e-1, - weight_decay = 1e-8) +def train_model(): + # Model Initialization + model = VAE(latent_dim=16) + # Validation using MSE Loss function + def loss_function(mean, log_var, reconstructed, original, kl_beta=0.001): + kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0) + recon = torch.nn.functional.mse_loss(reconstructed, original) + # print(f"KL Error {kl}, Recon Error {recon}") + return kl_beta * kl + recon -epochs = 10 -outputs = [] -losses = [] -for epoch in tqdm(range(epochs)): - for (image, _) in loader: - # Reshaping the image to (-1, 784) - image = image.reshape(-1, 28*28) - # Output of Autoencoder - reconstructed = model(image) - # Calculating the loss function - loss = loss_function(reconstructed, image) - # The gradients are set to zero, - # the the gradient is computed and stored. - # .step() performs parameter update - optimizer.zero_grad() - loss.backward() - optimizer.step() - # Storing the losses in a list for plotting - losses.append(loss.detach().cpu()) - outputs.append((epochs, image, reconstructed)) - -# Defining the Plot Style -plt.style.use('fivethirtyeight') -plt.xlabel('Iterations') -plt.ylabel('Loss') - -# Plotting the last 100 values -plt.plot(losses[-100:]) \ No newline at end of file + # Using an Adam Optimizer with lr = 0.1 + optimizer = torch.optim.Adam(model.parameters(), + lr = 1e-3, + weight_decay = 1e-8) + + epochs = 100 + outputs = [] + losses = [] + for epoch in tqdm(range(epochs)): + for (image, _) in loader: + # Reshaping the image to (-1, 784) + image = image.reshape(-1, 28*28) + # Output of Autoencoder + mean, log_var, reconstructed, image = model(image) + # Calculating the loss function + loss = loss_function(mean, log_var, reconstructed, image) + # The gradients are set to zero, + # the the gradient is computed and stored. + # .step() performs parameter update + optimizer.zero_grad() + loss.backward() + optimizer.step() + # Storing the losses in a list for plotting + if torch.isnan(loss): + raise Exception() + losses.append(loss.detach().cpu()) + outputs.append((epochs, image, reconstructed)) + + torch.save(model.state_dict(), "saved_models/model.pth") + + # Defining the Plot Style + plt.style.use('fivethirtyeight') + plt.xlabel('Iterations') + plt.ylabel('Loss') + + # Plotting the last 100 values + plt.plot(losses) + plt.show() + +if __name__ == "__main__": + train_model() \ No newline at end of file diff --git a/src/neural_network.py b/src/neural_network.py index 4103c42..e90df75 100644 --- a/src/neural_network.py +++ b/src/neural_network.py @@ -15,15 +15,18 @@ class NeuralNetworkLayer(VGroup): """Handles rendering a layer for a neural network""" def __init__( - self, num_nodes, layer_width=0.2, node_radius=0.12, + self, num_nodes, layer_buffer=SMALL_BUFF/2, node_radius=0.08, node_color=BLUE, node_outline_color=WHITE, rectangle_color=WHITE, - node_spacing=0.4, rectangle_fill_color=BLACK): + node_spacing=0.3, rectangle_fill_color=BLACK, node_stroke_width=2.0, + rectangle_stroke_width=2.0): super(VGroup, self).__init__() self.num_nodes = num_nodes - self.layer_width = layer_width + self.layer_buffer = layer_buffer self.node_radius = node_radius self.node_color = node_color + self.node_stroke_width = node_stroke_width self.node_outline_color = node_outline_color + self.rectangle_stroke_width = rectangle_stroke_width self.rectangle_color = rectangle_color self.node_spacing = node_spacing self.rectangle_fill_color = rectangle_fill_color @@ -36,7 +39,7 @@ class NeuralNetworkLayer(VGroup): """Creates the neural network layer""" # Add Nodes for node_number in range(self.num_nodes): - node_object = Circle(radius=self.node_radius, color=self.node_color) + node_object = Circle(radius=self.node_radius, color=self.node_color, stroke_width=self.node_stroke_width) self.node_group.add(node_object) # Space the nodes # Assumes Vertical orientation @@ -45,24 +48,28 @@ class NeuralNetworkLayer(VGroup): node_object.move_to([0, location, 0]) # Create Surrounding Rectangle surrounding_rectangle = SurroundingRectangle( - self.node_group, color=self.rectangle_color, fill_color=self.rectangle_fill_color, fill_opacity=1.0) + self.node_group, color=self.rectangle_color, fill_color=self.rectangle_fill_color, + fill_opacity=1.0, buff=self.layer_buffer, stroke_width=self.rectangle_stroke_width + ) # Add the objects to the class self.add(surrounding_rectangle, self.node_group) class NeuralNetwork(VGroup): def __init__( - self, layer_node_count, layer_width=1.0, node_radius=1.0, - node_color=BLUE, edge_color=WHITE, layer_spacing=1.2, - animation_dot_color=RED): + self, layer_node_count, layer_width=0.6, node_radius=1.0, + node_color=BLUE, edge_color=WHITE, layer_spacing=0.8, + animation_dot_color=RED, edge_width=2.0, dot_radius=0.05): super(VGroup, self).__init__() self.layer_node_count = layer_node_count self.layer_width = layer_width self.node_radius = node_radius + self.edge_width = edge_width self.node_color = node_color self.edge_color = edge_color self.layer_spacing = layer_spacing self.animation_dot_color = animation_dot_color + self.dot_radius = dot_radius # TODO take layer_node_count [0, (1, 2), 0] # and make it have explicit distinct subspaces @@ -97,7 +104,7 @@ class NeuralNetwork(VGroup): # Go through each node in the two layers and make a connecting line for node_i in current_layer.node_group: for node_j in next_layer.node_group: - line = Line(node_i.get_center(), node_j.get_center(), color=self.edge_color) + line = Line(node_i.get_center(), node_j.get_center(), color=self.edge_color, stroke_width=self.edge_width) edge_layer.add(line) edge_layers.add(edge_layer) # Handle layering @@ -112,7 +119,7 @@ class NeuralNetwork(VGroup): for edge_layer in self.edge_layers: path_animations = [] for edge in edge_layer: - dot = Dot(color=self.animation_dot_color, fill_opacity=1.0, radius=0.06) + dot = Dot(color=self.animation_dot_color, fill_opacity=1.0, radius=self.dot_radius) # Handle layering dot.set_z_index(1) # Add to dots group diff --git a/src/vae.py b/src/vae.py index d137e8f..279e643 100644 --- a/src/vae.py +++ b/src/vae.py @@ -4,46 +4,49 @@ In this module I define Manim visualizations for Variational Autoencoders and Traditional Autoencoders. """ +from configparser import Interpolation from typing_extensions import runtime from manim import * +import pickle import numpy as np import neural_network -class Autoencoder(VGroup): - """Traditional Autoencoder Manim Visualization""" - - def __init__(self, encoder_nodes_per_layer=[5, 3], decoder_nodes_per_layer=[3, 5], point_color=BLUE, dot_radius=0.06): - super(VGroup, self).__init__() +class VariationalAutoencoder(Group): + """Variational Autoencoder Manim Visualization""" + + def __init__( + self, encoder_nodes_per_layer=[5, 3], decoder_nodes_per_layer=[3, 5], point_color=BLUE, + dot_radius=0.05, ellipse_stroke_width=2.0 + ): + super(Group, self).__init__() self.encoder_nodes_per_layer = encoder_nodes_per_layer self.decoder_nodes_per_layer = decoder_nodes_per_layer self.point_color = point_color self.dot_radius = dot_radius + self.ellipse_stroke_width = ellipse_stroke_width # Make the VMobjects self.encoder, self.decoder = self._construct_encoder_and_decoder() self.embedding = self._construct_embedding() - # self.input_image, self.output_image = self._construct_input_output_images() # Setup the relative locations self.embedding.move_to(self.encoder) self.embedding.shift([1.1 * self.encoder.width, 0, 0]) self.decoder.move_to(self.embedding) self.decoder.shift([self.decoder.width * 1.1, 0, 0]) - # self.embedding.shift(self.encoder.width * 1.5) - # self.decoder.move_to(self.embedding.get_center()) # Add the objects to the VAE object self.add(self.encoder) self.add(self.decoder) self.add(self.embedding) - # self.add(self.input_image) - # self.add(self.output_image) def _construct_encoder_and_decoder(self): """Makes the VAE encoder and decoder""" # Make the encoder layer_node_count = self.encoder_nodes_per_layer - encoder = neural_network.NeuralNetwork(layer_node_count) + encoder = neural_network.NeuralNetwork(layer_node_count, dot_radius=self.dot_radius) + encoder.scale(1.2) # Make the decoder layer_node_count = self.decoder_nodes_per_layer - decoder = neural_network.NeuralNetwork(layer_node_count) + decoder = neural_network.NeuralNetwork(layer_node_count, dot_radius=self.dot_radius) + decoder.scale(1.2) return encoder, decoder @@ -59,55 +62,40 @@ class Autoencoder(VGroup): embedding.axes = Axes( x_range=[-3, 3], y_range=[-3, 3], - x_length=2.5, - y_length=2.5, + x_length=2.2, + y_length=2.2, tips=False, ) # Add each point to the axes self.point_dots = VGroup() for point in points: point_location = embedding.axes.coords_to_point(*point) - dot = Dot(point_location, color=self.point_color, radius=self.dot_radius / 2) + dot = Dot(point_location, color=self.point_color, radius=self.dot_radius/2) self.point_dots.add(dot) embedding.add(self.point_dots) return embedding - def _construct_input_output_images(self, input_output_image_pairs): + def _construct_input_output_images(self, image_pair): """Places the input and output images for the AE""" - pass + # Takes the image pair + # image_pair is assumed to be [2, x, y] + input_image = image_pair[0][None, :, :] + recon_image = image_pair[1][None, :, :] + # Convert images to rgb + input_image = np.repeat(input_image, 3, axis=0) + input_image = np.rollaxis(input_image, 0, start=3) + recon_image = np.repeat(recon_image, 3, axis=0) + recon_image = np.rollaxis(recon_image, 0, start=3) + # Make an image objects + input_image_object = ImageMobject(input_image, image_mode="RGB") + input_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"]) + input_image_object.height = 2 + recon_image_object = ImageMobject(recon_image, image_mode="RGB") + recon_image_object.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"]) + recon_image_object.height = 2 - def make_forward_pass_animation(self, run_time=2): - """Makes an animation of a forward pass throgh the VAE""" - per_unit_runtime = run_time // 3 - # Make encoder forward pass - encoder_forward_pass = self.encoder.make_forward_propagation_animation(run_time=per_unit_runtime) - # Make red dot in embedding - location = [1.0, 1.5] - location_point = self.embedding.axes.coords_to_point(*location) - # dot = Dot(location_point, color=RED) - # create_dot_animation = Create(dot, run_time=per_unit_runtime) - # Make decoder foward pass - decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime) - # Add the animations to the group - animation_group = Succession( - encoder_forward_pass, - create_dot_animation, - decoder_forward_pass, - lag_ratio=1, - ) - - return animation_group - - def make_interpolation_animation(self): - """Makes animation of interpolating in the latent space""" - pass - -class VariationalAutoencoder(Autoencoder): - """Variational Autoencoder Manim Visualization""" - - def __init__(self): - super().__init__() + return input_image_object, recon_image_object def make_dot_convergence_animation(self, location, run_time=1.5): """Makes dots converge on a specific location""" @@ -141,9 +129,15 @@ class VariationalAutoencoder(Autoencoder): animation_group = AnimationGroup(*animations) return animation_group - def make_forward_pass_animation(self, run_time=1.5): + def make_forward_pass_animation(self, image_pair, run_time=1.5): """Overriden forward pass animation specific to a VAE""" per_unit_runtime = run_time + # Setup images + self.input_image, self.output_image = self._construct_input_output_images(image_pair) + self.input_image.move_to(self.encoder.get_left()) + self.input_image.shift(LEFT) + self.output_image.move_to(self.decoder.get_right()) + self.output_image.shift(RIGHT * 1.2) # Make encoder forward pass encoder_forward_pass = self.encoder.make_forward_propagation_animation(run_time=per_unit_runtime) # Make red dot in embedding @@ -158,7 +152,7 @@ class VariationalAutoencoder(Autoencoder): ) # Make an ellipse centered at mean_point witAnimationGraph std outline center_dot = Dot(mean_point, radius=self.dot_radius, color=GREEN) - ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.5) + ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.5, stroke_width=self.ellipse_stroke_width) ellipse.move_to(mean_point) ellipse_animation = AnimationGroup( GrowFromCenter(center_dot), @@ -170,21 +164,47 @@ class VariationalAutoencoder(Autoencoder): decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime) # Add the animations to the group animation_group = AnimationGroup( + FadeIn(self.input_image), encoding_succesion, ellipse_animation, dot_divergence_animation, decoder_forward_pass, + FadeIn(self.output_image), lag_ratio=1, ) return animation_group + +class MNISTImageHandler(): + """Deals with loading serialized VAE mnist images from "autoencoder_models" """ + + def __init__( + self, + image_pairs_path="src/autoencoder_models/image_pairs.pkl", + interpolations_path="src/autoencoder_models/interpolations.pkl" + ): + self.image_pairs_path = image_pairs_path + self.interpolations_path = interpolations_path + + self.image_pairs = [] + self.interpolations = [] + + self.load_serialized_data() + + def load_serialized_data(self): + with open(self.image_pairs_path, "rb") as f: + self.image_pairs = pickle.load(f) + + with open(self.interpolations_path, "rb") as f: + self.interpolations_path = pickle.load(f) + """ The VAE Scene for the twitter video. """ config.pixel_height = 720 -config.pixel_width = 720 +config.pixel_width = 1280 config.frame_height = 10.0 config.frame_width = 10.0 # Set random seed so point distribution is constant @@ -196,10 +216,12 @@ class VAEScene(Scene): def construct(self): # Set Scene config vae = VariationalAutoencoder() + mnist_image_handler = MNISTImageHandler() + image_pair = mnist_image_handler.image_pairs[2] vae.move_to(ORIGIN) vae.scale(1.2) self.add(vae) - forward_pass_animation = vae.make_forward_pass_animation() + forward_pass_animation = vae.make_forward_pass_animation(image_pair) self.play(forward_pass_animation) """ autoencoder = Autoencoder()