mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 06:16:05 +08:00 
			
		
		
		
	✨ cycle gan save/load
This commit is contained in:
		@ -31,6 +31,7 @@ from torchvision.utils import save_image
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from labml import lab, tracker, experiment, monit, configs
 | 
					from labml import lab, tracker, experiment, monit, configs
 | 
				
			||||||
from labml.configs import BaseConfigs
 | 
					from labml.configs import BaseConfigs
 | 
				
			||||||
 | 
					from labml.utils.pytorch import get_modules
 | 
				
			||||||
from labml_helpers.device import DeviceConfigs
 | 
					from labml_helpers.device import DeviceConfigs
 | 
				
			||||||
from labml_helpers.module import Module
 | 
					from labml_helpers.module import Module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -324,6 +325,11 @@ class Configs(BaseConfigs):
 | 
				
			|||||||
            # Arrange images along y-axis
 | 
					            # Arrange images along y-axis
 | 
				
			||||||
            image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
 | 
					            image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Create folder to store sampled images
 | 
				
			||||||
 | 
					        images_path = Path(f'images/{self.dataset_name}')
 | 
				
			||||||
 | 
					        if not images_path.exists():
 | 
				
			||||||
 | 
					            images_path.mkdir(parents=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Save grid
 | 
					        # Save grid
 | 
				
			||||||
        save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
 | 
					        save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -453,6 +459,9 @@ class Configs(BaseConfigs):
 | 
				
			|||||||
                # Save images at intervals
 | 
					                # Save images at intervals
 | 
				
			||||||
                batches_done = epoch * len(self.dataloader) + i
 | 
					                batches_done = epoch * len(self.dataloader) + i
 | 
				
			||||||
                if batches_done % self.sample_interval == 0:
 | 
					                if batches_done % self.sample_interval == 0:
 | 
				
			||||||
 | 
					                    # Save models when sampling images
 | 
				
			||||||
 | 
					                    experiment.save_checkpoint()
 | 
				
			||||||
 | 
					                    # Sample images
 | 
				
			||||||
                    self.sample_images(batches_done)
 | 
					                    self.sample_images(batches_done)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Update learning rates
 | 
					            # Update learning rates
 | 
				
			||||||
@ -606,13 +615,88 @@ def setup_dataloader(self: Configs):
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def train():
 | 
				
			||||||
    conf = Configs()
 | 
					    conf = Configs()
 | 
				
			||||||
    experiment.create(name='cycle_gan')
 | 
					    experiment.create(name='cycle_gan')
 | 
				
			||||||
    experiment.configs(conf, 'run')
 | 
					    experiment.configs(conf, {
 | 
				
			||||||
 | 
					        'dataset_name': 'summer2winter_yosemite'
 | 
				
			||||||
 | 
					    }, 'run')
 | 
				
			||||||
 | 
					    # Register models for saving and loading.
 | 
				
			||||||
 | 
					    # `get_modules` gives a dictionary of `nn.Modules` in `conf`.
 | 
				
			||||||
 | 
					    # You can also specify a custom dictionary of models.
 | 
				
			||||||
 | 
					    experiment.add_pytorch_models(get_modules(conf))
 | 
				
			||||||
    with experiment.start():
 | 
					    with experiment.start():
 | 
				
			||||||
        conf.run()
 | 
					        conf.run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def sample():
 | 
				
			||||||
 | 
					    from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Set the run uuid from the training run
 | 
				
			||||||
 | 
					    trained_run_uuid = 'f73c1164184711eb9190b74249275441'
 | 
				
			||||||
 | 
					    # Create configs object
 | 
				
			||||||
 | 
					    conf = Configs()
 | 
				
			||||||
 | 
					    # Create experiment
 | 
				
			||||||
 | 
					    experiment.create(name='cycle_gan_inference')
 | 
				
			||||||
 | 
					    # Load hyper parameters set for training
 | 
				
			||||||
 | 
					    conf_dict = experiment.load_configs(trained_run_uuid)
 | 
				
			||||||
 | 
					    # Calculate configurations. We specify the generators `'generator_xy', 'generator_yx'`
 | 
				
			||||||
 | 
					    # so that it only loads those and their dependencies.
 | 
				
			||||||
 | 
					    # Configs like `device`, `img_height` and `img_width` will be calculated since these are required by
 | 
				
			||||||
 | 
					    # `generator_xy` and `generator_yx`.
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    # If you want other parameters like `dataset_name` you should specify them here.
 | 
				
			||||||
 | 
					    # If you specify nothing all the configurations will be calculated including data loaders.
 | 
				
			||||||
 | 
					    # Calculation of configurations and their dependencies will happen when you call `experiment.start`
 | 
				
			||||||
 | 
					    experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx')
 | 
				
			||||||
 | 
					    # Register models for saving and loading.
 | 
				
			||||||
 | 
					    # `get_modules` gives a dictionary of `nn.Modules` in `conf`.
 | 
				
			||||||
 | 
					    # You can also specify a custom dictionary of models.
 | 
				
			||||||
 | 
					    experiment.add_pytorch_models(get_modules(conf))
 | 
				
			||||||
 | 
					    # Specify which run to load from.
 | 
				
			||||||
 | 
					    # Loading will actually happen when you call `experiment.start`
 | 
				
			||||||
 | 
					    experiment.load(trained_run_uuid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Start the experiment
 | 
				
			||||||
 | 
					    with experiment.start():
 | 
				
			||||||
 | 
					        # Load your own data, here we try test set.
 | 
				
			||||||
 | 
					        # I was trying with yosemite photos, they look awesome.
 | 
				
			||||||
 | 
					        # You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
 | 
				
			||||||
 | 
					        # in the call to `experiment.configs`
 | 
				
			||||||
 | 
					        images_path = lab.get_data_path() / 'cycle_gan' / 'summer2winter_yosemite'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Image transformations
 | 
				
			||||||
 | 
					        transforms_ = [
 | 
				
			||||||
 | 
					            transforms.Resize(int(conf.img_height * 1.12), Image.BICUBIC),
 | 
				
			||||||
 | 
					            transforms.RandomCrop((conf.img_height, conf.img_width)),
 | 
				
			||||||
 | 
					            transforms.RandomHorizontalFlip(),
 | 
				
			||||||
 | 
					            transforms.ToTensor(),
 | 
				
			||||||
 | 
					            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Load dataset
 | 
				
			||||||
 | 
					        dataset = ImageDataset(images_path, transforms_, True, 'train')
 | 
				
			||||||
 | 
					        # Get an images from dataset
 | 
				
			||||||
 | 
					        x_image = dataset[0]['x']
 | 
				
			||||||
 | 
					        # Display the image. We have to change the order of dimensions to HWC.
 | 
				
			||||||
 | 
					        plt.imshow(x_image.permute(1, 2, 0))
 | 
				
			||||||
 | 
					        plt.show()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Evaluation mode
 | 
				
			||||||
 | 
					        conf.generator_xy.eval()
 | 
				
			||||||
 | 
					        conf.generator_yx.eval()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # We dont need gradients
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            # Add batch dimension and move to the device we use
 | 
				
			||||||
 | 
					            data = x_image.unsqueeze(0).to(conf.device)
 | 
				
			||||||
 | 
					            generated_y = conf.generator_xy(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Display the generated image. We have to change the order of dimensions to HWC.
 | 
				
			||||||
 | 
					        plt.imshow(generated_y[0].cpu().permute(1, 2, 0))
 | 
				
			||||||
 | 
					        plt.show()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    main()
 | 
					    train()
 | 
				
			||||||
 | 
					    # sample()
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user