From 62a173067bb61b9c3df97a00d05c4a3daadb4669 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Wed, 30 Sep 2020 20:57:23 +0530 Subject: [PATCH] comments --- labml_nn/gan/cycle_gan.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/labml_nn/gan/cycle_gan.py b/labml_nn/gan/cycle_gan.py index 8eb40986..99b565d9 100644 --- a/labml_nn/gan/cycle_gan.py +++ b/labml_nn/gan/cycle_gan.py @@ -229,7 +229,8 @@ class ReplayBuffer: self.max_size = max_size self.data = [] - def push_and_pop(self, data): + def push_and_pop(self, data: torch.Tensor): + """Add/retrieve an image""" data = data.detach() res = [] for element in data: @@ -247,14 +248,17 @@ class ReplayBuffer: class Configs(BaseConfigs): - """This is the configurations for the experiment""" + """## Configurations""" + + # `DeviceConfigs` will pick a GPU if available device: torch.device = DeviceConfigs() + + # Hyper-parameters epochs: int = 200 dataset_name: str = 'monet2photo' batch_size: int = 1 data_loader_workers = 8 - is_save_models = True learning_rate = 0.0002 adam_betas = (0.5, 0.999) @@ -268,30 +272,35 @@ class Configs(BaseConfigs): cycle_loss = torch.nn.L1Loss() identity_loss = torch.nn.L1Loss() - batch_step = 'cycle_gan_batch_step' - + # Image dimensions img_height = 256 img_width = 256 img_channels = 3 + # Number of residual blocks in the generator n_residual_blocks = 9 + # Loss coefficients cyclic_loss_coefficient = 10.0 identity_loss_coefficient = 5. sample_interval = 500 + # Models generator_xy: GeneratorResNet generator_yx: GeneratorResNet discriminator_x: Discriminator discriminator_y: Discriminator + # Optimizers generator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam + # Learning rate schedules generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR + # Data loaders dataloader: DataLoader valid_dataloader: DataLoader @@ -311,12 +320,16 @@ class Configs(BaseConfigs): gen_x = make_grid(gen_x, nrow=5, normalize=True) gen_y = make_grid(gen_y, nrow=5, normalize=True) - # arange images along y-axis + # Arrange images along y-axis image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1) + + # Save grid save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False) def run(self): """ + ## Training + We aim to solve: $$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$ @@ -447,7 +460,7 @@ class Configs(BaseConfigs): def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor): """ - Optimize the generators with identity, gan and cycle losses. + ### Optimize the generators with identity, gan and cycle losses. """ # Change to training mode @@ -501,7 +514,7 @@ class Configs(BaseConfigs): gen_x: torch.Tensor, gen_y: torch.Tensor, true_labels: torch.Tensor, false_labels: torch.Tensor): """ - Optimize the discriminators with gan loss. + ### Optimize the discriminators with gan loss. """ # GAN Loss # \begin{align} @@ -529,7 +542,7 @@ class Configs(BaseConfigs): Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler]) def setup_models(self: Configs): """ - This method setup the models + ## setup the models """ input_shape = (self.img_channels, self.img_height, self.img_width) @@ -560,7 +573,7 @@ def setup_models(self: Configs): @configs.setup([Configs.dataloader, Configs.valid_dataloader]) def setup_dataloader(self: Configs): """ - This method setup the data loaders + ## setup the data loaders """ # Location of the dataset