♻️ remove setup

This commit is contained in:
Varuna Jayasiri
2020-11-14 09:49:52 +05:30
parent b1a7da2fa1
commit b41c7ff5c6
2 changed files with 94 additions and 102 deletions

View File

@ -328,6 +328,62 @@ class Configs(BaseConfigs):
# Save grid # Save grid
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False) save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
def initialize(self):
"""
## Initialize models and data loaders
"""
input_shape = (self.img_channels, self.img_height, self.img_width)
# Create the models
self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
self.discriminator_x = Discriminator(input_shape).to(self.device)
self.discriminator_y = Discriminator(input_shape).to(self.device)
# Create the optmizers
self.generator_optimizer = torch.optim.Adam(
itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
self.discriminator_optimizer = torch.optim.Adam(
itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
# Create the learning rate schedules.
# The learning rate stars flat until `decay_start` epochs,
# and then linearly reduces to $0$ at end of training.
decay_epochs = self.epochs - self.decay_start
self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
# Location of the dataset
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
# Image transformations
transforms_ = [
transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((self.img_height, self.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Training data loader
self.dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, 'train'),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.data_loader_workers,
)
# Validation data loader
self.valid_dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, "test"),
batch_size=5,
shuffle=True,
num_workers=self.data_loader_workers,
)
def run(self): def run(self):
""" """
## Training ## Training
@ -542,74 +598,6 @@ class Configs(BaseConfigs):
tracker.add({'loss.discriminator': loss_discriminator}) tracker.add({'loss.discriminator': loss_discriminator})
@configs.setup([Configs.generator_xy, Configs.generator_yx, Configs.discriminator_x, Configs.discriminator_y,
Configs.generator_optimizer, Configs.discriminator_optimizer,
Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler])
def setup_models(self: Configs):
"""
## setup the models
"""
input_shape = (self.img_channels, self.img_height, self.img_width)
# Create the models
self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
self.discriminator_x = Discriminator(input_shape).to(self.device)
self.discriminator_y = Discriminator(input_shape).to(self.device)
# Create the optmizers
self.generator_optimizer = torch.optim.Adam(
itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
self.discriminator_optimizer = torch.optim.Adam(
itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
# Create the learning rate schedules.
# The learning rate stars flat until `decay_start` epochs,
# and then linearly reduces to $0$ at end of training.
decay_epochs = self.epochs - self.decay_start
self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
@configs.setup([Configs.dataloader, Configs.valid_dataloader])
def setup_dataloader(self: Configs):
"""
## setup the data loaders
"""
# Location of the dataset
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
# Image transformations
transforms_ = [
transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((self.img_height, self.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Training data loader
self.dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, 'train'),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.data_loader_workers,
)
# Validation data loader
self.valid_dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, "test"),
batch_size=5,
shuffle=True,
num_workers=self.data_loader_workers,
)
def train(): def train():
""" """
## Train Cycle GAN ## Train Cycle GAN
@ -620,7 +608,9 @@ def train():
experiment.create(name='cycle_gan') experiment.create(name='cycle_gan')
# Calculate configurations. # Calculate configurations.
# It will calculate `conf.run` and all other configs required by it. # It will calculate `conf.run` and all other configs required by it.
experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run') experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
conf.initialize()
# Register models for saving and loading. # Register models for saving and loading.
# `get_modules` gives a dictionary of `nn.Modules` in `conf`. # `get_modules` gives a dictionary of `nn.Modules` in `conf`.
# You can also specify a custom dictionary of models. # You can also specify a custom dictionary of models.
@ -668,7 +658,9 @@ def evaluate():
# If you want other parameters like `dataset_name` you should specify them here. # If you want other parameters like `dataset_name` you should specify them here.
# If you specify nothing all the configurations will be calculated including data loaders. # If you specify nothing all the configurations will be calculated including data loaders.
# Calculation of configurations and their dependencies will happen when you call `experiment.start` # Calculation of configurations and their dependencies will happen when you call `experiment.start`
experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx') experiment.configs(conf, conf_dict)
conf.initialize()
# Register models for saving and loading. # Register models for saving and loading.
# `get_modules` gives a dictionary of `nn.Modules` in `conf`. # `get_modules` gives a dictionary of `nn.Modules` in `conf`.
# You can also specify a custom dictionary of models. # You can also specify a custom dictionary of models.

View File

@ -131,6 +131,7 @@ class BivariateGaussianMixture:
This class adjust temperatures and creates the categorical and gaussian This class adjust temperatures and creates the categorical and gaussian
distributions from the parameters. distributions from the parameters.
""" """
def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor, def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor): sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
self.pi_logits = pi_logits self.pi_logits = pi_logits
@ -595,26 +596,15 @@ class Configs(TrainValidConfigs):
epochs = 100 epochs = 100
def sample(self): def initialize(self):
# Randomly pick a sample from validation dataset to encoder
data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
# Add batch dimension and move it to device
data = data.unsqueeze(1).to(self.device)
# Sample
self.sampler.sample(data, self.temperature)
@setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler,
Configs.train_dataset, Configs.train_loader,
Configs.valid_dataset, Configs.valid_loader])
def setup_all(self: Configs):
# Initialize encoder & decoder # Initialize encoder & decoder
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device) self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
# Set optimizer, things like type of optimizer and learning rate are configurable # Set optimizer, things like type of optimizer and learning rate are configurable
self.optimizer = OptimizerConfigs() optimizer = OptimizerConfigs()
self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
self.optimizer = optimizer
# Create sampler # Create sampler
self.sampler = Sampler(self.encoder, self.decoder) self.sampler = Sampler(self.encoder, self.decoder)
@ -634,6 +624,14 @@ def setup_all(self: Configs):
# Create validation data loader # Create validation data loader
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size) self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
def sample(self):
# Randomly pick a sample from validation dataset to encoder
data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
# Add batch dimension and move it to device
data = data.unsqueeze(1).to(self.device)
# Sample
self.sampler.sample(data, self.temperature)
@option(Configs.batch_step) @option(Configs.batch_step)
def strokes_batch_step(c: Configs): def strokes_batch_step(c: Configs):
@ -655,7 +653,9 @@ def main():
'dataset_name': 'bicycle', 'dataset_name': 'bicycle',
# Number of inner iterations within an epoch to switch between training, validation and sampling. # Number of inner iterations within an epoch to switch between training, validation and sampling.
'inner_iterations': 10 'inner_iterations': 10
}, 'run') })
configs.initialize()
with experiment.start(): with experiment.start():
# Run the experiment # Run the experiment