mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-07 16:50:09 +08:00
Changed directory structure to accomodate examples as apposed to everything being a part of the core library. May need to rethink this in the future. Added some boilerplate for pip packaging to the .gitignore.
This commit is contained in:

committed by
Alec Helbling

parent
4eb5296c9c
commit
3be5c54d26
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from variational_autoencoder import VAE
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
# Load model
|
||||
vae = VAE(latent_dim=16)
|
||||
vae.load_state_dict(torch.load("saved_models/model.pth"))
|
||||
# Transforms images to a PyTorch Tensor
|
||||
tensor_transform = transforms.ToTensor()
|
||||
# Download the MNIST Dataset
|
||||
dataset = datasets.MNIST(root = "./data",
|
||||
train = True,
|
||||
download = True,
|
||||
transform = tensor_transform)
|
||||
# Generate reconstructions
|
||||
num_recons = 10
|
||||
fig, axs = plt.subplots(num_recons, 2, figsize=(2, num_recons))
|
||||
image_pairs = []
|
||||
for i in range(num_recons):
|
||||
base_image, _ = dataset[i]
|
||||
base_image = base_image.reshape(-1, 28*28)
|
||||
_, _, recon_image, _ = vae.forward(base_image)
|
||||
base_image = base_image.detach().numpy()
|
||||
base_image = np.reshape(base_image, (28, 28)) * 255
|
||||
recon_image = recon_image.detach().numpy()
|
||||
recon_image = np.reshape(recon_image, (28, 28)) * 255
|
||||
# Add to plot
|
||||
axs[i][0].imshow(base_image)
|
||||
axs[i][1].imshow(recon_image)
|
||||
# image pairs
|
||||
image_pairs.append((base_image, recon_image))
|
||||
|
||||
with open("image_pairs.pkl", "wb") as f:
|
||||
pickle.dump(image_pairs, f)
|
||||
|
||||
plt.show()
|
Reference in New Issue
Block a user