Finished VAE Disentanglement video

This commit is contained in:
Alec Helbling
2022-03-07 02:06:17 -05:00
committed by Alec Helbling
parent a7c43d0f8b
commit ca0e93d657
17 changed files with 389 additions and 91 deletions

@ -1,3 +1,6 @@
setup:
conda activate manim
export PROJECT_ROOT=$(pwd)
video:
manim -pqh src/vae.py VAEScene --media_dir media
cp media/videos/vae/720p60/VAEScene.mp4 examples
@ -5,6 +8,11 @@ train:
cd src/autoencoder_models
python vanilla_autoencoder.py
python variational_autoencoder.py
generate_visualizations:
cd src/autoencoder_models
python generate_images.py
python generate_interpolation.py
python generate_disentanglement.py
checkstyle:
pycodestyle src
pydocstyle src

Binary file not shown.

After

(image error) Size: 148 KiB

Binary file not shown.

Binary file not shown.

After

(image error) Size: 137 KiB

BIN
examples/VAEScene.gif Normal file

Binary file not shown.

After

(image error) Size: 991 KiB

BIN
src/autoencoder_models/.DS_Store vendored Normal file

Binary file not shown.

Binary file not shown.

@ -0,0 +1,99 @@
import pickle
import sys
import os
sys.path.append(os.environ["PROJECT_ROOT"])
from autoencoder_models.variational_autoencoder import VAE, load_dataset, load_vae_from_path
import matplotlib.pyplot as plt
import numpy as np
import torch
import scipy
import scipy.stats
import cv2
def binned_images(model_path, num_x_bins=10, plot=False):
latent_dim = 2
model = load_vae_from_path(model_path, latent_dim)
image_dataset = load_dataset(digit=2)
# Compute embedding
num_images = 500
embedding = []
images = []
for i in range(num_images):
image, _ = image_dataset[i]
mean, _, recon, _ = model.forward(image)
mean = mean.detach().numpy()
recon = recon.detach().numpy()
recon = recon.reshape(32, 32)
images.append(recon.squeeze())
if latent_dim > 2:
mean = mean[:2]
embedding.append(mean)
images = np.stack(images)
tsne_points = np.array(embedding)
tsne_points = (tsne_points - tsne_points.mean(axis=0))/(tsne_points.std(axis=0))
# make vis
num_points = np.shape(tsne_points)[0]
x_min = np.amin(tsne_points.T[0])
y_min = np.amin(tsne_points.T[1])
y_max = np.amax(tsne_points.T[1])
x_max = np.amax(tsne_points.T[0])
# make the bins from the ranges
# to keep it square the same width is used for x and y dim
x_bins, step = np.linspace(x_min, x_max, num_x_bins, retstep=True)
x_bins = x_bins.astype(float)
num_y_bins = np.absolute(np.ceil((y_max - y_min)/step)).astype(int)
y_bins = np.linspace(y_min, y_max, num_y_bins)
# sort the tsne_points into a 2d histogram
tsne_points = tsne_points.squeeze()
hist_obj = scipy.stats.binned_statistic_dd(tsne_points, np.arange(num_points), statistic='count', bins=[x_bins, y_bins], expand_binnumbers=True)
# sample one point from each bucket
binnumbers = hist_obj.binnumber
num_x_bins = np.amax(binnumbers[0]) + 1
num_y_bins = np.amax(binnumbers[1]) + 1
binnumbers = binnumbers.T
# some places have no value in a region
used_mask = np.zeros((num_y_bins, num_x_bins))
image_bins = np.zeros((num_y_bins, num_x_bins, 3, np.shape(images)[2], np.shape(images)[2]))
for i, bin_num in enumerate(list(binnumbers)):
used_mask[bin_num[1], bin_num[0]] = 1
image_bins[bin_num[1], bin_num[0]] = images[i]
# plot a grid of the images
fig, axs = plt.subplots(nrows=np.shape(y_bins)[0], ncols=np.shape(x_bins)[0], constrained_layout=False, dpi=50)
images = []
bin_indices = []
for y in range(num_y_bins):
for x in range(num_x_bins):
if used_mask[y, x] > 0.0:
image = np.uint8(image_bins[y][x].squeeze()*255)
image = np.rollaxis(image, 0, 3)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
axs[num_y_bins - 1 - y][x].imshow(image)
images.append(image)
bin_indices.append((y, x))
axs[y, x].axis('off')
if plot:
plt.axis('off')
plt.show()
else:
return images, bin_indices
def generate_disentanglement(model_path="saved_models/model_dim2.pth"):
"""Generates disentanglement visualization and serializes it"""
# Disentanglement object
disentanglement_object = {}
# Make Disentanglement
images, bin_indices = binned_images(model_path)
disentanglement_object["images"] = images
disentanglement_object["bin_indices"] = bin_indices
# Serialize Images
with open("disentanglement.pkl", "wb") as f:
pickle.dump(disentanglement_object, f)
if __name__ == "__main__":
plot = False
if plot:
model_path = "saved_models/model_dim2.pth"
#uniform_image_sample(model_path)
binned_images(model_path)
else:
generate_disentanglement()

@ -1,5 +1,5 @@
import torch
from variational_autoencoder import VAE
from variational_autoencoder import VAE, load_dataset
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
@ -10,13 +10,7 @@ import pickle
# Load model
vae = VAE(latent_dim=16)
vae.load_state_dict(torch.load("saved_models/model.pth"))
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()
# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data",
train = True,
download = True,
transform = tensor_transform)
dataset = load_dataset()
# Generate reconstructions
num_images = 50
image_pairs = []
@ -24,8 +18,8 @@ save_object = {"interpolation_path":[], "interpolation_images":[]}
# Make interpolation path
image_a, image_b = dataset[0][0], dataset[1][0]
image_a = image_a.view(28*28)
image_b = image_b.view(28*28)
image_a = image_a.view(32*32)
image_b = image_b.view(32*32)
z_a, _, _, _ = vae.forward(image_a)
z_a = z_a.detach().cpu().numpy()
z_b, _, _, _ = vae.forward(image_b)
@ -38,7 +32,7 @@ for i in range(num_images):
# Generate
z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
gen_image = vae.decode(z).detach().numpy()
gen_image = np.reshape(gen_image, (28, 28)) * 255
gen_image = np.reshape(gen_image, (32, 32)) * 255
save_object["interpolation_images"].append(gen_image)
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))

Binary file not shown.

Binary file not shown.

Binary file not shown.

@ -1,81 +1,173 @@
import torch
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
"""
These are utility functions that help to calculate the input and output
sizes of convolutional neural networks
"""
def num2tuple(num):
return num if isinstance(num, tuple) else (num, num)
def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
h_w, kernel_size, stride, pad, dilation = num2tuple(h_w), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation)
pad = num2tuple(pad[0]), num2tuple(pad[1])
h = math.floor((h_w[0] + sum(pad[0]) - dilation[0]*(kernel_size[0]-1) - 1) / stride[0] + 1)
w = math.floor((h_w[1] + sum(pad[1]) - dilation[1]*(kernel_size[1]-1) - 1) / stride[1] + 1)
return h, w
def convtransp2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0):
h_w, kernel_size, stride, pad, dilation, out_pad = num2tuple(h_w), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), num2tuple(out_pad)
pad = num2tuple(pad[0]), num2tuple(pad[1])
h = (h_w[0] - 1)*stride[0] - sum(pad[0]) + dialation[0]*(kernel_size[0]-1) + out_pad[0] + 1
w = (h_w[1] - 1)*stride[1] - sum(pad[1]) + dialation[1]*(kernel_size[1]-1) + out_pad[1] + 1
return h, w
def conv2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1):
h_w_in, h_w_out, kernel_size, stride, dilation = num2tuple(h_w_in), num2tuple(h_w_out), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation)
p_h = ((h_w_out[0] - 1)*stride[0] - h_w_in[0] + dilation[0]*(kernel_size[0]-1) + 1)
p_w = ((h_w_out[1] - 1)*stride[1] - h_w_in[1] + dilation[1]*(kernel_size[1]-1) + 1)
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
def convtransp2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0):
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = num2tuple(h_w_in), num2tuple(h_w_out), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), num2tuple(out_pad)
p_h = -(h_w_out[0] - 1 - out_pad[0] - dilation[0]*(kernel_size[0]-1) - (h_w_in[0] - 1)*stride[0]) / 2
p_w = -(h_w_out[1] - 1 - out_pad[1] - dilation[1]*(kernel_size[1]-1) - (h_w_in[1] - 1)*stride[1]) / 2
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
def load_dataset(train=True, digit=None):
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.Compose([
transforms.Pad(2),
transforms.ToTensor()
])
# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data",
train = train,
download = True,
transform = tensor_transform)
# Load specific image
if not digit is None:
idx = dataset.train_labels == digit
dataset.targets = dataset.targets[idx]
dataset.data = dataset.data[idx]
return dataset
def load_vae_from_path(path, latent_dim):
model = VAE(latent_dim)
model.load_state_dict(torch.load(path))
return model
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()
# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data",
train = True,
download = True,
transform = tensor_transform)
# DataLoader is used to load the dataset
# for training
loader = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True)
# Creating a PyTorch class
# 28*28 ==> 9 ==> 28*28
class VAE(torch.nn.Module):
def __init__(self, latent_dim=5):
def __init__(self, latent_dim=5, layer_count=4, channels=1):
super().__init__()
self.latent_dim = latent_dim
# Building an linear encoder with Linear
# layer followed by Relu activation function
# 784 ==> 9
self.encoder = torch.nn.Sequential(
torch.nn.Linear(28 * 28, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 36),
torch.nn.ReLU(),
torch.nn.Linear(36, 18),
torch.nn.ReLU(),
)
self.mean_embedding = torch.nn.Linear(18, self.latent_dim)
self.logvar_embedding = torch.nn.Linear(18, self.latent_dim)
self.in_shape = 32
self.layer_count = layer_count
self.channels = channels
self.d = 128
mul = 1
inputs = self.channels
out_sizes = [(self.in_shape, self.in_shape)]
for i in range(self.layer_count):
setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, self.d * mul, 4, 2, 1))
setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
h_w = (out_sizes[-1][-1], out_sizes[-1][-1])
out_sizes.append(conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1))
inputs = self.d * mul
mul *= 2
# Building an linear decoder with Linear
# layer followed by Relu activation function
# The Sigmoid activation function
# outputs the value between 0 and 1
# 9 ==> 784
self.decoder = torch.nn.Sequential(
torch.nn.Linear(self.latent_dim, 18),
torch.nn.ReLU(),
torch.nn.Linear(18, 36),
torch.nn.ReLU(),
torch.nn.Linear(36, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 28 * 28),
torch.nn.Sigmoid()
)
self.d_max = inputs
self.last_size = out_sizes[-1][-1]
self.num_linear = self.last_size ** 2 * self.d_max
# Encoder linear layers
self.encoder_mean_linear = nn.Linear(self.num_linear, self.latent_dim)
self.encoder_logvar_linear = nn.Linear(self.num_linear, self.latent_dim)
# Decoder linear layer
self.decoder_linear = nn.Linear(self.latent_dim, self.num_linear)
def decode(self, z):
return self.decoder(z)
mul = inputs // self.d // 2
for i in range(1, self.layer_count):
setattr(self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1))
setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
inputs = self.d * mul
mul //= 2
setattr(self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1))
def encode(self, x):
if len(x.shape) < 3:
x = x.unsqueeze(0)
if len(x.shape) < 4:
x = x.unsqueeze(1)
batch_size = x.shape[0]
for i in range(self.layer_count):
x = F.relu(getattr(self, "conv%d_bn" % (i + 1))(getattr(self, "conv%d" % (i + 1))(x)))
x = x.view(batch_size, -1)
mean = self.encoder_mean_linear(x)
logvar = self.encoder_logvar_linear(x)
return mean, logvar
def decode(self, x):
x = x.view(x.shape[0], self.latent_dim)
x = self.decoder_linear(x)
x = x.view(x.shape[0], self.d_max, self.last_size, self.last_size)
#x = self.deconv1_bn(x)
x = F.leaky_relu(x, 0.2)
for i in range(1, self.layer_count):
x = F.leaky_relu(getattr(self, "deconv%d_bn" % (i + 1))(getattr(self, "deconv%d" % (i + 1))(x)), 0.2)
x = getattr(self, "deconv%d" % (self.layer_count + 1))(x)
x = torch.sigmoid(x)
return x
def forward(self, x):
encoded = self.encoder(x)
mean = self.mean_embedding(encoded)
logvar = self.logvar_embedding(encoded)
batch_size = x.shape[0]
mean, logvar = self.encode(x)
eps = torch.randn(batch_size, self.latent_dim)
z = mean + torch.exp(logvar / 2) * eps
reconstructed = self.decoder(z)
reconstructed = self.decode(z)
return mean, logvar, reconstructed, x
def train_model():
def train_model(latent_dim=16, plot=True, digit=1, epochs=200):
dataset = load_dataset(train=True, digit=digit)
# DataLoader is used to load the dataset
# for training
loader = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True)
# Model Initialization
model = VAE(latent_dim=16)
model = VAE(latent_dim=latent_dim)
# Validation using MSE Loss function
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.001):
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.0001):
kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0)
recon = torch.nn.functional.mse_loss(reconstructed, original)
# print(f"KL Error {kl}, Recon Error {recon}")
@ -83,16 +175,13 @@ def train_model():
# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(model.parameters(),
lr = 1e-3,
weight_decay = 1e-8)
lr = 1e-4,
weight_decay = 0e-8)
epochs = 100
outputs = []
losses = []
for epoch in tqdm(range(epochs)):
for (image, _) in loader:
# Reshaping the image to (-1, 784)
image = image.reshape(-1, 28*28)
# Output of Autoencoder
mean, log_var, reconstructed, image = model(image)
# Calculating the loss function
@ -109,16 +198,17 @@ def train_model():
losses.append(loss.detach().cpu())
outputs.append((epochs, image, reconstructed))
torch.save(model.state_dict(), "saved_models/model.pth")
torch.save(model.state_dict(), f"saved_models/model_dim{latent_dim}.pth")
# Defining the Plot Style
plt.style.use('fivethirtyeight')
plt.xlabel('Iterations')
plt.ylabel('Loss')
# Plotting the last 100 values
plt.plot(losses)
plt.show()
if plot:
# Defining the Plot Style
plt.style.use('fivethirtyeight')
plt.xlabel('Iterations')
plt.ylabel('Loss')
# Plotting the last 100 values
plt.plot(losses)
plt.show()
if __name__ == "__main__":
train_model()
train_model(latent_dim=2, digit=1, epochs=40)

93
src/disentanglement.py Normal file

@ -0,0 +1,93 @@
"""This module is dedicated to visualizing VAE disentanglement"""
from manim import *
from neural_network import NeuralNetwork
import util
import pickle
class VAEDecoder(VGroup):
"""Just shows the VAE encoder"""
def __init__(self):
super(VGroup, self).__init__()
# Setup the Neural Network
node_counts = [3, 5]
self.neural_network = NeuralNetwork(node_counts, layer_spacing=0.55)
self.add(self.neural_network)
def make_encoding_animation(self):
pass
class DisentanglementVisualization(VGroup):
def __init__(self, model_path="autoencoder_models/saved_models/model_dim2.pth", image_height=0.2):
self.model_path = model_path
self.image_height = image_height
# Load disentanglement image objects
with open("autoencoder_models/disentanglement.pkl", "rb") as f:
self.image_handler = pickle.load(f)
def make_disentanglement_generation_animation(self):
animation_list = []
for image_index, image in enumerate(self.image_handler["images"]):
image_mobject = util.construct_image_mobject(image, height=self.image_height)
r, c = self.image_handler["bin_indices"][image_index]
# Move the image to the correct location
r_offset = -1.2
c_offset = 0.2
image_location = [c_offset + c*self.image_height, r_offset + r*self.image_height, 0]
image_mobject.move_to(image_location)
animation_list.append(FadeIn(image_mobject))
generation_animation = AnimationGroup(*animation_list[::-1], lag_ratio=1.0)
return generation_animation
config.pixel_height = 720
config.pixel_width = 1280
config.frame_height = 5.0
config.frame_width = 5.0
class DisentanglementScene(Scene):
"""Disentanglement Scene Object"""
def _construct_embedding(self, point_color=BLUE, dot_radius=0.05):
"""Makes a Gaussian-like embedding"""
embedding = VGroup()
# Sample points from a Gaussian
num_points = 200
standard_deviation = [0.6, 1.0]
mean = [0, 0]
points = np.random.normal(mean, standard_deviation, size=(num_points, 2))
# Make an axes
embedding.axes = Axes(
x_range=[-3, 3],
y_range=[-3, 3],
x_length=2.2,
y_length=2.2,
tips=False,
)
# Add each point to the axes
self.point_dots = VGroup()
for point in points:
point_location = embedding.axes.coords_to_point(*point)
dot = Dot(point_location, color=point_color, radius=dot_radius/2)
self.point_dots.add(dot)
embedding.add(self.point_dots)
return embedding
def construct(self):
# Make the VAE decoder
vae_decoder = VAEDecoder()
vae_decoder.shift([-0.65, 0, 0])
self.play(Create(vae_decoder), run_time=1)
# Make the embedding
embedding = self._construct_embedding()
embedding.scale(0.8)
embedding.move_to(vae_decoder.get_left())
embedding.shift([-0.7, 0, 0])
self.play(Create(embedding))
# Make disentanglment visualization
disentanglement = DisentanglementVisualization()
disentanglement_animation = disentanglement.make_disentanglement_generation_animation()
self.play(disentanglement_animation, run_time=3)
self.play(Wait(2))

@ -1,11 +1,11 @@
"""Neural Network Manim Visualization
This module is responsible for generating a neural network visualization with
manim, specifically a fully connected neural network diagram.
manim, specifically a fully connected neural network diagram.
Example:
# Specify how many nodes are in each node layer
layer_node_count = [5, 3, 5]
layer_node_count = [5, 3, 5]
# Create the object with default style settings
NeuralNetwork(layer_node_count)
"""
@ -15,7 +15,7 @@ class NeuralNetworkLayer(VGroup):
"""Handles rendering a layer for a neural network"""
def __init__(
self, num_nodes, layer_buffer=SMALL_BUFF/2, node_radius=0.08,
self, num_nodes, layer_buffer=SMALL_BUFF/2, node_radius=0.08,
node_color=BLUE, node_outline_color=WHITE, rectangle_color=WHITE,
node_spacing=0.3, rectangle_fill_color=BLACK, node_stroke_width=2.0,
rectangle_stroke_width=2.0):
@ -48,7 +48,7 @@ class NeuralNetworkLayer(VGroup):
node_object.move_to([0, location, 0])
# Create Surrounding Rectangle
surrounding_rectangle = SurroundingRectangle(
self.node_group, color=self.rectangle_color, fill_color=self.rectangle_fill_color,
self.node_group, color=self.rectangle_color, fill_color=self.rectangle_fill_color,
fill_opacity=1.0, buff=self.layer_buffer, stroke_width=self.rectangle_stroke_width
)
# Add the objects to the class
@ -57,7 +57,7 @@ class NeuralNetworkLayer(VGroup):
class NeuralNetwork(VGroup):
def __init__(
self, layer_node_count, layer_width=0.6, node_radius=1.0,
self, layer_node_count, layer_width=0.6, node_radius=1.0,
node_color=BLUE, edge_color=WHITE, layer_spacing=0.8,
animation_dot_color=RED, edge_width=2.0, dot_radius=0.05):
super(VGroup, self).__init__()
@ -73,7 +73,6 @@ class NeuralNetwork(VGroup):
# TODO take layer_node_count [0, (1, 2), 0]
# and make it have explicit distinct subspaces
self.layers = self._construct_layers()
self.edge_layers = self._construct_edges()
@ -99,7 +98,7 @@ class NeuralNetwork(VGroup):
edge_layers = VGroup()
for layer_index in range(len(self.layer_node_count) - 1):
current_layer = self.layers[layer_index]
next_layer = self.layers[layer_index + 1]
next_layer = self.layers[layer_index + 1]
edge_layer = VGroup()
# Go through each node in the two layers and make a connecting line
for node_i in current_layer.node_group:
@ -134,8 +133,8 @@ class NeuralNetwork(VGroup):
return animation_group
config.pixel_height = 720
config.pixel_width = 1280
config.pixel_height = 720
config.pixel_width = 1280
config.frame_height = 6.0
config.frame_width = 6.0

15
src/util.py Normal file

@ -0,0 +1,15 @@
from manim import *
import numpy as np
def construct_image_mobject(input_image, height=2.3):
"""Constructs an ImageMobject from a numpy grayscale image"""
# Convert image to rgb
if len(input_image.shape) == 2:
input_image = np.repeat(input_image, 3, axis=0)
input_image = np.rollaxis(input_image, 0, start=3)
# Make the ImageMobject
image_mobject = ImageMobject(input_image, image_mode="RGB")
image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
image_mobject.height = height
return image_mobject