mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-26 19:48:34 +08:00
For truncated BPTT, use built-in detach
For GANs, detach variables of fake images when training discriminator.
This commit is contained in:
@ -61,7 +61,7 @@ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
# Truncated Backpropagation
|
||||
def detach(states):
|
||||
return [Variable(state.data) for state in states]
|
||||
return [state.detach() for state in states]
|
||||
|
||||
# Training
|
||||
for epoch in range(num_epochs):
|
||||
@ -119,4 +119,4 @@ with open(sample_path, 'w') as f:
|
||||
print('Sampled [%d/%d] words and save to %s'%(i+1, num_samples, sample_path))
|
||||
|
||||
# Save the Trained Model
|
||||
torch.save(model.state_dict(), 'model.pkl')
|
||||
torch.save(model.state_dict(), 'model.pkl')
|
||||
|
Reference in New Issue
Block a user