mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-06 01:15:59 +08:00
rearrange code for cuda()
This commit is contained in:
@ -55,11 +55,10 @@ def main():
|
|||||||
# Set mini-batch dataset
|
# Set mini-batch dataset
|
||||||
images = Variable(images)
|
images = Variable(images)
|
||||||
captions = Variable(captions)
|
captions = Variable(captions)
|
||||||
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
images = images.cuda()
|
images = images.cuda()
|
||||||
captions = captions.cuda()
|
captions = captions.cuda()
|
||||||
|
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
|
||||||
|
|
||||||
# Forward, Backward and Optimize
|
# Forward, Backward and Optimize
|
||||||
decoder.zero_grad()
|
decoder.zero_grad()
|
||||||
|
Reference in New Issue
Block a user