cycle gan save/load

This commit is contained in:
Varuna Jayasiri
2020-10-27 17:32:52 +05:30
parent 18c96c5692
commit 98b659f439

View File

@ -31,6 +31,7 @@ from torchvision.utils import save_image
from labml import lab, tracker, experiment, monit, configs
from labml.configs import BaseConfigs
from labml.utils.pytorch import get_modules
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
@ -324,6 +325,11 @@ class Configs(BaseConfigs):
# Arrange images along y-axis
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
# Create folder to store sampled images
images_path = Path(f'images/{self.dataset_name}')
if not images_path.exists():
images_path.mkdir(parents=True)
# Save grid
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
@ -453,6 +459,9 @@ class Configs(BaseConfigs):
# Save images at intervals
batches_done = epoch * len(self.dataloader) + i
if batches_done % self.sample_interval == 0:
# Save models when sampling images
experiment.save_checkpoint()
# Sample images
self.sample_images(batches_done)
# Update learning rates
@ -606,13 +615,88 @@ def setup_dataloader(self: Configs):
)
def main():
def train():
conf = Configs()
experiment.create(name='cycle_gan')
experiment.configs(conf, 'run')
experiment.configs(conf, {
'dataset_name': 'summer2winter_yosemite'
}, 'run')
# Register models for saving and loading.
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
# You can also specify a custom dictionary of models.
experiment.add_pytorch_models(get_modules(conf))
with experiment.start():
conf.run()
def sample():
from matplotlib import pyplot as plt
# Set the run uuid from the training run
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
# Create configs object
conf = Configs()
# Create experiment
experiment.create(name='cycle_gan_inference')
# Load hyper parameters set for training
conf_dict = experiment.load_configs(trained_run_uuid)
# Calculate configurations. We specify the generators `'generator_xy', 'generator_yx'`
# so that it only loads those and their dependencies.
# Configs like `device`, `img_height` and `img_width` will be calculated since these are required by
# `generator_xy` and `generator_yx`.
#
# If you want other parameters like `dataset_name` you should specify them here.
# If you specify nothing all the configurations will be calculated including data loaders.
# Calculation of configurations and their dependencies will happen when you call `experiment.start`
experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx')
# Register models for saving and loading.
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
# You can also specify a custom dictionary of models.
experiment.add_pytorch_models(get_modules(conf))
# Specify which run to load from.
# Loading will actually happen when you call `experiment.start`
experiment.load(trained_run_uuid)
# Start the experiment
with experiment.start():
# Load your own data, here we try test set.
# I was trying with yosemite photos, they look awesome.
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
# in the call to `experiment.configs`
images_path = lab.get_data_path() / 'cycle_gan' / 'summer2winter_yosemite'
# Image transformations
transforms_ = [
transforms.Resize(int(conf.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((conf.img_height, conf.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Load dataset
dataset = ImageDataset(images_path, transforms_, True, 'train')
# Get an images from dataset
x_image = dataset[0]['x']
# Display the image. We have to change the order of dimensions to HWC.
plt.imshow(x_image.permute(1, 2, 0))
plt.show()
# Evaluation mode
conf.generator_xy.eval()
conf.generator_yx.eval()
# We dont need gradients
with torch.no_grad():
# Add batch dimension and move to the device we use
data = x_image.unsqueeze(0).to(conf.device)
generated_y = conf.generator_xy(data)
# Display the generated image. We have to change the order of dimensions to HWC.
plt.imshow(generated_y[0].cpu().permute(1, 2, 0))
plt.show()
if __name__ == '__main__':
main()
train()
# sample()