normalize sampling

This commit is contained in:
Varuna Jayasiri
2020-10-28 11:24:28 +05:30
parent 4081db3fb3
commit dff635566a

View File

@ -627,9 +627,24 @@ def train():
conf.run() conf.run()
def sample(): def plot_image(img: torch.Tensor):
"""
Plots an image with matplotlib
"""
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
# Get min and max values of the image for normalization
img_min, img_max = img.min(), img.max()
# Scale image values to be [0...1]
img = (img - img_min) / (img_max - img_min + 1e-5)
# We have to change the order of dimensions to HWC.
img = img.permute(1, 2, 0)
# Show image
plt.imshow(img)
plt.show()
def sample():
# Set the run uuid from the training run # Set the run uuid from the training run
trained_run_uuid = 'f73c1164184711eb9190b74249275441' trained_run_uuid = 'f73c1164184711eb9190b74249275441'
# Create configs object # Create configs object
@ -676,9 +691,8 @@ def sample():
dataset = ImageDataset(images_path, transforms_, True, 'train') dataset = ImageDataset(images_path, transforms_, True, 'train')
# Get an images from dataset # Get an images from dataset
x_image = dataset[0]['x'] x_image = dataset[0]['x']
# Display the image. We have to change the order of dimensions to HWC. # Display the image
plt.imshow(x_image.permute(1, 2, 0)) plot_image(x_image)
plt.show()
# Evaluation mode # Evaluation mode
conf.generator_xy.eval() conf.generator_xy.eval()
@ -694,9 +708,8 @@ def sample():
mm_range = generated_y.min(), generated_y.max() mm_range = generated_y.min(), generated_y.max()
generated_y = (generated_y - mm_range[0]) / (mm_range[1] - mm_range[0] + 1e-5) generated_y = (generated_y - mm_range[0]) / (mm_range[1] - mm_range[0] + 1e-5)
# Display the generated image. We have to change the order of dimensions to HWC. # Display the generated image.
plt.imshow(generated_y[0].cpu().permute(1, 2, 0)) plot_image(generated_y[0].cpu())
plt.show()
if __name__ == '__main__': if __name__ == '__main__':