Update model.py

This commit is contained in:
yunjey
2017-09-28 15:33:17 +09:00
committed by GitHub
parent 17ef894aec
commit 1ec90fad72

View File

@ -64,5 +64,6 @@ class DecoderRNN(nn.Module):
predicted = outputs.max(1)[1]
sampled_ids.append(predicted)
inputs = self.embed(predicted)
inputs = inputs.unsqueeze(1) # (batch_size, 1, embed_size)
sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20)
return sampled_ids.squeeze()
return sampled_ids.squeeze()