Files
Varuna Jayasiri d3871d42a6 cycle gan
2020-09-28 18:10:36 +05:30

363 lines
13 KiB
Python

"""
Download datasets from https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/[DATASET NAME].zip
and extract them into labml_nn/data/cycle_gan/[DATASET NAME]
I've taken pieces of code from https://github.com/eriklindernoren/PyTorch-GAN
"""
import itertools
import random
from pathlib import PurePath, Path
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.utils import make_grid
from torchvision.utils import save_image
from labml import lab, tracker, experiment
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
from labml_helpers.training_loop import TrainingLoopConfigs
class ReplayBuffer:
def __init__(self, max_size: int = 50):
self.max_size = max_size
self.data = []
def push_and_pop(self, data):
to_return = []
for element in data:
if len(self.data) < self.max_size:
self.data.append(element)
to_return.append(element)
else:
if random.uniform(0, 1) > 0.5:
i = random.randint(0, self.max_size - 1)
to_return.append(self.data[i].clone())
self.data[i] = element
else:
to_return.append(element)
return torch.stack(to_return)
def load_image(path: str):
image = Image.open(path)
if image.mode != 'RGB':
image = Image.new("RGB", image.size).pase(image)
return image
class ImageDataset(Dataset):
def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str):
root = Path(root)
self.transform = transforms.Compose(transforms_)
self.unaligned = unaligned
self.files_A = sorted(str(f) for f in (root / f'{mode}A').iterdir())
self.files_B = sorted(str(f) for f in (root / f'{mode}B').iterdir())
def __getitem__(self, index):
return {"a": self.transform(load_image(self.files_A[index % len(self.files_A)])),
"b": self.transform(load_image(self.files_B[index % len(self.files_B)]))}
def __len__(self):
return max(len(self.files_A), len(self.files_B))
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
class ResidualBlock(Module):
def __init__(self, in_features: int):
super().__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
)
def __call__(self, x: torch.Tensor):
return x + self.block(x)
class GeneratorResNet(Module):
def __init__(self, input_shape, num_residual_blocks):
super().__init__()
channels = input_shape[0]
# Initial convolution block
out_features = 64
layers = [
nn.ReflectionPad2d(channels),
nn.Conv2d(channels, out_features, 7),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Downsampling
for _ in range(2):
out_features *= 2
layers += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Residual blocks
for _ in range(num_residual_blocks):
layers += [ResidualBlock(out_features)]
# Upsampling
for _ in range(2):
out_features //= 2
layers += [
nn.Upsample(scale_factor=2),
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Output layer
layers += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
self.layers = nn.Sequential(*layers)
self.apply(weights_init_normal)
def __call__(self, x):
return self.layers(x)
class DiscriminatorBlock(Module):
def __init__(self, in_filters, out_filters, normalize=True):
super().__init__()
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
self.layers = nn.Sequential(*layers)
def __call__(self, x: torch.Tensor):
return self.layers(x)
class Discriminator(Module):
def __init__(self, input_shape):
super().__init__()
channels, height, width = input_shape
# Calculate output shape of image discriminator (PatchGAN)
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
self.model = nn.Sequential(
DiscriminatorBlock(channels, 64, normalize=False),
DiscriminatorBlock(64, 128),
DiscriminatorBlock(128, 256),
DiscriminatorBlock(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1)
)
self.apply(weights_init_normal)
def forward(self, img):
return self.model(img)
def sample_images(n, dataset_name, valid_dataloader, generator_ab, generator_ba):
"""Saves a generated sample from the test set"""
batch = next(iter(valid_dataloader))
generator_ab.eval()
generator_ba.eval()
with torch.no_grad():
real_a, real_b = batch['a'].to(generator_ab.device), batch['b'].to(generator_ba.device)
fake_b = generator_ab(real_a)
fake_a = generator_ba(real_b)
# Arange images along x-axis
real_a = make_grid(real_a, nrow=5, normalize=True)
real_b = make_grid(real_b, nrow=5, normalize=True)
fake_a = make_grid(fake_a, nrow=5, normalize=True)
fake_b = make_grid(fake_b, nrow=5, normalize=True)
# arange images along y-axis
image_grid = torch.cat((real_a, fake_b, real_b, fake_a), 1)
save_image(image_grid, "images/%s/%s.png" % (dataset_name, n), normalize=False)
class Configs(TrainingLoopConfigs):
device: torch.device = DeviceConfigs()
loop_count: int = 200
dataset_name: str = 'monet2photo'
batch_size: int = 1
data_loader_workers = 8
is_save_models = True
learning_rate = 0.0002
adam_betas = (0.5, 0.999)
decay_start = 100
identity_loss = torch.nn.L1Loss()
cycle_loss = torch.nn.L1Loss()
gan_loss = torch.nn.MSELoss()
batch_step = 'cycle_gan_batch_step'
img_height = 256
img_width = 256
img_channels = 3
n_residual_blocks = 9
cyclic_loss_coefficient = 10.0
identity_loss_coefficient = 5.
sample_interval = 100
def run(self):
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
input_shape = (self.img_channels, self.img_height, self.img_width)
generator_ab = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device)
generator_ba = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device)
discriminator_a = Discriminator(input_shape).to(self.device)
discriminator_b = Discriminator(input_shape).to(self.device)
generator_optimizer = torch.optim.Adam(
itertools.chain(generator_ab.parameters(), generator_ba.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
discriminator_optimizer = torch.optim.Adam(
itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
decay_epochs = self.loop_count - self.decay_start
generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
# Image transformations
transforms_ = [
transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((self.img_height, self.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Training data loader
dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, 'train'),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.data_loader_workers,
)
# Test data loader
valid_dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, "test"),
batch_size=5,
shuffle=True,
num_workers=self.data_loader_workers,
)
# Buffers of previously generated samples
fake_a_buffer = ReplayBuffer()
fake_b_buffer = ReplayBuffer()
for epoch in self.training_loop:
for i, batch in enumerate(dataloader):
# Set model input
real_a, real_b = batch['a'].to(self.device), batch['b'].to(self.device)
# adversarial ground truths
valid = torch.ones(real_a.size(0), *discriminator_a.output_shape,
device=self.device, requires_grad=False)
fake = torch.zeros(real_a.size(0), *discriminator_a.output_shape,
device=self.device, requires_grad=False)
# Train generators
generator_ab.train()
generator_ba.train()
# Identity loss
loss_identity = self.identity_loss(generator_ba(real_a), real_a) + \
self.identity_loss(generator_ab(real_b), real_b)
# GAN loss
fake_b = generator_ab(real_a)
fake_a = generator_ba(real_b)
loss_gan = self.gan_loss(discriminator_b(fake_b), valid) + \
self.gan_loss(discriminator_a(fake_a), valid)
loss_cycle = self.cycle_loss(generator_ba(fake_b), real_a) + \
self.cycle_loss(generator_ab(fake_a), real_b)
# Total loss
loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle
+ self.identity_loss_coefficient * loss_identity)
generator_optimizer.zero_grad()
loss_generator.backward()
generator_optimizer.step()
# Train discriminators
fake_a_replay = fake_a_buffer.push_and_pop(fake_a)
fake_b_replay = fake_b_buffer.push_and_pop(fake_b)
loss_discriminator = self.gan_loss(discriminator_a(real_a), valid) + \
self.gan_loss(discriminator_a(fake_a_replay.detach()), fake) + \
self.gan_loss(discriminator_b(real_b), valid) + \
self.gan_loss(discriminator_b(fake_b_replay.detach()), fake)
discriminator_optimizer.zero_grad()
loss_discriminator.backward()
discriminator_optimizer.step()
tracker.save({'loss.generator': loss_generator,
'loss.discriminator': loss_discriminator,
'loss.generator.cycle': loss_cycle,
'loss.generator.gan': loss_gan,
'loss.generator.identity': loss_identity})
# If at sample interval save image
batches_done = epoch * len(dataloader) + i
if batches_done % self.sample_interval == 0:
sample_images(batches_done, self.dataset_name, valid_dataloader, generator_ab, generator_ba)
tracker.add_global_step(max(len(real_a), len(real_b)))
# Update learning rates
generator_lr_scheduler.step()
discriminator_lr_scheduler.step()
def main():
conf = Configs()
experiment.create(name='cycle_gan')
experiment.configs(conf, 'run')
with experiment.start():
conf.run()
if __name__ == '__main__':
main()