mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 10:18:50 +08:00
✨ normalize sampling
This commit is contained in:
@ -627,9 +627,24 @@ def train():
|
|||||||
conf.run()
|
conf.run()
|
||||||
|
|
||||||
|
|
||||||
def sample():
|
def plot_image(img: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Plots an image with matplotlib
|
||||||
|
"""
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
# Get min and max values of the image for normalization
|
||||||
|
img_min, img_max = img.min(), img.max()
|
||||||
|
# Scale image values to be [0...1]
|
||||||
|
img = (img - img_min) / (img_max - img_min + 1e-5)
|
||||||
|
# We have to change the order of dimensions to HWC.
|
||||||
|
img = img.permute(1, 2, 0)
|
||||||
|
# Show image
|
||||||
|
plt.imshow(img)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def sample():
|
||||||
# Set the run uuid from the training run
|
# Set the run uuid from the training run
|
||||||
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
|
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
|
||||||
# Create configs object
|
# Create configs object
|
||||||
@ -676,9 +691,8 @@ def sample():
|
|||||||
dataset = ImageDataset(images_path, transforms_, True, 'train')
|
dataset = ImageDataset(images_path, transforms_, True, 'train')
|
||||||
# Get an images from dataset
|
# Get an images from dataset
|
||||||
x_image = dataset[0]['x']
|
x_image = dataset[0]['x']
|
||||||
# Display the image. We have to change the order of dimensions to HWC.
|
# Display the image
|
||||||
plt.imshow(x_image.permute(1, 2, 0))
|
plot_image(x_image)
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# Evaluation mode
|
# Evaluation mode
|
||||||
conf.generator_xy.eval()
|
conf.generator_xy.eval()
|
||||||
@ -694,9 +708,8 @@ def sample():
|
|||||||
mm_range = generated_y.min(), generated_y.max()
|
mm_range = generated_y.min(), generated_y.max()
|
||||||
generated_y = (generated_y - mm_range[0]) / (mm_range[1] - mm_range[0] + 1e-5)
|
generated_y = (generated_y - mm_range[0]) / (mm_range[1] - mm_range[0] + 1e-5)
|
||||||
|
|
||||||
# Display the generated image. We have to change the order of dimensions to HWC.
|
# Display the generated image.
|
||||||
plt.imshow(generated_y[0].cpu().permute(1, 2, 0))
|
plot_image(generated_y[0].cpu())
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user