vanilla gan added'

This commit is contained in:
yunjey
2017-03-26 17:41:51 +09:00
parent 2fe796bb10
commit 8c4dd99de4
2 changed files with 60 additions and 112 deletions

View File

@ -1,79 +1,51 @@
import torch import torch
import torchvision import torchvision
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets import torchvision.datasets as dsets
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.autograd import Variable from torch.autograd import Variable
# Image Preprocessing # Image Preprocessing
transform = transforms.Compose([ transform = transforms.Compose([
transforms.Scale(36),
transforms.RandomCrop(32),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
# CIFAR-10 Dataset # MNIST Dataset
train_dataset = dsets.CIFAR10(root='../data/', train_dataset = dsets.MNIST(root='../data/',
train=True, train=True,
transform=transform, transform=transform,
download=True) download=True)
# Data Loader (Input Pipeline) # Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=100, batch_size=100,
shuffle=True) shuffle=True)
# 4x4 Convolution
def conv4x4(in_channels, out_channels, stride):
return nn.Conv2d(in_channels, out_channels, kernel_size=4,
stride=stride, padding=1, bias=False)
# Discriminator Model # Discriminator Model
class Discriminator(nn.Module): class Discriminator(nn.Module):
def __init__(self): def __init__(self):
super(Discriminator, self).__init__() super(Discriminator, self).__init__()
self.model = nn.Sequential( self.fc1 = nn.Linear(784, 256)
conv4x4(3, 16, 2), self.fc2 = nn.Linear(256, 1)
nn.LeakyReLU(0.2, inplace=True),
conv4x4(16, 32, 2),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2, inplace=True),
conv4x4(32, 64, 2),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 1, kernel_size=4),
nn.Sigmoid())
def forward(self, x): def forward(self, x):
out = self.model(x) h = F.relu(self.fc1(x))
out = out.view(out.size(0), -1) out = F.sigmoid(self.fc2(h))
return out return out
# 4x4 Transpose convolution
def conv_transpose4x4(in_channels, out_channels, stride=1, padding=1, bias=False):
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4,
stride=stride, padding=padding, bias=bias)
# Generator Model # Generator Model
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self): def __init__(self):
super(Generator, self).__init__() super(Generator, self).__init__()
self.model = nn.Sequential( self.fc1 = nn.Linear(128, 256)
conv_transpose4x4(128, 64, padding=0), self.fc2 = nn.Linear(256, 512)
nn.BatchNorm2d(64), self.fc3 = nn.Linear(512, 784)
nn.ReLU(inplace=True),
conv_transpose4x4(64, 32, 2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
conv_transpose4x4(32, 16, 2),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
conv_transpose4x4(16, 3, 2, bias=True),
nn.Tanh())
def forward(self, x): def forward(self, x):
x = x.view(x.size(0), 128, 1, 1) h = F.leaky_relu(self.fc1(x))
out = self.model(x) h = F.leaky_relu(self.fc2(h))
out = F.tanh(self.fc3(h))
return out return out
discriminator = Discriminator() discriminator = Discriminator()
@ -83,13 +55,14 @@ generator.cuda()
# Loss and Optimizer # Loss and Optimizer
criterion = nn.BCELoss() criterion = nn.BCELoss()
lr = 0.002 d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0005)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr) g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0005)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
# Training # Training
for epoch in range(50): for epoch in range(200):
for i, (images, _) in enumerate(train_loader): for i, (images, _) in enumerate(train_loader):
# Build mini-batch dataset
images = images.view(images.size(0), -1)
images = Variable(images.cuda()) images = Variable(images.cuda())
real_labels = Variable(torch.ones(images.size(0))).cuda() real_labels = Variable(torch.ones(images.size(0))).cuda()
fake_labels = Variable(torch.zeros(images.size(0))).cuda() fake_labels = Variable(torch.zeros(images.size(0))).cuda()
@ -119,15 +92,16 @@ for epoch in range(50):
g_loss.backward() g_loss.backward()
g_optimizer.step() g_optimizer.step()
if (i+1) % 100 == 0: if (i+1) % 300 == 0:
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, ' print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
'D(x): %.2f, D(G(z)): %.2f' 'D(x): %.2f, D(G(z)): %.2f'
%(epoch, 50, i+1, 500, d_loss.data[0], g_loss.data[0], %(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
real_score.cpu().data.mean(), fake_score.cpu().data.mean())) real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
# Save the sampled images # Save the sampled images
torchvision.utils.save_image(fake_images.data, fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1)) torchvision.utils.save_image(fake_images.data,
'./data2/fake_samples_%d.png' %epoch+1)
# Save the Models # Save the Models
torch.save(generator.state_dict(), './generator.pkl') torch.save(generator.state_dict(), './generator.pkl')

View File

@ -1,79 +1,51 @@
import torch import torch
import torchvision import torchvision
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets import torchvision.datasets as dsets
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.autograd import Variable from torch.autograd import Variable
# Image Preprocessing # Image Preprocessing
transform = transforms.Compose([ transform = transforms.Compose([
transforms.Scale(36),
transforms.RandomCrop(32),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
# CIFAR-10 Dataset # MNIST Dataset
train_dataset = dsets.CIFAR10(root='../data/', train_dataset = dsets.MNIST(root='../data/',
train=True, train=True,
transform=transform, transform=transform,
download=True) download=True)
# Data Loader (Input Pipeline) # Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=100, batch_size=100,
shuffle=True) shuffle=True)
# 4x4 Convolution
def conv4x4(in_channels, out_channels, stride):
return nn.Conv2d(in_channels, out_channels, kernel_size=4,
stride=stride, padding=1, bias=False)
# Discriminator Model # Discriminator Model
class Discriminator(nn.Module): class Discriminator(nn.Module):
def __init__(self): def __init__(self):
super(Discriminator, self).__init__() super(Discriminator, self).__init__()
self.model = nn.Sequential( self.fc1 = nn.Linear(784, 256)
conv4x4(3, 16, 2), self.fc2 = nn.Linear(256, 1)
nn.LeakyReLU(0.2, inplace=True),
conv4x4(16, 32, 2),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2, inplace=True),
conv4x4(32, 64, 2),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 1, kernel_size=4),
nn.Sigmoid())
def forward(self, x): def forward(self, x):
out = self.model(x) h = F.relu(self.fc1(x))
out = out.view(out.size(0), -1) out = F.sigmoid(self.fc2(h))
return out return out
# 4x4 Transpose convolution
def conv_transpose4x4(in_channels, out_channels, stride=1, padding=1, bias=False):
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4,
stride=stride, padding=padding, bias=bias)
# Generator Model # Generator Model
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self): def __init__(self):
super(Generator, self).__init__() super(Generator, self).__init__()
self.model = nn.Sequential( self.fc1 = nn.Linear(128, 256)
conv_transpose4x4(128, 64, padding=0), self.fc2 = nn.Linear(256, 512)
nn.BatchNorm2d(64), self.fc3 = nn.Linear(512, 784)
nn.ReLU(inplace=True),
conv_transpose4x4(64, 32, 2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
conv_transpose4x4(32, 16, 2),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
conv_transpose4x4(16, 3, 2, bias=True),
nn.Tanh())
def forward(self, x): def forward(self, x):
x = x.view(x.size(0), 128, 1, 1) h = F.leaky_relu(self.fc1(x))
out = self.model(x) h = F.leaky_relu(self.fc2(h))
out = F.tanh(self.fc3(h))
return out return out
discriminator = Discriminator() discriminator = Discriminator()
@ -83,13 +55,14 @@ generator = Generator()
# Loss and Optimizer # Loss and Optimizer
criterion = nn.BCELoss() criterion = nn.BCELoss()
lr = 0.0002 d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0005)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr) g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0005)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
# Training # Training
for epoch in range(50): for epoch in range(200):
for i, (images, _) in enumerate(train_loader): for i, (images, _) in enumerate(train_loader):
# Build mini-batch dataset
images = images.view(images.size(0), -1)
images = Variable(images) images = Variable(images)
real_labels = Variable(torch.ones(images.size(0))) real_labels = Variable(torch.ones(images.size(0)))
fake_labels = Variable(torch.zeros(images.size(0))) fake_labels = Variable(torch.zeros(images.size(0)))
@ -119,16 +92,17 @@ for epoch in range(50):
g_loss.backward() g_loss.backward()
g_optimizer.step() g_optimizer.step()
if (i+1) % 100 == 0: if (i+1) % 300 == 0:
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, ' print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
'D(x): %.2f, D(G(z)): %.2f' 'D(x): %.2f, D(G(z)): %.2f'
%(epoch, 50, i+1, 500, d_loss.data[0], g_loss.data[0], %(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
real_score.data.mean(), fake_score.data.mean())) real_score.data.mean(), fake_score.cpu().data.mean()))
# Save the sampled images # Save the sampled images
torchvision.utils.save_image(fake_images.data, fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1)) torchvision.utils.save_image(fake_images.data,
'./data2/fake_samples_%d.png' %epoch+1)
# Save the Models # Save the Models
torch.save(generator.state_dict(), './generator.pkl') torch.save(generator.state_dict(), './generator.pkl')
torch.save(discriminator.state_dict(), './discriminator.pkl') torch.save(discriminator.state_dict(), './discriminator.pkl')