Files
2023-01-01 23:24:59 -05:00

41 lines
1.2 KiB
Python

import torch
from variational_autoencoder import VAE
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import pickle
# Load model
vae = VAE(latent_dim=16)
vae.load_state_dict(torch.load("saved_models/model.pth"))
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()
# Download the MNIST Dataset
dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=tensor_transform
)
# Generate reconstructions
num_recons = 10
fig, axs = plt.subplots(num_recons, 2, figsize=(2, num_recons))
image_pairs = []
for i in range(num_recons):
base_image, _ = dataset[i]
base_image = base_image.reshape(-1, 28 * 28)
_, _, recon_image, _ = vae.forward(base_image)
base_image = base_image.detach().numpy()
base_image = np.reshape(base_image, (28, 28)) * 255
recon_image = recon_image.detach().numpy()
recon_image = np.reshape(recon_image, (28, 28)) * 255
# Add to plot
axs[i][0].imshow(base_image)
axs[i][1].imshow(recon_image)
# image pairs
image_pairs.append((base_image, recon_image))
with open("image_pairs.pkl", "wb") as f:
pickle.dump(image_pairs, f)
plt.show()