mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 09:38:56 +08:00
✨ normalize sampling
This commit is contained in:
@ -627,9 +627,24 @@ def train():
|
||||
conf.run()
|
||||
|
||||
|
||||
def sample():
|
||||
def plot_image(img: torch.Tensor):
|
||||
"""
|
||||
Plots an image with matplotlib
|
||||
"""
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
# Get min and max values of the image for normalization
|
||||
img_min, img_max = img.min(), img.max()
|
||||
# Scale image values to be [0...1]
|
||||
img = (img - img_min) / (img_max - img_min + 1e-5)
|
||||
# We have to change the order of dimensions to HWC.
|
||||
img = img.permute(1, 2, 0)
|
||||
# Show image
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
|
||||
|
||||
def sample():
|
||||
# Set the run uuid from the training run
|
||||
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
|
||||
# Create configs object
|
||||
@ -676,9 +691,8 @@ def sample():
|
||||
dataset = ImageDataset(images_path, transforms_, True, 'train')
|
||||
# Get an images from dataset
|
||||
x_image = dataset[0]['x']
|
||||
# Display the image. We have to change the order of dimensions to HWC.
|
||||
plt.imshow(x_image.permute(1, 2, 0))
|
||||
plt.show()
|
||||
# Display the image
|
||||
plot_image(x_image)
|
||||
|
||||
# Evaluation mode
|
||||
conf.generator_xy.eval()
|
||||
@ -694,9 +708,8 @@ def sample():
|
||||
mm_range = generated_y.min(), generated_y.max()
|
||||
generated_y = (generated_y - mm_range[0]) / (mm_range[1] - mm_range[0] + 1e-5)
|
||||
|
||||
# Display the generated image. We have to change the order of dimensions to HWC.
|
||||
plt.imshow(generated_y[0].cpu().permute(1, 2, 0))
|
||||
plt.show()
|
||||
# Display the generated image.
|
||||
plot_image(generated_y[0].cpu())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user