diff --git a/tutorials/09 - Image Captioning/data.py b/tutorials/09 - Image Captioning/data.py new file mode 100644 index 0000000..a244fa2 --- /dev/null +++ b/tutorials/09 - Image Captioning/data.py @@ -0,0 +1,97 @@ +import torch +import torchvision.transforms as transforms +import torch.utils.data as data +import os +import pickle +import numpy as np +import nltk +from PIL import Image +from vocab import Vocabulary +from pycocotools.coco import COCO + + +class CocoDataset(data.Dataset): + """COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" + def __init__(self, root, json, vocab, transform=None): + """ + Args: + root: image directory. + json: coco annotation file path. + vocab: vocabulary wrapper. + transform: transformer for image. + """ + self.root = root + self.coco = COCO(json) + self.ids = list(self.coco.anns.keys()) + self.vocab = vocab + self.transform = transform + + def __getitem__(self, index): + """This function should return one data pair(image and caption).""" + coco = self.coco + vocab = self.vocab + ann_id = self.ids[index] + caption = coco.anns[ann_id]['caption'] + img_id = coco.anns[ann_id]['image_id'] + path = coco.loadImgs(img_id)[0]['file_name'] + + image = Image.open(os.path.join(self.root, path)).convert('RGB') + if self.transform is not None: + image = self.transform(image) + + # Convert caption (string) to word ids. + tokens = nltk.tokenize.word_tokenize(str(caption).lower()) + caption = [] + caption.append(vocab('')) + caption.extend([vocab(token) for token in tokens]) + caption.append(vocab('')) + target = torch.Tensor(caption) + return image, target + + def __len__(self): + return len(self.ids) + + +def collate_fn(data): + """Build mini-batch tensors from a list of (image, caption) tuples. + Args: + data: list of (image, caption) tuple. + - image: torch tensor of shape (3, 256, 256). + - caption: torch tensor of shape (?); variable length. + + Returns: + images: torch tensor of shape (batch_size, 3, 256, 256). + targets: torch tensor of shape (batch_size, padded_length). + lengths: list; valid length for each padded caption. + """ + # Sort a data list by caption length + data.sort(key=lambda x: len(x[1]), reverse=True) + images, captions = zip(*data) + + # Merge images (convert tuple of 3D tensor to 4D tensor) + images = torch.stack(images, 0) + + # Merget captions (convert tuple of 1D tensor to 2D tensor) + lengths = [len(cap) for cap in captions] + targets = torch.zeros(len(captions), max(lengths)).long() + for i, cap in enumerate(captions): + end = lengths[i] + targets[i, :end] = cap[:end] + return images, targets, lengths + + +def get_loader(root, json, vocab, transform, batch_size=100, shuffle=True, num_workers=2): + """Returns torch.utils.data.DataLoader for custom coco dataset.""" + # COCO custom dataset + coco = CocoDataset(root=root, + json=json, + vocab = vocab, + transform=transform) + + # Data loader + data_loader = torch.utils.data.DataLoader(dataset=coco, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + collate_fn=collate_fn) + return data_loader \ No newline at end of file diff --git a/tutorials/09 - Image Captioning/resize.py b/tutorials/09 - Image Captioning/resize.py new file mode 100644 index 0000000..da1cd99 --- /dev/null +++ b/tutorials/09 - Image Captioning/resize.py @@ -0,0 +1,35 @@ +from PIL import Image +import os + + +def resize_image(image, size): + """Resizes an image to the given size.""" + return image.resize(size, Image.ANTIALIAS) + +def resize_images(image_dir, output_dir, size): + """Resizes the images in the image_dir and save into the output_dir.""" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + images = os.listdir(image_dir) + num_images = len(images) + for i, image in enumerate(images): + with open(os.path.join(image_dir, image), 'r+b') as f: + with Image.open(f) as img: + img = resize_image(img, size) + img.save( + os.path.join(output_dir, image), img.format) + if i % 100 == 0: + print ('[%d/%d] Resized the images and saved into %s.' + %(i, num_images, output_dir)) + +def main(): + splits = ['train', 'val'] + for split in splits: + image_dir = './data/%s2014/' %split + output_dir = './data/%s2014resized' %split + resize_images(image_dir, output_dir, (256, 256)) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tutorials/09 - Image Captioning/train.py b/tutorials/09 - Image Captioning/train.py new file mode 100644 index 0000000..7998616 --- /dev/null +++ b/tutorials/09 - Image Captioning/train.py @@ -0,0 +1,69 @@ +from data import get_loader +from vocab import Vocabulary +from models import EncoderCNN, DecoderRNN +from torch.autograd import Variable +from torch.nn.utils.rnn import pack_padded_sequence +import torch +import torch.nn as nn +import numpy as np +import torchvision.transforms as T +import pickle + +# Hyper Parameters +num_epochs = 5 +batch_size = 100 +embed_size = 128 +hidden_size = 512 +num_layers = 1 +learning_rate = 0.001 +train_image_path = './data/train2014resized/' +train_json_path = './data/annotations/captions_train2014.json' + +# Image Preprocessing +transform = T.Compose([ + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) + +# Load Vocabulary Wrapper +with open('./data/vocab.pkl', 'rb') as f: + vocab = pickle.load(f) + +# Build Dataset Loader +train_loader = get_loader(train_image_path, train_json_path, vocab, transform, + batch_size=batch_size, shuffle=True, num_workers=2) +total_step = len(train_loader) + +# Build Models +encoder = EncoderCNN(embed_size) +decoder = DecoderRNN(embed_size, hidden_size, len(vocab), num_layers) +encoder.cuda() +decoder.cuda() + +# Loss and Optimizer +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate) + +# Train the Decoder +for epoch in range(num_epochs): + for i, (images, captions, lengths) in enumerate(train_loader): + # Set mini-batch dataset + images = Variable(images).cuda() + captions = Variable(captions).cuda() + targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] + + # Forward, Backward and Optimize + decoder.zero_grad() + features = encoder(images) + outputs = decoder(features, captions, lengths) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + if i % 100 == 0: + print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' + %(epoch, num_epochs, i, total_step, loss.data[0], np.exp(loss.data[0]))) + +# Save the Model +torch.save(decoder, 'decoder.pkl') +torch.save(encoder, 'encoder.pkl') \ No newline at end of file diff --git a/tutorials/09 - Image Captioning/vocab.py b/tutorials/09 - Image Captioning/vocab.py new file mode 100644 index 0000000..2301354 --- /dev/null +++ b/tutorials/09 - Image Captioning/vocab.py @@ -0,0 +1,65 @@ +# Create a vocabulary wrapper +import nltk +import pickle +from collections import Counter +from pycocotools.coco import COCO + + +class Vocabulary(object): + """Simple vocabulary wrapper.""" + def __init__(self): + self.word2idx = {} + self.idx2word = {} + self.idx = 0 + + def add_word(self, word): + if not word in self.word2idx: + self.word2idx[word] = self.idx + self.idx2word[self.idx] = word + self.idx += 1 + + def __call__(self, word): + if not word in self.word2idx: + return self.word2idx[''] + return self.word2idx[word] + + def __len__(self): + return len(self.word2idx) + +def build_vocab(json, threshold): + """Build a simple vocabulary wrapper.""" + coco = COCO(json) + counter = Counter() + ids = coco.anns.keys() + for i, id in enumerate(ids): + caption = str(coco.anns[id]['caption']) + tokens = nltk.tokenize.word_tokenize(caption.lower()) + counter.update(tokens) + + if i % 1000 == 0: + print("[%d/%d] tokenized the captions." %(i, len(ids))) + + # Discard if the occurrence of the word is less than min_word_cnt. + words = [word for word, cnt in counter.items() if cnt >= threshold] + + # Create a vocab wrapper and add some special tokens. + vocab = Vocabulary() + vocab.add_word('') + vocab.add_word('') + vocab.add_word('') + vocab.add_word('') + + # Add words to the vocabulary. + for i, word in enumerate(words): + vocab.add_word(word) + return vocab + +def main(): + vocab = create_vocab(json='./data/annotations/captions_train2014.json', + threshold=4) + with open('./data/vocab.pkl', 'wb') as f: + pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) + print("Saved vocabulary file to ", './data/vocab.pkl') + +if __name__ == '__main__': + main() \ No newline at end of file