♻️ remove setup

This commit is contained in:
Varuna Jayasiri
2020-11-14 09:49:52 +05:30
parent b1a7da2fa1
commit b41c7ff5c6
2 changed files with 94 additions and 102 deletions

View File

@ -328,6 +328,62 @@ class Configs(BaseConfigs):
# Save grid
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
def initialize(self):
"""
## Initialize models and data loaders
"""
input_shape = (self.img_channels, self.img_height, self.img_width)
# Create the models
self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
self.discriminator_x = Discriminator(input_shape).to(self.device)
self.discriminator_y = Discriminator(input_shape).to(self.device)
# Create the optmizers
self.generator_optimizer = torch.optim.Adam(
itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
self.discriminator_optimizer = torch.optim.Adam(
itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
# Create the learning rate schedules.
# The learning rate stars flat until `decay_start` epochs,
# and then linearly reduces to $0$ at end of training.
decay_epochs = self.epochs - self.decay_start
self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
# Location of the dataset
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
# Image transformations
transforms_ = [
transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((self.img_height, self.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Training data loader
self.dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, 'train'),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.data_loader_workers,
)
# Validation data loader
self.valid_dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, "test"),
batch_size=5,
shuffle=True,
num_workers=self.data_loader_workers,
)
def run(self):
"""
## Training
@ -542,74 +598,6 @@ class Configs(BaseConfigs):
tracker.add({'loss.discriminator': loss_discriminator})
@configs.setup([Configs.generator_xy, Configs.generator_yx, Configs.discriminator_x, Configs.discriminator_y,
Configs.generator_optimizer, Configs.discriminator_optimizer,
Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler])
def setup_models(self: Configs):
"""
## setup the models
"""
input_shape = (self.img_channels, self.img_height, self.img_width)
# Create the models
self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
self.discriminator_x = Discriminator(input_shape).to(self.device)
self.discriminator_y = Discriminator(input_shape).to(self.device)
# Create the optmizers
self.generator_optimizer = torch.optim.Adam(
itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
self.discriminator_optimizer = torch.optim.Adam(
itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
# Create the learning rate schedules.
# The learning rate stars flat until `decay_start` epochs,
# and then linearly reduces to $0$ at end of training.
decay_epochs = self.epochs - self.decay_start
self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
@configs.setup([Configs.dataloader, Configs.valid_dataloader])
def setup_dataloader(self: Configs):
"""
## setup the data loaders
"""
# Location of the dataset
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
# Image transformations
transforms_ = [
transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((self.img_height, self.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Training data loader
self.dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, 'train'),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.data_loader_workers,
)
# Validation data loader
self.valid_dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, "test"),
batch_size=5,
shuffle=True,
num_workers=self.data_loader_workers,
)
def train():
"""
## Train Cycle GAN
@ -620,7 +608,9 @@ def train():
experiment.create(name='cycle_gan')
# Calculate configurations.
# It will calculate `conf.run` and all other configs required by it.
experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run')
experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
conf.initialize()
# Register models for saving and loading.
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
# You can also specify a custom dictionary of models.
@ -668,7 +658,9 @@ def evaluate():
# If you want other parameters like `dataset_name` you should specify them here.
# If you specify nothing all the configurations will be calculated including data loaders.
# Calculation of configurations and their dependencies will happen when you call `experiment.start`
experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx')
experiment.configs(conf, conf_dict)
conf.initialize()
# Register models for saving and loading.
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
# You can also specify a custom dictionary of models.

View File

@ -131,6 +131,7 @@ class BivariateGaussianMixture:
This class adjust temperatures and creates the categorical and gaussian
distributions from the parameters.
"""
def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
self.pi_logits = pi_logits
@ -595,6 +596,34 @@ class Configs(TrainValidConfigs):
epochs = 100
def initialize(self):
# Initialize encoder & decoder
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
# Set optimizer, things like type of optimizer and learning rate are configurable
optimizer = OptimizerConfigs()
optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
self.optimizer = optimizer
# Create sampler
self.sampler = Sampler(self.encoder, self.decoder)
# `npz` file path is `data/sketch/[DATASET NAME].npz`
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
# Load the numpy file.
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
# Create training dataset
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
# Create validation dataset
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
# Create training data loader
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
# Create validation data loader
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
def sample(self):
# Randomly pick a sample from validation dataset to encoder
data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
@ -604,37 +633,6 @@ class Configs(TrainValidConfigs):
self.sampler.sample(data, self.temperature)
@setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler,
Configs.train_dataset, Configs.train_loader,
Configs.valid_dataset, Configs.valid_loader])
def setup_all(self: Configs):
# Initialize encoder & decoder
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
# Set optimizer, things like type of optimizer and learning rate are configurable
self.optimizer = OptimizerConfigs()
self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
# Create sampler
self.sampler = Sampler(self.encoder, self.decoder)
# `npz` file path is `data/sketch/[DATASET NAME].npz`
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
# Load the numpy file.
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
# Create training dataset
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
# Create validation dataset
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
# Create training data loader
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
# Create validation data loader
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
@option(Configs.batch_step)
def strokes_batch_step(c: Configs):
"""Set Strokes Training and Validation module"""
@ -655,7 +653,9 @@ def main():
'dataset_name': 'bicycle',
# Number of inner iterations within an epoch to switch between training, validation and sampling.
'inner_iterations': 10
}, 'run')
})
configs.initialize()
with experiment.start():
# Run the experiment