Changed directory structure to accomodate examples as apposed to everything being a part of the core library. May need to rethink this in the future. Added some boilerplate for pip packaging to the .gitignore.

This commit is contained in:
Alec Helbling
2022-03-28 14:01:00 -04:00
committed by Alec Helbling
parent 4eb5296c9c
commit 3be5c54d26
40 changed files with 30 additions and 15 deletions

View File

@ -0,0 +1,41 @@
import torch
from variational_autoencoder import VAE
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import pickle
# Load model
vae = VAE(latent_dim=16)
vae.load_state_dict(torch.load("saved_models/model.pth"))
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()
# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data",
train = True,
download = True,
transform = tensor_transform)
# Generate reconstructions
num_recons = 10
fig, axs = plt.subplots(num_recons, 2, figsize=(2, num_recons))
image_pairs = []
for i in range(num_recons):
base_image, _ = dataset[i]
base_image = base_image.reshape(-1, 28*28)
_, _, recon_image, _ = vae.forward(base_image)
base_image = base_image.detach().numpy()
base_image = np.reshape(base_image, (28, 28)) * 255
recon_image = recon_image.detach().numpy()
recon_image = np.reshape(recon_image, (28, 28)) * 255
# Add to plot
axs[i][0].imshow(base_image)
axs[i][1].imshow(recon_image)
# image pairs
image_pairs.append((base_image, recon_image))
with open("image_pairs.pkl", "wb") as f:
pickle.dump(image_pairs, f)
plt.show()