mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-05 16:36:44 +08:00
rearrange code for cuda()
This commit is contained in:
@ -55,12 +55,11 @@ def main():
|
||||
# Set mini-batch dataset
|
||||
images = Variable(images)
|
||||
captions = Variable(captions)
|
||||
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
images = images.cuda()
|
||||
captions = captions.cuda()
|
||||
|
||||
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
|
||||
|
||||
# Forward, Backward and Optimize
|
||||
decoder.zero_grad()
|
||||
encoder.zero_grad()
|
||||
|
Reference in New Issue
Block a user