cleanup cycle gan code

This commit is contained in:
Varuna Jayasiri
2020-09-30 15:07:17 +05:30
parent 096c21fe0a
commit 39dbc9e6b1

View File

@ -53,7 +53,7 @@ class GeneratorResNet(Module):
# `inplace=True` in `ReLU` saves a little bit of memory.
out_features = 64
layers = [
nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflection'),
nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
@ -80,14 +80,15 @@ class GeneratorResNet(Module):
for _ in range(2):
out_features //= 2
layers += [
nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
nn.Upsample(scale_factor=2),
nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Finally we map the feature map to an RGB image
layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflection'), nn.Tanh()]
layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
# Create a sequential module with the layers
self.layers = nn.Sequential(*layers)
@ -107,10 +108,10 @@ class ResidualBlock(Module):
def __init__(self, in_features: int):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflection'),
nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflection'),
nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
)
@ -158,6 +159,7 @@ class DiscriminatorBlock(Module):
It shrinks the height and width of the input feature map by half.
"""
def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
super().__init__()
layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
@ -194,6 +196,7 @@ class ImageDataset(Dataset):
"""
Dataset to load images
"""
def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str):
root = Path(root)
self.transform = transforms.Compose(transforms_)
@ -221,6 +224,7 @@ class ReplayBuffer:
This is done to reduce model oscillation.
"""
def __init__(self, max_size: int = 50):
self.max_size = max_size
self.data = []
@ -249,7 +253,7 @@ class Configs(BaseConfigs):
dataset_name: str = 'monet2photo'
batch_size: int = 1
data_loader_workers = 8
data_loader_workers = 0
is_save_models = True
learning_rate = 0.0002
@ -311,6 +315,57 @@ class Configs(BaseConfigs):
image_grid = torch.cat((real_a, fake_b, real_b, fake_a), 1)
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
def optimize_generators(self, real_a: torch.Tensor, real_b: torch.Tensor, true_labels: torch.Tensor):
# Change to training mode
self.generator_ab.train()
self.generator_ba.train()
# Identity loss
loss_identity = (self.identity_loss(self.generator_ba(real_a), real_a) +
self.identity_loss(self.generator_ab(real_b), real_b))
# Generate images
fake_b = self.generator_ab(real_a)
fake_a = self.generator_ba(real_b)
# GAN loss
loss_gan = (self.gan_loss(self.discriminator_b(fake_b), true_labels) +
self.gan_loss(self.discriminator_a(fake_a), true_labels))
# Cycle loss
loss_cycle = (self.cycle_loss(self.generator_ba(fake_b), real_a) +
self.cycle_loss(self.generator_ab(fake_a), real_b))
# Total loss
loss_generator = (loss_gan +
self.cyclic_loss_coefficient * loss_cycle +
self.identity_loss_coefficient * loss_identity)
self.generator_optimizer.zero_grad()
loss_generator.backward()
self.generator_optimizer.step()
tracker.add({'loss.generator': loss_generator,
'loss.generator.cycle': loss_cycle,
'loss.generator.gan': loss_gan,
'loss.generator.identity': loss_identity})
return fake_a, fake_b
def optimize_discriminator(self, real_a: torch.Tensor, real_b: torch.Tensor,
fake_a: torch.Tensor, fake_b: torch.Tensor,
true_labels: torch.Tensor, false_labels: torch.Tensor):
loss_discriminator = (self.gan_loss(self.discriminator_a(real_a), true_labels) +
self.gan_loss(self.discriminator_a(fake_a), false_labels) +
self.gan_loss(self.discriminator_b(real_b), true_labels) +
self.gan_loss(self.discriminator_b(fake_b), false_labels))
self.discriminator_optimizer.zero_grad()
loss_discriminator.backward()
self.discriminator_optimizer.step()
tracker.add({'loss.discriminator': loss_discriminator})
def run(self):
# Replay buffers to keep generated samples
fake_a_buffer = ReplayBuffer()
@ -321,56 +376,23 @@ class Configs(BaseConfigs):
# Move images to the device
real_a, real_b = batch['a'].to(self.device), batch['b'].to(self.device)
# valid labels equal to $1$
valid = torch.ones(real_a.size(0), *self.discriminator_a.output_shape,
device=self.device, requires_grad=False)
# fake labels equal to $0$
fake = torch.zeros(real_a.size(0), *self.discriminator_a.output_shape,
device=self.device, requires_grad=False)
# true labels equal to $1$
true_labels = torch.ones(real_a.size(0), *self.discriminator_a.output_shape,
device=self.device, requires_grad=False)
# false labels equal to $0$
false_labels = torch.zeros(real_a.size(0), *self.discriminator_a.output_shape,
device=self.device, requires_grad=False)
# Train generators
self.generator_ab.train()
self.generator_ba.train()
# Identity loss
loss_identity = self.identity_loss(self.generator_ba(real_a), real_a) + \
self.identity_loss(self.generator_ab(real_b), real_b)
# GAN loss
fake_b = self.generator_ab(real_a)
fake_a = self.generator_ba(real_b)
loss_gan = self.gan_loss(self.discriminator_b(fake_b), valid) + \
self.gan_loss(self.discriminator_a(fake_a), valid)
loss_cycle = self.cycle_loss(self.generator_ba(fake_b), real_a) + \
self.cycle_loss(self.generator_ab(fake_a), real_b)
# Total loss
loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle
+ self.identity_loss_coefficient * loss_identity)
self.generator_optimizer.zero_grad()
loss_generator.backward()
self.generator_optimizer.step()
# Train the generators
fake_a, fake_b = self.optimize_generators(real_a, real_b, true_labels)
# Train discriminators
fake_a_replay = fake_a_buffer.push_and_pop(fake_a)
fake_b_replay = fake_b_buffer.push_and_pop(fake_b)
loss_discriminator = self.gan_loss(self.discriminator_a(real_a), valid) + \
self.gan_loss(self.discriminator_a(fake_a_replay), fake) + \
self.gan_loss(self.discriminator_b(real_b), valid) + \
self.gan_loss(self.discriminator_b(fake_b_replay), fake)
self.optimize_discriminator(real_a, real_b,
fake_a_buffer.push_and_pop(fake_a), fake_b_buffer.push_and_pop(fake_b),
true_labels, false_labels)
self.discriminator_optimizer.zero_grad()
loss_discriminator.backward()
self.discriminator_optimizer.step()
tracker.save({'loss.generator': loss_generator,
'loss.discriminator': loss_discriminator,
'loss.generator.cycle': loss_cycle,
'loss.generator.gan': loss_gan,
'loss.generator.identity': loss_identity})
# Save training statistics
tracker.save()
# If at sample interval save image
batches_done = epoch * len(self.dataloader) + i