mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-09-18 12:54:25 +08:00
Changed directory structure to accomodate examples as apposed to everything being a part of the core library. May need to rethink this in the future. Added some boilerplate for pip packaging to the .gitignore.
This commit is contained in:

committed by
Alec Helbling

parent
4eb5296c9c
commit
3be5c54d26
BIN
examples/variational_autoencoder/autoencoder_models/.DS_Store
vendored
Normal file
BIN
examples/variational_autoencoder/autoencoder_models/.DS_Store
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,99 @@
|
||||
import pickle
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.environ["PROJECT_ROOT"])
|
||||
from autoencoder_models.variational_autoencoder import VAE, load_dataset, load_vae_from_path
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import scipy
|
||||
import scipy.stats
|
||||
import cv2
|
||||
|
||||
def binned_images(model_path, num_x_bins=6, plot=False):
|
||||
latent_dim = 2
|
||||
model = load_vae_from_path(model_path, latent_dim)
|
||||
image_dataset = load_dataset(digit=2)
|
||||
# Compute embedding
|
||||
num_images = 500
|
||||
embedding = []
|
||||
images = []
|
||||
for i in range(num_images):
|
||||
image, _ = image_dataset[i]
|
||||
mean, _, recon, _ = model.forward(image)
|
||||
mean = mean.detach().numpy()
|
||||
recon = recon.detach().numpy()
|
||||
recon = recon.reshape(32, 32)
|
||||
images.append(recon.squeeze())
|
||||
if latent_dim > 2:
|
||||
mean = mean[:2]
|
||||
embedding.append(mean)
|
||||
images = np.stack(images)
|
||||
tsne_points = np.array(embedding)
|
||||
tsne_points = (tsne_points - tsne_points.mean(axis=0))/(tsne_points.std(axis=0))
|
||||
# make vis
|
||||
num_points = np.shape(tsne_points)[0]
|
||||
x_min = np.amin(tsne_points.T[0])
|
||||
y_min = np.amin(tsne_points.T[1])
|
||||
y_max = np.amax(tsne_points.T[1])
|
||||
x_max = np.amax(tsne_points.T[0])
|
||||
# make the bins from the ranges
|
||||
# to keep it square the same width is used for x and y dim
|
||||
x_bins, step = np.linspace(x_min, x_max, num_x_bins, retstep=True)
|
||||
x_bins = x_bins.astype(float)
|
||||
num_y_bins = np.absolute(np.ceil((y_max - y_min)/step)).astype(int)
|
||||
y_bins = np.linspace(y_min, y_max, num_y_bins)
|
||||
# sort the tsne_points into a 2d histogram
|
||||
tsne_points = tsne_points.squeeze()
|
||||
hist_obj = scipy.stats.binned_statistic_dd(tsne_points, np.arange(num_points), statistic='count', bins=[x_bins, y_bins], expand_binnumbers=True)
|
||||
# sample one point from each bucket
|
||||
binnumbers = hist_obj.binnumber
|
||||
num_x_bins = np.amax(binnumbers[0]) + 1
|
||||
num_y_bins = np.amax(binnumbers[1]) + 1
|
||||
binnumbers = binnumbers.T
|
||||
# some places have no value in a region
|
||||
used_mask = np.zeros((num_y_bins, num_x_bins))
|
||||
image_bins = np.zeros((num_y_bins, num_x_bins, 3, np.shape(images)[2], np.shape(images)[2]))
|
||||
for i, bin_num in enumerate(list(binnumbers)):
|
||||
used_mask[bin_num[1], bin_num[0]] = 1
|
||||
image_bins[bin_num[1], bin_num[0]] = images[i]
|
||||
# plot a grid of the images
|
||||
fig, axs = plt.subplots(nrows=np.shape(y_bins)[0], ncols=np.shape(x_bins)[0], constrained_layout=False, dpi=50)
|
||||
images = []
|
||||
bin_indices = []
|
||||
for y in range(num_y_bins):
|
||||
for x in range(num_x_bins):
|
||||
if used_mask[y, x] > 0.0:
|
||||
image = np.uint8(image_bins[y][x].squeeze()*255)
|
||||
image = np.rollaxis(image, 0, 3)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
axs[num_y_bins - 1 - y][x].imshow(image)
|
||||
images.append(image)
|
||||
bin_indices.append((y, x))
|
||||
axs[y, x].axis('off')
|
||||
if plot:
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
else:
|
||||
return images, bin_indices
|
||||
|
||||
def generate_disentanglement(model_path="saved_models/model_dim2.pth"):
|
||||
"""Generates disentanglement visualization and serializes it"""
|
||||
# Disentanglement object
|
||||
disentanglement_object = {}
|
||||
# Make Disentanglement
|
||||
images, bin_indices = binned_images(model_path)
|
||||
disentanglement_object["images"] = images
|
||||
disentanglement_object["bin_indices"] = bin_indices
|
||||
# Serialize Images
|
||||
with open("disentanglement.pkl", "wb") as f:
|
||||
pickle.dump(disentanglement_object, f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
plot = False
|
||||
if plot:
|
||||
model_path = "saved_models/model_dim2.pth"
|
||||
#uniform_image_sample(model_path)
|
||||
binned_images(model_path)
|
||||
else:
|
||||
generate_disentanglement()
|
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from variational_autoencoder import VAE
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
# Load model
|
||||
vae = VAE(latent_dim=16)
|
||||
vae.load_state_dict(torch.load("saved_models/model.pth"))
|
||||
# Transforms images to a PyTorch Tensor
|
||||
tensor_transform = transforms.ToTensor()
|
||||
# Download the MNIST Dataset
|
||||
dataset = datasets.MNIST(root = "./data",
|
||||
train = True,
|
||||
download = True,
|
||||
transform = tensor_transform)
|
||||
# Generate reconstructions
|
||||
num_recons = 10
|
||||
fig, axs = plt.subplots(num_recons, 2, figsize=(2, num_recons))
|
||||
image_pairs = []
|
||||
for i in range(num_recons):
|
||||
base_image, _ = dataset[i]
|
||||
base_image = base_image.reshape(-1, 28*28)
|
||||
_, _, recon_image, _ = vae.forward(base_image)
|
||||
base_image = base_image.detach().numpy()
|
||||
base_image = np.reshape(base_image, (28, 28)) * 255
|
||||
recon_image = recon_image.detach().numpy()
|
||||
recon_image = np.reshape(recon_image, (28, 28)) * 255
|
||||
# Add to plot
|
||||
axs[i][0].imshow(base_image)
|
||||
axs[i][1].imshow(recon_image)
|
||||
# image pairs
|
||||
image_pairs.append((base_image, recon_image))
|
||||
|
||||
with open("image_pairs.pkl", "wb") as f:
|
||||
pickle.dump(image_pairs, f)
|
||||
|
||||
plt.show()
|
@ -0,0 +1,49 @@
|
||||
import torch
|
||||
from variational_autoencoder import VAE, load_dataset
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
# Load model
|
||||
vae = VAE(latent_dim=16)
|
||||
vae.load_state_dict(torch.load("saved_models/model.pth"))
|
||||
dataset = load_dataset()
|
||||
# Generate reconstructions
|
||||
num_images = 50
|
||||
image_pairs = []
|
||||
save_object = {"interpolation_path":[], "interpolation_images":[]}
|
||||
|
||||
# Make interpolation path
|
||||
image_a, image_b = dataset[0][0], dataset[1][0]
|
||||
image_a = image_a.view(32*32)
|
||||
image_b = image_b.view(32*32)
|
||||
z_a, _, _, _ = vae.forward(image_a)
|
||||
z_a = z_a.detach().cpu().numpy()
|
||||
z_b, _, _, _ = vae.forward(image_b)
|
||||
z_b = z_b.detach().cpu().numpy()
|
||||
interpolation_path = np.linspace(z_a, z_b, num=num_images)
|
||||
# interpolation_path[:, 4] = np.linspace(-3, 3, num=num_images)
|
||||
save_object["interpolation_path"] = interpolation_path
|
||||
|
||||
for i in range(num_images):
|
||||
# Generate
|
||||
z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
|
||||
gen_image = vae.decode(z).detach().numpy()
|
||||
gen_image = np.reshape(gen_image, (32, 32)) * 255
|
||||
save_object["interpolation_images"].append(gen_image)
|
||||
|
||||
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))
|
||||
image_pairs = []
|
||||
for i in range(num_images):
|
||||
recon_image = save_object["interpolation_images"][i]
|
||||
# Add to plot
|
||||
axs[i].imshow(recon_image)
|
||||
|
||||
# Perform intrpolations
|
||||
with open("interpolations.pkl", "wb") as f:
|
||||
pickle.dump(save_object, f)
|
||||
|
||||
plt.show()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,219 @@
|
||||
import torch
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
|
||||
"""
|
||||
These are utility functions that help to calculate the input and output
|
||||
sizes of convolutional neural networks
|
||||
"""
|
||||
|
||||
def num2tuple(num):
|
||||
return num if isinstance(num, tuple) else (num, num)
|
||||
|
||||
def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
|
||||
h_w, kernel_size, stride, pad, dilation = num2tuple(h_w), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation)
|
||||
pad = num2tuple(pad[0]), num2tuple(pad[1])
|
||||
|
||||
h = math.floor((h_w[0] + sum(pad[0]) - dilation[0]*(kernel_size[0]-1) - 1) / stride[0] + 1)
|
||||
w = math.floor((h_w[1] + sum(pad[1]) - dilation[1]*(kernel_size[1]-1) - 1) / stride[1] + 1)
|
||||
|
||||
return h, w
|
||||
|
||||
def convtransp2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0):
|
||||
h_w, kernel_size, stride, pad, dilation, out_pad = num2tuple(h_w), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), num2tuple(out_pad)
|
||||
pad = num2tuple(pad[0]), num2tuple(pad[1])
|
||||
|
||||
h = (h_w[0] - 1)*stride[0] - sum(pad[0]) + dialation[0]*(kernel_size[0]-1) + out_pad[0] + 1
|
||||
w = (h_w[1] - 1)*stride[1] - sum(pad[1]) + dialation[1]*(kernel_size[1]-1) + out_pad[1] + 1
|
||||
|
||||
return h, w
|
||||
|
||||
def conv2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1):
|
||||
h_w_in, h_w_out, kernel_size, stride, dilation = num2tuple(h_w_in), num2tuple(h_w_out), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation)
|
||||
|
||||
p_h = ((h_w_out[0] - 1)*stride[0] - h_w_in[0] + dilation[0]*(kernel_size[0]-1) + 1)
|
||||
p_w = ((h_w_out[1] - 1)*stride[1] - h_w_in[1] + dilation[1]*(kernel_size[1]-1) + 1)
|
||||
|
||||
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
|
||||
|
||||
def convtransp2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0):
|
||||
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = num2tuple(h_w_in), num2tuple(h_w_out), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), num2tuple(out_pad)
|
||||
|
||||
p_h = -(h_w_out[0] - 1 - out_pad[0] - dilation[0]*(kernel_size[0]-1) - (h_w_in[0] - 1)*stride[0]) / 2
|
||||
p_w = -(h_w_out[1] - 1 - out_pad[1] - dilation[1]*(kernel_size[1]-1) - (h_w_in[1] - 1)*stride[1]) / 2
|
||||
|
||||
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
|
||||
|
||||
def load_dataset(train=True, digit=None):
|
||||
# Transforms images to a PyTorch Tensor
|
||||
tensor_transform = transforms.Compose([
|
||||
transforms.Pad(2),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
# Download the MNIST Dataset
|
||||
dataset = datasets.MNIST(root = "./data",
|
||||
train = train,
|
||||
download = True,
|
||||
transform = tensor_transform)
|
||||
# Load specific image
|
||||
if not digit is None:
|
||||
idx = dataset.train_labels == digit
|
||||
dataset.targets = dataset.targets[idx]
|
||||
dataset.data = dataset.data[idx]
|
||||
|
||||
return dataset
|
||||
|
||||
def load_vae_from_path(path, latent_dim):
|
||||
model = VAE(latent_dim)
|
||||
model.load_state_dict(torch.load(path))
|
||||
|
||||
return model
|
||||
|
||||
# Creating a PyTorch class
|
||||
# 28*28 ==> 9 ==> 28*28
|
||||
class VAE(torch.nn.Module):
|
||||
def __init__(self, latent_dim=5, layer_count=4, channels=1):
|
||||
super().__init__()
|
||||
self.latent_dim = latent_dim
|
||||
self.in_shape = 32
|
||||
self.layer_count = layer_count
|
||||
self.channels = channels
|
||||
self.d = 128
|
||||
mul = 1
|
||||
inputs = self.channels
|
||||
out_sizes = [(self.in_shape, self.in_shape)]
|
||||
for i in range(self.layer_count):
|
||||
setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, self.d * mul, 4, 2, 1))
|
||||
setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
|
||||
h_w = (out_sizes[-1][-1], out_sizes[-1][-1])
|
||||
out_sizes.append(conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1))
|
||||
inputs = self.d * mul
|
||||
mul *= 2
|
||||
|
||||
self.d_max = inputs
|
||||
self.last_size = out_sizes[-1][-1]
|
||||
self.num_linear = self.last_size ** 2 * self.d_max
|
||||
# Encoder linear layers
|
||||
self.encoder_mean_linear = nn.Linear(self.num_linear, self.latent_dim)
|
||||
self.encoder_logvar_linear = nn.Linear(self.num_linear, self.latent_dim)
|
||||
# Decoder linear layer
|
||||
self.decoder_linear = nn.Linear(self.latent_dim, self.num_linear)
|
||||
|
||||
mul = inputs // self.d // 2
|
||||
|
||||
for i in range(1, self.layer_count):
|
||||
setattr(self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1))
|
||||
setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
|
||||
inputs = self.d * mul
|
||||
mul //= 2
|
||||
|
||||
setattr(self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1))
|
||||
|
||||
def encode(self, x):
|
||||
if len(x.shape) < 3:
|
||||
x = x.unsqueeze(0)
|
||||
if len(x.shape) < 4:
|
||||
x = x.unsqueeze(1)
|
||||
batch_size = x.shape[0]
|
||||
|
||||
for i in range(self.layer_count):
|
||||
x = F.relu(getattr(self, "conv%d_bn" % (i + 1))(getattr(self, "conv%d" % (i + 1))(x)))
|
||||
|
||||
x = x.view(batch_size, -1)
|
||||
|
||||
mean = self.encoder_mean_linear(x)
|
||||
logvar = self.encoder_logvar_linear(x)
|
||||
|
||||
return mean, logvar
|
||||
|
||||
def decode(self, x):
|
||||
x = x.view(x.shape[0], self.latent_dim)
|
||||
x = self.decoder_linear(x)
|
||||
x = x.view(x.shape[0], self.d_max, self.last_size, self.last_size)
|
||||
#x = self.deconv1_bn(x)
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
|
||||
for i in range(1, self.layer_count):
|
||||
x = F.leaky_relu(getattr(self, "deconv%d_bn" % (i + 1))(getattr(self, "deconv%d" % (i + 1))(x)), 0.2)
|
||||
x = getattr(self, "deconv%d" % (self.layer_count + 1))(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
mean, logvar = self.encode(x)
|
||||
eps = torch.randn(batch_size, self.latent_dim)
|
||||
z = mean + torch.exp(logvar / 2) * eps
|
||||
reconstructed = self.decode(z)
|
||||
return mean, logvar, reconstructed, x
|
||||
|
||||
def train_model(latent_dim=16, plot=True, digit=1, epochs=200):
|
||||
dataset = load_dataset(train=True, digit=digit)
|
||||
# DataLoader is used to load the dataset
|
||||
# for training
|
||||
loader = torch.utils.data.DataLoader(dataset = dataset,
|
||||
batch_size = 32,
|
||||
shuffle = True)
|
||||
# Model Initialization
|
||||
model = VAE(latent_dim=latent_dim)
|
||||
# Validation using MSE Loss function
|
||||
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.0001):
|
||||
kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0)
|
||||
recon = torch.nn.functional.mse_loss(reconstructed, original)
|
||||
# print(f"KL Error {kl}, Recon Error {recon}")
|
||||
return kl_beta * kl + recon
|
||||
|
||||
# Using an Adam Optimizer with lr = 0.1
|
||||
optimizer = torch.optim.Adam(model.parameters(),
|
||||
lr = 1e-4,
|
||||
weight_decay = 0e-8)
|
||||
|
||||
outputs = []
|
||||
losses = []
|
||||
for epoch in tqdm(range(epochs)):
|
||||
for (image, _) in loader:
|
||||
# Output of Autoencoder
|
||||
mean, log_var, reconstructed, image = model(image)
|
||||
# Calculating the loss function
|
||||
loss = loss_function(mean, log_var, reconstructed, image)
|
||||
# The gradients are set to zero,
|
||||
# the the gradient is computed and stored.
|
||||
# .step() performs parameter update
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# Storing the losses in a list for plotting
|
||||
if torch.isnan(loss):
|
||||
raise Exception()
|
||||
losses.append(loss.detach().cpu())
|
||||
outputs.append((epochs, image, reconstructed))
|
||||
|
||||
torch.save(model.state_dict(),
|
||||
os.path.join(
|
||||
os.environ["PROJECT_ROOT"],
|
||||
f"examples/variational_autoencoder/autoencoder_model/saved_models/model_dim{latent_dim}.pth"
|
||||
)
|
||||
)
|
||||
|
||||
if plot:
|
||||
# Defining the Plot Style
|
||||
plt.style.use('fivethirtyeight')
|
||||
plt.xlabel('Iterations')
|
||||
plt.ylabel('Loss')
|
||||
|
||||
# Plotting the last 100 values
|
||||
plt.plot(losses)
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_model(latent_dim=2, digit=2, epochs=40)
|
313
examples/variational_autoencoder/variational_autoencoder.py
Normal file
313
examples/variational_autoencoder/variational_autoencoder.py
Normal file
@ -0,0 +1,313 @@
|
||||
"""Autoencoder Manim Visualizations
|
||||
|
||||
In this module I define Manim visualizations for Variational Autoencoders
|
||||
and Traditional Autoencoders.
|
||||
|
||||
"""
|
||||
from manim import *
|
||||
import pickle
|
||||
import numpy as np
|
||||
import os
|
||||
import manim_ml.neural_network as neural_network
|
||||
|
||||
class VariationalAutoencoder(VGroup):
|
||||
"""Variational Autoencoder Manim Visualization"""
|
||||
|
||||
def __init__(
|
||||
self, encoder_nodes_per_layer=[5, 3], decoder_nodes_per_layer=[3, 5], point_color=BLUE,
|
||||
dot_radius=0.05, ellipse_stroke_width=2.0, layer_spacing=0.5
|
||||
):
|
||||
super(VGroup, self).__init__()
|
||||
self.encoder_nodes_per_layer = encoder_nodes_per_layer
|
||||
self.decoder_nodes_per_layer = decoder_nodes_per_layer
|
||||
self.point_color = point_color
|
||||
self.dot_radius = dot_radius
|
||||
self.layer_spacing = layer_spacing
|
||||
self.ellipse_stroke_width = ellipse_stroke_width
|
||||
# Make the VMobjects
|
||||
self.encoder, self.decoder = self._construct_encoder_and_decoder()
|
||||
self.embedding = self._construct_embedding()
|
||||
# Setup the relative locations
|
||||
self.embedding.move_to(self.encoder)
|
||||
self.embedding.shift([1.4 * self.encoder.width, 0, 0])
|
||||
self.decoder.move_to(self.embedding)
|
||||
self.decoder.shift([self.decoder.width * 1.4, 0, 0])
|
||||
# Add the objects to the VAE object
|
||||
self.add(self.encoder)
|
||||
self.add(self.decoder)
|
||||
self.add(self.embedding)
|
||||
|
||||
def _construct_encoder_and_decoder(self):
|
||||
"""Makes the VAE encoder and decoder"""
|
||||
# Make the encoder
|
||||
layer_node_count = self.encoder_nodes_per_layer
|
||||
encoder = neural_network.NeuralNetwork(layer_node_count, dot_radius=self.dot_radius, layer_spacing=self.layer_spacing)
|
||||
encoder.scale(1.2)
|
||||
# Make the decoder
|
||||
layer_node_count = self.decoder_nodes_per_layer
|
||||
decoder = neural_network.NeuralNetwork(layer_node_count, dot_radius=self.dot_radius, layer_spacing=self.layer_spacing)
|
||||
decoder.scale(1.2)
|
||||
|
||||
return encoder, decoder
|
||||
|
||||
def _construct_embedding(self):
|
||||
"""Makes a Gaussian-like embedding"""
|
||||
embedding = VGroup()
|
||||
# Sample points from a Gaussian
|
||||
num_points = 200
|
||||
standard_deviation = [0.9, 0.9]
|
||||
mean = [0, 0]
|
||||
points = np.random.normal(mean, standard_deviation, size=(num_points, 2))
|
||||
# Make an axes
|
||||
embedding.axes = Axes(
|
||||
x_range=[-3, 3],
|
||||
y_range=[-3, 3],
|
||||
x_length=2.2,
|
||||
y_length=2.2,
|
||||
tips=False,
|
||||
)
|
||||
# Add each point to the axes
|
||||
self.point_dots = VGroup()
|
||||
for point in points:
|
||||
point_location = embedding.axes.coords_to_point(*point)
|
||||
dot = Dot(point_location, color=self.point_color, radius=self.dot_radius/2)
|
||||
self.point_dots.add(dot)
|
||||
|
||||
embedding.add(self.point_dots)
|
||||
return embedding
|
||||
|
||||
def _construct_image_mobject(self, input_image, height=2.3):
|
||||
"""Constructs an ImageMobject from a numpy grayscale image"""
|
||||
# Convert image to rgb
|
||||
input_image = np.repeat(input_image, 3, axis=0)
|
||||
input_image = np.rollaxis(input_image, 0, start=3)
|
||||
# Make the ImageMobject
|
||||
image_mobject = ImageMobject(input_image, image_mode="RGB")
|
||||
image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||
image_mobject.height = height
|
||||
|
||||
return image_mobject
|
||||
|
||||
def _construct_input_output_images(self, image_pair):
|
||||
"""Places the input and output images for the AE"""
|
||||
# Takes the image pair
|
||||
# image_pair is assumed to be [2, x, y]
|
||||
input_image = image_pair[0][None, :, :]
|
||||
recon_image = image_pair[1][None, :, :]
|
||||
# Make the image mobjects
|
||||
input_image_object = self._construct_image_mobject(input_image)
|
||||
recon_image_object = self._construct_image_mobject(recon_image)
|
||||
|
||||
return input_image_object, recon_image_object
|
||||
|
||||
def make_dot_convergence_animation(self, location, run_time=1.5):
|
||||
"""Makes dots converge on a specific location"""
|
||||
# Move to location
|
||||
animations = []
|
||||
for dot in self.encoder.dots:
|
||||
coords = self.embedding.axes.coords_to_point(*location)
|
||||
animations.append(dot.animate.move_to(coords))
|
||||
move_animations = AnimationGroup(*animations, run_time=1.5)
|
||||
# Follow up with remove animations
|
||||
remove_animations = []
|
||||
for dot in self.encoder.dots:
|
||||
remove_animations.append(FadeOut(dot))
|
||||
remove_animations = AnimationGroup(*remove_animations, run_time=0.2)
|
||||
|
||||
animation_group = Succession(move_animations, remove_animations, lag_ratio=1.0)
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_dot_divergence_animation(self, location, run_time=3.0):
|
||||
"""Makes dots diverge from the given location and move the decoder"""
|
||||
animations = []
|
||||
for node in self.decoder.layers[0].node_group:
|
||||
new_dot = Dot(location, radius=self.dot_radius, color=RED)
|
||||
per_node_succession = Succession(
|
||||
Create(new_dot),
|
||||
new_dot.animate.move_to(node.get_center()),
|
||||
)
|
||||
animations.append(per_node_succession)
|
||||
|
||||
animation_group = AnimationGroup(*animations)
|
||||
return animation_group
|
||||
|
||||
def make_reset_vae_animation(self):
|
||||
"""Resets the VAE to just the neural network"""
|
||||
animation_group = AnimationGroup(
|
||||
FadeOut(self.input_image),
|
||||
FadeOut(self.output_image),
|
||||
FadeOut(self.distribution_objects),
|
||||
run_time=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_forward_pass_animation(self, image_pair, run_time=1.5):
|
||||
"""Overriden forward pass animation specific to a VAE"""
|
||||
per_unit_runtime = run_time
|
||||
# Setup images
|
||||
self.input_image, self.output_image = self._construct_input_output_images(image_pair)
|
||||
self.input_image.move_to(self.encoder.get_left())
|
||||
self.input_image.shift(LEFT)
|
||||
self.output_image.move_to(self.decoder.get_right())
|
||||
self.output_image.shift(RIGHT*1.3)
|
||||
# Make encoder forward pass
|
||||
encoder_forward_pass = self.encoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
||||
# Make red dot in embedding
|
||||
mean = [1.0, 1.5]
|
||||
mean_point = self.embedding.axes.coords_to_point(*mean)
|
||||
std = [0.8, 1.2]
|
||||
# Make the dot convergence animation
|
||||
dot_convergence_animation = self.make_dot_convergence_animation(mean, run_time=per_unit_runtime)
|
||||
encoding_succesion = Succession(
|
||||
encoder_forward_pass,
|
||||
dot_convergence_animation
|
||||
)
|
||||
# Make an ellipse centered at mean_point witAnimationGraph std outline
|
||||
center_dot = Dot(mean_point, radius=self.dot_radius, color=RED)
|
||||
ellipse = Ellipse(width=std[0], height=std[1], color=RED, fill_opacity=0.3, stroke_width=self.ellipse_stroke_width)
|
||||
ellipse.move_to(mean_point)
|
||||
self.distribution_objects = VGroup(
|
||||
center_dot,
|
||||
ellipse
|
||||
)
|
||||
# Make ellipse animation
|
||||
ellipse_animation = AnimationGroup(
|
||||
GrowFromCenter(center_dot),
|
||||
GrowFromCenter(ellipse),
|
||||
)
|
||||
# Make the dot divergence animation
|
||||
sampled_point = [0.51, 1.0]
|
||||
divergence_point = self.embedding.axes.coords_to_point(*sampled_point)
|
||||
dot_divergence_animation = self.make_dot_divergence_animation(divergence_point, run_time=per_unit_runtime)
|
||||
# Make decoder foward pass
|
||||
decoder_forward_pass = self.decoder.make_forward_propagation_animation(run_time=per_unit_runtime)
|
||||
# Add the animations to the group
|
||||
animation_group = AnimationGroup(
|
||||
FadeIn(self.input_image),
|
||||
encoding_succesion,
|
||||
ellipse_animation,
|
||||
dot_divergence_animation,
|
||||
decoder_forward_pass,
|
||||
FadeIn(self.output_image),
|
||||
lag_ratio=1,
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
def make_interpolation_animation(self, interpolation_images, frame_rate=5):
|
||||
"""Makes an animation interpolation"""
|
||||
num_images = len(interpolation_images)
|
||||
# Make madeup path
|
||||
interpolation_latent_path = np.linspace([-0.7, -1.2], [1.2, 1.5], num=num_images)
|
||||
# Make the path animation
|
||||
first_dot_location = self.embedding.axes.coords_to_point(*interpolation_latent_path[0])
|
||||
last_dot_location = self.embedding.axes.coords_to_point(*interpolation_latent_path[-1])
|
||||
moving_dot = Dot(first_dot_location, radius=self.dot_radius, color=RED)
|
||||
self.add(moving_dot)
|
||||
animation_list = [Create(Line(first_dot_location, last_dot_location, color=RED), run_time=0.1*num_images)]
|
||||
for image_index in range(num_images - 1):
|
||||
next_index = image_index + 1
|
||||
# Get path
|
||||
next_point = interpolation_latent_path[next_index]
|
||||
next_position = self.embedding.axes.coords_to_point(*next_point)
|
||||
# Draw path from current point to next point
|
||||
move_animation = moving_dot.animate(run_time=0.1*num_images).move_to(next_position)
|
||||
animation_list.append(move_animation)
|
||||
|
||||
interpolation_animation = AnimationGroup(*animation_list)
|
||||
# Make the images animation
|
||||
animation_list = [Wait(0.5)]
|
||||
for numpy_image in interpolation_images:
|
||||
numpy_image = numpy_image[None, :, :]
|
||||
manim_image = self._construct_image_mobject(numpy_image)
|
||||
# Move the image to the correct location
|
||||
manim_image.move_to(self.output_image)
|
||||
# Add the image
|
||||
animation_list.append(FadeIn(manim_image, run_time=0.1))
|
||||
# Wait
|
||||
# animation_list.append(Wait(1 / frame_rate))
|
||||
# Remove the image
|
||||
# animation_list.append(FadeOut(manim_image, run_time=0.1))
|
||||
images_animation = AnimationGroup(*animation_list, lag_ratio=1.0)
|
||||
# Combine the two into an AnimationGroup
|
||||
animation_group = AnimationGroup(
|
||||
interpolation_animation,
|
||||
images_animation
|
||||
)
|
||||
|
||||
return animation_group
|
||||
|
||||
class MNISTImageHandler():
|
||||
"""Deals with loading serialized VAE mnist images from "autoencoder_models" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_pairs_file_path=os.path.join(os.environ["PROJECT_ROOT"], "examples/variational_autoencoder/autoencoder_models/image_pairs.pkl"),
|
||||
interpolations_file_path=os.path.join(os.environ["PROJECT_ROOT"], "examples/variational_autoencoder/autoencoder_models/interpolations.pkl")
|
||||
):
|
||||
self.image_pairs_file_path = image_pairs_file_path
|
||||
self.interpolations_file_path = interpolations_file_path
|
||||
|
||||
self.image_pairs = []
|
||||
self.interpolation_images = []
|
||||
self.interpolation_latent_path = []
|
||||
|
||||
self.load_serialized_data()
|
||||
|
||||
def load_serialized_data(self):
|
||||
with open(self.image_pairs_file_path, "rb") as f:
|
||||
self.image_pairs = pickle.load(f)
|
||||
|
||||
with open(self.interpolations_file_path, "rb") as f:
|
||||
self.interpolation_dict = pickle.load(f)
|
||||
self.interpolation_images = self.interpolation_dict["interpolation_images"]
|
||||
self.interpolation_latent_path = self.interpolation_dict["interpolation_path"]
|
||||
|
||||
"""
|
||||
The VAE Scene for the twitter video.
|
||||
"""
|
||||
config.pixel_height = 720
|
||||
config.pixel_width = 1280
|
||||
config.frame_height = 5.0
|
||||
config.frame_width = 5.0
|
||||
# Set random seed so point distribution is constant
|
||||
np.random.seed(1)
|
||||
|
||||
class VAEScene(Scene):
|
||||
"""Scene object for a Variational Autoencoder and Autoencoder"""
|
||||
|
||||
def construct(self):
|
||||
# Set Scene config
|
||||
vae = VariationalAutoencoder()
|
||||
mnist_image_handler = MNISTImageHandler()
|
||||
image_pair = mnist_image_handler.image_pairs[3]
|
||||
vae.move_to(ORIGIN)
|
||||
vae.scale(1.3)
|
||||
self.play(Create(vae), run_time=3)
|
||||
# Make a forward pass animation
|
||||
forward_pass_animation = vae.make_forward_pass_animation(image_pair)
|
||||
self.play(forward_pass_animation)
|
||||
# Remove the input and output images
|
||||
reset_animation = vae.make_reset_vae_animation()
|
||||
self.play(reset_animation)
|
||||
# Interpolation animation
|
||||
interpolation_images = mnist_image_handler.interpolation_images
|
||||
interpolation_animation = vae.make_interpolation_animation(interpolation_images)
|
||||
self.play(interpolation_animation)
|
||||
|
||||
class VAEImage(Scene):
|
||||
|
||||
def construct(self):
|
||||
# Set Scene config
|
||||
vae = VariationalAutoencoder()
|
||||
mnist_image_handler = MNISTImageHandler()
|
||||
image_pair = mnist_image_handler.image_pairs[3]
|
||||
vae.move_to(ORIGIN)
|
||||
vae.scale(1.3)
|
||||
self.play(Create(vae), run_time=3)
|
||||
# Make a forward pass animation
|
||||
forward_pass_animation = vae.make_forward_pass_animation(image_pair)
|
||||
self.play(forward_pass_animation)
|
||||
|
Reference in New Issue
Block a user