image captioning completed'

This commit is contained in:
yunjey
2017-03-21 20:01:47 +09:00
parent ba7d5467be
commit 6f5fda14f0
7 changed files with 297 additions and 68 deletions

View File

@ -1,3 +1,6 @@
import torchvision.transforms as T
class Config(object): class Config(object):
"""Wrapper class for hyper-parameters.""" """Wrapper class for hyper-parameters."""
def __init__(self): def __init__(self):
@ -8,6 +11,21 @@ class Config(object):
self.word_count_threshold = 4 self.word_count_threshold = 4
self.num_threads = 2 self.num_threads = 2
# Image preprocessing in training phase
self.train_transform = T.Compose([
T.Scale(self.image_size),
T.RandomCrop(self.crop_size),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Image preprocessing in test phase
self.test_transform = T.Compose([
T.Scale(self.crop_size),
T.CenterCrop(self.crop_size),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Training # Training
self.num_epochs = 5 self.num_epochs = 5
self.batch_size = 64 self.batch_size = 64
@ -23,4 +41,7 @@ class Config(object):
# Path # Path
self.image_path = './data/' self.image_path = './data/'
self.caption_path = './data/annotations/' self.caption_path = './data/annotations/'
self.vocab_path = './data/' self.vocab_path = './data/'
self.model_path = './model/'
self.trained_encoder = 'encoder-4-6000.pkl'
self.trained_decoder = 'decoder-4-6000.pkl'

View File

@ -2,11 +2,13 @@ import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torch.utils.data as data import torch.utils.data as data
import os import os
import sys
import pickle import pickle
import numpy as np import numpy as np
import nltk import nltk
from PIL import Image from PIL import Image
from vocab import Vocabulary from vocab import Vocabulary
sys.path.append('../../../coco/PythonAPI')
from pycocotools.coco import COCO from pycocotools.coco import COCO

File diff suppressed because one or more lines are too long

View File

@ -53,14 +53,14 @@ class DecoderRNN(nn.Module):
return outputs return outputs
def sample(self, features, states): def sample(self, features, states):
"""Samples captions for given image features.""" """Samples captions for given image features (Greedy search)."""
sampled_ids = [] sampled_ids = []
inputs = features.unsqueeze(1) inputs = features.unsqueeze(1)
for i in range(20): for i in range(20):
hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size) hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size)
outputs = self.linear(hiddens.unsqueeze()) # (batch_size, vocab_size) outputs = self.linear(hiddens.squeeze(1)) # (batch_size, vocab_size)
predicted = outputs.max(1)[1] predicted = outputs.max(1)[1]
sampled_ids.append(predicted) sampled_ids.append(predicted)
inputs = self.embed(predicted) inputs = self.embed(predicted)
sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20) sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20)
return sampled_ids return sampled_ids.squeeze()

View File

@ -0,0 +1,5 @@
matplotlib==2.0.0
nltk==3.2.2
numpy==1.12.0
Pillow==4.0.0
argparse

View File

@ -1,58 +1,77 @@
import os
import numpy as np
import torch
import torchvision.transforms as T
import pickle
import matplotlib.pyplot as plt
from PIL import Image
from model import EncoderCNN, DecoderRNN
from vocab import Vocabulary from vocab import Vocabulary
from torch.autograd import Variable from model import EncoderCNN, DecoderRNN
from configuration import Config
from PIL import Image
from torch.autograd import Variable
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import argparse
import pickle
import os
# Image processing
transform = T.Compose([
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
# Hyper Parameters def main(params):
embed_size = 128 # Configuration for hyper-parameters
hidden_size = 512 config = Config()
num_layers = 1
# Load vocabulary
with open('./data/vocab.pkl', 'rb') as f:
vocab = pickle.load(f)
# Load an image array # Image Preprocessing
images = os.listdir('./data/train2014resized/') transform = config.test_transform
image_path = './data/train2014resized/' + images[12]
img = Image.open(image_path)
image = transform(img).unsqueeze(0)
# Load the trained models # Load vocabulary
encoder = torch.load('./encoder.pkl') with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
decoder = torch.load('./decoder.pkl') vocab = pickle.load(f)
# Encode the image # Build Models
feature = encoder(Variable(image).cuda()) encoder = EncoderCNN(config.embed_size)
encoder.eval() # evaluation mode (BN uses moving mean/variance)
decoder = DecoderRNN(config.embed_size, config.hidden_size,
len(vocab), config.num_layers)
# Set initial states # Load the trained model parameters
state = (Variable(torch.zeros(num_layers, 1, hidden_size).cuda()), encoder.load_state_dict(torch.load(os.path.join(config.model_path,
Variable(torch.zeros(num_layers, 1, hidden_size)).cuda()) config.trained_encoder)))
decoder.load_state_dict(torch.load(os.path.join(config.model_path,
config.trained_decoder)))
# Decode the feature to caption # Prepare Image
ids = decoder.sample(feature, state) image = Image.open(params['image'])
image_tensor = Variable(transform(image).unsqueeze(0))
words = []
for id in ids: # Set initial states
word = vocab.idx2word[id.data[0, 0]] state = (Variable(torch.zeros(config.num_layers, 1, config.hidden_size)),
words.append(word) Variable(torch.zeros(config.num_layers, 1, config.hidden_size)))
if word == '<end>':
break # If use gpu
caption = ' '.join(words) if torch.cuda.is_available():
encoder.cuda()
# Display the image and generated caption decoder.cuda()
plt.imshow(img) state = [s.cuda() for s in state]
plt.show() image_tensor = image_tensor.cuda()
print (caption)
# Generate caption from image
feature = encoder(image_tensor)
sampled_ids = decoder.sample(feature, state)
sampled_ids = sampled_ids.cpu().data.numpy()
# Decode word_ids to words
sampled_caption = []
for word_id in sampled_ids:
word = vocab.idx2word[word_id]
sampled_caption.append(word)
if word == '<end>':
break
sentence = ' '.join(sampled_caption)
# Print out image and generated caption.
print (sentence)
plt.imshow(np.asarray(image))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, required=True, help='image for generating caption')
args = parser.parse_args()
params = vars(args)
main(params)

View File

@ -6,7 +6,6 @@ from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence from torch.nn.utils.rnn import pack_padded_sequence
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.transforms as T
import numpy as np import numpy as np
import pickle import pickle
import os import os
@ -16,14 +15,13 @@ def main():
# Configuration for hyper-parameters # Configuration for hyper-parameters
config = Config() config = Config()
# Create model directory
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
# Image preprocessing # Image preprocessing
transform = T.Compose([ transform = config.train_transform
T.Scale(config.image_size), # no resize
T.RandomCrop(config.crop_size),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Load vocabulary wrapper # Load vocabulary wrapper
with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f: with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
vocab = pickle.load(f) vocab = pickle.load(f)
@ -40,22 +38,28 @@ def main():
encoder = EncoderCNN(config.embed_size) encoder = EncoderCNN(config.embed_size)
decoder = DecoderRNN(config.embed_size, config.hidden_size, decoder = DecoderRNN(config.embed_size, config.hidden_size,
len(vocab), config.num_layers) len(vocab), config.num_layers)
encoder.cuda()
decoder.cuda() if torch.cuda.is_available()
encoder.cuda()
decoder.cuda()
# Loss and Optimizer # Loss and Optimizer
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters()) params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters())
optimizer = torch.optim.Adam(params, lr=config.learning_rate) optimizer = torch.optim.Adam(params, lr=config.learning_rate)
# Train the Models # Train the Models
for epoch in range(config.num_epochs): for epoch in range(config.num_epochs):
for i, (images, captions, lengths) in enumerate(train_loader): for i, (images, captions, lengths) in enumerate(train_loader):
# Set mini-batch dataset # Set mini-batch dataset
images = Variable(images).cuda() images = Variable(images)
captions = Variable(captions).cuda() captions = Variable(captions)
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
if torch.cuda.is_available():
images = images.cuda()
captions = captions.cuda()
# Forward, Backward and Optimize # Forward, Backward and Optimize
decoder.zero_grad() decoder.zero_grad()
@ -80,5 +84,6 @@ def main():
torch.save(encoder.state_dict(), torch.save(encoder.state_dict(),
os.path.join(config.model_path, os.path.join(config.model_path,
'encoder-%d-%d.pkl' %(epoch+1, i+1))) 'encoder-%d-%d.pkl' %(epoch+1, i+1)))
if __name__ == '__main__': if __name__ == '__main__':
main() main()