mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 09:38:56 +08:00
♻️ remove setup
This commit is contained in:
@ -328,6 +328,62 @@ class Configs(BaseConfigs):
|
||||
# Save grid
|
||||
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
## Initialize models and data loaders
|
||||
"""
|
||||
input_shape = (self.img_channels, self.img_height, self.img_width)
|
||||
|
||||
# Create the models
|
||||
self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
||||
self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
||||
self.discriminator_x = Discriminator(input_shape).to(self.device)
|
||||
self.discriminator_y = Discriminator(input_shape).to(self.device)
|
||||
|
||||
# Create the optmizers
|
||||
self.generator_optimizer = torch.optim.Adam(
|
||||
itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
|
||||
lr=self.learning_rate, betas=self.adam_betas)
|
||||
self.discriminator_optimizer = torch.optim.Adam(
|
||||
itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
|
||||
lr=self.learning_rate, betas=self.adam_betas)
|
||||
|
||||
# Create the learning rate schedules.
|
||||
# The learning rate stars flat until `decay_start` epochs,
|
||||
# and then linearly reduces to $0$ at end of training.
|
||||
decay_epochs = self.epochs - self.decay_start
|
||||
self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
||||
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
||||
# Location of the dataset
|
||||
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
|
||||
|
||||
# Image transformations
|
||||
transforms_ = [
|
||||
transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
|
||||
transforms.RandomCrop((self.img_height, self.img_width)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
|
||||
# Training data loader
|
||||
self.dataloader = DataLoader(
|
||||
ImageDataset(images_path, transforms_, True, 'train'),
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.data_loader_workers,
|
||||
)
|
||||
|
||||
# Validation data loader
|
||||
self.valid_dataloader = DataLoader(
|
||||
ImageDataset(images_path, transforms_, True, "test"),
|
||||
batch_size=5,
|
||||
shuffle=True,
|
||||
num_workers=self.data_loader_workers,
|
||||
)
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
## Training
|
||||
@ -542,74 +598,6 @@ class Configs(BaseConfigs):
|
||||
tracker.add({'loss.discriminator': loss_discriminator})
|
||||
|
||||
|
||||
@configs.setup([Configs.generator_xy, Configs.generator_yx, Configs.discriminator_x, Configs.discriminator_y,
|
||||
Configs.generator_optimizer, Configs.discriminator_optimizer,
|
||||
Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler])
|
||||
def setup_models(self: Configs):
|
||||
"""
|
||||
## setup the models
|
||||
"""
|
||||
input_shape = (self.img_channels, self.img_height, self.img_width)
|
||||
|
||||
# Create the models
|
||||
self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
||||
self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
||||
self.discriminator_x = Discriminator(input_shape).to(self.device)
|
||||
self.discriminator_y = Discriminator(input_shape).to(self.device)
|
||||
|
||||
# Create the optmizers
|
||||
self.generator_optimizer = torch.optim.Adam(
|
||||
itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
|
||||
lr=self.learning_rate, betas=self.adam_betas)
|
||||
self.discriminator_optimizer = torch.optim.Adam(
|
||||
itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
|
||||
lr=self.learning_rate, betas=self.adam_betas)
|
||||
|
||||
# Create the learning rate schedules.
|
||||
# The learning rate stars flat until `decay_start` epochs,
|
||||
# and then linearly reduces to $0$ at end of training.
|
||||
decay_epochs = self.epochs - self.decay_start
|
||||
self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
||||
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
||||
|
||||
|
||||
@configs.setup([Configs.dataloader, Configs.valid_dataloader])
|
||||
def setup_dataloader(self: Configs):
|
||||
"""
|
||||
## setup the data loaders
|
||||
"""
|
||||
|
||||
# Location of the dataset
|
||||
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
|
||||
|
||||
# Image transformations
|
||||
transforms_ = [
|
||||
transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
|
||||
transforms.RandomCrop((self.img_height, self.img_width)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
|
||||
# Training data loader
|
||||
self.dataloader = DataLoader(
|
||||
ImageDataset(images_path, transforms_, True, 'train'),
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.data_loader_workers,
|
||||
)
|
||||
|
||||
# Validation data loader
|
||||
self.valid_dataloader = DataLoader(
|
||||
ImageDataset(images_path, transforms_, True, "test"),
|
||||
batch_size=5,
|
||||
shuffle=True,
|
||||
num_workers=self.data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
def train():
|
||||
"""
|
||||
## Train Cycle GAN
|
||||
@ -620,7 +608,9 @@ def train():
|
||||
experiment.create(name='cycle_gan')
|
||||
# Calculate configurations.
|
||||
# It will calculate `conf.run` and all other configs required by it.
|
||||
experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run')
|
||||
experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
|
||||
conf.initialize()
|
||||
|
||||
# Register models for saving and loading.
|
||||
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
||||
# You can also specify a custom dictionary of models.
|
||||
@ -668,7 +658,9 @@ def evaluate():
|
||||
# If you want other parameters like `dataset_name` you should specify them here.
|
||||
# If you specify nothing all the configurations will be calculated including data loaders.
|
||||
# Calculation of configurations and their dependencies will happen when you call `experiment.start`
|
||||
experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx')
|
||||
experiment.configs(conf, conf_dict)
|
||||
conf.initialize()
|
||||
|
||||
# Register models for saving and loading.
|
||||
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
||||
# You can also specify a custom dictionary of models.
|
||||
|
||||
@ -131,6 +131,7 @@ class BivariateGaussianMixture:
|
||||
This class adjust temperatures and creates the categorical and gaussian
|
||||
distributions from the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
|
||||
sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
|
||||
self.pi_logits = pi_logits
|
||||
@ -595,6 +596,34 @@ class Configs(TrainValidConfigs):
|
||||
|
||||
epochs = 100
|
||||
|
||||
def initialize(self):
|
||||
# Initialize encoder & decoder
|
||||
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
|
||||
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
|
||||
|
||||
# Set optimizer, things like type of optimizer and learning rate are configurable
|
||||
optimizer = OptimizerConfigs()
|
||||
optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||
self.optimizer = optimizer
|
||||
|
||||
# Create sampler
|
||||
self.sampler = Sampler(self.encoder, self.decoder)
|
||||
|
||||
# `npz` file path is `data/sketch/[DATASET NAME].npz`
|
||||
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
|
||||
# Load the numpy file.
|
||||
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
|
||||
|
||||
# Create training dataset
|
||||
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
|
||||
# Create validation dataset
|
||||
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
|
||||
|
||||
# Create training data loader
|
||||
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
|
||||
# Create validation data loader
|
||||
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
|
||||
|
||||
def sample(self):
|
||||
# Randomly pick a sample from validation dataset to encoder
|
||||
data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
|
||||
@ -604,37 +633,6 @@ class Configs(TrainValidConfigs):
|
||||
self.sampler.sample(data, self.temperature)
|
||||
|
||||
|
||||
@setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler,
|
||||
Configs.train_dataset, Configs.train_loader,
|
||||
Configs.valid_dataset, Configs.valid_loader])
|
||||
def setup_all(self: Configs):
|
||||
# Initialize encoder & decoder
|
||||
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
|
||||
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
|
||||
|
||||
# Set optimizer, things like type of optimizer and learning rate are configurable
|
||||
self.optimizer = OptimizerConfigs()
|
||||
self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||
|
||||
# Create sampler
|
||||
self.sampler = Sampler(self.encoder, self.decoder)
|
||||
|
||||
# `npz` file path is `data/sketch/[DATASET NAME].npz`
|
||||
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
|
||||
# Load the numpy file.
|
||||
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
|
||||
|
||||
# Create training dataset
|
||||
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
|
||||
# Create validation dataset
|
||||
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
|
||||
|
||||
# Create training data loader
|
||||
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
|
||||
# Create validation data loader
|
||||
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
|
||||
|
||||
|
||||
@option(Configs.batch_step)
|
||||
def strokes_batch_step(c: Configs):
|
||||
"""Set Strokes Training and Validation module"""
|
||||
@ -655,7 +653,9 @@ def main():
|
||||
'dataset_name': 'bicycle',
|
||||
# Number of inner iterations within an epoch to switch between training, validation and sampling.
|
||||
'inner_iterations': 10
|
||||
}, 'run')
|
||||
})
|
||||
|
||||
configs.initialize()
|
||||
|
||||
with experiment.start():
|
||||
# Run the experiment
|
||||
|
||||
Reference in New Issue
Block a user