diff --git a/tutorials/02-intermediate/generative_adversarial_network/main.py b/tutorials/02-intermediate/generative_adversarial_network/main.py index 29b7f35..23964ff 100644 --- a/tutorials/02-intermediate/generative_adversarial_network/main.py +++ b/tutorials/02-intermediate/generative_adversarial_network/main.py @@ -64,12 +64,14 @@ for epoch in range(200): # Build mini-batch dataset batch_size = images.size(0) images = to_var(images.view(batch_size, -1)) + # Create the labels which are later used as input for the BCE loss real_labels = to_var(torch.ones(batch_size)) fake_labels = to_var(torch.zeros(batch_size)) #============= Train the discriminator =============# # Compute loss with real images outputs = D(images) + # Apply BCE loss. Second term is always zero since real_labels == 1 d_loss_real = criterion(outputs, real_labels) real_score = outputs @@ -77,6 +79,7 @@ for epoch in range(200): z = to_var(torch.randn(batch_size, 64)) fake_images = G(z) outputs = D(fake_images) + # Apply BCE loss. First term is always zero since fake_labels == 0 d_loss_fake = criterion(outputs, fake_labels) fake_score = outputs @@ -91,6 +94,11 @@ for epoch in range(200): z = to_var(torch.randn(batch_size, 64)) fake_images = G(z) outputs = D(fake_images) + # remember that min log(1-D(G(z))) has the same fix point as max log(D(G(z))) + # Here we maximize log(D(G(z))), which is exactly the first term in the BCE loss + # with t=1. (see definition of BCE for info on t) + # t==1 is valid for real_labels, thus we use them as input for the BCE loss. + # Don't get yourself confused by this. It is just convenient to use to the BCE loss. g_loss = criterion(outputs, real_labels) # Backprop + Optimize @@ -116,4 +124,4 @@ for epoch in range(200): # Save the trained parameters torch.save(G.state_dict(), './generator.pkl') -torch.save(D.state_dict(), './discriminator.pkl') \ No newline at end of file +torch.save(D.state_dict(), './discriminator.pkl')