normalize sampling

This commit is contained in:
Varuna Jayasiri
2020-10-28 11:24:28 +05:30
parent 4081db3fb3
commit dff635566a

View File

@ -627,9 +627,24 @@ def train():
conf.run()
def sample():
def plot_image(img: torch.Tensor):
"""
Plots an image with matplotlib
"""
from matplotlib import pyplot as plt
# Get min and max values of the image for normalization
img_min, img_max = img.min(), img.max()
# Scale image values to be [0...1]
img = (img - img_min) / (img_max - img_min + 1e-5)
# We have to change the order of dimensions to HWC.
img = img.permute(1, 2, 0)
# Show image
plt.imshow(img)
plt.show()
def sample():
# Set the run uuid from the training run
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
# Create configs object
@ -676,9 +691,8 @@ def sample():
dataset = ImageDataset(images_path, transforms_, True, 'train')
# Get an images from dataset
x_image = dataset[0]['x']
# Display the image. We have to change the order of dimensions to HWC.
plt.imshow(x_image.permute(1, 2, 0))
plt.show()
# Display the image
plot_image(x_image)
# Evaluation mode
conf.generator_xy.eval()
@ -694,9 +708,8 @@ def sample():
mm_range = generated_y.min(), generated_y.max()
generated_y = (generated_y - mm_range[0]) / (mm_range[1] - mm_range[0] + 1e-5)
# Display the generated image. We have to change the order of dimensions to HWC.
plt.imshow(generated_y[0].cpu().permute(1, 2, 0))
plt.show()
# Display the generated image.
plot_image(generated_y[0].cpu())
if __name__ == '__main__':