model serialization code changed

This commit is contained in:
yunjey
2017-03-18 16:28:19 +09:00
parent 6c74dfab60
commit 3a7a9cf07b
16 changed files with 41 additions and 37 deletions

View File

@ -23,8 +23,8 @@ train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=100,
shuffle=True)
# 5x5 Convolution
def conv5x5(in_channels, out_channels, stride):
# 4x4 Convolution
def conv4x4(in_channels, out_channels, stride):
return nn.Conv2d(in_channels, out_channels, kernel_size=4,
stride=stride, padding=1, bias=False)
@ -33,12 +33,12 @@ class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
conv5x5(3, 16, 2),
conv4x4(3, 16, 2),
nn.LeakyReLU(0.2, inplace=True),
conv5x5(16, 32, 2),
conv4x4(16, 32, 2),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2, inplace=True),
conv5x5(32, 64, 2),
conv4x4(32, 64, 2),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 1, kernel_size=4),
@ -91,8 +91,8 @@ g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
for epoch in range(50):
for i, (images, _) in enumerate(train_loader):
images = Variable(images.cuda())
real_labels = Variable(torch.ones(images.size(0)).cuda())
fake_labels = Variable(torch.zeros(images.size(0)).cuda())
real_labels = Variable(torch.ones(images.size(0))).cuda()
fake_labels = Variable(torch.zeros(images.size(0))).cuda()
# Train the discriminator
discriminator.zero_grad()
@ -100,7 +100,7 @@ for epoch in range(50):
real_loss = criterion(outputs, real_labels)
real_score = outputs
noise = Variable(torch.randn(images.size(0), 128).cuda())
noise = Variable(torch.randn(images.size(0), 128)).cuda()
fake_images = generator(noise)
outputs = discriminator(fake_images)
fake_loss = criterion(outputs, fake_labels)
@ -112,7 +112,7 @@ for epoch in range(50):
# Train the generator
generator.zero_grad()
noise = Variable(torch.randn(images.size(0), 128).cuda())
noise = Variable(torch.randn(images.size(0), 128)).cuda()
fake_images = generator(noise)
outputs = discriminator(fake_images)
g_loss = criterion(outputs, real_labels)
@ -130,5 +130,5 @@ for epoch in range(50):
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
# Save the Models
torch.save(generator, './generator.pkl')
torch.save(discriminator, './discriminator.pkl')
torch.save(generator.state_dict(), './generator.pkl')
torch.save(discriminator.state_dict(), './discriminator.pkl')