Changed directory structure to accomodate examples as apposed to everything being a part of the core library. May need to rethink this in the future. Added some boilerplate for pip packaging to the .gitignore.

This commit is contained in:
Alec Helbling
2022-03-28 14:01:00 -04:00
committed by Alec Helbling
parent 4eb5296c9c
commit 3be5c54d26
40 changed files with 30 additions and 15 deletions

Binary file not shown.

View File

@ -0,0 +1,99 @@
import pickle
import sys
import os
sys.path.append(os.environ["PROJECT_ROOT"])
from autoencoder_models.variational_autoencoder import VAE, load_dataset, load_vae_from_path
import matplotlib.pyplot as plt
import numpy as np
import torch
import scipy
import scipy.stats
import cv2
def binned_images(model_path, num_x_bins=6, plot=False):
latent_dim = 2
model = load_vae_from_path(model_path, latent_dim)
image_dataset = load_dataset(digit=2)
# Compute embedding
num_images = 500
embedding = []
images = []
for i in range(num_images):
image, _ = image_dataset[i]
mean, _, recon, _ = model.forward(image)
mean = mean.detach().numpy()
recon = recon.detach().numpy()
recon = recon.reshape(32, 32)
images.append(recon.squeeze())
if latent_dim > 2:
mean = mean[:2]
embedding.append(mean)
images = np.stack(images)
tsne_points = np.array(embedding)
tsne_points = (tsne_points - tsne_points.mean(axis=0))/(tsne_points.std(axis=0))
# make vis
num_points = np.shape(tsne_points)[0]
x_min = np.amin(tsne_points.T[0])
y_min = np.amin(tsne_points.T[1])
y_max = np.amax(tsne_points.T[1])
x_max = np.amax(tsne_points.T[0])
# make the bins from the ranges
# to keep it square the same width is used for x and y dim
x_bins, step = np.linspace(x_min, x_max, num_x_bins, retstep=True)
x_bins = x_bins.astype(float)
num_y_bins = np.absolute(np.ceil((y_max - y_min)/step)).astype(int)
y_bins = np.linspace(y_min, y_max, num_y_bins)
# sort the tsne_points into a 2d histogram
tsne_points = tsne_points.squeeze()
hist_obj = scipy.stats.binned_statistic_dd(tsne_points, np.arange(num_points), statistic='count', bins=[x_bins, y_bins], expand_binnumbers=True)
# sample one point from each bucket
binnumbers = hist_obj.binnumber
num_x_bins = np.amax(binnumbers[0]) + 1
num_y_bins = np.amax(binnumbers[1]) + 1
binnumbers = binnumbers.T
# some places have no value in a region
used_mask = np.zeros((num_y_bins, num_x_bins))
image_bins = np.zeros((num_y_bins, num_x_bins, 3, np.shape(images)[2], np.shape(images)[2]))
for i, bin_num in enumerate(list(binnumbers)):
used_mask[bin_num[1], bin_num[0]] = 1
image_bins[bin_num[1], bin_num[0]] = images[i]
# plot a grid of the images
fig, axs = plt.subplots(nrows=np.shape(y_bins)[0], ncols=np.shape(x_bins)[0], constrained_layout=False, dpi=50)
images = []
bin_indices = []
for y in range(num_y_bins):
for x in range(num_x_bins):
if used_mask[y, x] > 0.0:
image = np.uint8(image_bins[y][x].squeeze()*255)
image = np.rollaxis(image, 0, 3)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
axs[num_y_bins - 1 - y][x].imshow(image)
images.append(image)
bin_indices.append((y, x))
axs[y, x].axis('off')
if plot:
plt.axis('off')
plt.show()
else:
return images, bin_indices
def generate_disentanglement(model_path="saved_models/model_dim2.pth"):
"""Generates disentanglement visualization and serializes it"""
# Disentanglement object
disentanglement_object = {}
# Make Disentanglement
images, bin_indices = binned_images(model_path)
disentanglement_object["images"] = images
disentanglement_object["bin_indices"] = bin_indices
# Serialize Images
with open("disentanglement.pkl", "wb") as f:
pickle.dump(disentanglement_object, f)
if __name__ == "__main__":
plot = False
if plot:
model_path = "saved_models/model_dim2.pth"
#uniform_image_sample(model_path)
binned_images(model_path)
else:
generate_disentanglement()

View File

@ -0,0 +1,41 @@
import torch
from variational_autoencoder import VAE
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import pickle
# Load model
vae = VAE(latent_dim=16)
vae.load_state_dict(torch.load("saved_models/model.pth"))
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()
# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data",
train = True,
download = True,
transform = tensor_transform)
# Generate reconstructions
num_recons = 10
fig, axs = plt.subplots(num_recons, 2, figsize=(2, num_recons))
image_pairs = []
for i in range(num_recons):
base_image, _ = dataset[i]
base_image = base_image.reshape(-1, 28*28)
_, _, recon_image, _ = vae.forward(base_image)
base_image = base_image.detach().numpy()
base_image = np.reshape(base_image, (28, 28)) * 255
recon_image = recon_image.detach().numpy()
recon_image = np.reshape(recon_image, (28, 28)) * 255
# Add to plot
axs[i][0].imshow(base_image)
axs[i][1].imshow(recon_image)
# image pairs
image_pairs.append((base_image, recon_image))
with open("image_pairs.pkl", "wb") as f:
pickle.dump(image_pairs, f)
plt.show()

View File

@ -0,0 +1,49 @@
import torch
from variational_autoencoder import VAE, load_dataset
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import pickle
# Load model
vae = VAE(latent_dim=16)
vae.load_state_dict(torch.load("saved_models/model.pth"))
dataset = load_dataset()
# Generate reconstructions
num_images = 50
image_pairs = []
save_object = {"interpolation_path":[], "interpolation_images":[]}
# Make interpolation path
image_a, image_b = dataset[0][0], dataset[1][0]
image_a = image_a.view(32*32)
image_b = image_b.view(32*32)
z_a, _, _, _ = vae.forward(image_a)
z_a = z_a.detach().cpu().numpy()
z_b, _, _, _ = vae.forward(image_b)
z_b = z_b.detach().cpu().numpy()
interpolation_path = np.linspace(z_a, z_b, num=num_images)
# interpolation_path[:, 4] = np.linspace(-3, 3, num=num_images)
save_object["interpolation_path"] = interpolation_path
for i in range(num_images):
# Generate
z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
gen_image = vae.decode(z).detach().numpy()
gen_image = np.reshape(gen_image, (32, 32)) * 255
save_object["interpolation_images"].append(gen_image)
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))
image_pairs = []
for i in range(num_images):
recon_image = save_object["interpolation_images"][i]
# Add to plot
axs[i].imshow(recon_image)
# Perform intrpolations
with open("interpolations.pkl", "wb") as f:
pickle.dump(save_object, f)
plt.show()

View File

@ -0,0 +1,219 @@
import torch
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
"""
These are utility functions that help to calculate the input and output
sizes of convolutional neural networks
"""
def num2tuple(num):
return num if isinstance(num, tuple) else (num, num)
def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
h_w, kernel_size, stride, pad, dilation = num2tuple(h_w), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation)
pad = num2tuple(pad[0]), num2tuple(pad[1])
h = math.floor((h_w[0] + sum(pad[0]) - dilation[0]*(kernel_size[0]-1) - 1) / stride[0] + 1)
w = math.floor((h_w[1] + sum(pad[1]) - dilation[1]*(kernel_size[1]-1) - 1) / stride[1] + 1)
return h, w
def convtransp2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0):
h_w, kernel_size, stride, pad, dilation, out_pad = num2tuple(h_w), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), num2tuple(out_pad)
pad = num2tuple(pad[0]), num2tuple(pad[1])
h = (h_w[0] - 1)*stride[0] - sum(pad[0]) + dialation[0]*(kernel_size[0]-1) + out_pad[0] + 1
w = (h_w[1] - 1)*stride[1] - sum(pad[1]) + dialation[1]*(kernel_size[1]-1) + out_pad[1] + 1
return h, w
def conv2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1):
h_w_in, h_w_out, kernel_size, stride, dilation = num2tuple(h_w_in), num2tuple(h_w_out), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation)
p_h = ((h_w_out[0] - 1)*stride[0] - h_w_in[0] + dilation[0]*(kernel_size[0]-1) + 1)
p_w = ((h_w_out[1] - 1)*stride[1] - h_w_in[1] + dilation[1]*(kernel_size[1]-1) + 1)
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
def convtransp2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0):
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = num2tuple(h_w_in), num2tuple(h_w_out), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), num2tuple(out_pad)
p_h = -(h_w_out[0] - 1 - out_pad[0] - dilation[0]*(kernel_size[0]-1) - (h_w_in[0] - 1)*stride[0]) / 2
p_w = -(h_w_out[1] - 1 - out_pad[1] - dilation[1]*(kernel_size[1]-1) - (h_w_in[1] - 1)*stride[1]) / 2
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
def load_dataset(train=True, digit=None):
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.Compose([
transforms.Pad(2),
transforms.ToTensor()
])
# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data",
train = train,
download = True,
transform = tensor_transform)
# Load specific image
if not digit is None:
idx = dataset.train_labels == digit
dataset.targets = dataset.targets[idx]
dataset.data = dataset.data[idx]
return dataset
def load_vae_from_path(path, latent_dim):
model = VAE(latent_dim)
model.load_state_dict(torch.load(path))
return model
# Creating a PyTorch class
# 28*28 ==> 9 ==> 28*28
class VAE(torch.nn.Module):
def __init__(self, latent_dim=5, layer_count=4, channels=1):
super().__init__()
self.latent_dim = latent_dim
self.in_shape = 32
self.layer_count = layer_count
self.channels = channels
self.d = 128
mul = 1
inputs = self.channels
out_sizes = [(self.in_shape, self.in_shape)]
for i in range(self.layer_count):
setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, self.d * mul, 4, 2, 1))
setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
h_w = (out_sizes[-1][-1], out_sizes[-1][-1])
out_sizes.append(conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1))
inputs = self.d * mul
mul *= 2
self.d_max = inputs
self.last_size = out_sizes[-1][-1]
self.num_linear = self.last_size ** 2 * self.d_max
# Encoder linear layers
self.encoder_mean_linear = nn.Linear(self.num_linear, self.latent_dim)
self.encoder_logvar_linear = nn.Linear(self.num_linear, self.latent_dim)
# Decoder linear layer
self.decoder_linear = nn.Linear(self.latent_dim, self.num_linear)
mul = inputs // self.d // 2
for i in range(1, self.layer_count):
setattr(self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1))
setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
inputs = self.d * mul
mul //= 2
setattr(self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1))
def encode(self, x):
if len(x.shape) < 3:
x = x.unsqueeze(0)
if len(x.shape) < 4:
x = x.unsqueeze(1)
batch_size = x.shape[0]
for i in range(self.layer_count):
x = F.relu(getattr(self, "conv%d_bn" % (i + 1))(getattr(self, "conv%d" % (i + 1))(x)))
x = x.view(batch_size, -1)
mean = self.encoder_mean_linear(x)
logvar = self.encoder_logvar_linear(x)
return mean, logvar
def decode(self, x):
x = x.view(x.shape[0], self.latent_dim)
x = self.decoder_linear(x)
x = x.view(x.shape[0], self.d_max, self.last_size, self.last_size)
#x = self.deconv1_bn(x)
x = F.leaky_relu(x, 0.2)
for i in range(1, self.layer_count):
x = F.leaky_relu(getattr(self, "deconv%d_bn" % (i + 1))(getattr(self, "deconv%d" % (i + 1))(x)), 0.2)
x = getattr(self, "deconv%d" % (self.layer_count + 1))(x)
x = torch.sigmoid(x)
return x
def forward(self, x):
batch_size = x.shape[0]
mean, logvar = self.encode(x)
eps = torch.randn(batch_size, self.latent_dim)
z = mean + torch.exp(logvar / 2) * eps
reconstructed = self.decode(z)
return mean, logvar, reconstructed, x
def train_model(latent_dim=16, plot=True, digit=1, epochs=200):
dataset = load_dataset(train=True, digit=digit)
# DataLoader is used to load the dataset
# for training
loader = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True)
# Model Initialization
model = VAE(latent_dim=latent_dim)
# Validation using MSE Loss function
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.0001):
kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0)
recon = torch.nn.functional.mse_loss(reconstructed, original)
# print(f"KL Error {kl}, Recon Error {recon}")
return kl_beta * kl + recon
# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(model.parameters(),
lr = 1e-4,
weight_decay = 0e-8)
outputs = []
losses = []
for epoch in tqdm(range(epochs)):
for (image, _) in loader:
# Output of Autoencoder
mean, log_var, reconstructed, image = model(image)
# Calculating the loss function
loss = loss_function(mean, log_var, reconstructed, image)
# The gradients are set to zero,
# the the gradient is computed and stored.
# .step() performs parameter update
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Storing the losses in a list for plotting
if torch.isnan(loss):
raise Exception()
losses.append(loss.detach().cpu())
outputs.append((epochs, image, reconstructed))
torch.save(model.state_dict(),
os.path.join(
os.environ["PROJECT_ROOT"],
f"examples/variational_autoencoder/autoencoder_model/saved_models/model_dim{latent_dim}.pth"
)
)
if plot:
# Defining the Plot Style
plt.style.use('fivethirtyeight')
plt.xlabel('Iterations')
plt.ylabel('Loss')
# Plotting the last 100 values
plt.plot(losses)
plt.show()
if __name__ == "__main__":
train_model(latent_dim=2, digit=2, epochs=40)