From b41c7ff5c66fc47d4d887d65129b7067db868ff5 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sat, 14 Nov 2020 09:49:52 +0530 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20=20remove=20setup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- labml_nn/gan/cycle_gan.py | 132 +++++++++++++++----------------- labml_nn/sketch_rnn/__init__.py | 64 ++++++++-------- 2 files changed, 94 insertions(+), 102 deletions(-) diff --git a/labml_nn/gan/cycle_gan.py b/labml_nn/gan/cycle_gan.py index fe88fcb0..573a674b 100644 --- a/labml_nn/gan/cycle_gan.py +++ b/labml_nn/gan/cycle_gan.py @@ -328,6 +328,62 @@ class Configs(BaseConfigs): # Save grid save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False) + def initialize(self): + """ + ## Initialize models and data loaders + """ + input_shape = (self.img_channels, self.img_height, self.img_width) + + # Create the models + self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) + self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) + self.discriminator_x = Discriminator(input_shape).to(self.device) + self.discriminator_y = Discriminator(input_shape).to(self.device) + + # Create the optmizers + self.generator_optimizer = torch.optim.Adam( + itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()), + lr=self.learning_rate, betas=self.adam_betas) + self.discriminator_optimizer = torch.optim.Adam( + itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()), + lr=self.learning_rate, betas=self.adam_betas) + + # Create the learning rate schedules. + # The learning rate stars flat until `decay_start` epochs, + # and then linearly reduces to $0$ at end of training. + decay_epochs = self.epochs - self.decay_start + self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) + self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) + # Location of the dataset + images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name + + # Image transformations + transforms_ = [ + transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC), + transforms.RandomCrop((self.img_height, self.img_width)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + + # Training data loader + self.dataloader = DataLoader( + ImageDataset(images_path, transforms_, True, 'train'), + batch_size=self.batch_size, + shuffle=True, + num_workers=self.data_loader_workers, + ) + + # Validation data loader + self.valid_dataloader = DataLoader( + ImageDataset(images_path, transforms_, True, "test"), + batch_size=5, + shuffle=True, + num_workers=self.data_loader_workers, + ) + def run(self): """ ## Training @@ -542,74 +598,6 @@ class Configs(BaseConfigs): tracker.add({'loss.discriminator': loss_discriminator}) -@configs.setup([Configs.generator_xy, Configs.generator_yx, Configs.discriminator_x, Configs.discriminator_y, - Configs.generator_optimizer, Configs.discriminator_optimizer, - Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler]) -def setup_models(self: Configs): - """ - ## setup the models - """ - input_shape = (self.img_channels, self.img_height, self.img_width) - - # Create the models - self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) - self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) - self.discriminator_x = Discriminator(input_shape).to(self.device) - self.discriminator_y = Discriminator(input_shape).to(self.device) - - # Create the optmizers - self.generator_optimizer = torch.optim.Adam( - itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()), - lr=self.learning_rate, betas=self.adam_betas) - self.discriminator_optimizer = torch.optim.Adam( - itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()), - lr=self.learning_rate, betas=self.adam_betas) - - # Create the learning rate schedules. - # The learning rate stars flat until `decay_start` epochs, - # and then linearly reduces to $0$ at end of training. - decay_epochs = self.epochs - self.decay_start - self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( - self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) - self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( - self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) - - -@configs.setup([Configs.dataloader, Configs.valid_dataloader]) -def setup_dataloader(self: Configs): - """ - ## setup the data loaders - """ - - # Location of the dataset - images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name - - # Image transformations - transforms_ = [ - transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC), - transforms.RandomCrop((self.img_height, self.img_width)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - - # Training data loader - self.dataloader = DataLoader( - ImageDataset(images_path, transforms_, True, 'train'), - batch_size=self.batch_size, - shuffle=True, - num_workers=self.data_loader_workers, - ) - - # Validation data loader - self.valid_dataloader = DataLoader( - ImageDataset(images_path, transforms_, True, "test"), - batch_size=5, - shuffle=True, - num_workers=self.data_loader_workers, - ) - - def train(): """ ## Train Cycle GAN @@ -620,7 +608,9 @@ def train(): experiment.create(name='cycle_gan') # Calculate configurations. # It will calculate `conf.run` and all other configs required by it. - experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run') + experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}) + conf.initialize() + # Register models for saving and loading. # `get_modules` gives a dictionary of `nn.Modules` in `conf`. # You can also specify a custom dictionary of models. @@ -668,7 +658,9 @@ def evaluate(): # If you want other parameters like `dataset_name` you should specify them here. # If you specify nothing all the configurations will be calculated including data loaders. # Calculation of configurations and their dependencies will happen when you call `experiment.start` - experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx') + experiment.configs(conf, conf_dict) + conf.initialize() + # Register models for saving and loading. # `get_modules` gives a dictionary of `nn.Modules` in `conf`. # You can also specify a custom dictionary of models. diff --git a/labml_nn/sketch_rnn/__init__.py b/labml_nn/sketch_rnn/__init__.py index 17d8f73c..3b3cf71d 100644 --- a/labml_nn/sketch_rnn/__init__.py +++ b/labml_nn/sketch_rnn/__init__.py @@ -131,6 +131,7 @@ class BivariateGaussianMixture: This class adjust temperatures and creates the categorical and gaussian distributions from the parameters. """ + def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor, sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor): self.pi_logits = pi_logits @@ -595,6 +596,34 @@ class Configs(TrainValidConfigs): epochs = 100 + def initialize(self): + # Initialize encoder & decoder + self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) + self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device) + + # Set optimizer, things like type of optimizer and learning rate are configurable + optimizer = OptimizerConfigs() + optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) + self.optimizer = optimizer + + # Create sampler + self.sampler = Sampler(self.encoder, self.decoder) + + # `npz` file path is `data/sketch/[DATASET NAME].npz` + path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' + # Load the numpy file. + dataset = np.load(str(path), encoding='latin1', allow_pickle=True) + + # Create training dataset + self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) + # Create validation dataset + self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) + + # Create training data loader + self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) + # Create validation data loader + self.valid_loader = DataLoader(self.valid_dataset, self.batch_size) + def sample(self): # Randomly pick a sample from validation dataset to encoder data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))] @@ -604,37 +633,6 @@ class Configs(TrainValidConfigs): self.sampler.sample(data, self.temperature) -@setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler, - Configs.train_dataset, Configs.train_loader, - Configs.valid_dataset, Configs.valid_loader]) -def setup_all(self: Configs): - # Initialize encoder & decoder - self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) - self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device) - - # Set optimizer, things like type of optimizer and learning rate are configurable - self.optimizer = OptimizerConfigs() - self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) - - # Create sampler - self.sampler = Sampler(self.encoder, self.decoder) - - # `npz` file path is `data/sketch/[DATASET NAME].npz` - path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' - # Load the numpy file. - dataset = np.load(str(path), encoding='latin1', allow_pickle=True) - - # Create training dataset - self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) - # Create validation dataset - self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) - - # Create training data loader - self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) - # Create validation data loader - self.valid_loader = DataLoader(self.valid_dataset, self.batch_size) - - @option(Configs.batch_step) def strokes_batch_step(c: Configs): """Set Strokes Training and Validation module""" @@ -655,7 +653,9 @@ def main(): 'dataset_name': 'bicycle', # Number of inner iterations within an epoch to switch between training, validation and sampling. 'inner_iterations': 10 - }, 'run') + }) + + configs.initialize() with experiment.start(): # Run the experiment