mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-27 20:13:33 +08:00
vanilla gan added'
This commit is contained in:
@ -1,79 +1,51 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import torchvision.datasets as dsets
|
import torchvision.datasets as dsets
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
|
|
||||||
# Image Preprocessing
|
# Image Preprocessing
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.Scale(36),
|
|
||||||
transforms.RandomCrop(32),
|
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
# CIFAR-10 Dataset
|
# MNIST Dataset
|
||||||
train_dataset = dsets.CIFAR10(root='../data/',
|
train_dataset = dsets.MNIST(root='../data/',
|
||||||
train=True,
|
train=True,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
download=True)
|
download=True)
|
||||||
|
|
||||||
# Data Loader (Input Pipeline)
|
# Data Loader (Input Pipeline)
|
||||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
||||||
batch_size=100,
|
batch_size=100,
|
||||||
shuffle=True)
|
shuffle=True)
|
||||||
|
|
||||||
# 4x4 Convolution
|
|
||||||
def conv4x4(in_channels, out_channels, stride):
|
|
||||||
return nn.Conv2d(in_channels, out_channels, kernel_size=4,
|
|
||||||
stride=stride, padding=1, bias=False)
|
|
||||||
|
|
||||||
# Discriminator Model
|
# Discriminator Model
|
||||||
class Discriminator(nn.Module):
|
class Discriminator(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Discriminator, self).__init__()
|
super(Discriminator, self).__init__()
|
||||||
self.model = nn.Sequential(
|
self.fc1 = nn.Linear(784, 256)
|
||||||
conv4x4(3, 16, 2),
|
self.fc2 = nn.Linear(256, 1)
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
conv4x4(16, 32, 2),
|
|
||||||
nn.BatchNorm2d(32),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
conv4x4(32, 64, 2),
|
|
||||||
nn.BatchNorm2d(64),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
nn.Conv2d(64, 1, kernel_size=4),
|
|
||||||
nn.Sigmoid())
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.model(x)
|
h = F.relu(self.fc1(x))
|
||||||
out = out.view(out.size(0), -1)
|
out = F.sigmoid(self.fc2(h))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# 4x4 Transpose convolution
|
|
||||||
def conv_transpose4x4(in_channels, out_channels, stride=1, padding=1, bias=False):
|
|
||||||
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4,
|
|
||||||
stride=stride, padding=padding, bias=bias)
|
|
||||||
|
|
||||||
# Generator Model
|
# Generator Model
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.model = nn.Sequential(
|
self.fc1 = nn.Linear(128, 256)
|
||||||
conv_transpose4x4(128, 64, padding=0),
|
self.fc2 = nn.Linear(256, 512)
|
||||||
nn.BatchNorm2d(64),
|
self.fc3 = nn.Linear(512, 784)
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
conv_transpose4x4(64, 32, 2),
|
|
||||||
nn.BatchNorm2d(32),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
conv_transpose4x4(32, 16, 2),
|
|
||||||
nn.BatchNorm2d(16),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
conv_transpose4x4(16, 3, 2, bias=True),
|
|
||||||
nn.Tanh())
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x.view(x.size(0), 128, 1, 1)
|
h = F.leaky_relu(self.fc1(x))
|
||||||
out = self.model(x)
|
h = F.leaky_relu(self.fc2(h))
|
||||||
|
out = F.tanh(self.fc3(h))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
discriminator = Discriminator()
|
discriminator = Discriminator()
|
||||||
@ -83,13 +55,14 @@ generator.cuda()
|
|||||||
|
|
||||||
# Loss and Optimizer
|
# Loss and Optimizer
|
||||||
criterion = nn.BCELoss()
|
criterion = nn.BCELoss()
|
||||||
lr = 0.002
|
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0005)
|
||||||
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
|
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0005)
|
||||||
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
for epoch in range(50):
|
for epoch in range(200):
|
||||||
for i, (images, _) in enumerate(train_loader):
|
for i, (images, _) in enumerate(train_loader):
|
||||||
|
# Build mini-batch dataset
|
||||||
|
images = images.view(images.size(0), -1)
|
||||||
images = Variable(images.cuda())
|
images = Variable(images.cuda())
|
||||||
real_labels = Variable(torch.ones(images.size(0))).cuda()
|
real_labels = Variable(torch.ones(images.size(0))).cuda()
|
||||||
fake_labels = Variable(torch.zeros(images.size(0))).cuda()
|
fake_labels = Variable(torch.zeros(images.size(0))).cuda()
|
||||||
@ -119,15 +92,16 @@ for epoch in range(50):
|
|||||||
g_loss.backward()
|
g_loss.backward()
|
||||||
g_optimizer.step()
|
g_optimizer.step()
|
||||||
|
|
||||||
if (i+1) % 100 == 0:
|
if (i+1) % 300 == 0:
|
||||||
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
|
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
|
||||||
'D(x): %.2f, D(G(z)): %.2f'
|
'D(x): %.2f, D(G(z)): %.2f'
|
||||||
%(epoch, 50, i+1, 500, d_loss.data[0], g_loss.data[0],
|
%(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
|
||||||
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
|
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
|
||||||
|
|
||||||
# Save the sampled images
|
# Save the sampled images
|
||||||
torchvision.utils.save_image(fake_images.data,
|
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
||||||
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
|
torchvision.utils.save_image(fake_images.data,
|
||||||
|
'./data2/fake_samples_%d.png' %epoch+1)
|
||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
|
@ -1,79 +1,51 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import torchvision.datasets as dsets
|
import torchvision.datasets as dsets
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
|
|
||||||
# Image Preprocessing
|
# Image Preprocessing
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.Scale(36),
|
|
||||||
transforms.RandomCrop(32),
|
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
# CIFAR-10 Dataset
|
# MNIST Dataset
|
||||||
train_dataset = dsets.CIFAR10(root='../data/',
|
train_dataset = dsets.MNIST(root='../data/',
|
||||||
train=True,
|
train=True,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
download=True)
|
download=True)
|
||||||
|
|
||||||
# Data Loader (Input Pipeline)
|
# Data Loader (Input Pipeline)
|
||||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
||||||
batch_size=100,
|
batch_size=100,
|
||||||
shuffle=True)
|
shuffle=True)
|
||||||
|
|
||||||
# 4x4 Convolution
|
|
||||||
def conv4x4(in_channels, out_channels, stride):
|
|
||||||
return nn.Conv2d(in_channels, out_channels, kernel_size=4,
|
|
||||||
stride=stride, padding=1, bias=False)
|
|
||||||
|
|
||||||
# Discriminator Model
|
# Discriminator Model
|
||||||
class Discriminator(nn.Module):
|
class Discriminator(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Discriminator, self).__init__()
|
super(Discriminator, self).__init__()
|
||||||
self.model = nn.Sequential(
|
self.fc1 = nn.Linear(784, 256)
|
||||||
conv4x4(3, 16, 2),
|
self.fc2 = nn.Linear(256, 1)
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
conv4x4(16, 32, 2),
|
|
||||||
nn.BatchNorm2d(32),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
conv4x4(32, 64, 2),
|
|
||||||
nn.BatchNorm2d(64),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
nn.Conv2d(64, 1, kernel_size=4),
|
|
||||||
nn.Sigmoid())
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.model(x)
|
h = F.relu(self.fc1(x))
|
||||||
out = out.view(out.size(0), -1)
|
out = F.sigmoid(self.fc2(h))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# 4x4 Transpose convolution
|
|
||||||
def conv_transpose4x4(in_channels, out_channels, stride=1, padding=1, bias=False):
|
|
||||||
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4,
|
|
||||||
stride=stride, padding=padding, bias=bias)
|
|
||||||
|
|
||||||
# Generator Model
|
# Generator Model
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.model = nn.Sequential(
|
self.fc1 = nn.Linear(128, 256)
|
||||||
conv_transpose4x4(128, 64, padding=0),
|
self.fc2 = nn.Linear(256, 512)
|
||||||
nn.BatchNorm2d(64),
|
self.fc3 = nn.Linear(512, 784)
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
conv_transpose4x4(64, 32, 2),
|
|
||||||
nn.BatchNorm2d(32),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
conv_transpose4x4(32, 16, 2),
|
|
||||||
nn.BatchNorm2d(16),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
conv_transpose4x4(16, 3, 2, bias=True),
|
|
||||||
nn.Tanh())
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x.view(x.size(0), 128, 1, 1)
|
h = F.leaky_relu(self.fc1(x))
|
||||||
out = self.model(x)
|
h = F.leaky_relu(self.fc2(h))
|
||||||
|
out = F.tanh(self.fc3(h))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
discriminator = Discriminator()
|
discriminator = Discriminator()
|
||||||
@ -83,13 +55,14 @@ generator = Generator()
|
|||||||
|
|
||||||
# Loss and Optimizer
|
# Loss and Optimizer
|
||||||
criterion = nn.BCELoss()
|
criterion = nn.BCELoss()
|
||||||
lr = 0.0002
|
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0005)
|
||||||
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
|
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0005)
|
||||||
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
for epoch in range(50):
|
for epoch in range(200):
|
||||||
for i, (images, _) in enumerate(train_loader):
|
for i, (images, _) in enumerate(train_loader):
|
||||||
|
# Build mini-batch dataset
|
||||||
|
images = images.view(images.size(0), -1)
|
||||||
images = Variable(images)
|
images = Variable(images)
|
||||||
real_labels = Variable(torch.ones(images.size(0)))
|
real_labels = Variable(torch.ones(images.size(0)))
|
||||||
fake_labels = Variable(torch.zeros(images.size(0)))
|
fake_labels = Variable(torch.zeros(images.size(0)))
|
||||||
@ -119,15 +92,16 @@ for epoch in range(50):
|
|||||||
g_loss.backward()
|
g_loss.backward()
|
||||||
g_optimizer.step()
|
g_optimizer.step()
|
||||||
|
|
||||||
if (i+1) % 100 == 0:
|
if (i+1) % 300 == 0:
|
||||||
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
|
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
|
||||||
'D(x): %.2f, D(G(z)): %.2f'
|
'D(x): %.2f, D(G(z)): %.2f'
|
||||||
%(epoch, 50, i+1, 500, d_loss.data[0], g_loss.data[0],
|
%(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
|
||||||
real_score.data.mean(), fake_score.data.mean()))
|
real_score.data.mean(), fake_score.cpu().data.mean()))
|
||||||
|
|
||||||
# Save the sampled images
|
# Save the sampled images
|
||||||
torchvision.utils.save_image(fake_images.data,
|
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
||||||
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
|
torchvision.utils.save_image(fake_images.data,
|
||||||
|
'./data2/fake_samples_%d.png' %epoch+1)
|
||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
|
Reference in New Issue
Block a user