diff --git a/tutorials/09 - Image Captioning/train.py b/tutorials/09 - Image Captioning/train.py index eeb41e5..d167fd1 100644 --- a/tutorials/09 - Image Captioning/train.py +++ b/tutorials/09 - Image Captioning/train.py @@ -55,12 +55,11 @@ def main(): # Set mini-batch dataset images = Variable(images) captions = Variable(captions) - targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] - if torch.cuda.is_available(): images = images.cuda() captions = captions.cuda() - + targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] + # Forward, Backward and Optimize decoder.zero_grad() encoder.zero_grad()