captioning modules are edited

This commit is contained in:
yunjey
2017-03-21 01:05:47 +09:00
parent 247de2da86
commit 4fc2b1fa8a
5 changed files with 128 additions and 109 deletions

View File

@ -13,12 +13,13 @@ from pycocotools.coco import COCO
class CocoDataset(data.Dataset): class CocoDataset(data.Dataset):
"""COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
def __init__(self, root, json, vocab, transform=None): def __init__(self, root, json, vocab, transform=None):
""" """Set the path for images, captions and vocabulary wrapper.
Args: Args:
root: image directory. root: image directory.
json: coco annotation file path. json: coco annotation file path.
vocab: vocabulary wrapper. vocab: vocabulary wrapper.
transform: transformer for image. transform: image transformer
""" """
self.root = root self.root = root
self.coco = COCO(json) self.coco = COCO(json)
@ -27,7 +28,7 @@ class CocoDataset(data.Dataset):
self.transform = transform self.transform = transform
def __getitem__(self, index): def __getitem__(self, index):
"""This function should return one data pair(image and caption).""" """Returns one data pair (image and caption)."""
coco = self.coco coco = self.coco
vocab = self.vocab vocab = self.vocab
ann_id = self.ids[index] ann_id = self.ids[index]
@ -53,9 +54,10 @@ class CocoDataset(data.Dataset):
def collate_fn(data): def collate_fn(data):
"""Build mini-batch tensors from a list of (image, caption) tuples. """Creates mini-batch tensors from the list of tuples (image, caption).
Args: Args:
data: list of (image, caption) tuple. data: list of tuple (image, caption).
- image: torch tensor of shape (3, 256, 256). - image: torch tensor of shape (3, 256, 256).
- caption: torch tensor of shape (?); variable length. - caption: torch tensor of shape (?); variable length.
@ -68,10 +70,10 @@ def collate_fn(data):
data.sort(key=lambda x: len(x[1]), reverse=True) data.sort(key=lambda x: len(x[1]), reverse=True)
images, captions = zip(*data) images, captions = zip(*data)
# Merge images (convert tuple of 3D tensor to 4D tensor) # Merge images (from tuple of 3D tensor to 4D tensor)
images = torch.stack(images, 0) images = torch.stack(images, 0)
# Merget captions (convert tuple of 1D tensor to 2D tensor) # Merge captions (from tuple of 1D tensor to 2D tensor)
lengths = [len(cap) for cap in captions] lengths = [len(cap) for cap in captions]
targets = torch.zeros(len(captions), max(lengths)).long() targets = torch.zeros(len(captions), max(lengths)).long()
for i, cap in enumerate(captions): for i, cap in enumerate(captions):
@ -80,18 +82,18 @@ def collate_fn(data):
return images, targets, lengths return images, targets, lengths
def get_loader(root, json, vocab, transform, batch_size=100, shuffle=True, num_workers=2): def get_data_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
"""Returns torch.utils.data.DataLoader for custom coco dataset.""" """Returns torch.utils.data.DataLoader for custom coco dataset."""
# COCO custom dataset # COCO dataset
coco = CocoDataset(root=root, coco = CocoDataset(root=root,
json=json, json=json,
vocab = vocab, vocab = vocab,
transform=transform) transform=transform)
# Data loader # Data loader for COCO dataset
data_loader = torch.utils.data.DataLoader(dataset=coco, data_loader = torch.utils.data.DataLoader(dataset=coco,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=shuffle,
num_workers=num_workers, num_workers=num_workers,
collate_fn=collate_fn) collate_fn=collate_fn)
return data_loader return data_loader

View File

@ -7,43 +7,44 @@ from torch.autograd import Variable
class EncoderCNN(nn.Module): class EncoderCNN(nn.Module):
def __init__(self, embed_size): def __init__(self, embed_size):
"""Load pretrained ResNet-152 and replace top fc layer.""" """Loads the pretrained ResNet-152 and replace top fc layer."""
super(EncoderCNN, self).__init__() super(EncoderCNN, self).__init__()
self.resnet = models.resnet152(pretrained=True) self.resnet = models.resnet152(pretrained=True)
# For efficient memory usage.
for param in self.resnet.parameters(): for param in self.resnet.parameters():
param.requires_grad = False param.requires_grad = False
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size) self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size)
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self):
self.resnet.fc.weight.data.uniform_(-0.1, 0.1) """Initialize weights."""
self.resnet.fc.weight.data.normal_(0.0, 0.02)
self.resnet.fc.bias.data.fill_(0) self.resnet.fc.bias.data.fill_(0)
def forward(self, images): def forward(self, images):
"""Extract image feature vectors.""" """Extracts the image feature vectors."""
features = self.resnet(images) features = self.resnet(images)
features = self.bn(features)
return features return features
class DecoderRNN(nn.Module): class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers): def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
"""Set hyper-parameters and build layers.""" """Set the hyper-parameters and build the layers."""
super(DecoderRNN, self).__init__() super(DecoderRNN, self).__init__()
self.embed_size = embed_size
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.embed = nn.Embedding(vocab_size, embed_size) self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers) self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size) self.linear = nn.Linear(hidden_size, vocab_size)
self.init_weights()
def init_weights(self): def init_weights(self):
"""Initialize weights."""
self.embed.weight.data.uniform_(-0.1, 0.1) self.embed.weight.data.uniform_(-0.1, 0.1)
self.linear.weigth.data.uniform_(-0.1, 0.1) self.linear.weight.data.uniform_(-0.1, 0.1)
self.linear.bias.data.fill_(0) self.linear.bias.data.fill_(0)
def forward(self, features, captions, lengths): def forward(self, features, captions, lengths):
"""Decode image feature vectors and generate caption.""" """Decodes image feature vectors and generates captions."""
embeddings = self.embed(captions) embeddings = self.embed(captions)
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
packed = pack_padded_sequence(embeddings, lengths, batch_first=True) packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
@ -51,14 +52,15 @@ class DecoderRNN(nn.Module):
outputs = self.linear(hiddens[0]) outputs = self.linear(hiddens[0])
return outputs return outputs
def sample(self, feature, state): def sample(self, features, states):
"""Sample a caption for given a image feature.""" """Samples captions for given image features."""
sampled_ids = [] sampled_ids = []
input = feature.unsqueeze(1) inputs = features.unsqueeze(1)
for i in range(20): for i in range(20):
hidden, state = self.lstm(input, state) # (1, 1, hidden_size) hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size)
output = self.linear(hidden.view(-1, self.hidden_size)) # (1, vocab_size) outputs = self.linear(hiddens.unsqueeze()) # (batch_size, vocab_size)
predicted = output.max(1)[1] predicted = outputs.max(1)[1]
sampled_ids.append(predicted) sampled_ids.append(predicted)
input = self.embed(predicted) inputs = self.embed(predicted)
sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20)
return sampled_ids return sampled_ids

View File

@ -1,34 +1,34 @@
from PIL import Image from PIL import Image
from configuration import Config
import os import os
def resize_image(image, size): def resize_image(image, size):
"""Resizes an image to the given size.""" """Resizes the image to the given size."""
return image.resize(size, Image.ANTIALIAS) return image.resize(size, Image.ANTIALIAS)
def resize_images(image_dir, output_dir, size): def resize_images(image_dir, output_dir, size):
"""Resizes the images in the image_dir and save into the output_dir.""" """Resizes the images in 'image_dir' and save them in 'output_dir'."""
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
images = os.listdir(image_dir) images = os.listdir(image_dir)
num_images = len(images) num_images = len(images)
for i, image in enumerate(images): for i, image in enumerate(images):
with open(os.path.join(image_dir, image), 'r+b') as f: with open(os.path.join(image_dir, image), 'r+b') as f:
with Image.open(f) as img: with Image.open(f) as img:
img = resize_image(img, size) img = resize_image(img, size)
img.save( img.save(os.path.join(output_dir, image), img.format)
os.path.join(output_dir, image), img.format)
if i % 100 == 0: if i % 100 == 0:
print ('[%d/%d] Resized the images and saved into %s.' print ('[%d/%d] Resized the images and saved them in %s.'
%(i, num_images, output_dir)) %(i, num_images, output_dir))
def main(): def main():
config = Config()
splits = ['train', 'val'] splits = ['train', 'val']
for split in splits: for split in splits:
image_dir = './data/%s2014/' %split image_dir = os.path.join(config.image_path, '%s2014/' %split)
output_dir = './data/%s2014resized' %split output_dir = os.path.join(config.image_path, '%s2014resized' %split)
resize_images(image_dir, output_dir, (256, 256)) resize_images(image_dir, output_dir, (config.image_size, config.image_size))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,55 +1,57 @@
from data import get_loader from data import get_data_loader
from vocab import Vocabulary from vocab import Vocabulary
from configuration import Config
from model import EncoderCNN, DecoderRNN from model import EncoderCNN, DecoderRNN
from torch.autograd import Variable from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence from torch.nn.utils.rnn import pack_padded_sequence
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
import torchvision.transforms as T import torchvision.transforms as T
import numpy as np
import pickle import pickle
import os
# Hyper Parameters
num_epochs = 1
batch_size = 32
embed_size = 256
hidden_size = 512
crop_size = 224
num_layers = 1
learning_rate = 0.001
train_image_path = './data/train2014resized/'
train_json_path = './data/annotations/captions_train2014.json'
# Image Preprocessing def main():
# Configuration for hyper-parameters
config = Config()
# Image preprocessing
transform = T.Compose([ transform = T.Compose([
T.RandomCrop(crop_size), T.Scale(config.image_size), # no resize
T.RandomCrop(config.crop_size),
T.RandomHorizontalFlip(), T.RandomHorizontalFlip(),
T.ToTensor(), T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Load Vocabulary Wrapper # Load vocabulary wrapper
with open('./data/vocab.pkl', 'rb') as f: with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
vocab = pickle.load(f) vocab = pickle.load(f)
# Build Dataset Loader # Build data loader
train_loader = get_loader(train_image_path, train_json_path, vocab, transform, image_path = os.path.join(config.image_path, 'train2014')
batch_size=batch_size, shuffle=True, num_workers=2) json_path = os.path.join(config.caption_path, 'captions_train2014.json')
train_loader = get_data_loader(image_path, json_path, vocab,
transform, config.batch_size,
shuffle=True, num_workers=config.num_threads)
total_step = len(train_loader) total_step = len(train_loader)
# Build Models # Build Models
encoder = EncoderCNN(embed_size) encoder = EncoderCNN(config.embed_size)
decoder = DecoderRNN(embed_size, hidden_size, len(vocab), num_layers) decoder = DecoderRNN(config.embed_size, config.hidden_size,
len(vocab), config.num_layers)
encoder.cuda() encoder.cuda()
decoder.cuda() decoder.cuda()
# Loss and Optimizer # Loss and Optimizer
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters()) params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate) optimizer = torch.optim.Adam(params, lr=config.learning_rate)
# Train the Decoder # Train the Models
for epoch in range(num_epochs): for epoch in range(config.num_epochs):
for i, (images, captions, lengths) in enumerate(train_loader): for i, (images, captions, lengths) in enumerate(train_loader):
# Set mini-batch dataset # Set mini-batch dataset
images = Variable(images).cuda() images = Variable(images).cuda()
captions = Variable(captions).cuda() captions = Variable(captions).cuda()
@ -57,16 +59,26 @@ for epoch in range(num_epochs):
# Forward, Backward and Optimize # Forward, Backward and Optimize
decoder.zero_grad() decoder.zero_grad()
encoder.zero_grad()
features = encoder(images) features = encoder(images)
outputs = decoder(features, captions, lengths) outputs = decoder(features, captions, lengths)
loss = criterion(outputs, targets) loss = criterion(outputs, targets)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if i % 100 == 0: # Print log info
if i % config.log_step == 0:
print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f'
%(epoch, num_epochs, i, total_step, loss.data[0], np.exp(loss.data[0]))) %(epoch, config.num_epochs, i, total_step,
loss.data[0], np.exp(loss.data[0])))
# Save the Model # Save the Model
torch.save(decoder, 'decoder.pkl') if (i+1) % config.save_step == 0:
torch.save(encoder, 'encoder.pkl') torch.save(decoder.state_dict(),
os.path.join(config.model_path,
'decoder-%d-%d.pkl' %(epoch+1, i+1)))
torch.save(encoder.state_dict(),
os.path.join(config.model_path,
'encoder-%d-%d.pkl' %(epoch+1, i+1)))
if __name__ == '__main__':
main()

View File

@ -1,6 +1,7 @@
# Create a vocabulary wrapper
import nltk import nltk
import pickle import pickle
import os
from configuration import Config
from collections import Counter from collections import Counter
from pycocotools.coco import COCO from pycocotools.coco import COCO
@ -27,7 +28,7 @@ class Vocabulary(object):
return len(self.word2idx) return len(self.word2idx)
def build_vocab(json, threshold): def build_vocab(json, threshold):
"""Build a simple vocabulary wrapper.""" """Builds a simple vocabulary wrapper."""
coco = COCO(json) coco = COCO(json)
counter = Counter() counter = Counter()
ids = coco.anns.keys() ids = coco.anns.keys()
@ -37,29 +38,31 @@ def build_vocab(json, threshold):
counter.update(tokens) counter.update(tokens)
if i % 1000 == 0: if i % 1000 == 0:
print("[%d/%d] tokenized the captions." %(i, len(ids))) print("[%d/%d] Tokenized the captions." %(i, len(ids)))
# Discard if the occurrence of the word is less than min_word_cnt. # If the word frequency is less than 'threshold', then the word is discarded.
words = [word for word, cnt in counter.items() if cnt >= threshold] words = [word for word, cnt in counter.items() if cnt >= threshold]
# Create a vocab wrapper and add some special tokens. # Creates a vocab wrapper and add some special tokens.
vocab = Vocabulary() vocab = Vocabulary()
vocab.add_word('<pad>') vocab.add_word('<pad>')
vocab.add_word('<start>') vocab.add_word('<start>')
vocab.add_word('<end>') vocab.add_word('<end>')
vocab.add_word('<unk>') vocab.add_word('<unk>')
# Add words to the vocabulary. # Adds the words to the vocabulary.
for i, word in enumerate(words): for i, word in enumerate(words):
vocab.add_word(word) vocab.add_word(word)
return vocab return vocab
def main(): def main():
vocab = build_vocab(json='./data/annotations/captions_train2014.json', config = Config()
threshold=4) vocab = build_vocab(json=os.path.join(config.caption_path, 'captions_train2014.json'),
with open('./data/vocab.pkl', 'wb') as f: threshold=config.word_count_threshold)
vocab_path = os.path.join(config.vocab_path, 'vocab.pkl')
with open(vocab_path, 'wb') as f:
pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL)
print("Saved vocabulary file to ", './data/vocab.pkl') print("Saved the vocabulary wrapper to ", vocab_path)
if __name__ == '__main__': if __name__ == '__main__':
main() main()