diff --git a/tutorials/09 - Image Captioning/data.py b/tutorials/09 - Image Captioning/data.py index a244fa2..dccaf84 100644 --- a/tutorials/09 - Image Captioning/data.py +++ b/tutorials/09 - Image Captioning/data.py @@ -13,12 +13,13 @@ from pycocotools.coco import COCO class CocoDataset(data.Dataset): """COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" def __init__(self, root, json, vocab, transform=None): - """ + """Set the path for images, captions and vocabulary wrapper. + Args: root: image directory. json: coco annotation file path. vocab: vocabulary wrapper. - transform: transformer for image. + transform: image transformer """ self.root = root self.coco = COCO(json) @@ -27,7 +28,7 @@ class CocoDataset(data.Dataset): self.transform = transform def __getitem__(self, index): - """This function should return one data pair(image and caption).""" + """Returns one data pair (image and caption).""" coco = self.coco vocab = self.vocab ann_id = self.ids[index] @@ -53,12 +54,13 @@ class CocoDataset(data.Dataset): def collate_fn(data): - """Build mini-batch tensors from a list of (image, caption) tuples. + """Creates mini-batch tensors from the list of tuples (image, caption). + Args: - data: list of (image, caption) tuple. + data: list of tuple (image, caption). - image: torch tensor of shape (3, 256, 256). - caption: torch tensor of shape (?); variable length. - + Returns: images: torch tensor of shape (batch_size, 3, 256, 256). targets: torch tensor of shape (batch_size, padded_length). @@ -68,10 +70,10 @@ def collate_fn(data): data.sort(key=lambda x: len(x[1]), reverse=True) images, captions = zip(*data) - # Merge images (convert tuple of 3D tensor to 4D tensor) + # Merge images (from tuple of 3D tensor to 4D tensor) images = torch.stack(images, 0) - # Merget captions (convert tuple of 1D tensor to 2D tensor) + # Merge captions (from tuple of 1D tensor to 2D tensor) lengths = [len(cap) for cap in captions] targets = torch.zeros(len(captions), max(lengths)).long() for i, cap in enumerate(captions): @@ -80,18 +82,18 @@ def collate_fn(data): return images, targets, lengths -def get_loader(root, json, vocab, transform, batch_size=100, shuffle=True, num_workers=2): +def get_data_loader(root, json, vocab, transform, batch_size, shuffle, num_workers): """Returns torch.utils.data.DataLoader for custom coco dataset.""" - # COCO custom dataset + # COCO dataset coco = CocoDataset(root=root, json=json, vocab = vocab, transform=transform) - # Data loader + # Data loader for COCO dataset data_loader = torch.utils.data.DataLoader(dataset=coco, batch_size=batch_size, - shuffle=True, + shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn) return data_loader \ No newline at end of file diff --git a/tutorials/09 - Image Captioning/model.py b/tutorials/09 - Image Captioning/model.py index fcef106..6e31514 100644 --- a/tutorials/09 - Image Captioning/model.py +++ b/tutorials/09 - Image Captioning/model.py @@ -7,43 +7,44 @@ from torch.autograd import Variable class EncoderCNN(nn.Module): def __init__(self, embed_size): - """Load pretrained ResNet-152 and replace top fc layer.""" + """Loads the pretrained ResNet-152 and replace top fc layer.""" super(EncoderCNN, self).__init__() self.resnet = models.resnet152(pretrained=True) - # For efficient memory usage. for param in self.resnet.parameters(): param.requires_grad = False self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size) + self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) self.init_weights() - + def init_weights(self): - self.resnet.fc.weight.data.uniform_(-0.1, 0.1) + """Initialize weights.""" + self.resnet.fc.weight.data.normal_(0.0, 0.02) self.resnet.fc.bias.data.fill_(0) def forward(self, images): - """Extract image feature vectors.""" + """Extracts the image feature vectors.""" features = self.resnet(images) + features = self.bn(features) return features class DecoderRNN(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, num_layers): - """Set hyper-parameters and build layers.""" + """Set the hyper-parameters and build the layers.""" super(DecoderRNN, self).__init__() - self.embed_size = embed_size - self.hidden_size = hidden_size - self.vocab_size = vocab_size self.embed = nn.Embedding(vocab_size, embed_size) - self.lstm = nn.LSTM(embed_size, hidden_size, num_layers) + self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, vocab_size) + self.init_weights() def init_weights(self): + """Initialize weights.""" self.embed.weight.data.uniform_(-0.1, 0.1) - self.linear.weigth.data.uniform_(-0.1, 0.1) + self.linear.weight.data.uniform_(-0.1, 0.1) self.linear.bias.data.fill_(0) def forward(self, features, captions, lengths): - """Decode image feature vectors and generate caption.""" + """Decodes image feature vectors and generates captions.""" embeddings = self.embed(captions) embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) packed = pack_padded_sequence(embeddings, lengths, batch_first=True) @@ -51,14 +52,15 @@ class DecoderRNN(nn.Module): outputs = self.linear(hiddens[0]) return outputs - def sample(self, feature, state): - """Sample a caption for given a image feature.""" + def sample(self, features, states): + """Samples captions for given image features.""" sampled_ids = [] - input = feature.unsqueeze(1) + inputs = features.unsqueeze(1) for i in range(20): - hidden, state = self.lstm(input, state) # (1, 1, hidden_size) - output = self.linear(hidden.view(-1, self.hidden_size)) # (1, vocab_size) - predicted = output.max(1)[1] + hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size) + outputs = self.linear(hiddens.unsqueeze()) # (batch_size, vocab_size) + predicted = outputs.max(1)[1] sampled_ids.append(predicted) - input = self.embed(predicted) + inputs = self.embed(predicted) + sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20) return sampled_ids \ No newline at end of file diff --git a/tutorials/09 - Image Captioning/resize.py b/tutorials/09 - Image Captioning/resize.py index da1cd99..02a0974 100644 --- a/tutorials/09 - Image Captioning/resize.py +++ b/tutorials/09 - Image Captioning/resize.py @@ -1,34 +1,34 @@ from PIL import Image +from configuration import Config import os def resize_image(image, size): - """Resizes an image to the given size.""" + """Resizes the image to the given size.""" return image.resize(size, Image.ANTIALIAS) def resize_images(image_dir, output_dir, size): - """Resizes the images in the image_dir and save into the output_dir.""" + """Resizes the images in 'image_dir' and save them in 'output_dir'.""" if not os.path.exists(output_dir): os.makedirs(output_dir) - images = os.listdir(image_dir) num_images = len(images) for i, image in enumerate(images): with open(os.path.join(image_dir, image), 'r+b') as f: with Image.open(f) as img: img = resize_image(img, size) - img.save( - os.path.join(output_dir, image), img.format) + img.save(os.path.join(output_dir, image), img.format) if i % 100 == 0: - print ('[%d/%d] Resized the images and saved into %s.' + print ('[%d/%d] Resized the images and saved them in %s.' %(i, num_images, output_dir)) def main(): + config = Config() splits = ['train', 'val'] for split in splits: - image_dir = './data/%s2014/' %split - output_dir = './data/%s2014resized' %split - resize_images(image_dir, output_dir, (256, 256)) + image_dir = os.path.join(config.image_path, '%s2014/' %split) + output_dir = os.path.join(config.image_path, '%s2014resized' %split) + resize_images(image_dir, output_dir, (config.image_size, config.image_size)) if __name__ == '__main__': diff --git a/tutorials/09 - Image Captioning/train.py b/tutorials/09 - Image Captioning/train.py index d97976d..611ac92 100644 --- a/tutorials/09 - Image Captioning/train.py +++ b/tutorials/09 - Image Captioning/train.py @@ -1,72 +1,84 @@ -from data import get_loader +from data import get_data_loader from vocab import Vocabulary +from configuration import Config from model import EncoderCNN, DecoderRNN from torch.autograd import Variable from torch.nn.utils.rnn import pack_padded_sequence import torch -import torch.nn as nn -import numpy as np +import torch.nn as nn import torchvision.transforms as T +import numpy as np import pickle +import os -# Hyper Parameters -num_epochs = 1 -batch_size = 32 -embed_size = 256 -hidden_size = 512 -crop_size = 224 -num_layers = 1 -learning_rate = 0.001 -train_image_path = './data/train2014resized/' -train_json_path = './data/annotations/captions_train2014.json' -# Image Preprocessing -transform = T.Compose([ - T.RandomCrop(crop_size), - T.RandomHorizontalFlip(), - T.ToTensor(), - T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) +def main(): + # Configuration for hyper-parameters + config = Config() + + # Image preprocessing + transform = T.Compose([ + T.Scale(config.image_size), # no resize + T.RandomCrop(config.crop_size), + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) -# Load Vocabulary Wrapper -with open('./data/vocab.pkl', 'rb') as f: + # Load vocabulary wrapper + with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f: vocab = pickle.load(f) - -# Build Dataset Loader -train_loader = get_loader(train_image_path, train_json_path, vocab, transform, - batch_size=batch_size, shuffle=True, num_workers=2) -total_step = len(train_loader) -# Build Models -encoder = EncoderCNN(embed_size) -decoder = DecoderRNN(embed_size, hidden_size, len(vocab), num_layers) -encoder.cuda() -decoder.cuda() - -# Loss and Optimizer -criterion = nn.CrossEntropyLoss() -params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters()) -optimizer = torch.optim.Adam(params, lr=learning_rate) + # Build data loader + image_path = os.path.join(config.image_path, 'train2014') + json_path = os.path.join(config.caption_path, 'captions_train2014.json') + train_loader = get_data_loader(image_path, json_path, vocab, + transform, config.batch_size, + shuffle=True, num_workers=config.num_threads) + total_step = len(train_loader) -# Train the Decoder -for epoch in range(num_epochs): - for i, (images, captions, lengths) in enumerate(train_loader): - # Set mini-batch dataset - images = Variable(images).cuda() - captions = Variable(captions).cuda() - targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] - - # Forward, Backward and Optimize - decoder.zero_grad() - features = encoder(images) - outputs = decoder(features, captions, lengths) - loss = criterion(outputs, targets) - loss.backward() - optimizer.step() - - if i % 100 == 0: - print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' - %(epoch, num_epochs, i, total_step, loss.data[0], np.exp(loss.data[0]))) + # Build Models + encoder = EncoderCNN(config.embed_size) + decoder = DecoderRNN(config.embed_size, config.hidden_size, + len(vocab), config.num_layers) + encoder.cuda() + decoder.cuda() -# Save the Model -torch.save(decoder, 'decoder.pkl') -torch.save(encoder, 'encoder.pkl') \ No newline at end of file + # Loss and Optimizer + criterion = nn.CrossEntropyLoss() + params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters()) + optimizer = torch.optim.Adam(params, lr=config.learning_rate) + + # Train the Models + for epoch in range(config.num_epochs): + for i, (images, captions, lengths) in enumerate(train_loader): + + # Set mini-batch dataset + images = Variable(images).cuda() + captions = Variable(captions).cuda() + targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] + + # Forward, Backward and Optimize + decoder.zero_grad() + encoder.zero_grad() + features = encoder(images) + outputs = decoder(features, captions, lengths) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + # Print log info + if i % config.log_step == 0: + print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' + %(epoch, config.num_epochs, i, total_step, + loss.data[0], np.exp(loss.data[0]))) + + # Save the Model + if (i+1) % config.save_step == 0: + torch.save(decoder.state_dict(), + os.path.join(config.model_path, + 'decoder-%d-%d.pkl' %(epoch+1, i+1))) + torch.save(encoder.state_dict(), + os.path.join(config.model_path, + 'encoder-%d-%d.pkl' %(epoch+1, i+1))) +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tutorials/09 - Image Captioning/vocab.py b/tutorials/09 - Image Captioning/vocab.py index 78c16f2..198d511 100644 --- a/tutorials/09 - Image Captioning/vocab.py +++ b/tutorials/09 - Image Captioning/vocab.py @@ -1,6 +1,7 @@ -# Create a vocabulary wrapper import nltk import pickle +import os +from configuration import Config from collections import Counter from pycocotools.coco import COCO @@ -27,7 +28,7 @@ class Vocabulary(object): return len(self.word2idx) def build_vocab(json, threshold): - """Build a simple vocabulary wrapper.""" + """Builds a simple vocabulary wrapper.""" coco = COCO(json) counter = Counter() ids = coco.anns.keys() @@ -37,29 +38,31 @@ def build_vocab(json, threshold): counter.update(tokens) if i % 1000 == 0: - print("[%d/%d] tokenized the captions." %(i, len(ids))) - - # Discard if the occurrence of the word is less than min_word_cnt. + print("[%d/%d] Tokenized the captions." %(i, len(ids))) + + # If the word frequency is less than 'threshold', then the word is discarded. words = [word for word, cnt in counter.items() if cnt >= threshold] - # Create a vocab wrapper and add some special tokens. + # Creates a vocab wrapper and add some special tokens. vocab = Vocabulary() vocab.add_word('') vocab.add_word('') vocab.add_word('') vocab.add_word('') - # Add words to the vocabulary. + # Adds the words to the vocabulary. for i, word in enumerate(words): vocab.add_word(word) return vocab def main(): - vocab = build_vocab(json='./data/annotations/captions_train2014.json', - threshold=4) - with open('./data/vocab.pkl', 'wb') as f: + config = Config() + vocab = build_vocab(json=os.path.join(config.caption_path, 'captions_train2014.json'), + threshold=config.word_count_threshold) + vocab_path = os.path.join(config.vocab_path, 'vocab.pkl') + with open(vocab_path, 'wb') as f: pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) - print("Saved vocabulary file to ", './data/vocab.pkl') + print("Saved the vocabulary wrapper to ", vocab_path) if __name__ == '__main__': main() \ No newline at end of file