import torch import torchvision.transforms as transforms import torch.utils.data as data import os import pickle import numpy as np import nltk from PIL import Image from vocab import Vocabulary from pycocotools.coco import COCO class CocoDataset(data.Dataset): """COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" def __init__(self, root, json, vocab, transform=None): """Set the path for images, captions and vocabulary wrapper. Args: root: image directory. json: coco annotation file path. vocab: vocabulary wrapper. transform: image transformer """ self.root = root self.coco = COCO(json) self.ids = list(self.coco.anns.keys()) self.vocab = vocab self.transform = transform def __getitem__(self, index): """Returns one data pair (image and caption).""" coco = self.coco vocab = self.vocab ann_id = self.ids[index] caption = coco.anns[ann_id]['caption'] img_id = coco.anns[ann_id]['image_id'] path = coco.loadImgs(img_id)[0]['file_name'] image = Image.open(os.path.join(self.root, path)).convert('RGB') if self.transform is not None: image = self.transform(image) # Convert caption (string) to word ids. tokens = nltk.tokenize.word_tokenize(str(caption).lower()) caption = [] caption.append(vocab('')) caption.extend([vocab(token) for token in tokens]) caption.append(vocab('')) target = torch.Tensor(caption) return image, target def __len__(self): return len(self.ids) def collate_fn(data): """Creates mini-batch tensors from the list of tuples (image, caption). Args: data: list of tuple (image, caption). - image: torch tensor of shape (3, 256, 256). - caption: torch tensor of shape (?); variable length. Returns: images: torch tensor of shape (batch_size, 3, 256, 256). targets: torch tensor of shape (batch_size, padded_length). lengths: list; valid length for each padded caption. """ # Sort a data list by caption length data.sort(key=lambda x: len(x[1]), reverse=True) images, captions = zip(*data) # Merge images (from tuple of 3D tensor to 4D tensor) images = torch.stack(images, 0) # Merge captions (from tuple of 1D tensor to 2D tensor) lengths = [len(cap) for cap in captions] targets = torch.zeros(len(captions), max(lengths)).long() for i, cap in enumerate(captions): end = lengths[i] targets[i, :end] = cap[:end] return images, targets, lengths def get_data_loader(root, json, vocab, transform, batch_size, shuffle, num_workers): """Returns torch.utils.data.DataLoader for custom coco dataset.""" # COCO dataset coco = CocoDataset(root=root, json=json, vocab = vocab, transform=transform) # Data loader for COCO dataset data_loader = torch.utils.data.DataLoader(dataset=coco, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn) return data_loader