Used Black to reformat the code in the repository.

This commit is contained in:
Alec Helbling
2023-01-01 23:24:59 -05:00
parent 334662e8c8
commit 3d6e8072e1
71 changed files with 1701 additions and 1135 deletions

View File

@ -13,17 +13,16 @@ vae.load_state_dict(torch.load("saved_models/model.pth"))
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()
# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data",
train = True,
download = True,
transform = tensor_transform)
dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=tensor_transform
)
# Generate reconstructions
num_recons = 10
fig, axs = plt.subplots(num_recons, 2, figsize=(2, num_recons))
image_pairs = []
for i in range(num_recons):
base_image, _ = dataset[i]
base_image = base_image.reshape(-1, 28*28)
base_image = base_image.reshape(-1, 28 * 28)
_, _, recon_image, _ = vae.forward(base_image)
base_image = base_image.detach().numpy()
base_image = np.reshape(base_image, (28, 28)) * 255