For truncated BPTT, use built-in detach

For GANs, detach variables of fake images when training discriminator.
This commit is contained in:
Ke Ding
2017-04-23 05:31:20 -07:00
parent 1aab031bd5
commit d38e95c94d
6 changed files with 12 additions and 12 deletions

View File

@ -61,7 +61,7 @@ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Truncated Backpropagation # Truncated Backpropagation
def detach(states): def detach(states):
return [Variable(state.data) for state in states] return [state.detach() for state in states]
# Training # Training
for epoch in range(num_epochs): for epoch in range(num_epochs):

View File

@ -61,7 +61,7 @@ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Truncated Backpropagation # Truncated Backpropagation
def detach(states): def detach(states):
return [Variable(state.data) for state in states] return [state.detach() for state in states]
# Training # Training
for epoch in range(num_epochs): for epoch in range(num_epochs):

View File

@ -77,7 +77,7 @@ for epoch in range(200):
noise = Variable(torch.randn(images.size(0), 128)).cuda() noise = Variable(torch.randn(images.size(0), 128)).cuda()
fake_images = generator(noise) fake_images = generator(noise)
outputs = discriminator(fake_images) outputs = discriminator(fake_images.detach())
fake_loss = criterion(outputs, fake_labels) fake_loss = criterion(outputs, fake_labels)
fake_score = outputs fake_score = outputs

View File

@ -77,7 +77,7 @@ for epoch in range(200):
noise = Variable(torch.randn(images.size(0), 128)) noise = Variable(torch.randn(images.size(0), 128))
fake_images = generator(noise) fake_images = generator(noise)
outputs = discriminator(fake_images) outputs = discriminator(fake_images.detach())
fake_loss = criterion(outputs, fake_labels) fake_loss = criterion(outputs, fake_labels)
fake_score = outputs fake_score = outputs

View File

@ -102,7 +102,7 @@ for epoch in range(50):
noise = Variable(torch.randn(images.size(0), 128)).cuda() noise = Variable(torch.randn(images.size(0), 128)).cuda()
fake_images = generator(noise) fake_images = generator(noise)
outputs = discriminator(fake_images) outputs = discriminator(fake_images.detach())
fake_loss = criterion(outputs, fake_labels) fake_loss = criterion(outputs, fake_labels)
fake_score = outputs fake_score = outputs

View File

@ -102,7 +102,7 @@ for epoch in range(50):
noise = Variable(torch.randn(images.size(0), 128)) noise = Variable(torch.randn(images.size(0), 128))
fake_images = generator(noise) fake_images = generator(noise)
outputs = discriminator(fake_images) outputs = discriminator(fake_images.detch())
fake_loss = criterion(outputs, fake_labels) fake_loss = criterion(outputs, fake_labels)
fake_score = outputs fake_score = outputs