add denormalization function

This commit is contained in:
yunjey
2017-05-23 16:34:11 +09:00
parent 47d70f91ac
commit 6d0c3c8b33
4 changed files with 29 additions and 17 deletions

View File

@ -11,8 +11,11 @@ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
def denorm(x):
return (x + 1) / 2
# MNIST Dataset # MNIST Dataset
train_dataset = dsets.MNIST(root='../data/', train_dataset = dsets.MNIST(root='./data/',
train=True, train=True,
transform=transform, transform=transform,
download=True) download=True)
@ -66,8 +69,8 @@ for epoch in range(200):
# Build mini-batch dataset # Build mini-batch dataset
images = images.view(images.size(0), -1) images = images.view(images.size(0), -1)
images = Variable(images.cuda()) images = Variable(images.cuda())
real_labels = Variable(torch.ones(images.size(0))).cuda() real_labels = Variable(torch.ones(images.size(0)).cuda())
fake_labels = Variable(torch.zeros(images.size(0))).cuda() fake_labels = Variable(torch.zeros(images.size(0)).cuda())
# Train the discriminator # Train the discriminator
discriminator.zero_grad() discriminator.zero_grad()
@ -75,7 +78,7 @@ for epoch in range(200):
real_loss = criterion(outputs, real_labels) real_loss = criterion(outputs, real_labels)
real_score = outputs real_score = outputs
noise = Variable(torch.randn(images.size(0), 128)).cuda() noise = Variable(torch.randn(images.size(0), 128).cuda())
fake_images = generator(noise) fake_images = generator(noise)
outputs = discriminator(fake_images.detach()) outputs = discriminator(fake_images.detach())
fake_loss = criterion(outputs, fake_labels) fake_loss = criterion(outputs, fake_labels)
@ -87,7 +90,7 @@ for epoch in range(200):
# Train the generator # Train the generator
generator.zero_grad() generator.zero_grad()
noise = Variable(torch.randn(images.size(0), 128)).cuda() noise = Variable(torch.randn(images.size(0), 128).cuda())
fake_images = generator(noise) fake_images = generator(noise)
outputs = discriminator(fake_images) outputs = discriminator(fake_images)
g_loss = criterion(outputs, real_labels) g_loss = criterion(outputs, real_labels)
@ -98,11 +101,11 @@ for epoch in range(200):
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, ' print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
'D(x): %.2f, D(G(z)): %.2f' 'D(x): %.2f, D(G(z)): %.2f'
%(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0], %(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
real_score.cpu().data.mean(), fake_score.cpu().data.mean())) real_score.data.mean(), fake_score.cpu().data.mean()))
# Save the sampled images # Save the sampled images
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28) fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
torchvision.utils.save_image(fake_images.data, torchvision.utils.save_image(denorm(fake_images.data),
'./data/fake_samples_%d.png' %(epoch+1)) './data/fake_samples_%d.png' %(epoch+1))
# Save the Models # Save the Models

View File

@ -11,8 +11,11 @@ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
def denorm(x):
return (x + 1) / 2
# MNIST Dataset # MNIST Dataset
train_dataset = dsets.MNIST(root='../data/', train_dataset = dsets.MNIST(root='./data/',
train=True, train=True,
transform=transform, transform=transform,
download=True) download=True)
@ -102,7 +105,7 @@ for epoch in range(200):
# Save the sampled images # Save the sampled images
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28) fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
torchvision.utils.save_image(fake_images.data, torchvision.utils.save_image(denorm(fake_images.data),
'./data/fake_samples_%d.png' %(epoch+1)) './data/fake_samples_%d.png' %(epoch+1))
# Save the Models # Save the Models

View File

@ -12,8 +12,11 @@ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
def denorm(x):
return (x + 1) / 2
# CIFAR-10 Dataset # CIFAR-10 Dataset
train_dataset = dsets.CIFAR10(root='../data/', train_dataset = dsets.CIFAR10(root='./data/',
train=True, train=True,
transform=transform, transform=transform,
download=True) download=True)
@ -126,7 +129,7 @@ for epoch in range(50):
real_score.cpu().data.mean(), fake_score.cpu().data.mean())) real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
# Save the sampled images # Save the sampled images
torchvision.utils.save_image(fake_images.data, torchvision.utils.save_image(denorm(fake_images.data),
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1)) './data/fake_samples_%d_%d.png' %(epoch+1, i+1))
# Save the Models # Save the Models

View File

@ -12,8 +12,11 @@ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
def denorm(x):
return (x + 1) / 2
# CIFAR-10 Dataset # CIFAR-10 Dataset
train_dataset = dsets.CIFAR10(root='../data/', train_dataset = dsets.CIFAR10(root='./data/',
train=True, train=True,
transform=transform, transform=transform,
download=True) download=True)
@ -126,7 +129,7 @@ for epoch in range(50):
real_score.data.mean(), fake_score.data.mean())) real_score.data.mean(), fake_score.data.mean()))
# Save the sampled images # Save the sampled images
torchvision.utils.save_image(fake_images.data, torchvision.utils.save_image(denorm(fake_images.data),
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1)) './data/fake_samples_%d_%d.png' %(epoch+1, i+1))
# Save the Models # Save the Models