import torch import torch.nn as nn import torchvision.models as models from torch.nn.utils.rnn import pack_padded_sequence from torch.autograd import Variable class EncoderCNN(nn.Module): def __init__(self, embed_size): """Loads the pretrained ResNet-152 and replace top fc layer.""" super(EncoderCNN, self).__init__() self.resnet = models.resnet152(pretrained=True) for param in self.resnet.parameters(): param.requires_grad = False self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size) self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) self.init_weights() def init_weights(self): """Initialize weights.""" self.resnet.fc.weight.data.normal_(0.0, 0.02) self.resnet.fc.bias.data.fill_(0) def forward(self, images): """Extracts the image feature vectors.""" features = self.resnet(images) features = self.bn(features) return features class DecoderRNN(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, num_layers): """Set the hyper-parameters and build the layers.""" super(DecoderRNN, self).__init__() self.embed = nn.Embedding(vocab_size, embed_size) self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, vocab_size) self.init_weights() def init_weights(self): """Initialize weights.""" self.embed.weight.data.uniform_(-0.1, 0.1) self.linear.weight.data.uniform_(-0.1, 0.1) self.linear.bias.data.fill_(0) def forward(self, features, captions, lengths): """Decodes image feature vectors and generates captions.""" embeddings = self.embed(captions) embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) packed = pack_padded_sequence(embeddings, lengths, batch_first=True) hiddens, _ = self.lstm(packed) outputs = self.linear(hiddens[0]) return outputs def sample(self, features, states): """Samples captions for given image features.""" sampled_ids = [] inputs = features.unsqueeze(1) for i in range(20): hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size) outputs = self.linear(hiddens.unsqueeze()) # (batch_size, vocab_size) predicted = outputs.max(1)[1] sampled_ids.append(predicted) inputs = self.embed(predicted) sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20) return sampled_ids