diff --git a/labml_nn/gan/cycle_gan.py b/labml_nn/gan/cycle_gan.py index b2832ef2..437aa6e6 100644 --- a/labml_nn/gan/cycle_gan.py +++ b/labml_nn/gan/cycle_gan.py @@ -31,6 +31,7 @@ from torchvision.utils import save_image from labml import lab, tracker, experiment, monit, configs from labml.configs import BaseConfigs +from labml.utils.pytorch import get_modules from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module @@ -324,6 +325,11 @@ class Configs(BaseConfigs): # Arrange images along y-axis image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1) + # Create folder to store sampled images + images_path = Path(f'images/{self.dataset_name}') + if not images_path.exists(): + images_path.mkdir(parents=True) + # Save grid save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False) @@ -453,6 +459,9 @@ class Configs(BaseConfigs): # Save images at intervals batches_done = epoch * len(self.dataloader) + i if batches_done % self.sample_interval == 0: + # Save models when sampling images + experiment.save_checkpoint() + # Sample images self.sample_images(batches_done) # Update learning rates @@ -606,13 +615,88 @@ def setup_dataloader(self: Configs): ) -def main(): +def train(): conf = Configs() experiment.create(name='cycle_gan') - experiment.configs(conf, 'run') + experiment.configs(conf, { + 'dataset_name': 'summer2winter_yosemite' + }, 'run') + # Register models for saving and loading. + # `get_modules` gives a dictionary of `nn.Modules` in `conf`. + # You can also specify a custom dictionary of models. + experiment.add_pytorch_models(get_modules(conf)) with experiment.start(): conf.run() +def sample(): + from matplotlib import pyplot as plt + + # Set the run uuid from the training run + trained_run_uuid = 'f73c1164184711eb9190b74249275441' + # Create configs object + conf = Configs() + # Create experiment + experiment.create(name='cycle_gan_inference') + # Load hyper parameters set for training + conf_dict = experiment.load_configs(trained_run_uuid) + # Calculate configurations. We specify the generators `'generator_xy', 'generator_yx'` + # so that it only loads those and their dependencies. + # Configs like `device`, `img_height` and `img_width` will be calculated since these are required by + # `generator_xy` and `generator_yx`. + # + # If you want other parameters like `dataset_name` you should specify them here. + # If you specify nothing all the configurations will be calculated including data loaders. + # Calculation of configurations and their dependencies will happen when you call `experiment.start` + experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx') + # Register models for saving and loading. + # `get_modules` gives a dictionary of `nn.Modules` in `conf`. + # You can also specify a custom dictionary of models. + experiment.add_pytorch_models(get_modules(conf)) + # Specify which run to load from. + # Loading will actually happen when you call `experiment.start` + experiment.load(trained_run_uuid) + + # Start the experiment + with experiment.start(): + # Load your own data, here we try test set. + # I was trying with yosemite photos, they look awesome. + # You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated + # in the call to `experiment.configs` + images_path = lab.get_data_path() / 'cycle_gan' / 'summer2winter_yosemite' + + # Image transformations + transforms_ = [ + transforms.Resize(int(conf.img_height * 1.12), Image.BICUBIC), + transforms.RandomCrop((conf.img_height, conf.img_width)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + + # Load dataset + dataset = ImageDataset(images_path, transforms_, True, 'train') + # Get an images from dataset + x_image = dataset[0]['x'] + # Display the image. We have to change the order of dimensions to HWC. + plt.imshow(x_image.permute(1, 2, 0)) + plt.show() + + # Evaluation mode + conf.generator_xy.eval() + conf.generator_yx.eval() + + # We dont need gradients + with torch.no_grad(): + # Add batch dimension and move to the device we use + data = x_image.unsqueeze(0).to(conf.device) + generated_y = conf.generator_xy(data) + + # Display the generated image. We have to change the order of dimensions to HWC. + plt.imshow(generated_y[0].cpu().permute(1, 2, 0)) + plt.show() + + if __name__ == '__main__': - main() + train() + # sample()