mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-23 21:46:52 +08:00
Changed directory structure to accomodate examples as apposed to everything being a part of the core library. May need to rethink this in the future. Added some boilerplate for pip packaging to the .gitignore.
This commit is contained in:

committed by
Alec Helbling

parent
4eb5296c9c
commit
3be5c54d26
BIN
examples/variational_autoencoder/autoencoder_models/.DS_Store
vendored
Normal file
BIN
examples/variational_autoencoder/autoencoder_models/.DS_Store
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,99 @@
|
||||
import pickle
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.environ["PROJECT_ROOT"])
|
||||
from autoencoder_models.variational_autoencoder import VAE, load_dataset, load_vae_from_path
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import scipy
|
||||
import scipy.stats
|
||||
import cv2
|
||||
|
||||
def binned_images(model_path, num_x_bins=6, plot=False):
|
||||
latent_dim = 2
|
||||
model = load_vae_from_path(model_path, latent_dim)
|
||||
image_dataset = load_dataset(digit=2)
|
||||
# Compute embedding
|
||||
num_images = 500
|
||||
embedding = []
|
||||
images = []
|
||||
for i in range(num_images):
|
||||
image, _ = image_dataset[i]
|
||||
mean, _, recon, _ = model.forward(image)
|
||||
mean = mean.detach().numpy()
|
||||
recon = recon.detach().numpy()
|
||||
recon = recon.reshape(32, 32)
|
||||
images.append(recon.squeeze())
|
||||
if latent_dim > 2:
|
||||
mean = mean[:2]
|
||||
embedding.append(mean)
|
||||
images = np.stack(images)
|
||||
tsne_points = np.array(embedding)
|
||||
tsne_points = (tsne_points - tsne_points.mean(axis=0))/(tsne_points.std(axis=0))
|
||||
# make vis
|
||||
num_points = np.shape(tsne_points)[0]
|
||||
x_min = np.amin(tsne_points.T[0])
|
||||
y_min = np.amin(tsne_points.T[1])
|
||||
y_max = np.amax(tsne_points.T[1])
|
||||
x_max = np.amax(tsne_points.T[0])
|
||||
# make the bins from the ranges
|
||||
# to keep it square the same width is used for x and y dim
|
||||
x_bins, step = np.linspace(x_min, x_max, num_x_bins, retstep=True)
|
||||
x_bins = x_bins.astype(float)
|
||||
num_y_bins = np.absolute(np.ceil((y_max - y_min)/step)).astype(int)
|
||||
y_bins = np.linspace(y_min, y_max, num_y_bins)
|
||||
# sort the tsne_points into a 2d histogram
|
||||
tsne_points = tsne_points.squeeze()
|
||||
hist_obj = scipy.stats.binned_statistic_dd(tsne_points, np.arange(num_points), statistic='count', bins=[x_bins, y_bins], expand_binnumbers=True)
|
||||
# sample one point from each bucket
|
||||
binnumbers = hist_obj.binnumber
|
||||
num_x_bins = np.amax(binnumbers[0]) + 1
|
||||
num_y_bins = np.amax(binnumbers[1]) + 1
|
||||
binnumbers = binnumbers.T
|
||||
# some places have no value in a region
|
||||
used_mask = np.zeros((num_y_bins, num_x_bins))
|
||||
image_bins = np.zeros((num_y_bins, num_x_bins, 3, np.shape(images)[2], np.shape(images)[2]))
|
||||
for i, bin_num in enumerate(list(binnumbers)):
|
||||
used_mask[bin_num[1], bin_num[0]] = 1
|
||||
image_bins[bin_num[1], bin_num[0]] = images[i]
|
||||
# plot a grid of the images
|
||||
fig, axs = plt.subplots(nrows=np.shape(y_bins)[0], ncols=np.shape(x_bins)[0], constrained_layout=False, dpi=50)
|
||||
images = []
|
||||
bin_indices = []
|
||||
for y in range(num_y_bins):
|
||||
for x in range(num_x_bins):
|
||||
if used_mask[y, x] > 0.0:
|
||||
image = np.uint8(image_bins[y][x].squeeze()*255)
|
||||
image = np.rollaxis(image, 0, 3)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
axs[num_y_bins - 1 - y][x].imshow(image)
|
||||
images.append(image)
|
||||
bin_indices.append((y, x))
|
||||
axs[y, x].axis('off')
|
||||
if plot:
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
else:
|
||||
return images, bin_indices
|
||||
|
||||
def generate_disentanglement(model_path="saved_models/model_dim2.pth"):
|
||||
"""Generates disentanglement visualization and serializes it"""
|
||||
# Disentanglement object
|
||||
disentanglement_object = {}
|
||||
# Make Disentanglement
|
||||
images, bin_indices = binned_images(model_path)
|
||||
disentanglement_object["images"] = images
|
||||
disentanglement_object["bin_indices"] = bin_indices
|
||||
# Serialize Images
|
||||
with open("disentanglement.pkl", "wb") as f:
|
||||
pickle.dump(disentanglement_object, f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
plot = False
|
||||
if plot:
|
||||
model_path = "saved_models/model_dim2.pth"
|
||||
#uniform_image_sample(model_path)
|
||||
binned_images(model_path)
|
||||
else:
|
||||
generate_disentanglement()
|
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from variational_autoencoder import VAE
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
# Load model
|
||||
vae = VAE(latent_dim=16)
|
||||
vae.load_state_dict(torch.load("saved_models/model.pth"))
|
||||
# Transforms images to a PyTorch Tensor
|
||||
tensor_transform = transforms.ToTensor()
|
||||
# Download the MNIST Dataset
|
||||
dataset = datasets.MNIST(root = "./data",
|
||||
train = True,
|
||||
download = True,
|
||||
transform = tensor_transform)
|
||||
# Generate reconstructions
|
||||
num_recons = 10
|
||||
fig, axs = plt.subplots(num_recons, 2, figsize=(2, num_recons))
|
||||
image_pairs = []
|
||||
for i in range(num_recons):
|
||||
base_image, _ = dataset[i]
|
||||
base_image = base_image.reshape(-1, 28*28)
|
||||
_, _, recon_image, _ = vae.forward(base_image)
|
||||
base_image = base_image.detach().numpy()
|
||||
base_image = np.reshape(base_image, (28, 28)) * 255
|
||||
recon_image = recon_image.detach().numpy()
|
||||
recon_image = np.reshape(recon_image, (28, 28)) * 255
|
||||
# Add to plot
|
||||
axs[i][0].imshow(base_image)
|
||||
axs[i][1].imshow(recon_image)
|
||||
# image pairs
|
||||
image_pairs.append((base_image, recon_image))
|
||||
|
||||
with open("image_pairs.pkl", "wb") as f:
|
||||
pickle.dump(image_pairs, f)
|
||||
|
||||
plt.show()
|
@ -0,0 +1,49 @@
|
||||
import torch
|
||||
from variational_autoencoder import VAE, load_dataset
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
# Load model
|
||||
vae = VAE(latent_dim=16)
|
||||
vae.load_state_dict(torch.load("saved_models/model.pth"))
|
||||
dataset = load_dataset()
|
||||
# Generate reconstructions
|
||||
num_images = 50
|
||||
image_pairs = []
|
||||
save_object = {"interpolation_path":[], "interpolation_images":[]}
|
||||
|
||||
# Make interpolation path
|
||||
image_a, image_b = dataset[0][0], dataset[1][0]
|
||||
image_a = image_a.view(32*32)
|
||||
image_b = image_b.view(32*32)
|
||||
z_a, _, _, _ = vae.forward(image_a)
|
||||
z_a = z_a.detach().cpu().numpy()
|
||||
z_b, _, _, _ = vae.forward(image_b)
|
||||
z_b = z_b.detach().cpu().numpy()
|
||||
interpolation_path = np.linspace(z_a, z_b, num=num_images)
|
||||
# interpolation_path[:, 4] = np.linspace(-3, 3, num=num_images)
|
||||
save_object["interpolation_path"] = interpolation_path
|
||||
|
||||
for i in range(num_images):
|
||||
# Generate
|
||||
z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
|
||||
gen_image = vae.decode(z).detach().numpy()
|
||||
gen_image = np.reshape(gen_image, (32, 32)) * 255
|
||||
save_object["interpolation_images"].append(gen_image)
|
||||
|
||||
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))
|
||||
image_pairs = []
|
||||
for i in range(num_images):
|
||||
recon_image = save_object["interpolation_images"][i]
|
||||
# Add to plot
|
||||
axs[i].imshow(recon_image)
|
||||
|
||||
# Perform intrpolations
|
||||
with open("interpolations.pkl", "wb") as f:
|
||||
pickle.dump(save_object, f)
|
||||
|
||||
plt.show()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,219 @@
|
||||
import torch
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
|
||||
"""
|
||||
These are utility functions that help to calculate the input and output
|
||||
sizes of convolutional neural networks
|
||||
"""
|
||||
|
||||
def num2tuple(num):
|
||||
return num if isinstance(num, tuple) else (num, num)
|
||||
|
||||
def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
|
||||
h_w, kernel_size, stride, pad, dilation = num2tuple(h_w), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation)
|
||||
pad = num2tuple(pad[0]), num2tuple(pad[1])
|
||||
|
||||
h = math.floor((h_w[0] + sum(pad[0]) - dilation[0]*(kernel_size[0]-1) - 1) / stride[0] + 1)
|
||||
w = math.floor((h_w[1] + sum(pad[1]) - dilation[1]*(kernel_size[1]-1) - 1) / stride[1] + 1)
|
||||
|
||||
return h, w
|
||||
|
||||
def convtransp2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0):
|
||||
h_w, kernel_size, stride, pad, dilation, out_pad = num2tuple(h_w), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), num2tuple(out_pad)
|
||||
pad = num2tuple(pad[0]), num2tuple(pad[1])
|
||||
|
||||
h = (h_w[0] - 1)*stride[0] - sum(pad[0]) + dialation[0]*(kernel_size[0]-1) + out_pad[0] + 1
|
||||
w = (h_w[1] - 1)*stride[1] - sum(pad[1]) + dialation[1]*(kernel_size[1]-1) + out_pad[1] + 1
|
||||
|
||||
return h, w
|
||||
|
||||
def conv2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1):
|
||||
h_w_in, h_w_out, kernel_size, stride, dilation = num2tuple(h_w_in), num2tuple(h_w_out), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation)
|
||||
|
||||
p_h = ((h_w_out[0] - 1)*stride[0] - h_w_in[0] + dilation[0]*(kernel_size[0]-1) + 1)
|
||||
p_w = ((h_w_out[1] - 1)*stride[1] - h_w_in[1] + dilation[1]*(kernel_size[1]-1) + 1)
|
||||
|
||||
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
|
||||
|
||||
def convtransp2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0):
|
||||
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = num2tuple(h_w_in), num2tuple(h_w_out), \
|
||||
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), num2tuple(out_pad)
|
||||
|
||||
p_h = -(h_w_out[0] - 1 - out_pad[0] - dilation[0]*(kernel_size[0]-1) - (h_w_in[0] - 1)*stride[0]) / 2
|
||||
p_w = -(h_w_out[1] - 1 - out_pad[1] - dilation[1]*(kernel_size[1]-1) - (h_w_in[1] - 1)*stride[1]) / 2
|
||||
|
||||
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
|
||||
|
||||
def load_dataset(train=True, digit=None):
|
||||
# Transforms images to a PyTorch Tensor
|
||||
tensor_transform = transforms.Compose([
|
||||
transforms.Pad(2),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
# Download the MNIST Dataset
|
||||
dataset = datasets.MNIST(root = "./data",
|
||||
train = train,
|
||||
download = True,
|
||||
transform = tensor_transform)
|
||||
# Load specific image
|
||||
if not digit is None:
|
||||
idx = dataset.train_labels == digit
|
||||
dataset.targets = dataset.targets[idx]
|
||||
dataset.data = dataset.data[idx]
|
||||
|
||||
return dataset
|
||||
|
||||
def load_vae_from_path(path, latent_dim):
|
||||
model = VAE(latent_dim)
|
||||
model.load_state_dict(torch.load(path))
|
||||
|
||||
return model
|
||||
|
||||
# Creating a PyTorch class
|
||||
# 28*28 ==> 9 ==> 28*28
|
||||
class VAE(torch.nn.Module):
|
||||
def __init__(self, latent_dim=5, layer_count=4, channels=1):
|
||||
super().__init__()
|
||||
self.latent_dim = latent_dim
|
||||
self.in_shape = 32
|
||||
self.layer_count = layer_count
|
||||
self.channels = channels
|
||||
self.d = 128
|
||||
mul = 1
|
||||
inputs = self.channels
|
||||
out_sizes = [(self.in_shape, self.in_shape)]
|
||||
for i in range(self.layer_count):
|
||||
setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, self.d * mul, 4, 2, 1))
|
||||
setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
|
||||
h_w = (out_sizes[-1][-1], out_sizes[-1][-1])
|
||||
out_sizes.append(conv2d_output_shape(h_w, kernel_size=4, stride=2, pad=1, dilation=1))
|
||||
inputs = self.d * mul
|
||||
mul *= 2
|
||||
|
||||
self.d_max = inputs
|
||||
self.last_size = out_sizes[-1][-1]
|
||||
self.num_linear = self.last_size ** 2 * self.d_max
|
||||
# Encoder linear layers
|
||||
self.encoder_mean_linear = nn.Linear(self.num_linear, self.latent_dim)
|
||||
self.encoder_logvar_linear = nn.Linear(self.num_linear, self.latent_dim)
|
||||
# Decoder linear layer
|
||||
self.decoder_linear = nn.Linear(self.latent_dim, self.num_linear)
|
||||
|
||||
mul = inputs // self.d // 2
|
||||
|
||||
for i in range(1, self.layer_count):
|
||||
setattr(self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, self.d * mul, 4, 2, 1))
|
||||
setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(self.d * mul))
|
||||
inputs = self.d * mul
|
||||
mul //= 2
|
||||
|
||||
setattr(self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, self.channels, 4, 2, 1))
|
||||
|
||||
def encode(self, x):
|
||||
if len(x.shape) < 3:
|
||||
x = x.unsqueeze(0)
|
||||
if len(x.shape) < 4:
|
||||
x = x.unsqueeze(1)
|
||||
batch_size = x.shape[0]
|
||||
|
||||
for i in range(self.layer_count):
|
||||
x = F.relu(getattr(self, "conv%d_bn" % (i + 1))(getattr(self, "conv%d" % (i + 1))(x)))
|
||||
|
||||
x = x.view(batch_size, -1)
|
||||
|
||||
mean = self.encoder_mean_linear(x)
|
||||
logvar = self.encoder_logvar_linear(x)
|
||||
|
||||
return mean, logvar
|
||||
|
||||
def decode(self, x):
|
||||
x = x.view(x.shape[0], self.latent_dim)
|
||||
x = self.decoder_linear(x)
|
||||
x = x.view(x.shape[0], self.d_max, self.last_size, self.last_size)
|
||||
#x = self.deconv1_bn(x)
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
|
||||
for i in range(1, self.layer_count):
|
||||
x = F.leaky_relu(getattr(self, "deconv%d_bn" % (i + 1))(getattr(self, "deconv%d" % (i + 1))(x)), 0.2)
|
||||
x = getattr(self, "deconv%d" % (self.layer_count + 1))(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
mean, logvar = self.encode(x)
|
||||
eps = torch.randn(batch_size, self.latent_dim)
|
||||
z = mean + torch.exp(logvar / 2) * eps
|
||||
reconstructed = self.decode(z)
|
||||
return mean, logvar, reconstructed, x
|
||||
|
||||
def train_model(latent_dim=16, plot=True, digit=1, epochs=200):
|
||||
dataset = load_dataset(train=True, digit=digit)
|
||||
# DataLoader is used to load the dataset
|
||||
# for training
|
||||
loader = torch.utils.data.DataLoader(dataset = dataset,
|
||||
batch_size = 32,
|
||||
shuffle = True)
|
||||
# Model Initialization
|
||||
model = VAE(latent_dim=latent_dim)
|
||||
# Validation using MSE Loss function
|
||||
def loss_function(mean, log_var, reconstructed, original, kl_beta=0.0001):
|
||||
kl = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim = 1), dim = 0)
|
||||
recon = torch.nn.functional.mse_loss(reconstructed, original)
|
||||
# print(f"KL Error {kl}, Recon Error {recon}")
|
||||
return kl_beta * kl + recon
|
||||
|
||||
# Using an Adam Optimizer with lr = 0.1
|
||||
optimizer = torch.optim.Adam(model.parameters(),
|
||||
lr = 1e-4,
|
||||
weight_decay = 0e-8)
|
||||
|
||||
outputs = []
|
||||
losses = []
|
||||
for epoch in tqdm(range(epochs)):
|
||||
for (image, _) in loader:
|
||||
# Output of Autoencoder
|
||||
mean, log_var, reconstructed, image = model(image)
|
||||
# Calculating the loss function
|
||||
loss = loss_function(mean, log_var, reconstructed, image)
|
||||
# The gradients are set to zero,
|
||||
# the the gradient is computed and stored.
|
||||
# .step() performs parameter update
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# Storing the losses in a list for plotting
|
||||
if torch.isnan(loss):
|
||||
raise Exception()
|
||||
losses.append(loss.detach().cpu())
|
||||
outputs.append((epochs, image, reconstructed))
|
||||
|
||||
torch.save(model.state_dict(),
|
||||
os.path.join(
|
||||
os.environ["PROJECT_ROOT"],
|
||||
f"examples/variational_autoencoder/autoencoder_model/saved_models/model_dim{latent_dim}.pth"
|
||||
)
|
||||
)
|
||||
|
||||
if plot:
|
||||
# Defining the Plot Style
|
||||
plt.style.use('fivethirtyeight')
|
||||
plt.xlabel('Iterations')
|
||||
plt.ylabel('Loss')
|
||||
|
||||
# Plotting the last 100 values
|
||||
plt.plot(losses)
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_model(latent_dim=2, digit=2, epochs=40)
|
Reference in New Issue
Block a user