add denormalization function

This commit is contained in:
yunjey
2017-05-23 16:34:11 +09:00
parent 47d70f91ac
commit 6d0c3c8b33
4 changed files with 29 additions and 17 deletions

View File

@ -12,8 +12,11 @@ transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
def denorm(x):
return (x + 1) / 2
# CIFAR-10 Dataset
train_dataset = dsets.CIFAR10(root='../data/',
train_dataset = dsets.CIFAR10(root='./data/',
train=True,
transform=transform,
download=True)
@ -126,9 +129,9 @@ for epoch in range(50):
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
# Save the sampled images
torchvision.utils.save_image(fake_images.data,
torchvision.utils.save_image(denorm(fake_images.data),
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
# Save the Models
torch.save(generator.state_dict(), './generator.pkl')
torch.save(discriminator.state_dict(), './discriminator.pkl')
torch.save(discriminator.state_dict(), './discriminator.pkl')