mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-27 20:13:33 +08:00
add denormalization function
This commit is contained in:
@ -11,8 +11,11 @@ transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
||||
|
||||
def denorm(x):
|
||||
return (x + 1) / 2
|
||||
|
||||
# MNIST Dataset
|
||||
train_dataset = dsets.MNIST(root='../data/',
|
||||
train_dataset = dsets.MNIST(root='./data/',
|
||||
train=True,
|
||||
transform=transform,
|
||||
download=True)
|
||||
@ -66,8 +69,8 @@ for epoch in range(200):
|
||||
# Build mini-batch dataset
|
||||
images = images.view(images.size(0), -1)
|
||||
images = Variable(images.cuda())
|
||||
real_labels = Variable(torch.ones(images.size(0))).cuda()
|
||||
fake_labels = Variable(torch.zeros(images.size(0))).cuda()
|
||||
real_labels = Variable(torch.ones(images.size(0)).cuda())
|
||||
fake_labels = Variable(torch.zeros(images.size(0)).cuda())
|
||||
|
||||
# Train the discriminator
|
||||
discriminator.zero_grad()
|
||||
@ -75,7 +78,7 @@ for epoch in range(200):
|
||||
real_loss = criterion(outputs, real_labels)
|
||||
real_score = outputs
|
||||
|
||||
noise = Variable(torch.randn(images.size(0), 128)).cuda()
|
||||
noise = Variable(torch.randn(images.size(0), 128).cuda())
|
||||
fake_images = generator(noise)
|
||||
outputs = discriminator(fake_images.detach())
|
||||
fake_loss = criterion(outputs, fake_labels)
|
||||
@ -87,7 +90,7 @@ for epoch in range(200):
|
||||
|
||||
# Train the generator
|
||||
generator.zero_grad()
|
||||
noise = Variable(torch.randn(images.size(0), 128)).cuda()
|
||||
noise = Variable(torch.randn(images.size(0), 128).cuda())
|
||||
fake_images = generator(noise)
|
||||
outputs = discriminator(fake_images)
|
||||
g_loss = criterion(outputs, real_labels)
|
||||
@ -98,13 +101,13 @@ for epoch in range(200):
|
||||
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
|
||||
'D(x): %.2f, D(G(z)): %.2f'
|
||||
%(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
|
||||
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
|
||||
real_score.data.mean(), fake_score.cpu().data.mean()))
|
||||
|
||||
# Save the sampled images
|
||||
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
||||
torchvision.utils.save_image(fake_images.data,
|
||||
torchvision.utils.save_image(denorm(fake_images.data),
|
||||
'./data/fake_samples_%d.png' %(epoch+1))
|
||||
|
||||
# Save the Models
|
||||
torch.save(generator.state_dict(), './generator.pkl')
|
||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
Reference in New Issue
Block a user