From f159685a61122f61b1a51e4b97e5b3421f7e13e6 Mon Sep 17 00:00:00 2001 From: yunjey Date: Sat, 11 Mar 2017 20:00:02 +0900 Subject: [PATCH] model edited --- tutorials/09 - Image Captioning/model.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tutorials/09 - Image Captioning/model.py b/tutorials/09 - Image Captioning/model.py index 5bbc3cb..448c201 100644 --- a/tutorials/09 - Image Captioning/model.py +++ b/tutorials/09 - Image Captioning/model.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torchvision.models as models -import torch.nn.utils.rnn as rnn_utils +from torch.nn.utils.rnn import pack_padded_sequence from torch.autograd import Variable @@ -31,27 +31,22 @@ class DecoderRNN(nn.Module): self.lstm = nn.LSTM(embed_size, hidden_size, num_layers) self.linear = nn.Linear(hidden_size, vocab_size) - def init_weights(self): - pass - def forward(self, features, captions, lengths): """Decode image feature vectors and generate caption.""" embeddings = self.embed(captions) embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) - packed = rnn_utils.pack_padded_sequence(embeddings, lengths, batch_first=True) # lengths is ok + packed = pack_padded_sequence(embeddings, lengths, batch_first=True) hiddens, _ = self.lstm(packed) outputs = self.linear(hiddens[0]) return outputs def sample(self, feature, state): """Sample a caption for given a image feature.""" - # (batch_size, seq_length, embed_size) - # features: (1, 128) sampled_ids = [] input = feature.unsqueeze(1) for i in range(20): - hidden, state = self.lstm(input, state) # (1, 1, 512) - output = self.linear(hidden.view(-1, self.hidden_size)) # (1, 10000) + hidden, state = self.lstm(input, state) # (1, 1, hidden_size) + output = self.linear(hidden.view(-1, self.hidden_size)) # (1, vocab_size) predicted = output.max(1)[1] sampled_ids.append(predicted) input = self.embed(predicted)