mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-03 05:46:16 +08:00
cleanup cycle gan code
This commit is contained in:
@ -53,7 +53,7 @@ class GeneratorResNet(Module):
|
||||
# `inplace=True` in `ReLU` saves a little bit of memory.
|
||||
out_features = 64
|
||||
layers = [
|
||||
nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflection'),
|
||||
nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
|
||||
nn.InstanceNorm2d(out_features),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
@ -80,14 +80,15 @@ class GeneratorResNet(Module):
|
||||
for _ in range(2):
|
||||
out_features //= 2
|
||||
layers += [
|
||||
nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
|
||||
nn.InstanceNorm2d(out_features),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
in_features = out_features
|
||||
|
||||
# Finally we map the feature map to an RGB image
|
||||
layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflection'), nn.Tanh()]
|
||||
layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
|
||||
|
||||
# Create a sequential module with the layers
|
||||
self.layers = nn.Sequential(*layers)
|
||||
@ -107,10 +108,10 @@ class ResidualBlock(Module):
|
||||
def __init__(self, in_features: int):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflection'),
|
||||
nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
|
||||
nn.InstanceNorm2d(in_features),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflection'),
|
||||
nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
|
||||
nn.InstanceNorm2d(in_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
@ -158,6 +159,7 @@ class DiscriminatorBlock(Module):
|
||||
|
||||
It shrinks the height and width of the input feature map by half.
|
||||
"""
|
||||
|
||||
def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
|
||||
super().__init__()
|
||||
layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
|
||||
@ -194,6 +196,7 @@ class ImageDataset(Dataset):
|
||||
"""
|
||||
Dataset to load images
|
||||
"""
|
||||
|
||||
def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str):
|
||||
root = Path(root)
|
||||
self.transform = transforms.Compose(transforms_)
|
||||
@ -221,6 +224,7 @@ class ReplayBuffer:
|
||||
|
||||
This is done to reduce model oscillation.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 50):
|
||||
self.max_size = max_size
|
||||
self.data = []
|
||||
@ -249,7 +253,7 @@ class Configs(BaseConfigs):
|
||||
dataset_name: str = 'monet2photo'
|
||||
batch_size: int = 1
|
||||
|
||||
data_loader_workers = 8
|
||||
data_loader_workers = 0
|
||||
is_save_models = True
|
||||
|
||||
learning_rate = 0.0002
|
||||
@ -311,6 +315,57 @@ class Configs(BaseConfigs):
|
||||
image_grid = torch.cat((real_a, fake_b, real_b, fake_a), 1)
|
||||
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
||||
|
||||
def optimize_generators(self, real_a: torch.Tensor, real_b: torch.Tensor, true_labels: torch.Tensor):
|
||||
# Change to training mode
|
||||
self.generator_ab.train()
|
||||
self.generator_ba.train()
|
||||
|
||||
# Identity loss
|
||||
loss_identity = (self.identity_loss(self.generator_ba(real_a), real_a) +
|
||||
self.identity_loss(self.generator_ab(real_b), real_b))
|
||||
|
||||
# Generate images
|
||||
fake_b = self.generator_ab(real_a)
|
||||
fake_a = self.generator_ba(real_b)
|
||||
|
||||
# GAN loss
|
||||
loss_gan = (self.gan_loss(self.discriminator_b(fake_b), true_labels) +
|
||||
self.gan_loss(self.discriminator_a(fake_a), true_labels))
|
||||
|
||||
# Cycle loss
|
||||
loss_cycle = (self.cycle_loss(self.generator_ba(fake_b), real_a) +
|
||||
self.cycle_loss(self.generator_ab(fake_a), real_b))
|
||||
|
||||
# Total loss
|
||||
loss_generator = (loss_gan +
|
||||
self.cyclic_loss_coefficient * loss_cycle +
|
||||
self.identity_loss_coefficient * loss_identity)
|
||||
|
||||
self.generator_optimizer.zero_grad()
|
||||
loss_generator.backward()
|
||||
self.generator_optimizer.step()
|
||||
|
||||
tracker.add({'loss.generator': loss_generator,
|
||||
'loss.generator.cycle': loss_cycle,
|
||||
'loss.generator.gan': loss_gan,
|
||||
'loss.generator.identity': loss_identity})
|
||||
|
||||
return fake_a, fake_b
|
||||
|
||||
def optimize_discriminator(self, real_a: torch.Tensor, real_b: torch.Tensor,
|
||||
fake_a: torch.Tensor, fake_b: torch.Tensor,
|
||||
true_labels: torch.Tensor, false_labels: torch.Tensor):
|
||||
loss_discriminator = (self.gan_loss(self.discriminator_a(real_a), true_labels) +
|
||||
self.gan_loss(self.discriminator_a(fake_a), false_labels) +
|
||||
self.gan_loss(self.discriminator_b(real_b), true_labels) +
|
||||
self.gan_loss(self.discriminator_b(fake_b), false_labels))
|
||||
|
||||
self.discriminator_optimizer.zero_grad()
|
||||
loss_discriminator.backward()
|
||||
self.discriminator_optimizer.step()
|
||||
|
||||
tracker.add({'loss.discriminator': loss_discriminator})
|
||||
|
||||
def run(self):
|
||||
# Replay buffers to keep generated samples
|
||||
fake_a_buffer = ReplayBuffer()
|
||||
@ -321,56 +376,23 @@ class Configs(BaseConfigs):
|
||||
# Move images to the device
|
||||
real_a, real_b = batch['a'].to(self.device), batch['b'].to(self.device)
|
||||
|
||||
# valid labels equal to $1$
|
||||
valid = torch.ones(real_a.size(0), *self.discriminator_a.output_shape,
|
||||
device=self.device, requires_grad=False)
|
||||
# fake labels equal to $0$
|
||||
fake = torch.zeros(real_a.size(0), *self.discriminator_a.output_shape,
|
||||
device=self.device, requires_grad=False)
|
||||
# true labels equal to $1$
|
||||
true_labels = torch.ones(real_a.size(0), *self.discriminator_a.output_shape,
|
||||
device=self.device, requires_grad=False)
|
||||
# false labels equal to $0$
|
||||
false_labels = torch.zeros(real_a.size(0), *self.discriminator_a.output_shape,
|
||||
device=self.device, requires_grad=False)
|
||||
|
||||
# Train generators
|
||||
self.generator_ab.train()
|
||||
self.generator_ba.train()
|
||||
|
||||
# Identity loss
|
||||
loss_identity = self.identity_loss(self.generator_ba(real_a), real_a) + \
|
||||
self.identity_loss(self.generator_ab(real_b), real_b)
|
||||
|
||||
# GAN loss
|
||||
fake_b = self.generator_ab(real_a)
|
||||
fake_a = self.generator_ba(real_b)
|
||||
|
||||
loss_gan = self.gan_loss(self.discriminator_b(fake_b), valid) + \
|
||||
self.gan_loss(self.discriminator_a(fake_a), valid)
|
||||
|
||||
loss_cycle = self.cycle_loss(self.generator_ba(fake_b), real_a) + \
|
||||
self.cycle_loss(self.generator_ab(fake_a), real_b)
|
||||
|
||||
# Total loss
|
||||
loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle
|
||||
+ self.identity_loss_coefficient * loss_identity)
|
||||
|
||||
self.generator_optimizer.zero_grad()
|
||||
loss_generator.backward()
|
||||
self.generator_optimizer.step()
|
||||
# Train the generators
|
||||
fake_a, fake_b = self.optimize_generators(real_a, real_b, true_labels)
|
||||
|
||||
# Train discriminators
|
||||
fake_a_replay = fake_a_buffer.push_and_pop(fake_a)
|
||||
fake_b_replay = fake_b_buffer.push_and_pop(fake_b)
|
||||
loss_discriminator = self.gan_loss(self.discriminator_a(real_a), valid) + \
|
||||
self.gan_loss(self.discriminator_a(fake_a_replay), fake) + \
|
||||
self.gan_loss(self.discriminator_b(real_b), valid) + \
|
||||
self.gan_loss(self.discriminator_b(fake_b_replay), fake)
|
||||
self.optimize_discriminator(real_a, real_b,
|
||||
fake_a_buffer.push_and_pop(fake_a), fake_b_buffer.push_and_pop(fake_b),
|
||||
true_labels, false_labels)
|
||||
|
||||
self.discriminator_optimizer.zero_grad()
|
||||
loss_discriminator.backward()
|
||||
self.discriminator_optimizer.step()
|
||||
|
||||
tracker.save({'loss.generator': loss_generator,
|
||||
'loss.discriminator': loss_discriminator,
|
||||
'loss.generator.cycle': loss_cycle,
|
||||
'loss.generator.gan': loss_gan,
|
||||
'loss.generator.identity': loss_identity})
|
||||
# Save training statistics
|
||||
tracker.save()
|
||||
|
||||
# If at sample interval save image
|
||||
batches_done = epoch * len(self.dataloader) + i
|
||||
|
||||
Reference in New Issue
Block a user