From 3afd026f93fdaccf719f3380e6a5caf5f9c2a2b5 Mon Sep 17 00:00:00 2001 From: yunjey Date: Sat, 11 Mar 2017 17:16:36 +0900 Subject: [PATCH] model code added --- tutorials/09 - Image Captioning/model.py | 58 ++++++++++++++++++++++++ tutorials/09 - Image Captioning/vocab.py | 2 +- 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 tutorials/09 - Image Captioning/model.py diff --git a/tutorials/09 - Image Captioning/model.py b/tutorials/09 - Image Captioning/model.py new file mode 100644 index 0000000..5bbc3cb --- /dev/null +++ b/tutorials/09 - Image Captioning/model.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +import torchvision.models as models +import torch.nn.utils.rnn as rnn_utils +from torch.autograd import Variable + + +class EncoderCNN(nn.Module): + def __init__(self, embed_size): + """Load pretrained ResNet-152 and replace top fc layer.""" + super(EncoderCNN, self).__init__() + self.resnet = models.resnet152(pretrained=True) + self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size) + for param in self.resnet.parameters(): + param.requires_grad = False + + def forward(self, images): + """Extract image feature vectors.""" + features = self.resnet(images) + return features + + +class DecoderRNN(nn.Module): + def __init__(self, embed_size, hidden_size, vocab_size, num_layers): + """Set hyper-parameters and build layers.""" + super(DecoderRNN, self).__init__() + self.embed_size = embed_size + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.embed = nn.Embedding(vocab_size, embed_size) + self.lstm = nn.LSTM(embed_size, hidden_size, num_layers) + self.linear = nn.Linear(hidden_size, vocab_size) + + def init_weights(self): + pass + + def forward(self, features, captions, lengths): + """Decode image feature vectors and generate caption.""" + embeddings = self.embed(captions) + embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) + packed = rnn_utils.pack_padded_sequence(embeddings, lengths, batch_first=True) # lengths is ok + hiddens, _ = self.lstm(packed) + outputs = self.linear(hiddens[0]) + return outputs + + def sample(self, feature, state): + """Sample a caption for given a image feature.""" + # (batch_size, seq_length, embed_size) + # features: (1, 128) + sampled_ids = [] + input = feature.unsqueeze(1) + for i in range(20): + hidden, state = self.lstm(input, state) # (1, 1, 512) + output = self.linear(hidden.view(-1, self.hidden_size)) # (1, 10000) + predicted = output.max(1)[1] + sampled_ids.append(predicted) + input = self.embed(predicted) + return sampled_ids \ No newline at end of file diff --git a/tutorials/09 - Image Captioning/vocab.py b/tutorials/09 - Image Captioning/vocab.py index 2301354..78c16f2 100644 --- a/tutorials/09 - Image Captioning/vocab.py +++ b/tutorials/09 - Image Captioning/vocab.py @@ -55,7 +55,7 @@ def build_vocab(json, threshold): return vocab def main(): - vocab = create_vocab(json='./data/annotations/captions_train2014.json', + vocab = build_vocab(json='./data/annotations/captions_train2014.json', threshold=4) with open('./data/vocab.pkl', 'wb') as f: pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL)