image captioning completed'

This commit is contained in:
yunjey
2017-03-21 20:01:47 +09:00
parent ba7d5467be
commit 6f5fda14f0
7 changed files with 297 additions and 68 deletions

View File

@ -53,14 +53,14 @@ class DecoderRNN(nn.Module):
return outputs
def sample(self, features, states):
"""Samples captions for given image features."""
"""Samples captions for given image features (Greedy search)."""
sampled_ids = []
inputs = features.unsqueeze(1)
for i in range(20):
hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size)
outputs = self.linear(hiddens.unsqueeze()) # (batch_size, vocab_size)
outputs = self.linear(hiddens.squeeze(1)) # (batch_size, vocab_size)
predicted = outputs.max(1)[1]
sampled_ids.append(predicted)
inputs = self.embed(predicted)
sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20)
return sampled_ids
sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20)
return sampled_ids.squeeze()