mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-07 01:54:41 +08:00
Update main-gpu.py
This commit is contained in:
@ -57,7 +57,7 @@ optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
|
||||
for epoch in range(num_epochs):
|
||||
for i, (images, labels) in enumerate(train_loader):
|
||||
# Convert torch tensor to Variable
|
||||
images = Variable(images.view(-1, 28*28).cuda(), requires_grad=True)
|
||||
images = Variable(images.view(-1, 28*28).cuda())
|
||||
labels = Variable(labels.cuda())
|
||||
|
||||
# Forward + Backward + Optimize
|
||||
|
Reference in New Issue
Block a user