mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 04:37:46 +08:00
✨ cycle gan save/load
This commit is contained in:
@ -31,6 +31,7 @@ from torchvision.utils import save_image
|
||||
|
||||
from labml import lab, tracker, experiment, monit, configs
|
||||
from labml.configs import BaseConfigs
|
||||
from labml.utils.pytorch import get_modules
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.module import Module
|
||||
|
||||
@ -324,6 +325,11 @@ class Configs(BaseConfigs):
|
||||
# Arrange images along y-axis
|
||||
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
|
||||
|
||||
# Create folder to store sampled images
|
||||
images_path = Path(f'images/{self.dataset_name}')
|
||||
if not images_path.exists():
|
||||
images_path.mkdir(parents=True)
|
||||
|
||||
# Save grid
|
||||
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
||||
|
||||
@ -453,6 +459,9 @@ class Configs(BaseConfigs):
|
||||
# Save images at intervals
|
||||
batches_done = epoch * len(self.dataloader) + i
|
||||
if batches_done % self.sample_interval == 0:
|
||||
# Save models when sampling images
|
||||
experiment.save_checkpoint()
|
||||
# Sample images
|
||||
self.sample_images(batches_done)
|
||||
|
||||
# Update learning rates
|
||||
@ -606,13 +615,88 @@ def setup_dataloader(self: Configs):
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def train():
|
||||
conf = Configs()
|
||||
experiment.create(name='cycle_gan')
|
||||
experiment.configs(conf, 'run')
|
||||
experiment.configs(conf, {
|
||||
'dataset_name': 'summer2winter_yosemite'
|
||||
}, 'run')
|
||||
# Register models for saving and loading.
|
||||
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
||||
# You can also specify a custom dictionary of models.
|
||||
experiment.add_pytorch_models(get_modules(conf))
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
||||
|
||||
def sample():
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
# Set the run uuid from the training run
|
||||
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
|
||||
# Create configs object
|
||||
conf = Configs()
|
||||
# Create experiment
|
||||
experiment.create(name='cycle_gan_inference')
|
||||
# Load hyper parameters set for training
|
||||
conf_dict = experiment.load_configs(trained_run_uuid)
|
||||
# Calculate configurations. We specify the generators `'generator_xy', 'generator_yx'`
|
||||
# so that it only loads those and their dependencies.
|
||||
# Configs like `device`, `img_height` and `img_width` will be calculated since these are required by
|
||||
# `generator_xy` and `generator_yx`.
|
||||
#
|
||||
# If you want other parameters like `dataset_name` you should specify them here.
|
||||
# If you specify nothing all the configurations will be calculated including data loaders.
|
||||
# Calculation of configurations and their dependencies will happen when you call `experiment.start`
|
||||
experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx')
|
||||
# Register models for saving and loading.
|
||||
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
||||
# You can also specify a custom dictionary of models.
|
||||
experiment.add_pytorch_models(get_modules(conf))
|
||||
# Specify which run to load from.
|
||||
# Loading will actually happen when you call `experiment.start`
|
||||
experiment.load(trained_run_uuid)
|
||||
|
||||
# Start the experiment
|
||||
with experiment.start():
|
||||
# Load your own data, here we try test set.
|
||||
# I was trying with yosemite photos, they look awesome.
|
||||
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
|
||||
# in the call to `experiment.configs`
|
||||
images_path = lab.get_data_path() / 'cycle_gan' / 'summer2winter_yosemite'
|
||||
|
||||
# Image transformations
|
||||
transforms_ = [
|
||||
transforms.Resize(int(conf.img_height * 1.12), Image.BICUBIC),
|
||||
transforms.RandomCrop((conf.img_height, conf.img_width)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
|
||||
# Load dataset
|
||||
dataset = ImageDataset(images_path, transforms_, True, 'train')
|
||||
# Get an images from dataset
|
||||
x_image = dataset[0]['x']
|
||||
# Display the image. We have to change the order of dimensions to HWC.
|
||||
plt.imshow(x_image.permute(1, 2, 0))
|
||||
plt.show()
|
||||
|
||||
# Evaluation mode
|
||||
conf.generator_xy.eval()
|
||||
conf.generator_yx.eval()
|
||||
|
||||
# We dont need gradients
|
||||
with torch.no_grad():
|
||||
# Add batch dimension and move to the device we use
|
||||
data = x_image.unsqueeze(0).to(conf.device)
|
||||
generated_y = conf.generator_xy(data)
|
||||
|
||||
# Display the generated image. We have to change the order of dimensions to HWC.
|
||||
plt.imshow(generated_y[0].cpu().permute(1, 2, 0))
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
train()
|
||||
# sample()
|
||||
|
||||
Reference in New Issue
Block a user