From dff635566af27ad4fe347d77a903b307fcadad2c Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Wed, 28 Oct 2020 11:24:28 +0530 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20normalize=20sampling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- labml_nn/gan/cycle_gan.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/labml_nn/gan/cycle_gan.py b/labml_nn/gan/cycle_gan.py index 7f190e4c..dd80c309 100644 --- a/labml_nn/gan/cycle_gan.py +++ b/labml_nn/gan/cycle_gan.py @@ -627,9 +627,24 @@ def train(): conf.run() -def sample(): +def plot_image(img: torch.Tensor): + """ + Plots an image with matplotlib + """ from matplotlib import pyplot as plt + # Get min and max values of the image for normalization + img_min, img_max = img.min(), img.max() + # Scale image values to be [0...1] + img = (img - img_min) / (img_max - img_min + 1e-5) + # We have to change the order of dimensions to HWC. + img = img.permute(1, 2, 0) + # Show image + plt.imshow(img) + plt.show() + + +def sample(): # Set the run uuid from the training run trained_run_uuid = 'f73c1164184711eb9190b74249275441' # Create configs object @@ -676,9 +691,8 @@ def sample(): dataset = ImageDataset(images_path, transforms_, True, 'train') # Get an images from dataset x_image = dataset[0]['x'] - # Display the image. We have to change the order of dimensions to HWC. - plt.imshow(x_image.permute(1, 2, 0)) - plt.show() + # Display the image + plot_image(x_image) # Evaluation mode conf.generator_xy.eval() @@ -694,9 +708,8 @@ def sample(): mm_range = generated_y.min(), generated_y.max() generated_y = (generated_y - mm_range[0]) / (mm_range[1] - mm_range[0] + 1e-5) - # Display the generated image. We have to change the order of dimensions to HWC. - plt.imshow(generated_y[0].cpu().permute(1, 2, 0)) - plt.show() + # Display the generated image. + plot_image(generated_y[0].cpu()) if __name__ == '__main__':