rearrange code for cuda()

This commit is contained in:
yunjey
2017-03-22 22:58:29 +09:00
parent 5061df5123
commit 2fe796bb10

View File

@ -55,12 +55,11 @@ def main():
# Set mini-batch dataset
images = Variable(images)
captions = Variable(captions)
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
if torch.cuda.is_available():
images = images.cuda()
captions = captions.cuda()
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
# Forward, Backward and Optimize
decoder.zero_grad()
encoder.zero_grad()