tutorial updated

This commit is contained in:
yunjey
2017-05-28 20:06:40 +09:00
parent c53c48809e
commit c548e2ae9f
71 changed files with 1102 additions and 1123 deletions

View File

@ -0,0 +1,41 @@
## Deep Convolutional GAN
[Generative Adversarial Network](https://arxiv.org/abs/1406.2661) is a generative model that contains a discriminator and a generator. The discriminator is a binary classifier that is trained to classify the real image as real and the fake image as fake. The discriminator is trained to assign 1 to the real image and 0 to the fake image.The generator is a generative model that creates an image from the latent code. The generator is trained to generate an image that can not be distinguishable from the real image in order to deceive the discriminator.
In the [Deep Convolutional GAN(DCGAN)](https://arxiv.org/abs/1511.06434), the authors introduce architecture guidlines for stable GAN training. They replace any pooling layers with strided convolutions (for the discriminator) and fractional-strided convolutions (for the generator) and use batchnorm in both the discriminator and the generator. In addition, they use ReLU activation in the generator and LeakyReLU activation in the discriminator. However, in our case, we use LeakyReLU activation in both models to avoid sparse gradients.
![alt text](png/dcgan.png)
## Usage
#### 1. Install dependencies
```bash
$ pip install -r requirements.txt
```
#### 2. Download the dataset
```bash
$ chmod +x download.sh
$ ./download.sh
```
#### 3. Train the model
```bash
$ python main.py --mode='train'
```
#### 3. Sample the images
```bash
$ python main.py --mode='sample'
```
<br>
## Results
The following is the result on the CelebA dataset.
![alt text](png/sample1.png)
![alt text](png/sample2.png)

View File

@ -0,0 +1,43 @@
import os
from torch.utils import data
from torchvision import transforms
from PIL import Image
class ImageFolder(data.Dataset):
"""Custom Dataset compatible with prebuilt DataLoader.
This is just for tutorial. You can use the prebuilt torchvision.datasets.ImageFolder.
"""
def __init__(self, root, transform=None):
"""Initializes image paths and preprocessing module."""
self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root)))
self.transform = transform
def __getitem__(self, index):
"""Reads an image from a file and preprocesses it and returns."""
image_path = self.image_paths[index]
image = Image.open(image_path).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image
def __len__(self):
"""Returns the total number of image files."""
return len(self.image_paths)
def get_loader(image_path, image_size, batch_size, num_workers=2):
"""Builds and returns Dataloader."""
transform = transforms.Compose([
transforms.Scale(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = ImageFolder(image_path, transform)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return data_loader

View File

@ -0,0 +1,2 @@
wget https://www.dropbox.com/s/e0ig4nf1v94hyj8/CelebA.zip?dl=0 -P ./
unzip CelebA.zip -d ./

View File

@ -0,0 +1,58 @@
import argparse
import os
from solver import Solver
from data_loader import get_loader
from torch.backends import cudnn
def main(config):
cudnn.benchmark = True
data_loader = get_loader(image_path=config.image_path,
image_size=config.image_size,
batch_size=config.batch_size,
num_workers=config.num_workers)
solver = Solver(config, data_loader)
# Create directories if not exist
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
if not os.path.exists(config.sample_path):
os.makedirs(config.sample_path)
# Train and sample the images
if config.mode == 'train':
solver.train()
elif config.mode == 'sample':
solver.sample()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# model hyper-parameters
parser.add_argument('--image_size', type=int, default=64)
parser.add_argument('--z_dim', type=int, default=100)
parser.add_argument('--g_conv_dim', type=int, default=64)
parser.add_argument('--d_conv_dim', type=int, default=64)
# training hyper-parameters
parser.add_argument('--num_epochs', type=int, default=20)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--sample_size', type=int, default=100)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam
parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam
# misc
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--model_path', type=str, default='./models')
parser.add_argument('--sample_path', type=str, default='./samples')
parser.add_argument('--image_path', type=str, default='./CelebA/128_crop')
parser.add_argument('--log_step', type=int , default=10)
parser.add_argument('--sample_step', type=int , default=500)
config = parser.parse_args()
print(config)
main(config)

View File

@ -0,0 +1,59 @@
import torch.nn as nn
import torch.nn.functional as F
def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
"""Custom deconvolutional layer for simplicity."""
layers = []
layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad))
if bn:
layers.append(nn.BatchNorm2d(c_out))
return nn.Sequential(*layers)
class Generator(nn.Module):
"""Generator containing 7 deconvolutional layers."""
def __init__(self, z_dim=256, image_size=128, conv_dim=64):
super(Generator, self).__init__()
self.fc = deconv(z_dim, conv_dim*8, int(image_size/16), 1, 0, bn=False)
self.deconv1 = deconv(conv_dim*8, conv_dim*4, 4)
self.deconv2 = deconv(conv_dim*4, conv_dim*2, 4)
self.deconv3 = deconv(conv_dim*2, conv_dim, 4)
self.deconv4 = deconv(conv_dim, 3, 4, bn=False)
def forward(self, z):
z = z.view(z.size(0), z.size(1), 1, 1) # If image_size is 64, output shape is as below.
out = self.fc(z) # (?, 512, 4, 4)
out = F.leaky_relu(self.deconv1(out), 0.05) # (?, 256, 8, 8)
out = F.leaky_relu(self.deconv2(out), 0.05) # (?, 128, 16, 16)
out = F.leaky_relu(self.deconv3(out), 0.05) # (?, 64, 32, 32)
out = F.tanh(self.deconv4(out)) # (?, 3, 64, 64)
return out
def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
"""Custom convolutional layer for simplicity."""
layers = []
layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad))
if bn:
layers.append(nn.BatchNorm2d(c_out))
return nn.Sequential(*layers)
class Discriminator(nn.Module):
"""Discriminator containing 4 convolutional layers."""
def __init__(self, image_size=128, conv_dim=64):
super(Discriminator, self).__init__()
self.conv1 = conv(3, conv_dim, 4, bn=False)
self.conv2 = conv(conv_dim, conv_dim*2, 4)
self.conv3 = conv(conv_dim*2, conv_dim*4, 4)
self.conv4 = conv(conv_dim*4, conv_dim*8, 4)
self.fc = conv(conv_dim*8, 1, int(image_size/16), 1, 0, False)
def forward(self, x): # If image_size is 64, output shape is as below.
out = F.leaky_relu(self.conv1(x), 0.05) # (?, 64, 32, 32)
out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 16, 16)
out = F.leaky_relu(self.conv3(out), 0.05) # (?, 256, 8, 8)
out = F.leaky_relu(self.conv4(out), 0.05) # (?, 512, 4, 4)
out = self.fc(out).squeeze()
return out

Binary file not shown.

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 992 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 984 KiB

View File

@ -0,0 +1,4 @@
torch
torchvision
Pillow
argparse

View File

@ -0,0 +1,147 @@
import torch
import torchvision
import os
from torch import optim
from torch.autograd import Variable
from model import Discriminator
from model import Generator
class Solver(object):
def __init__(self, config, data_loader):
self.generator = None
self.discriminator = None
self.g_optimizer = None
self.d_optimizer = None
self.g_conv_dim = config.g_conv_dim
self.d_conv_dim = config.d_conv_dim
self.z_dim = config.z_dim
self.beta1 = config.beta1
self.beta2 = config.beta2
self.image_size = config.image_size
self.data_loader = data_loader
self.num_epochs = config.num_epochs
self.batch_size = config.batch_size
self.sample_size = config.sample_size
self.lr = config.lr
self.log_step = config.log_step
self.sample_step = config.sample_step
self.sample_path = config.sample_path
self.model_path = config.model_path
self.build_model()
def build_model(self):
"""Build generator and discriminator."""
self.generator = Generator(z_dim=self.z_dim,
image_size=self.image_size,
conv_dim=self.g_conv_dim)
self.discriminator = Discriminator(image_size=self.image_size,
conv_dim=self.d_conv_dim)
self.g_optimizer = optim.Adam(self.generator.parameters(),
self.lr, [self.beta1, self.beta2])
self.d_optimizer = optim.Adam(self.discriminator.parameters(),
self.lr, [self.beta1, self.beta2])
if torch.cuda.is_available():
self.generator.cuda()
self.discriminator.cuda()
def to_variable(self, x):
"""Convert tensor to variable."""
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)
def to_data(self, x):
"""Convert variable to tensor."""
if torch.cuda.is_available():
x = x.cpu()
return x.data
def reset_grad(self):
"""Zero the gradient buffers."""
self.discriminator.zero_grad()
self.generator.zero_grad()
def denorm(self, x):
"""Convert range (-1, 1) to (0, 1)"""
out = (x + 1) / 2
return out.clamp(0, 1)
def train(self):
"""Train generator and discriminator."""
fixed_noise = self.to_variable(torch.randn(self.batch_size, self.z_dim))
total_step = len(self.data_loader)
for epoch in range(self.num_epochs):
for i, images in enumerate(self.data_loader):
#===================== Train D =====================#
images = self.to_variable(images)
batch_size = images.size(0)
noise = self.to_variable(torch.randn(batch_size, self.z_dim))
# Train D to recognize real images as real.
outputs = self.discriminator(images)
real_loss = torch.mean((outputs - 1) ** 2) # L2 loss instead of Binary cross entropy loss (this is optional for stable training)
# Train D to recognize fake images as fake.
fake_images = self.generator(noise)
outputs = self.discriminator(fake_images)
fake_loss = torch.mean(outputs ** 2)
# Backprop + optimize
d_loss = real_loss + fake_loss
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
#===================== Train G =====================#
noise = self.to_variable(torch.randn(batch_size, self.z_dim))
# Train G so that D recognizes G(z) as real.
fake_images = self.generator(noise)
outputs = self.discriminator(fake_images)
g_loss = torch.mean((outputs - 1) ** 2)
# Backprop + optimize
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
# print the log info
if (i+1) % self.log_step == 0:
print('Epoch [%d/%d], Step[%d/%d], d_real_loss: %.4f, '
'd_fake_loss: %.4f, g_loss: %.4f'
%(epoch+1, self.num_epochs, i+1, total_step,
real_loss.data[0], fake_loss.data[0], g_loss.data[0]))
# save the sampled images
if (i+1) % self.sample_step == 0:
fake_images = self.generator(fixed_noise)
torchvision.utils.save_image(self.denorm(fake_images.data),
os.path.join(self.sample_path,
'fake_samples-%d-%d.png' %(epoch+1, i+1)))
# save the model parameters for each epoch
g_path = os.path.join(self.model_path, 'generator-%d.pkl' %(epoch+1))
d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' %(epoch+1))
torch.save(self.generator.state_dict(), g_path)
torch.save(self.discriminator.state_dict(), d_path)
def sample(self):
# Load trained parameters
g_path = os.path.join(self.model_path, 'generator-%d.pkl' %(self.num_epochs))
d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' %(self.num_epochs))
self.generator.load_state_dict(torch.load(g_path))
self.discriminator.load_state_dict(torch.load(d_path))
self.generator.eval()
self.discriminator.eval()
# Sample the images
noise = self.to_variable(torch.randn(self.sample_size, self.z_dim))
fake_images = self.generator(noise)
sample_path = os.path.join(self.sample_path, 'fake_samples-final.png')
torchvision.utils.save_image(self.denorm(fake_images.data), sample_path, nrow=12)
print("Saved sampled images to '%s'" %sample_path)