modify the model

This commit is contained in:
yunjey
2017-03-13 14:35:34 +09:00
parent eadb0f9580
commit a500ce7396
3 changed files with 32 additions and 12 deletions

View File

@ -10,9 +10,15 @@ class EncoderCNN(nn.Module):
"""Load pretrained ResNet-152 and replace top fc layer."""
super(EncoderCNN, self).__init__()
self.resnet = models.resnet152(pretrained=True)
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size)
# For efficient memory usage.
for param in self.resnet.parameters():
param.requires_grad = False
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size)
self.init_weights()
def init_weights(self):
self.resnet.fc.weight.data.uniform_(-0.1, 0.1)
self.resnet.fc.bias.data.fill_(0)
def forward(self, images):
"""Extract image feature vectors."""
@ -30,6 +36,11 @@ class DecoderRNN(nn.Module):
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
self.linear = nn.Linear(hidden_size, vocab_size)
def init_weights(self):
self.embed.weight.data.uniform_(-0.1, 0.1)
self.linear.weigth.data.uniform_(-0.1, 0.1)
self.linear.bias.data.fill_(0)
def forward(self, features, captions, lengths):
"""Decode image feature vectors and generate caption."""