mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-27 20:13:33 +08:00
add denormalization function
This commit is contained in:
@ -12,8 +12,11 @@ transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
||||
|
||||
def denorm(x):
|
||||
return (x + 1) / 2
|
||||
|
||||
# CIFAR-10 Dataset
|
||||
train_dataset = dsets.CIFAR10(root='../data/',
|
||||
train_dataset = dsets.CIFAR10(root='./data/',
|
||||
train=True,
|
||||
transform=transform,
|
||||
download=True)
|
||||
@ -126,9 +129,9 @@ for epoch in range(50):
|
||||
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
|
||||
|
||||
# Save the sampled images
|
||||
torchvision.utils.save_image(fake_images.data,
|
||||
torchvision.utils.save_image(denorm(fake_images.data),
|
||||
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
|
||||
|
||||
# Save the Models
|
||||
torch.save(generator.state_dict(), './generator.pkl')
|
||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
Reference in New Issue
Block a user