This commit is contained in:
Varuna Jayasiri
2020-09-30 20:57:23 +05:30
parent abe0c6abb1
commit 62a173067b

View File

@ -229,7 +229,8 @@ class ReplayBuffer:
self.max_size = max_size self.max_size = max_size
self.data = [] self.data = []
def push_and_pop(self, data): def push_and_pop(self, data: torch.Tensor):
"""Add/retrieve an image"""
data = data.detach() data = data.detach()
res = [] res = []
for element in data: for element in data:
@ -247,14 +248,17 @@ class ReplayBuffer:
class Configs(BaseConfigs): class Configs(BaseConfigs):
"""This is the configurations for the experiment""" """## Configurations"""
# `DeviceConfigs` will pick a GPU if available
device: torch.device = DeviceConfigs() device: torch.device = DeviceConfigs()
# Hyper-parameters
epochs: int = 200 epochs: int = 200
dataset_name: str = 'monet2photo' dataset_name: str = 'monet2photo'
batch_size: int = 1 batch_size: int = 1
data_loader_workers = 8 data_loader_workers = 8
is_save_models = True
learning_rate = 0.0002 learning_rate = 0.0002
adam_betas = (0.5, 0.999) adam_betas = (0.5, 0.999)
@ -268,30 +272,35 @@ class Configs(BaseConfigs):
cycle_loss = torch.nn.L1Loss() cycle_loss = torch.nn.L1Loss()
identity_loss = torch.nn.L1Loss() identity_loss = torch.nn.L1Loss()
batch_step = 'cycle_gan_batch_step' # Image dimensions
img_height = 256 img_height = 256
img_width = 256 img_width = 256
img_channels = 3 img_channels = 3
# Number of residual blocks in the generator
n_residual_blocks = 9 n_residual_blocks = 9
# Loss coefficients
cyclic_loss_coefficient = 10.0 cyclic_loss_coefficient = 10.0
identity_loss_coefficient = 5. identity_loss_coefficient = 5.
sample_interval = 500 sample_interval = 500
# Models
generator_xy: GeneratorResNet generator_xy: GeneratorResNet
generator_yx: GeneratorResNet generator_yx: GeneratorResNet
discriminator_x: Discriminator discriminator_x: Discriminator
discriminator_y: Discriminator discriminator_y: Discriminator
# Optimizers
generator_optimizer: torch.optim.Adam generator_optimizer: torch.optim.Adam
discriminator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam
# Learning rate schedules
generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
# Data loaders
dataloader: DataLoader dataloader: DataLoader
valid_dataloader: DataLoader valid_dataloader: DataLoader
@ -311,12 +320,16 @@ class Configs(BaseConfigs):
gen_x = make_grid(gen_x, nrow=5, normalize=True) gen_x = make_grid(gen_x, nrow=5, normalize=True)
gen_y = make_grid(gen_y, nrow=5, normalize=True) gen_y = make_grid(gen_y, nrow=5, normalize=True)
# arange images along y-axis # Arrange images along y-axis
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1) image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
# Save grid
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False) save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
def run(self): def run(self):
""" """
## Training
We aim to solve: We aim to solve:
$$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$ $$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$
@ -447,7 +460,7 @@ class Configs(BaseConfigs):
def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor): def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
""" """
Optimize the generators with identity, gan and cycle losses. ### Optimize the generators with identity, gan and cycle losses.
""" """
# Change to training mode # Change to training mode
@ -501,7 +514,7 @@ class Configs(BaseConfigs):
gen_x: torch.Tensor, gen_y: torch.Tensor, gen_x: torch.Tensor, gen_y: torch.Tensor,
true_labels: torch.Tensor, false_labels: torch.Tensor): true_labels: torch.Tensor, false_labels: torch.Tensor):
""" """
Optimize the discriminators with gan loss. ### Optimize the discriminators with gan loss.
""" """
# GAN Loss # GAN Loss
# \begin{align} # \begin{align}
@ -529,7 +542,7 @@ class Configs(BaseConfigs):
Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler]) Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler])
def setup_models(self: Configs): def setup_models(self: Configs):
""" """
This method setup the models ## setup the models
""" """
input_shape = (self.img_channels, self.img_height, self.img_width) input_shape = (self.img_channels, self.img_height, self.img_width)
@ -560,7 +573,7 @@ def setup_models(self: Configs):
@configs.setup([Configs.dataloader, Configs.valid_dataloader]) @configs.setup([Configs.dataloader, Configs.valid_dataloader])
def setup_dataloader(self: Configs): def setup_dataloader(self: Configs):
""" """
This method setup the data loaders ## setup the data loaders
""" """
# Location of the dataset # Location of the dataset