mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-05 00:24:02 +08:00
Update main-gpu.py
ref: https://discuss.pytorch.org/t/how-to-get-cuda-variable-gradient/1386
This commit is contained in:
@ -57,8 +57,8 @@ optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
|
||||
for epoch in range(num_epochs):
|
||||
for i, (images, labels) in enumerate(train_loader):
|
||||
# Convert torch tensor to Variable
|
||||
images = Variable(images.view(-1, 28*28)).cuda()
|
||||
labels = Variable(labels).cuda()
|
||||
images = Variable(images.view(-1, 28*28).cuda(), requires_grad=True)
|
||||
labels = Variable(labels.cuda())
|
||||
|
||||
# Forward + Backward + Optimize
|
||||
optimizer.zero_grad() # zero the gradient buffer
|
||||
@ -84,4 +84,4 @@ for images, labels in test_loader:
|
||||
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
|
||||
|
||||
# Save the Model
|
||||
torch.save(net.state_dict(), 'model.pkl')
|
||||
torch.save(net.state_dict(), 'model.pkl')
|
||||
|
Reference in New Issue
Block a user