image captioning completed'

This commit is contained in:
yunjey
2017-03-21 20:01:47 +09:00
parent ba7d5467be
commit 6f5fda14f0
7 changed files with 297 additions and 68 deletions

View File

@ -1,58 +1,77 @@
import os
import numpy as np
import torch
import torchvision.transforms as T
import pickle
import matplotlib.pyplot as plt
from PIL import Image
from model import EncoderCNN, DecoderRNN
from vocab import Vocabulary
from torch.autograd import Variable
from model import EncoderCNN, DecoderRNN
from configuration import Config
from PIL import Image
from torch.autograd import Variable
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import argparse
import pickle
import os
# Image processing
transform = T.Compose([
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
# Hyper Parameters
embed_size = 128
hidden_size = 512
num_layers = 1
# Load vocabulary
with open('./data/vocab.pkl', 'rb') as f:
vocab = pickle.load(f)
def main(params):
# Configuration for hyper-parameters
config = Config()
# Load an image array
images = os.listdir('./data/train2014resized/')
image_path = './data/train2014resized/' + images[12]
img = Image.open(image_path)
image = transform(img).unsqueeze(0)
# Image Preprocessing
transform = config.test_transform
# Load the trained models
encoder = torch.load('./encoder.pkl')
decoder = torch.load('./decoder.pkl')
# Load vocabulary
with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
vocab = pickle.load(f)
# Encode the image
feature = encoder(Variable(image).cuda())
# Build Models
encoder = EncoderCNN(config.embed_size)
encoder.eval() # evaluation mode (BN uses moving mean/variance)
decoder = DecoderRNN(config.embed_size, config.hidden_size,
len(vocab), config.num_layers)
# Set initial states
state = (Variable(torch.zeros(num_layers, 1, hidden_size).cuda()),
Variable(torch.zeros(num_layers, 1, hidden_size)).cuda())
# Load the trained model parameters
encoder.load_state_dict(torch.load(os.path.join(config.model_path,
config.trained_encoder)))
decoder.load_state_dict(torch.load(os.path.join(config.model_path,
config.trained_decoder)))
# Decode the feature to caption
ids = decoder.sample(feature, state)
words = []
for id in ids:
word = vocab.idx2word[id.data[0, 0]]
words.append(word)
if word == '<end>':
break
caption = ' '.join(words)
# Display the image and generated caption
plt.imshow(img)
plt.show()
print (caption)
# Prepare Image
image = Image.open(params['image'])
image_tensor = Variable(transform(image).unsqueeze(0))
# Set initial states
state = (Variable(torch.zeros(config.num_layers, 1, config.hidden_size)),
Variable(torch.zeros(config.num_layers, 1, config.hidden_size)))
# If use gpu
if torch.cuda.is_available():
encoder.cuda()
decoder.cuda()
state = [s.cuda() for s in state]
image_tensor = image_tensor.cuda()
# Generate caption from image
feature = encoder(image_tensor)
sampled_ids = decoder.sample(feature, state)
sampled_ids = sampled_ids.cpu().data.numpy()
# Decode word_ids to words
sampled_caption = []
for word_id in sampled_ids:
word = vocab.idx2word[word_id]
sampled_caption.append(word)
if word == '<end>':
break
sentence = ' '.join(sampled_caption)
# Print out image and generated caption.
print (sentence)
plt.imshow(np.asarray(image))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, required=True, help='image for generating caption')
args = parser.parse_args()
params = vars(args)
main(params)