Changed directory structure to accomodate examples as apposed to everything being a part of the core library. May need to rethink this in the future. Added some boilerplate for pip packaging to the .gitignore.

This commit is contained in:
Alec Helbling
2022-03-28 14:01:00 -04:00
committed by Alec Helbling
parent 4eb5296c9c
commit 3be5c54d26
40 changed files with 30 additions and 15 deletions

Binary file not shown.

View File

@ -0,0 +1,99 @@
import pickle
import sys
import os
sys.path.append(os.environ["PROJECT_ROOT"])
from autoencoder_models.variational_autoencoder import VAE, load_dataset, load_vae_from_path
import matplotlib.pyplot as plt
import numpy as np
import torch
import scipy
import scipy.stats
import cv2
def binned_images(model_path, num_x_bins=6, plot=False):
latent_dim = 2
model = load_vae_from_path(model_path, latent_dim)
image_dataset = load_dataset(digit=2)
# Compute embedding
num_images = 500
embedding = []
images = []
for i in range(num_images):
image, _ = image_dataset[i]
mean, _, recon, _ = model.forward(image)
mean = mean.detach().numpy()
recon = recon.detach().numpy()
recon = recon.reshape(32, 32)
images.append(recon.squeeze())
if latent_dim > 2:
mean = mean[:2]
embedding.append(mean)
images = np.stack(images)
tsne_points = np.array(embedding)
tsne_points = (tsne_points - tsne_points.mean(axis=0))/(tsne_points.std(axis=0))
# make vis
num_points = np.shape(tsne_points)[0]
x_min = np.amin(tsne_points.T[0])
y_min = np.amin(tsne_points.T[1])
y_max = np.amax(tsne_points.T[1])
x_max = np.amax(tsne_points.T[0])
# make the bins from the ranges
# to keep it square the same width is used for x and y dim
x_bins, step = np.linspace(x_min, x_max, num_x_bins, retstep=True)
x_bins = x_bins.astype(float)
num_y_bins = np.absolute(np.ceil((y_max - y_min)/step)).astype(int)
y_bins = np.linspace(y_min, y_max, num_y_bins)
# sort the tsne_points into a 2d histogram
tsne_points = tsne_points.squeeze()
hist_obj = scipy.stats.binned_statistic_dd(tsne_points, np.arange(num_points), statistic='count', bins=[x_bins, y_bins], expand_binnumbers=True)
# sample one point from each bucket
binnumbers = hist_obj.binnumber
num_x_bins = np.amax(binnumbers[0]) + 1
num_y_bins = np.amax(binnumbers[1]) + 1
binnumbers = binnumbers.T
# some places have no value in a region
used_mask = np.zeros((num_y_bins, num_x_bins))
image_bins = np.zeros((num_y_bins, num_x_bins, 3, np.shape(images)[2], np.shape(images)[2]))
for i, bin_num in enumerate(list(binnumbers)):
used_mask[bin_num[1], bin_num[0]] = 1
image_bins[bin_num[1], bin_num[0]] = images[i]
# plot a grid of the images
fig, axs = plt.subplots(nrows=np.shape(y_bins)[0], ncols=np.shape(x_bins)[0], constrained_layout=False, dpi=50)
images = []
bin_indices = []
for y in range(num_y_bins):
for x in range(num_x_bins):
if used_mask[y, x] > 0.0:
image = np.uint8(image_bins[y][x].squeeze()*255)
image = np.rollaxis(image, 0, 3)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
axs[num_y_bins - 1 - y][x].imshow(image)
images.append(image)
bin_indices.append((y, x))
axs[y, x].axis('off')
if plot:
plt.axis('off')
plt.show()
else:
return images, bin_indices
def generate_disentanglement(model_path="saved_models/model_dim2.pth"):
"""Generates disentanglement visualization and serializes it"""
# Disentanglement object
disentanglement_object = {}
# Make Disentanglement
images, bin_indices = binned_images(model_path)
disentanglement_object["images"] = images
disentanglement_object["bin_indices"] = bin_indices
# Serialize Images
with open("disentanglement.pkl", "wb") as f:
pickle.dump(disentanglement_object, f)
if __name__ == "__main__":
plot = False
if plot:
model_path = "saved_models/model_dim2.pth"
#uniform_image_sample(model_path)
binned_images(model_path)
else:
generate_disentanglement()

View File

@ -0,0 +1,41 @@
import torch
from variational_autoencoder import VAE
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import pickle
# Load model
vae = VAE(latent_dim=16)
vae.load_state_dict(torch.load("saved_models/model.pth"))
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()
# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data",
train = True,
download = True,
transform = tensor_transform)
# Generate reconstructions
num_recons = 10
fig, axs = plt.subplots(num_recons, 2, figsize=(2, num_recons))
image_pairs = []
for i in range(num_recons):
base_image, _ = dataset[i]
base_image = base_image.reshape(-1, 28*28)
_, _, recon_image, _ = vae.forward(base_image)
base_image = base_image.detach().numpy()
base_image = np.reshape(base_image, (28, 28)) * 255
recon_image = recon_image.detach().numpy()
recon_image = np.reshape(recon_image, (28, 28)) * 255
# Add to plot
axs[i][0].imshow(base_image)
axs[i][1].imshow(recon_image)
# image pairs
image_pairs.append((base_image, recon_image))
with open("image_pairs.pkl", "wb") as f:
pickle.dump(image_pairs, f)
plt.show()

View File

@ -0,0 +1,49 @@
import torch
from variational_autoencoder import VAE, load_dataset
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import pickle
# Load model
vae = VAE(latent_dim=16)
vae.load_state_dict(torch.load("saved_models/model.pth"))
dataset = load_dataset()
# Generate reconstructions
num_images = 50
image_pairs = []
save_object = {"interpolation_path":[], "interpolation_images":[]}
# Make interpolation path
image_a, image_b = dataset[0][0], dataset[1][0]
image_a = image_a.view(32*32)
image_b = image_b.view(32*32)
z_a, _, _, _ = vae.forward(image_a)
z_a = z_a.detach().cpu().numpy()
z_b, _, _, _ = vae.forward(image_b)
z_b = z_b.detach().cpu().numpy()
interpolation_path = np.linspace(z_a, z_b, num=num_images)
# interpolation_path[:, 4] = np.linspace(-3, 3, num=num_images)
save_object["interpolation_path"] = interpolation_path
for i in range(num_images):
# Generate
z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
gen_image = vae.decode(z).detach().numpy()
gen_image = np.reshape(gen_image, (32, 32)) * 255
save_object["interpolation_images"].append(gen_image)
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))
image_pairs = []
for i in range(num_images):
recon_image = save_object["interpolation_images"][i]
# Add to plot
axs[i].imshow(recon_image)
# Perform intrpolations
with open("interpolations.pkl", "wb") as f:
pickle.dump(save_object, f)
plt.show()

View File

@ -0,0 +1,219 @@
import torch
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
"""
These are utility functions that help to calculate the input and output
sizes of convolutional neural networks
"""
def num2tuple(num):
return num if isinstance(num, tuple) else (num, num)
def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
h_w, kernel_size, stride, pad, dilation = num2tuple(h_w), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation)
pad = num2tuple(pad[0]), num2tuple(pad[1])
h = math.floor((h_w[0] + sum(pad[0]) - dilation[0]*(kernel_size[0]-1) - 1) / stride[0] + 1)
w = math.floor((h_w[1] + sum(pad[1]) - dilation[1]*(kernel_size[1]-1) - 1) / stride[1] + 1)
return h, w
def convtransp2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0):
h_w, kernel_size, stride, pad, dilation, out_pad = num2tuple(h_w), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), num2tuple(out_pad)
pad = num2tuple(pad[0]), num2tuple(pad[1])
h = (h_w[0] - 1)*stride[0] - sum(pad[0]) + dialation[0]*(kernel_size[0]-1) + out_pad[0] + 1
w = (h_w[1] - 1)*stride[1] - sum(pad[1]) + dialation[1]*(kernel_size[1]-1) + out_pad[1] + 1
return h, w
def conv2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1):
h_w_in, h_w_out, kernel_size, stride, dilation = num2tuple(h_w_in), num2tuple(h_w_out), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation)
p_h = ((h_w_out[0] - 1)*stride[0] - h_w_in[0] + dilation[0]*(kernel_size[0]-1) + 1)
p_w = ((h_w_out[1] - 1)*stride[1] - h_w_in[1] + dilation[1]*(kernel_size[1]-1) + 1)
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
def convtransp2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0):
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = num2tuple(h_w_in), num2tuple(h_w_out), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), num2tuple(out_pad)
p_h = -(h_w_out[0] - 1 - out_pad[0] - dilation[0]*(kernel_size[0]-1) - (h_w_in[0] - 1)*stride[0]) / 2
p_w = -(h_w_out[1] - 1 - out_pad[1] - dilation[1]*(kernel_size[1]-1) - (h_w_in[1] - 1)*stride[1]) / 2
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
def load_dataset(train=True, digit=None):
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.Compose([
transforms.Pad(2),
transforms.ToTensor()
])
# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data",
train = train,
download = True,
transform = tensor_transform)
# Load specific image
if not digit is None:
idx = dataset.train_labels == digit
dataset.targets = dataset.targets[idx]
dataset.data = dataset.data[idx]
return dataset
def load_vae_from_path(path, latent_dim):
model = VAE(latent_dim)
model.load_state_dict(torch.load(path))
return model
# Creating a PyTorch class
# 28*28 ==> 9 ==> 28*28
class VAE(torch.nn.Module):
def __init__(self, latent_dim=5, layer_count=4, channels=1):
super().__init__()
self.latent_dim = latent_dim
self.in_shape = 32
self.layer_count = layer_count
self.channels = channels
self.d = 128
mul = 1
inputs = self.channels
out_sizes = [(self.in_shape, self.in_shape)]
for i in range(self.layer_count):
setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, self.d * mul, 4, 2, 1))
setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
h_w = (out_sizes[-1][-1], out_sizes[-1][-1])
out_sizes.append(conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1))
inputs = self.d * mul
mul *= 2
self.d_max = inputs
self.last_size = out_sizes[-1][-1]
self.num_linear = self.last_size ** 2 * self.d_max
# Encoder linear layers
self.encoder_mean_linear = nn.Linear(self.num_linear, self.latent_dim)
self.encoder_logvar_linear = nn.Linear(self.num_linear, self.latent_dim)
# Decoder linear layer
self.decoder_linear = nn.Linear(self.latent_dim, self.num_linear)
mul = inputs // self.d // 2
for i in range(1, self.layer_count):
setattr(self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1))
setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
inputs = self.d * mul
mul //= 2
setattr(self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1))
def encode(self, x):
if len(x.shape) < 3:
x = x.unsqueeze(0)
if len(x.shape) < 4:
x = x.unsqueeze(1)
batch_size = x.shape[0]
for i in range(self.layer_count):
x = F.relu(getattr(self, "conv%d_bn" % (i + 1))(getattr(self, "conv%d" % (i + 1))(x)))
x = x.view(batch_size, -1)
mean = self.encoder_mean_linear(x)
logvar = self.encoder_logvar_linear(x)
return mean, logvar
def decode(self, x):
x = x.view(x.shape[0], self.latent_dim)
x = self.decoder_linear(x)
x = x.view(x.shape[0], self.d_max, self.last_size, self.last_size)
#x = self.deconv1_bn(x)
x = F.leaky_relu(x, 0.2)
for i in range(1, self.layer_count):
x = F.leaky_relu(getattr(self, "deconv%d_bn" % (i + 1))(getattr(self, "deconv%d" % (i + 1))(x)), 0.2)
x = getattr(self, "deconv%d" % (self.layer_count + 1))(x)
x = torch.sigmoid(x)
return x
def forward(self, x):
batch_size = x.shape[0]
mean, logvar = self.encode(x)
eps = torch.randn(batch_size, self.latent_dim)
z = mean + torch.exp(logvar / 2) * eps
reconstructed = self.decode(z)
return mean, logvar, reconstructed, x
def train_model(latent_dim=16, plot=True, digit=1, epochs=200):
dataset = load_dataset(train=True, digit=digit)
# DataLoader is used to load the dataset
# for training
loader = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True)
# Model Initialization
model = VAE(latent_dim=latent_dim)
# Validation using MSE Loss function
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.0001):
kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0)
recon = torch.nn.functional.mse_loss(reconstructed, original)
# print(f"KL Error {kl}, Recon Error {recon}")
return kl_beta * kl + recon
# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(model.parameters(),
lr = 1e-4,
weight_decay = 0e-8)
outputs = []
losses = []
for epoch in tqdm(range(epochs)):
for (image, _) in loader:
# Output of Autoencoder
mean, log_var, reconstructed, image = model(image)
# Calculating the loss function
loss = loss_function(mean, log_var, reconstructed, image)
# The gradients are set to zero,
# the the gradient is computed and stored.
# .step() performs parameter update
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Storing the losses in a list for plotting
if torch.isnan(loss):
raise Exception()
losses.append(loss.detach().cpu())
outputs.append((epochs, image, reconstructed))
torch.save(model.state_dict(),
os.path.join(
os.environ["PROJECT_ROOT"],
f"examples/variational_autoencoder/autoencoder_model/saved_models/model_dim{latent_dim}.pth"
)
)
if plot:
# Defining the Plot Style
plt.style.use('fivethirtyeight')
plt.xlabel('Iterations')
plt.ylabel('Loss')
# Plotting the last 100 values
plt.plot(losses)
plt.show()
if __name__ == "__main__":
train_model(latent_dim=2, digit=2, epochs=40)

View File

@ -0,0 +1,313 @@
"""Autoencoder Manim Visualizations
In this module I define Manim visualizations for Variational Autoencoders
and Traditional Autoencoders.
"""
from manim import *
import pickle
import numpy as np
import os
import manim_ml.neural_network as neural_network
class VariationalAutoencoder(VGroup):
"""Variational Autoencoder Manim Visualization"""
def __init__(
self, encoder_nodes_per_layer=[5, 3], decoder_nodes_per_layer=[3, 5], point_color=BLUE,
dot_radius=0.05, ellipse_stroke_width=2.0, layer_spacing=0.5
):
super(VGroup, self).__init__()
self.encoder_nodes_per_layer = encoder_nodes_per_layer
self.decoder_nodes_per_layer = decoder_nodes_per_layer
self.point_color = point_color
self.dot_radius = dot_radius
self.layer_spacing = layer_spacing
self.ellipse_stroke_width = ellipse_stroke_width
# Make the VMobjects
self.encoder, self.decoder = self._construct_encoder_and_decoder()
self.embedding = self._construct_embedding()
# Setup the relative locations
self.embedding.move_to(self.encoder)
self.embedding.shift([1.4 * self.encoder.width, 0, 0])
self.decoder.move_to(self.embedding)
self.decoder.shift([self.decoder.width * 1.4, 0, 0])
# Add the objects to the VAE object
self.add(self.encoder)
self.add(self.decoder)
self.add(self.embedding)
def _construct_encoder_and_decoder(self):
"""Makes the VAE encoder and decoder"""
# Make the encoder
layer_node_count = self.encoder_nodes_per_layer
encoder = neural_network.NeuralNetwork(layer_node_count, dot_radius=self.dot_radius, layer_spacing=self.layer_spacing)
encoder.scale(1.2)
# Make the decoder
layer_node_count = self.decoder_nodes_per_layer
decoder = neural_network.NeuralNetwork(layer_node_count, dot_radius=self.dot_radius, layer_spacing=self.layer_spacing)
decoder.scale(1.2)
return encoder, decoder
def _construct_embedding(self):
"""Makes a Gaussian-like embedding"""
embedding = VGroup()
# Sample points from a Gaussian
num_points = 200
standard_deviation = [0.9, 0.9]
mean = [0, 0]
points = np.random.normal(mean, standard_deviation, size=(num_points, 2))
# Make an axes
embedding.axes = Axes(
x_range=[-3, 3],
y_range=[-3, 3],
x_length=2.2,
y_length=2.2,
tips=False,
)
# Add each point to the axes
self.point_dots = VGroup()
for point in points:
point_location = embedding.axes.coords_to_point(*point)
dot = Dot(point_location, color=self.point_color, radius=self.dot_radius/2)
self.point_dots.add(dot)
embedding.add(self.point_dots)
return embedding
def _construct_image_mobject(self, input_image, height=2.3):
"""Constructs an ImageMobject from a numpy grayscale image"""
# Convert image to rgb
input_image = np.repeat(input_image, 3, axis=0)
input_image = np.rollaxis(input_image, 0, start=3)
# Make the ImageMobject
image_mobject = ImageMobject(input_image, image_mode="RGB")
image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
image_mobject.height = height
return image_mobject
def _construct_input_output_images(self, image_pair):
"""Places the input and output images for the AE"""
# Takes the image pair
# image_pair is assumed to be [2, x, y]
input_image = image_pair[0][None, :, :]
recon_image = image_pair[1][None, :, :]
# Make the image mobjects
input_image_object = self._construct_image_mobject(input_image)
recon_image_object = self._construct_image_mobject(recon_image)
return input_image_object, recon_image_object
def make_dot_convergence_animation(self, location, run_time=1.5):
"""Makes dots converge on a specific location"""
# Move to location
animations = []
for dot in self.encoder.dots:
coords = self.embedding.axes.coords_to_point(*location)
animations.append(dot.animate.move_to(coords))
move_animations = AnimationGroup(*animations, run_time=1.5)
# Follow up with remove animations
remove_animations = []
for dot in self.encoder.dots:
remove_animations.append(FadeOut(dot))
remove_animations = AnimationGroup(*remove_animations, run_time=0.2)
animation_group = Succession(move_animations, remove_animations, lag_ratio=1.0)
return animation_group
def make_dot_divergence_animation(self, location, run_time=3.0):
"""Makes dots diverge from the given location and move the decoder"""
animations = []
for node in self.decoder.layers[0].node_group:
new_dot = Dot(location, radius=self.dot_radius, color=RED)
per_node_succession = Succession(
Create(new_dot),
new_dot.animate.move_to(node.get_center()),
)
animations.append(per_node_succession)
animation_group = AnimationGroup(*animations)
return animation_group
def make_reset_vae_animation(self):
"""Resets the VAE to just the neural network"""
animation_group = AnimationGroup(
FadeOut(self.input_image),
FadeOut(self.output_image),
FadeOut(self.distribution_objects),
run_time=1.0
)
return animation_group
def make_forward_pass_animation(self, image_pair, run_time=1.5):
"""Overriden forward pass animation specific to a VAE"""
per_unit_runtime = run_time
# Setup images
self.input_image, self.output_image = self._construct_input_output_images(image_pair)
self.input_image.move_to(self.encoder.get_left())
self.input_image.shift(LEFT)
self.output_image.move_to(self.decoder.get_right())
self.output_image.shift(RIGHT*1.3)
# Make encoder forward pass
encoder_forward_pass = self.encoder.make_forward_propagation_animation(run_time=per_unit_runtime)
# Make red dot in embedding
mean = [1.0, 1.5]
mean_point = self.embedding.axes.coords_to_point(*mean)
std = [0.8, 1.2]
# Make the dot convergence animation
dot_convergence_animation = self.make_dot_convergence_animation(mean, run_time=per_unit_runtime)
encoding_succesion = Succession(
encoder_forward_pass,
dot_convergence_animation
)
# Make an ellipse centered at mean_point witAnimationGraph std outline
center_dot = Dot(mean_point, radius=self.dot_radius, color=RED)
ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.3, stroke_width=self.ellipse_stroke_width)
ellipse.move_to(mean_point)
self.distribution_objects = VGroup(
center_dot,
ellipse
)
# Make ellipse animation
ellipse_animation = AnimationGroup(
GrowFromCenter(center_dot),
GrowFromCenter(ellipse),
)
# Make the dot divergence animation
sampled_point = [0.51, 1.0]
divergence_point = self.embedding.axes.coords_to_point(*sampled_point)
dot_divergence_animation = self.make_dot_divergence_animation(divergence_point, run_time=per_unit_runtime)
# Make decoder foward pass
decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime)
# Add the animations to the group
animation_group = AnimationGroup(
FadeIn(self.input_image),
encoding_succesion,
ellipse_animation,
dot_divergence_animation,
decoder_forward_pass,
FadeIn(self.output_image),
lag_ratio=1,
)
return animation_group
def make_interpolation_animation(self, interpolation_images, frame_rate=5):
"""Makes an animation interpolation"""
num_images = len(interpolation_images)
# Make madeup path
interpolation_latent_path = np.linspace([-0.7, -1.2], [1.2, 1.5], num=num_images)
# Make the path animation
first_dot_location = self.embedding.axes.coords_to_point(*interpolation_latent_path[0])
last_dot_location = self.embedding.axes.coords_to_point(*interpolation_latent_path[-1])
moving_dot = Dot(first_dot_location, radius=self.dot_radius, color=RED)
self.add(moving_dot)
animation_list = [Create(Line(first_dot_location, last_dot_location, color=RED), run_time=0.1*num_images)]
for image_index in range(num_images - 1):
next_index = image_index + 1
# Get path
next_point = interpolation_latent_path[next_index]
next_position = self.embedding.axes.coords_to_point(*next_point)
# Draw path from current point to next point
move_animation = moving_dot.animate(run_time=0.1*num_images).move_to(next_position)
animation_list.append(move_animation)
interpolation_animation = AnimationGroup(*animation_list)
# Make the images animation
animation_list = [Wait(0.5)]
for numpy_image in interpolation_images:
numpy_image = numpy_image[None, :, :]
manim_image = self._construct_image_mobject(numpy_image)
# Move the image to the correct location
manim_image.move_to(self.output_image)
# Add the image
animation_list.append(FadeIn(manim_image, run_time=0.1))
# Wait
# animation_list.append(Wait(1 / frame_rate))
# Remove the image
# animation_list.append(FadeOut(manim_image, run_time=0.1))
images_animation = AnimationGroup(*animation_list, lag_ratio=1.0)
# Combine the two into an AnimationGroup
animation_group = AnimationGroup(
interpolation_animation,
images_animation
)
return animation_group
class MNISTImageHandler():
"""Deals with loading serialized VAE mnist images from "autoencoder_models" """
def __init__(
self,
image_pairs_file_path=os.path.join(os.environ["PROJECT_ROOT"], "examples/variational_autoencoder/autoencoder_models/image_pairs.pkl"),
interpolations_file_path=os.path.join(os.environ["PROJECT_ROOT"], "examples/variational_autoencoder/autoencoder_models/interpolations.pkl")
):
self.image_pairs_file_path = image_pairs_file_path
self.interpolations_file_path = interpolations_file_path
self.image_pairs = []
self.interpolation_images = []
self.interpolation_latent_path = []
self.load_serialized_data()
def load_serialized_data(self):
with open(self.image_pairs_file_path, "rb") as f:
self.image_pairs = pickle.load(f)
with open(self.interpolations_file_path, "rb") as f:
self.interpolation_dict = pickle.load(f)
self.interpolation_images = self.interpolation_dict["interpolation_images"]
self.interpolation_latent_path = self.interpolation_dict["interpolation_path"]
"""
The VAE Scene for the twitter video.
"""
config.pixel_height = 720
config.pixel_width = 1280
config.frame_height = 5.0
config.frame_width = 5.0
# Set random seed so point distribution is constant
np.random.seed(1)
class VAEScene(Scene):
"""Scene object for a Variational Autoencoder and Autoencoder"""
def construct(self):
# Set Scene config
vae = VariationalAutoencoder()
mnist_image_handler = MNISTImageHandler()
image_pair = mnist_image_handler.image_pairs[3]
vae.move_to(ORIGIN)
vae.scale(1.3)
self.play(Create(vae), run_time=3)
# Make a forward pass animation
forward_pass_animation = vae.make_forward_pass_animation(image_pair)
self.play(forward_pass_animation)
# Remove the input and output images
reset_animation = vae.make_reset_vae_animation()
self.play(reset_animation)
# Interpolation animation
interpolation_images = mnist_image_handler.interpolation_images
interpolation_animation = vae.make_interpolation_animation(interpolation_images)
self.play(interpolation_animation)
class VAEImage(Scene):
def construct(self):
# Set Scene config
vae = VariationalAutoencoder()
mnist_image_handler = MNISTImageHandler()
image_pair = mnist_image_handler.image_pairs[3]
vae.move_to(ORIGIN)
vae.scale(1.3)
self.play(Create(vae), run_time=3)
# Make a forward pass animation
forward_pass_animation = vae.make_forward_pass_animation(image_pair)
self.play(forward_pass_animation)