From 1ec90fad7260e7871a4912fce552cad90f6c2f4a Mon Sep 17 00:00:00 2001 From: yunjey Date: Thu, 28 Sep 2017 15:33:17 +0900 Subject: [PATCH] Update model.py --- tutorials/03-advanced/image_captioning/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tutorials/03-advanced/image_captioning/model.py b/tutorials/03-advanced/image_captioning/model.py index e2d4fa6..03ae8e4 100644 --- a/tutorials/03-advanced/image_captioning/model.py +++ b/tutorials/03-advanced/image_captioning/model.py @@ -64,5 +64,6 @@ class DecoderRNN(nn.Module): predicted = outputs.max(1)[1] sampled_ids.append(predicted) inputs = self.embed(predicted) + inputs = inputs.unsqueeze(1) # (batch_size, 1, embed_size) sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20) - return sampled_ids.squeeze() \ No newline at end of file + return sampled_ids.squeeze()