mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-05 16:36:44 +08:00
Update main.py
I got confused by the use of the binary cross entropy. In particular it wasn't clear to me why the variable real_labels are used in the training of the generator. I have added some comments. I am not sure if they are correct, so you might want to double check them.
This commit is contained in:
@ -64,12 +64,14 @@ for epoch in range(200):
|
|||||||
# Build mini-batch dataset
|
# Build mini-batch dataset
|
||||||
batch_size = images.size(0)
|
batch_size = images.size(0)
|
||||||
images = to_var(images.view(batch_size, -1))
|
images = to_var(images.view(batch_size, -1))
|
||||||
|
# Create the labels which are later used as input for the BCE loss
|
||||||
real_labels = to_var(torch.ones(batch_size))
|
real_labels = to_var(torch.ones(batch_size))
|
||||||
fake_labels = to_var(torch.zeros(batch_size))
|
fake_labels = to_var(torch.zeros(batch_size))
|
||||||
|
|
||||||
#============= Train the discriminator =============#
|
#============= Train the discriminator =============#
|
||||||
# Compute loss with real images
|
# Compute loss with real images
|
||||||
outputs = D(images)
|
outputs = D(images)
|
||||||
|
# Apply BCE loss. Second term is always zero since real_labels == 1
|
||||||
d_loss_real = criterion(outputs, real_labels)
|
d_loss_real = criterion(outputs, real_labels)
|
||||||
real_score = outputs
|
real_score = outputs
|
||||||
|
|
||||||
@ -77,6 +79,7 @@ for epoch in range(200):
|
|||||||
z = to_var(torch.randn(batch_size, 64))
|
z = to_var(torch.randn(batch_size, 64))
|
||||||
fake_images = G(z)
|
fake_images = G(z)
|
||||||
outputs = D(fake_images)
|
outputs = D(fake_images)
|
||||||
|
# Apply BCE loss. First term is always zero since fake_labels == 0
|
||||||
d_loss_fake = criterion(outputs, fake_labels)
|
d_loss_fake = criterion(outputs, fake_labels)
|
||||||
fake_score = outputs
|
fake_score = outputs
|
||||||
|
|
||||||
@ -91,6 +94,11 @@ for epoch in range(200):
|
|||||||
z = to_var(torch.randn(batch_size, 64))
|
z = to_var(torch.randn(batch_size, 64))
|
||||||
fake_images = G(z)
|
fake_images = G(z)
|
||||||
outputs = D(fake_images)
|
outputs = D(fake_images)
|
||||||
|
# remember that min log(1-D(G(z))) has the same fix point as max log(D(G(z)))
|
||||||
|
# Here we maximize log(D(G(z))), which is exactly the first term in the BCE loss
|
||||||
|
# with t=1. (see definition of BCE for info on t)
|
||||||
|
# t==1 is valid for real_labels, thus we use them as input for the BCE loss.
|
||||||
|
# Don't get yourself confused by this. It is just convenient to use to the BCE loss.
|
||||||
g_loss = criterion(outputs, real_labels)
|
g_loss = criterion(outputs, real_labels)
|
||||||
|
|
||||||
# Backprop + Optimize
|
# Backprop + Optimize
|
||||||
|
Reference in New Issue
Block a user