cycle gan

This commit is contained in:
Varuna Jayasiri
2020-09-28 18:10:36 +05:30
parent 7ef213f89c
commit d3871d42a6

362
labml_nn/gan/cycle_gan.py Normal file
View File

@ -0,0 +1,362 @@
"""
Download datasets from https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/[DATASET NAME].zip
and extract them into labml_nn/data/cycle_gan/[DATASET NAME]
I've taken pieces of code from https://github.com/eriklindernoren/PyTorch-GAN
"""
import itertools
import random
from pathlib import PurePath, Path
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.utils import make_grid
from torchvision.utils import save_image
from labml import lab, tracker, experiment
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
from labml_helpers.training_loop import TrainingLoopConfigs
class ReplayBuffer:
def __init__(self, max_size: int = 50):
self.max_size = max_size
self.data = []
def push_and_pop(self, data):
to_return = []
for element in data:
if len(self.data) < self.max_size:
self.data.append(element)
to_return.append(element)
else:
if random.uniform(0, 1) > 0.5:
i = random.randint(0, self.max_size - 1)
to_return.append(self.data[i].clone())
self.data[i] = element
else:
to_return.append(element)
return torch.stack(to_return)
def load_image(path: str):
image = Image.open(path)
if image.mode != 'RGB':
image = Image.new("RGB", image.size).pase(image)
return image
class ImageDataset(Dataset):
def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str):
root = Path(root)
self.transform = transforms.Compose(transforms_)
self.unaligned = unaligned
self.files_A = sorted(str(f) for f in (root / f'{mode}A').iterdir())
self.files_B = sorted(str(f) for f in (root / f'{mode}B').iterdir())
def __getitem__(self, index):
return {"a": self.transform(load_image(self.files_A[index % len(self.files_A)])),
"b": self.transform(load_image(self.files_B[index % len(self.files_B)]))}
def __len__(self):
return max(len(self.files_A), len(self.files_B))
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
class ResidualBlock(Module):
def __init__(self, in_features: int):
super().__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
)
def __call__(self, x: torch.Tensor):
return x + self.block(x)
class GeneratorResNet(Module):
def __init__(self, input_shape, num_residual_blocks):
super().__init__()
channels = input_shape[0]
# Initial convolution block
out_features = 64
layers = [
nn.ReflectionPad2d(channels),
nn.Conv2d(channels, out_features, 7),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Downsampling
for _ in range(2):
out_features *= 2
layers += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Residual blocks
for _ in range(num_residual_blocks):
layers += [ResidualBlock(out_features)]
# Upsampling
for _ in range(2):
out_features //= 2
layers += [
nn.Upsample(scale_factor=2),
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Output layer
layers += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
self.layers = nn.Sequential(*layers)
self.apply(weights_init_normal)
def __call__(self, x):
return self.layers(x)
class DiscriminatorBlock(Module):
def __init__(self, in_filters, out_filters, normalize=True):
super().__init__()
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
self.layers = nn.Sequential(*layers)
def __call__(self, x: torch.Tensor):
return self.layers(x)
class Discriminator(Module):
def __init__(self, input_shape):
super().__init__()
channels, height, width = input_shape
# Calculate output shape of image discriminator (PatchGAN)
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
self.model = nn.Sequential(
DiscriminatorBlock(channels, 64, normalize=False),
DiscriminatorBlock(64, 128),
DiscriminatorBlock(128, 256),
DiscriminatorBlock(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1)
)
self.apply(weights_init_normal)
def forward(self, img):
return self.model(img)
def sample_images(n, dataset_name, valid_dataloader, generator_ab, generator_ba):
"""Saves a generated sample from the test set"""
batch = next(iter(valid_dataloader))
generator_ab.eval()
generator_ba.eval()
with torch.no_grad():
real_a, real_b = batch['a'].to(generator_ab.device), batch['b'].to(generator_ba.device)
fake_b = generator_ab(real_a)
fake_a = generator_ba(real_b)
# Arange images along x-axis
real_a = make_grid(real_a, nrow=5, normalize=True)
real_b = make_grid(real_b, nrow=5, normalize=True)
fake_a = make_grid(fake_a, nrow=5, normalize=True)
fake_b = make_grid(fake_b, nrow=5, normalize=True)
# arange images along y-axis
image_grid = torch.cat((real_a, fake_b, real_b, fake_a), 1)
save_image(image_grid, "images/%s/%s.png" % (dataset_name, n), normalize=False)
class Configs(TrainingLoopConfigs):
device: torch.device = DeviceConfigs()
loop_count: int = 200
dataset_name: str = 'monet2photo'
batch_size: int = 1
data_loader_workers = 8
is_save_models = True
learning_rate = 0.0002
adam_betas = (0.5, 0.999)
decay_start = 100
identity_loss = torch.nn.L1Loss()
cycle_loss = torch.nn.L1Loss()
gan_loss = torch.nn.MSELoss()
batch_step = 'cycle_gan_batch_step'
img_height = 256
img_width = 256
img_channels = 3
n_residual_blocks = 9
cyclic_loss_coefficient = 10.0
identity_loss_coefficient = 5.
sample_interval = 100
def run(self):
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
input_shape = (self.img_channels, self.img_height, self.img_width)
generator_ab = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device)
generator_ba = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device)
discriminator_a = Discriminator(input_shape).to(self.device)
discriminator_b = Discriminator(input_shape).to(self.device)
generator_optimizer = torch.optim.Adam(
itertools.chain(generator_ab.parameters(), generator_ba.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
discriminator_optimizer = torch.optim.Adam(
itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
lr=self.learning_rate, betas=self.adam_betas)
decay_epochs = self.loop_count - self.decay_start
generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
# Image transformations
transforms_ = [
transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((self.img_height, self.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Training data loader
dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, 'train'),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.data_loader_workers,
)
# Test data loader
valid_dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, "test"),
batch_size=5,
shuffle=True,
num_workers=self.data_loader_workers,
)
# Buffers of previously generated samples
fake_a_buffer = ReplayBuffer()
fake_b_buffer = ReplayBuffer()
for epoch in self.training_loop:
for i, batch in enumerate(dataloader):
# Set model input
real_a, real_b = batch['a'].to(self.device), batch['b'].to(self.device)
# adversarial ground truths
valid = torch.ones(real_a.size(0), *discriminator_a.output_shape,
device=self.device, requires_grad=False)
fake = torch.zeros(real_a.size(0), *discriminator_a.output_shape,
device=self.device, requires_grad=False)
# Train generators
generator_ab.train()
generator_ba.train()
# Identity loss
loss_identity = self.identity_loss(generator_ba(real_a), real_a) + \
self.identity_loss(generator_ab(real_b), real_b)
# GAN loss
fake_b = generator_ab(real_a)
fake_a = generator_ba(real_b)
loss_gan = self.gan_loss(discriminator_b(fake_b), valid) + \
self.gan_loss(discriminator_a(fake_a), valid)
loss_cycle = self.cycle_loss(generator_ba(fake_b), real_a) + \
self.cycle_loss(generator_ab(fake_a), real_b)
# Total loss
loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle
+ self.identity_loss_coefficient * loss_identity)
generator_optimizer.zero_grad()
loss_generator.backward()
generator_optimizer.step()
# Train discriminators
fake_a_replay = fake_a_buffer.push_and_pop(fake_a)
fake_b_replay = fake_b_buffer.push_and_pop(fake_b)
loss_discriminator = self.gan_loss(discriminator_a(real_a), valid) + \
self.gan_loss(discriminator_a(fake_a_replay.detach()), fake) + \
self.gan_loss(discriminator_b(real_b), valid) + \
self.gan_loss(discriminator_b(fake_b_replay.detach()), fake)
discriminator_optimizer.zero_grad()
loss_discriminator.backward()
discriminator_optimizer.step()
tracker.save({'loss.generator': loss_generator,
'loss.discriminator': loss_discriminator,
'loss.generator.cycle': loss_cycle,
'loss.generator.gan': loss_gan,
'loss.generator.identity': loss_identity})
# If at sample interval save image
batches_done = epoch * len(dataloader) + i
if batches_done % self.sample_interval == 0:
sample_images(batches_done, self.dataset_name, valid_dataloader, generator_ab, generator_ba)
tracker.add_global_step(max(len(real_a), len(real_b)))
# Update learning rates
generator_lr_scheduler.step()
discriminator_lr_scheduler.step()
def main():
conf = Configs()
experiment.create(name='cycle_gan')
experiment.configs(conf, 'run')
with experiment.start():
conf.run()
if __name__ == '__main__':
main()