mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-31 02:39:16 +08:00
♻️ remove setup
This commit is contained in:
@ -328,6 +328,62 @@ class Configs(BaseConfigs):
|
|||||||
# Save grid
|
# Save grid
|
||||||
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""
|
||||||
|
## Initialize models and data loaders
|
||||||
|
"""
|
||||||
|
input_shape = (self.img_channels, self.img_height, self.img_width)
|
||||||
|
|
||||||
|
# Create the models
|
||||||
|
self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
||||||
|
self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
||||||
|
self.discriminator_x = Discriminator(input_shape).to(self.device)
|
||||||
|
self.discriminator_y = Discriminator(input_shape).to(self.device)
|
||||||
|
|
||||||
|
# Create the optmizers
|
||||||
|
self.generator_optimizer = torch.optim.Adam(
|
||||||
|
itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
|
||||||
|
lr=self.learning_rate, betas=self.adam_betas)
|
||||||
|
self.discriminator_optimizer = torch.optim.Adam(
|
||||||
|
itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
|
||||||
|
lr=self.learning_rate, betas=self.adam_betas)
|
||||||
|
|
||||||
|
# Create the learning rate schedules.
|
||||||
|
# The learning rate stars flat until `decay_start` epochs,
|
||||||
|
# and then linearly reduces to $0$ at end of training.
|
||||||
|
decay_epochs = self.epochs - self.decay_start
|
||||||
|
self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
|
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
||||||
|
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
|
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
||||||
|
# Location of the dataset
|
||||||
|
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
|
||||||
|
|
||||||
|
# Image transformations
|
||||||
|
transforms_ = [
|
||||||
|
transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
|
||||||
|
transforms.RandomCrop((self.img_height, self.img_width)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Training data loader
|
||||||
|
self.dataloader = DataLoader(
|
||||||
|
ImageDataset(images_path, transforms_, True, 'train'),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=self.data_loader_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validation data loader
|
||||||
|
self.valid_dataloader = DataLoader(
|
||||||
|
ImageDataset(images_path, transforms_, True, "test"),
|
||||||
|
batch_size=5,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=self.data_loader_workers,
|
||||||
|
)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""
|
"""
|
||||||
## Training
|
## Training
|
||||||
@ -542,74 +598,6 @@ class Configs(BaseConfigs):
|
|||||||
tracker.add({'loss.discriminator': loss_discriminator})
|
tracker.add({'loss.discriminator': loss_discriminator})
|
||||||
|
|
||||||
|
|
||||||
@configs.setup([Configs.generator_xy, Configs.generator_yx, Configs.discriminator_x, Configs.discriminator_y,
|
|
||||||
Configs.generator_optimizer, Configs.discriminator_optimizer,
|
|
||||||
Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler])
|
|
||||||
def setup_models(self: Configs):
|
|
||||||
"""
|
|
||||||
## setup the models
|
|
||||||
"""
|
|
||||||
input_shape = (self.img_channels, self.img_height, self.img_width)
|
|
||||||
|
|
||||||
# Create the models
|
|
||||||
self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
|
||||||
self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
|
||||||
self.discriminator_x = Discriminator(input_shape).to(self.device)
|
|
||||||
self.discriminator_y = Discriminator(input_shape).to(self.device)
|
|
||||||
|
|
||||||
# Create the optmizers
|
|
||||||
self.generator_optimizer = torch.optim.Adam(
|
|
||||||
itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
|
|
||||||
lr=self.learning_rate, betas=self.adam_betas)
|
|
||||||
self.discriminator_optimizer = torch.optim.Adam(
|
|
||||||
itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
|
|
||||||
lr=self.learning_rate, betas=self.adam_betas)
|
|
||||||
|
|
||||||
# Create the learning rate schedules.
|
|
||||||
# The learning rate stars flat until `decay_start` epochs,
|
|
||||||
# and then linearly reduces to $0$ at end of training.
|
|
||||||
decay_epochs = self.epochs - self.decay_start
|
|
||||||
self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
|
||||||
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
|
||||||
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
|
||||||
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
|
||||||
|
|
||||||
|
|
||||||
@configs.setup([Configs.dataloader, Configs.valid_dataloader])
|
|
||||||
def setup_dataloader(self: Configs):
|
|
||||||
"""
|
|
||||||
## setup the data loaders
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Location of the dataset
|
|
||||||
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
|
|
||||||
|
|
||||||
# Image transformations
|
|
||||||
transforms_ = [
|
|
||||||
transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
|
|
||||||
transforms.RandomCrop((self.img_height, self.img_width)),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Training data loader
|
|
||||||
self.dataloader = DataLoader(
|
|
||||||
ImageDataset(images_path, transforms_, True, 'train'),
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=self.data_loader_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validation data loader
|
|
||||||
self.valid_dataloader = DataLoader(
|
|
||||||
ImageDataset(images_path, transforms_, True, "test"),
|
|
||||||
batch_size=5,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=self.data_loader_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
"""
|
"""
|
||||||
## Train Cycle GAN
|
## Train Cycle GAN
|
||||||
@ -620,7 +608,9 @@ def train():
|
|||||||
experiment.create(name='cycle_gan')
|
experiment.create(name='cycle_gan')
|
||||||
# Calculate configurations.
|
# Calculate configurations.
|
||||||
# It will calculate `conf.run` and all other configs required by it.
|
# It will calculate `conf.run` and all other configs required by it.
|
||||||
experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run')
|
experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
|
||||||
|
conf.initialize()
|
||||||
|
|
||||||
# Register models for saving and loading.
|
# Register models for saving and loading.
|
||||||
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
||||||
# You can also specify a custom dictionary of models.
|
# You can also specify a custom dictionary of models.
|
||||||
@ -668,7 +658,9 @@ def evaluate():
|
|||||||
# If you want other parameters like `dataset_name` you should specify them here.
|
# If you want other parameters like `dataset_name` you should specify them here.
|
||||||
# If you specify nothing all the configurations will be calculated including data loaders.
|
# If you specify nothing all the configurations will be calculated including data loaders.
|
||||||
# Calculation of configurations and their dependencies will happen when you call `experiment.start`
|
# Calculation of configurations and their dependencies will happen when you call `experiment.start`
|
||||||
experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx')
|
experiment.configs(conf, conf_dict)
|
||||||
|
conf.initialize()
|
||||||
|
|
||||||
# Register models for saving and loading.
|
# Register models for saving and loading.
|
||||||
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
||||||
# You can also specify a custom dictionary of models.
|
# You can also specify a custom dictionary of models.
|
||||||
|
|||||||
@ -131,6 +131,7 @@ class BivariateGaussianMixture:
|
|||||||
This class adjust temperatures and creates the categorical and gaussian
|
This class adjust temperatures and creates the categorical and gaussian
|
||||||
distributions from the parameters.
|
distributions from the parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
|
def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
|
||||||
sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
|
sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
|
||||||
self.pi_logits = pi_logits
|
self.pi_logits = pi_logits
|
||||||
@ -595,26 +596,15 @@ class Configs(TrainValidConfigs):
|
|||||||
|
|
||||||
epochs = 100
|
epochs = 100
|
||||||
|
|
||||||
def sample(self):
|
def initialize(self):
|
||||||
# Randomly pick a sample from validation dataset to encoder
|
|
||||||
data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
|
|
||||||
# Add batch dimension and move it to device
|
|
||||||
data = data.unsqueeze(1).to(self.device)
|
|
||||||
# Sample
|
|
||||||
self.sampler.sample(data, self.temperature)
|
|
||||||
|
|
||||||
|
|
||||||
@setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler,
|
|
||||||
Configs.train_dataset, Configs.train_loader,
|
|
||||||
Configs.valid_dataset, Configs.valid_loader])
|
|
||||||
def setup_all(self: Configs):
|
|
||||||
# Initialize encoder & decoder
|
# Initialize encoder & decoder
|
||||||
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
|
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
|
||||||
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
|
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
|
||||||
|
|
||||||
# Set optimizer, things like type of optimizer and learning rate are configurable
|
# Set optimizer, things like type of optimizer and learning rate are configurable
|
||||||
self.optimizer = OptimizerConfigs()
|
optimizer = OptimizerConfigs()
|
||||||
self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
# Create sampler
|
# Create sampler
|
||||||
self.sampler = Sampler(self.encoder, self.decoder)
|
self.sampler = Sampler(self.encoder, self.decoder)
|
||||||
@ -634,6 +624,14 @@ def setup_all(self: Configs):
|
|||||||
# Create validation data loader
|
# Create validation data loader
|
||||||
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
|
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
# Randomly pick a sample from validation dataset to encoder
|
||||||
|
data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
|
||||||
|
# Add batch dimension and move it to device
|
||||||
|
data = data.unsqueeze(1).to(self.device)
|
||||||
|
# Sample
|
||||||
|
self.sampler.sample(data, self.temperature)
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.batch_step)
|
@option(Configs.batch_step)
|
||||||
def strokes_batch_step(c: Configs):
|
def strokes_batch_step(c: Configs):
|
||||||
@ -655,7 +653,9 @@ def main():
|
|||||||
'dataset_name': 'bicycle',
|
'dataset_name': 'bicycle',
|
||||||
# Number of inner iterations within an epoch to switch between training, validation and sampling.
|
# Number of inner iterations within an epoch to switch between training, validation and sampling.
|
||||||
'inner_iterations': 10
|
'inner_iterations': 10
|
||||||
}, 'run')
|
})
|
||||||
|
|
||||||
|
configs.initialize()
|
||||||
|
|
||||||
with experiment.start():
|
with experiment.start():
|
||||||
# Run the experiment
|
# Run the experiment
|
||||||
|
|||||||
Reference in New Issue
Block a user