diff --git a/labml_nn/gan/cycle_gan.py b/labml_nn/gan/cycle_gan.py index 437aa6e6..51095879 100644 --- a/labml_nn/gan/cycle_gan.py +++ b/labml_nn/gan/cycle_gan.py @@ -692,6 +692,10 @@ def sample(): data = x_image.unsqueeze(0).to(conf.device) generated_y = conf.generator_xy(data) + # Normalize the image + mm_range = generated_y.min(), generated_y.max() + generated_y = (generated_y - mm_range[0]) / (mm_range[1] - mm_range[0] + 1e-5) + # Display the generated image. We have to change the order of dimensions to HWC. plt.imshow(generated_y[0].cpu().permute(1, 2, 0)) plt.show()