This commit is contained in:
Varuna Jayasiri
2020-09-30 20:57:23 +05:30
parent abe0c6abb1
commit 62a173067b

View File

@ -229,7 +229,8 @@ class ReplayBuffer:
self.max_size = max_size
self.data = []
def push_and_pop(self, data):
def push_and_pop(self, data: torch.Tensor):
"""Add/retrieve an image"""
data = data.detach()
res = []
for element in data:
@ -247,14 +248,17 @@ class ReplayBuffer:
class Configs(BaseConfigs):
"""This is the configurations for the experiment"""
"""## Configurations"""
# `DeviceConfigs` will pick a GPU if available
device: torch.device = DeviceConfigs()
# Hyper-parameters
epochs: int = 200
dataset_name: str = 'monet2photo'
batch_size: int = 1
data_loader_workers = 8
is_save_models = True
learning_rate = 0.0002
adam_betas = (0.5, 0.999)
@ -268,30 +272,35 @@ class Configs(BaseConfigs):
cycle_loss = torch.nn.L1Loss()
identity_loss = torch.nn.L1Loss()
batch_step = 'cycle_gan_batch_step'
# Image dimensions
img_height = 256
img_width = 256
img_channels = 3
# Number of residual blocks in the generator
n_residual_blocks = 9
# Loss coefficients
cyclic_loss_coefficient = 10.0
identity_loss_coefficient = 5.
sample_interval = 500
# Models
generator_xy: GeneratorResNet
generator_yx: GeneratorResNet
discriminator_x: Discriminator
discriminator_y: Discriminator
# Optimizers
generator_optimizer: torch.optim.Adam
discriminator_optimizer: torch.optim.Adam
# Learning rate schedules
generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
# Data loaders
dataloader: DataLoader
valid_dataloader: DataLoader
@ -311,12 +320,16 @@ class Configs(BaseConfigs):
gen_x = make_grid(gen_x, nrow=5, normalize=True)
gen_y = make_grid(gen_y, nrow=5, normalize=True)
# arange images along y-axis
# Arrange images along y-axis
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
# Save grid
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
def run(self):
"""
## Training
We aim to solve:
$$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$
@ -447,7 +460,7 @@ class Configs(BaseConfigs):
def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
"""
Optimize the generators with identity, gan and cycle losses.
### Optimize the generators with identity, gan and cycle losses.
"""
# Change to training mode
@ -501,7 +514,7 @@ class Configs(BaseConfigs):
gen_x: torch.Tensor, gen_y: torch.Tensor,
true_labels: torch.Tensor, false_labels: torch.Tensor):
"""
Optimize the discriminators with gan loss.
### Optimize the discriminators with gan loss.
"""
# GAN Loss
# \begin{align}
@ -529,7 +542,7 @@ class Configs(BaseConfigs):
Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler])
def setup_models(self: Configs):
"""
This method setup the models
## setup the models
"""
input_shape = (self.img_channels, self.img_height, self.img_width)
@ -560,7 +573,7 @@ def setup_models(self: Configs):
@configs.setup([Configs.dataloader, Configs.valid_dataloader])
def setup_dataloader(self: Configs):
"""
This method setup the data loaders
## setup the data loaders
"""
# Location of the dataset