mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-05 16:36:44 +08:00
@ -64,16 +64,20 @@ for epoch in range(200):
|
||||
# Build mini-batch dataset
|
||||
batch_size = images.size(0)
|
||||
images = to_var(images.view(batch_size, -1))
|
||||
|
||||
# Create the labels which are later used as input for the BCE loss
|
||||
real_labels = to_var(torch.ones(batch_size))
|
||||
fake_labels = to_var(torch.zeros(batch_size))
|
||||
|
||||
#============= Train the discriminator =============#
|
||||
# Compute loss with real images
|
||||
# Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
|
||||
# Second term of the loss is always zero since real_labels == 1
|
||||
outputs = D(images)
|
||||
d_loss_real = criterion(outputs, real_labels)
|
||||
real_score = outputs
|
||||
|
||||
# Compute loss with fake images
|
||||
# Compute BCELoss using fake images
|
||||
# First term of the loss is always zero since fake_labels == 0
|
||||
z = to_var(torch.randn(batch_size, 64))
|
||||
fake_images = G(z)
|
||||
outputs = D(fake_images)
|
||||
@ -91,6 +95,9 @@ for epoch in range(200):
|
||||
z = to_var(torch.randn(batch_size, 64))
|
||||
fake_images = G(z)
|
||||
outputs = D(fake_images)
|
||||
|
||||
# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
|
||||
# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
|
||||
g_loss = criterion(outputs, real_labels)
|
||||
|
||||
# Backprop + Optimize
|
||||
@ -116,4 +123,4 @@ for epoch in range(200):
|
||||
|
||||
# Save the trained parameters
|
||||
torch.save(G.state_dict(), './generator.pkl')
|
||||
torch.save(D.state_dict(), './discriminator.pkl')
|
||||
torch.save(D.state_dict(), './discriminator.pkl')
|
||||
|
Reference in New Issue
Block a user