mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-07-06 16:18:17 +08:00
41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
import torch
|
|
from variational_autoencoder import VAE
|
|
import matplotlib.pyplot as plt
|
|
from torchvision import datasets
|
|
from torchvision import transforms
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import pickle
|
|
|
|
# Load model
|
|
vae = VAE(latent_dim=16)
|
|
vae.load_state_dict(torch.load("saved_models/model.pth"))
|
|
# Transforms images to a PyTorch Tensor
|
|
tensor_transform = transforms.ToTensor()
|
|
# Download the MNIST Dataset
|
|
dataset = datasets.MNIST(
|
|
root="./data", train=True, download=True, transform=tensor_transform
|
|
)
|
|
# Generate reconstructions
|
|
num_recons = 10
|
|
fig, axs = plt.subplots(num_recons, 2, figsize=(2, num_recons))
|
|
image_pairs = []
|
|
for i in range(num_recons):
|
|
base_image, _ = dataset[i]
|
|
base_image = base_image.reshape(-1, 28 * 28)
|
|
_, _, recon_image, _ = vae.forward(base_image)
|
|
base_image = base_image.detach().numpy()
|
|
base_image = np.reshape(base_image, (28, 28)) * 255
|
|
recon_image = recon_image.detach().numpy()
|
|
recon_image = np.reshape(recon_image, (28, 28)) * 255
|
|
# Add to plot
|
|
axs[i][0].imshow(base_image)
|
|
axs[i][1].imshow(recon_image)
|
|
# image pairs
|
|
image_pairs.append((base_image, recon_image))
|
|
|
|
with open("image_pairs.pkl", "wb") as f:
|
|
pickle.dump(image_pairs, f)
|
|
|
|
plt.show()
|