mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-27 20:13:33 +08:00
add denormalization function
This commit is contained in:
@ -11,8 +11,11 @@ transform = transforms.Compose([
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
def denorm(x):
|
||||||
|
return (x + 1) / 2
|
||||||
|
|
||||||
# MNIST Dataset
|
# MNIST Dataset
|
||||||
train_dataset = dsets.MNIST(root='../data/',
|
train_dataset = dsets.MNIST(root='./data/',
|
||||||
train=True,
|
train=True,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
download=True)
|
download=True)
|
||||||
@ -66,8 +69,8 @@ for epoch in range(200):
|
|||||||
# Build mini-batch dataset
|
# Build mini-batch dataset
|
||||||
images = images.view(images.size(0), -1)
|
images = images.view(images.size(0), -1)
|
||||||
images = Variable(images.cuda())
|
images = Variable(images.cuda())
|
||||||
real_labels = Variable(torch.ones(images.size(0))).cuda()
|
real_labels = Variable(torch.ones(images.size(0)).cuda())
|
||||||
fake_labels = Variable(torch.zeros(images.size(0))).cuda()
|
fake_labels = Variable(torch.zeros(images.size(0)).cuda())
|
||||||
|
|
||||||
# Train the discriminator
|
# Train the discriminator
|
||||||
discriminator.zero_grad()
|
discriminator.zero_grad()
|
||||||
@ -75,7 +78,7 @@ for epoch in range(200):
|
|||||||
real_loss = criterion(outputs, real_labels)
|
real_loss = criterion(outputs, real_labels)
|
||||||
real_score = outputs
|
real_score = outputs
|
||||||
|
|
||||||
noise = Variable(torch.randn(images.size(0), 128)).cuda()
|
noise = Variable(torch.randn(images.size(0), 128).cuda())
|
||||||
fake_images = generator(noise)
|
fake_images = generator(noise)
|
||||||
outputs = discriminator(fake_images.detach())
|
outputs = discriminator(fake_images.detach())
|
||||||
fake_loss = criterion(outputs, fake_labels)
|
fake_loss = criterion(outputs, fake_labels)
|
||||||
@ -87,7 +90,7 @@ for epoch in range(200):
|
|||||||
|
|
||||||
# Train the generator
|
# Train the generator
|
||||||
generator.zero_grad()
|
generator.zero_grad()
|
||||||
noise = Variable(torch.randn(images.size(0), 128)).cuda()
|
noise = Variable(torch.randn(images.size(0), 128).cuda())
|
||||||
fake_images = generator(noise)
|
fake_images = generator(noise)
|
||||||
outputs = discriminator(fake_images)
|
outputs = discriminator(fake_images)
|
||||||
g_loss = criterion(outputs, real_labels)
|
g_loss = criterion(outputs, real_labels)
|
||||||
@ -98,13 +101,13 @@ for epoch in range(200):
|
|||||||
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
|
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
|
||||||
'D(x): %.2f, D(G(z)): %.2f'
|
'D(x): %.2f, D(G(z)): %.2f'
|
||||||
%(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
|
%(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
|
||||||
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
|
real_score.data.mean(), fake_score.cpu().data.mean()))
|
||||||
|
|
||||||
# Save the sampled images
|
# Save the sampled images
|
||||||
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
||||||
torchvision.utils.save_image(fake_images.data,
|
torchvision.utils.save_image(denorm(fake_images.data),
|
||||||
'./data/fake_samples_%d.png' %(epoch+1))
|
'./data/fake_samples_%d.png' %(epoch+1))
|
||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
@ -11,8 +11,11 @@ transform = transforms.Compose([
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
def denorm(x):
|
||||||
|
return (x + 1) / 2
|
||||||
|
|
||||||
# MNIST Dataset
|
# MNIST Dataset
|
||||||
train_dataset = dsets.MNIST(root='../data/',
|
train_dataset = dsets.MNIST(root='./data/',
|
||||||
train=True,
|
train=True,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
download=True)
|
download=True)
|
||||||
@ -102,9 +105,9 @@ for epoch in range(200):
|
|||||||
|
|
||||||
# Save the sampled images
|
# Save the sampled images
|
||||||
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
||||||
torchvision.utils.save_image(fake_images.data,
|
torchvision.utils.save_image(denorm(fake_images.data),
|
||||||
'./data/fake_samples_%d.png' %(epoch+1))
|
'./data/fake_samples_%d.png' %(epoch+1))
|
||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
@ -12,8 +12,11 @@ transform = transforms.Compose([
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
def denorm(x):
|
||||||
|
return (x + 1) / 2
|
||||||
|
|
||||||
# CIFAR-10 Dataset
|
# CIFAR-10 Dataset
|
||||||
train_dataset = dsets.CIFAR10(root='../data/',
|
train_dataset = dsets.CIFAR10(root='./data/',
|
||||||
train=True,
|
train=True,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
download=True)
|
download=True)
|
||||||
@ -126,9 +129,9 @@ for epoch in range(50):
|
|||||||
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
|
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
|
||||||
|
|
||||||
# Save the sampled images
|
# Save the sampled images
|
||||||
torchvision.utils.save_image(fake_images.data,
|
torchvision.utils.save_image(denorm(fake_images.data),
|
||||||
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
|
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
|
||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
@ -12,8 +12,11 @@ transform = transforms.Compose([
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
def denorm(x):
|
||||||
|
return (x + 1) / 2
|
||||||
|
|
||||||
# CIFAR-10 Dataset
|
# CIFAR-10 Dataset
|
||||||
train_dataset = dsets.CIFAR10(root='../data/',
|
train_dataset = dsets.CIFAR10(root='./data/',
|
||||||
train=True,
|
train=True,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
download=True)
|
download=True)
|
||||||
@ -126,9 +129,9 @@ for epoch in range(50):
|
|||||||
real_score.data.mean(), fake_score.data.mean()))
|
real_score.data.mean(), fake_score.data.mean()))
|
||||||
|
|
||||||
# Save the sampled images
|
# Save the sampled images
|
||||||
torchvision.utils.save_image(fake_images.data,
|
torchvision.utils.save_image(denorm(fake_images.data),
|
||||||
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
|
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
|
||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
Reference in New Issue
Block a user