mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-04 14:29:43 +08:00
✨ cycle gan save/load
This commit is contained in:
@ -31,6 +31,7 @@ from torchvision.utils import save_image
|
|||||||
|
|
||||||
from labml import lab, tracker, experiment, monit, configs
|
from labml import lab, tracker, experiment, monit, configs
|
||||||
from labml.configs import BaseConfigs
|
from labml.configs import BaseConfigs
|
||||||
|
from labml.utils.pytorch import get_modules
|
||||||
from labml_helpers.device import DeviceConfigs
|
from labml_helpers.device import DeviceConfigs
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
|
|
||||||
@ -324,6 +325,11 @@ class Configs(BaseConfigs):
|
|||||||
# Arrange images along y-axis
|
# Arrange images along y-axis
|
||||||
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
|
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
|
||||||
|
|
||||||
|
# Create folder to store sampled images
|
||||||
|
images_path = Path(f'images/{self.dataset_name}')
|
||||||
|
if not images_path.exists():
|
||||||
|
images_path.mkdir(parents=True)
|
||||||
|
|
||||||
# Save grid
|
# Save grid
|
||||||
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
||||||
|
|
||||||
@ -453,6 +459,9 @@ class Configs(BaseConfigs):
|
|||||||
# Save images at intervals
|
# Save images at intervals
|
||||||
batches_done = epoch * len(self.dataloader) + i
|
batches_done = epoch * len(self.dataloader) + i
|
||||||
if batches_done % self.sample_interval == 0:
|
if batches_done % self.sample_interval == 0:
|
||||||
|
# Save models when sampling images
|
||||||
|
experiment.save_checkpoint()
|
||||||
|
# Sample images
|
||||||
self.sample_images(batches_done)
|
self.sample_images(batches_done)
|
||||||
|
|
||||||
# Update learning rates
|
# Update learning rates
|
||||||
@ -606,13 +615,88 @@ def setup_dataloader(self: Configs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def train():
|
||||||
conf = Configs()
|
conf = Configs()
|
||||||
experiment.create(name='cycle_gan')
|
experiment.create(name='cycle_gan')
|
||||||
experiment.configs(conf, 'run')
|
experiment.configs(conf, {
|
||||||
|
'dataset_name': 'summer2winter_yosemite'
|
||||||
|
}, 'run')
|
||||||
|
# Register models for saving and loading.
|
||||||
|
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
||||||
|
# You can also specify a custom dictionary of models.
|
||||||
|
experiment.add_pytorch_models(get_modules(conf))
|
||||||
with experiment.start():
|
with experiment.start():
|
||||||
conf.run()
|
conf.run()
|
||||||
|
|
||||||
|
|
||||||
|
def sample():
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
# Set the run uuid from the training run
|
||||||
|
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
|
||||||
|
# Create configs object
|
||||||
|
conf = Configs()
|
||||||
|
# Create experiment
|
||||||
|
experiment.create(name='cycle_gan_inference')
|
||||||
|
# Load hyper parameters set for training
|
||||||
|
conf_dict = experiment.load_configs(trained_run_uuid)
|
||||||
|
# Calculate configurations. We specify the generators `'generator_xy', 'generator_yx'`
|
||||||
|
# so that it only loads those and their dependencies.
|
||||||
|
# Configs like `device`, `img_height` and `img_width` will be calculated since these are required by
|
||||||
|
# `generator_xy` and `generator_yx`.
|
||||||
|
#
|
||||||
|
# If you want other parameters like `dataset_name` you should specify them here.
|
||||||
|
# If you specify nothing all the configurations will be calculated including data loaders.
|
||||||
|
# Calculation of configurations and their dependencies will happen when you call `experiment.start`
|
||||||
|
experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx')
|
||||||
|
# Register models for saving and loading.
|
||||||
|
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
||||||
|
# You can also specify a custom dictionary of models.
|
||||||
|
experiment.add_pytorch_models(get_modules(conf))
|
||||||
|
# Specify which run to load from.
|
||||||
|
# Loading will actually happen when you call `experiment.start`
|
||||||
|
experiment.load(trained_run_uuid)
|
||||||
|
|
||||||
|
# Start the experiment
|
||||||
|
with experiment.start():
|
||||||
|
# Load your own data, here we try test set.
|
||||||
|
# I was trying with yosemite photos, they look awesome.
|
||||||
|
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
|
||||||
|
# in the call to `experiment.configs`
|
||||||
|
images_path = lab.get_data_path() / 'cycle_gan' / 'summer2winter_yosemite'
|
||||||
|
|
||||||
|
# Image transformations
|
||||||
|
transforms_ = [
|
||||||
|
transforms.Resize(int(conf.img_height * 1.12), Image.BICUBIC),
|
||||||
|
transforms.RandomCrop((conf.img_height, conf.img_width)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
|
dataset = ImageDataset(images_path, transforms_, True, 'train')
|
||||||
|
# Get an images from dataset
|
||||||
|
x_image = dataset[0]['x']
|
||||||
|
# Display the image. We have to change the order of dimensions to HWC.
|
||||||
|
plt.imshow(x_image.permute(1, 2, 0))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Evaluation mode
|
||||||
|
conf.generator_xy.eval()
|
||||||
|
conf.generator_yx.eval()
|
||||||
|
|
||||||
|
# We dont need gradients
|
||||||
|
with torch.no_grad():
|
||||||
|
# Add batch dimension and move to the device we use
|
||||||
|
data = x_image.unsqueeze(0).to(conf.device)
|
||||||
|
generated_y = conf.generator_xy(data)
|
||||||
|
|
||||||
|
# Display the generated image. We have to change the order of dimensions to HWC.
|
||||||
|
plt.imshow(generated_y[0].cpu().permute(1, 2, 0))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
train()
|
||||||
|
# sample()
|
||||||
|
|||||||
Reference in New Issue
Block a user