Update main-gpu.py

This commit is contained in:
Kongsea
2017-07-11 11:43:18 +08:00
committed by GitHub
parent 4bf140b407
commit 9cce6c6df6

View File

@ -57,7 +57,7 @@ optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
for epoch in range(num_epochs): for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader): for i, (images, labels) in enumerate(train_loader):
# Convert torch tensor to Variable # Convert torch tensor to Variable
images = Variable(images.view(-1, 28*28).cuda(), requires_grad=True) images = Variable(images.view(-1, 28*28).cuda())
labels = Variable(labels.cuda()) labels = Variable(labels.cuda())
# Forward + Backward + Optimize # Forward + Backward + Optimize