Changed directory structure to accomodate examples as apposed to everything being a part of the core library. May need to rethink this in the future. Added some boilerplate for pip packaging to the .gitignore.

This commit is contained in:
Alec Helbling
2022-03-28 14:01:00 -04:00
committed by Alec Helbling
parent 4eb5296c9c
commit 3be5c54d26
40 changed files with 30 additions and 15 deletions

View File

@ -0,0 +1,49 @@
import torch
from variational_autoencoder import VAE, load_dataset
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import pickle
# Load model
vae = VAE(latent_dim=16)
vae.load_state_dict(torch.load("saved_models/model.pth"))
dataset = load_dataset()
# Generate reconstructions
num_images = 50
image_pairs = []
save_object = {"interpolation_path":[], "interpolation_images":[]}
# Make interpolation path
image_a, image_b = dataset[0][0], dataset[1][0]
image_a = image_a.view(32*32)
image_b = image_b.view(32*32)
z_a, _, _, _ = vae.forward(image_a)
z_a = z_a.detach().cpu().numpy()
z_b, _, _, _ = vae.forward(image_b)
z_b = z_b.detach().cpu().numpy()
interpolation_path = np.linspace(z_a, z_b, num=num_images)
# interpolation_path[:, 4] = np.linspace(-3, 3, num=num_images)
save_object["interpolation_path"] = interpolation_path
for i in range(num_images):
# Generate
z = torch.Tensor(interpolation_path[i]).unsqueeze(0)
gen_image = vae.decode(z).detach().numpy()
gen_image = np.reshape(gen_image, (32, 32)) * 255
save_object["interpolation_images"].append(gen_image)
fig, axs = plt.subplots(num_images, 1, figsize=(1, num_images))
image_pairs = []
for i in range(num_images):
recon_image = save_object["interpolation_images"][i]
# Add to plot
axs[i].imshow(recon_image)
# Perform intrpolations
with open("interpolations.pkl", "wb") as f:
pickle.dump(save_object, f)
plt.show()