mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-27 12:03:31 +08:00
image captioning completed'
This commit is contained in:
@ -1,3 +1,6 @@
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
class Config(object):
|
||||
"""Wrapper class for hyper-parameters."""
|
||||
def __init__(self):
|
||||
@ -8,6 +11,21 @@ class Config(object):
|
||||
self.word_count_threshold = 4
|
||||
self.num_threads = 2
|
||||
|
||||
# Image preprocessing in training phase
|
||||
self.train_transform = T.Compose([
|
||||
T.Scale(self.image_size),
|
||||
T.RandomCrop(self.crop_size),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.ToTensor(),
|
||||
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
# Image preprocessing in test phase
|
||||
self.test_transform = T.Compose([
|
||||
T.Scale(self.crop_size),
|
||||
T.CenterCrop(self.crop_size),
|
||||
T.ToTensor(),
|
||||
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
# Training
|
||||
self.num_epochs = 5
|
||||
self.batch_size = 64
|
||||
@ -24,3 +42,6 @@ class Config(object):
|
||||
self.image_path = './data/'
|
||||
self.caption_path = './data/annotations/'
|
||||
self.vocab_path = './data/'
|
||||
self.model_path = './model/'
|
||||
self.trained_encoder = 'encoder-4-6000.pkl'
|
||||
self.trained_decoder = 'decoder-4-6000.pkl'
|
@ -2,11 +2,13 @@ import torch
|
||||
import torchvision.transforms as transforms
|
||||
import torch.utils.data as data
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import numpy as np
|
||||
import nltk
|
||||
from PIL import Image
|
||||
from vocab import Vocabulary
|
||||
sys.path.append('../../../coco/PythonAPI')
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
|
||||
|
177
tutorials/09 - Image Captioning/evaluate_model.ipynb
Normal file
177
tutorials/09 - Image Captioning/evaluate_model.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -53,14 +53,14 @@ class DecoderRNN(nn.Module):
|
||||
return outputs
|
||||
|
||||
def sample(self, features, states):
|
||||
"""Samples captions for given image features."""
|
||||
"""Samples captions for given image features (Greedy search)."""
|
||||
sampled_ids = []
|
||||
inputs = features.unsqueeze(1)
|
||||
for i in range(20):
|
||||
hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size)
|
||||
outputs = self.linear(hiddens.unsqueeze()) # (batch_size, vocab_size)
|
||||
outputs = self.linear(hiddens.squeeze(1)) # (batch_size, vocab_size)
|
||||
predicted = outputs.max(1)[1]
|
||||
sampled_ids.append(predicted)
|
||||
inputs = self.embed(predicted)
|
||||
sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20)
|
||||
return sampled_ids
|
||||
return sampled_ids.squeeze()
|
5
tutorials/09 - Image Captioning/requirements.txt
Normal file
5
tutorials/09 - Image Captioning/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
matplotlib==2.0.0
|
||||
nltk==3.2.2
|
||||
numpy==1.12.0
|
||||
Pillow==4.0.0
|
||||
argparse
|
@ -1,58 +1,77 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from vocab import Vocabulary
|
||||
from model import EncoderCNN, DecoderRNN
|
||||
from configuration import Config
|
||||
from PIL import Image
|
||||
from torch.autograd import Variable
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import pickle
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
from model import EncoderCNN, DecoderRNN
|
||||
from vocab import Vocabulary
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
import argparse
|
||||
import pickle
|
||||
import os
|
||||
|
||||
# Image processing
|
||||
transform = T.Compose([
|
||||
T.CenterCrop(224),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
||||
|
||||
# Hyper Parameters
|
||||
embed_size = 128
|
||||
hidden_size = 512
|
||||
num_layers = 1
|
||||
def main(params):
|
||||
# Configuration for hyper-parameters
|
||||
config = Config()
|
||||
|
||||
# Image Preprocessing
|
||||
transform = config.test_transform
|
||||
|
||||
# Load vocabulary
|
||||
with open('./data/vocab.pkl', 'rb') as f:
|
||||
with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
|
||||
vocab = pickle.load(f)
|
||||
|
||||
# Load an image array
|
||||
images = os.listdir('./data/train2014resized/')
|
||||
image_path = './data/train2014resized/' + images[12]
|
||||
img = Image.open(image_path)
|
||||
image = transform(img).unsqueeze(0)
|
||||
# Build Models
|
||||
encoder = EncoderCNN(config.embed_size)
|
||||
encoder.eval() # evaluation mode (BN uses moving mean/variance)
|
||||
decoder = DecoderRNN(config.embed_size, config.hidden_size,
|
||||
len(vocab), config.num_layers)
|
||||
|
||||
# Load the trained models
|
||||
encoder = torch.load('./encoder.pkl')
|
||||
decoder = torch.load('./decoder.pkl')
|
||||
|
||||
# Encode the image
|
||||
feature = encoder(Variable(image).cuda())
|
||||
# Load the trained model parameters
|
||||
encoder.load_state_dict(torch.load(os.path.join(config.model_path,
|
||||
config.trained_encoder)))
|
||||
decoder.load_state_dict(torch.load(os.path.join(config.model_path,
|
||||
config.trained_decoder)))
|
||||
|
||||
# Prepare Image
|
||||
image = Image.open(params['image'])
|
||||
image_tensor = Variable(transform(image).unsqueeze(0))
|
||||
|
||||
# Set initial states
|
||||
state = (Variable(torch.zeros(num_layers, 1, hidden_size).cuda()),
|
||||
Variable(torch.zeros(num_layers, 1, hidden_size)).cuda())
|
||||
state = (Variable(torch.zeros(config.num_layers, 1, config.hidden_size)),
|
||||
Variable(torch.zeros(config.num_layers, 1, config.hidden_size)))
|
||||
|
||||
# Decode the feature to caption
|
||||
ids = decoder.sample(feature, state)
|
||||
# If use gpu
|
||||
if torch.cuda.is_available():
|
||||
encoder.cuda()
|
||||
decoder.cuda()
|
||||
state = [s.cuda() for s in state]
|
||||
image_tensor = image_tensor.cuda()
|
||||
|
||||
words = []
|
||||
for id in ids:
|
||||
word = vocab.idx2word[id.data[0, 0]]
|
||||
words.append(word)
|
||||
# Generate caption from image
|
||||
feature = encoder(image_tensor)
|
||||
sampled_ids = decoder.sample(feature, state)
|
||||
sampled_ids = sampled_ids.cpu().data.numpy()
|
||||
|
||||
# Decode word_ids to words
|
||||
sampled_caption = []
|
||||
for word_id in sampled_ids:
|
||||
word = vocab.idx2word[word_id]
|
||||
sampled_caption.append(word)
|
||||
if word == '<end>':
|
||||
break
|
||||
caption = ' '.join(words)
|
||||
sentence = ' '.join(sampled_caption)
|
||||
|
||||
# Display the image and generated caption
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
print (caption)
|
||||
# Print out image and generated caption.
|
||||
print (sentence)
|
||||
plt.imshow(np.asarray(image))
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--image', type=str, required=True, help='image for generating caption')
|
||||
args = parser.parse_args()
|
||||
params = vars(args)
|
||||
main(params)
|
@ -6,7 +6,6 @@ from torch.autograd import Variable
|
||||
from torch.nn.utils.rnn import pack_padded_sequence
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
@ -16,13 +15,12 @@ def main():
|
||||
# Configuration for hyper-parameters
|
||||
config = Config()
|
||||
|
||||
# Create model directory
|
||||
if not os.path.exists(config.model_path):
|
||||
os.makedirs(config.model_path)
|
||||
|
||||
# Image preprocessing
|
||||
transform = T.Compose([
|
||||
T.Scale(config.image_size), # no resize
|
||||
T.RandomCrop(config.crop_size),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.ToTensor(),
|
||||
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
transform = config.train_transform
|
||||
|
||||
# Load vocabulary wrapper
|
||||
with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
|
||||
@ -40,6 +38,8 @@ def main():
|
||||
encoder = EncoderCNN(config.embed_size)
|
||||
decoder = DecoderRNN(config.embed_size, config.hidden_size,
|
||||
len(vocab), config.num_layers)
|
||||
|
||||
if torch.cuda.is_available()
|
||||
encoder.cuda()
|
||||
decoder.cuda()
|
||||
|
||||
@ -53,10 +53,14 @@ def main():
|
||||
for i, (images, captions, lengths) in enumerate(train_loader):
|
||||
|
||||
# Set mini-batch dataset
|
||||
images = Variable(images).cuda()
|
||||
captions = Variable(captions).cuda()
|
||||
images = Variable(images)
|
||||
captions = Variable(captions)
|
||||
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
images = images.cuda()
|
||||
captions = captions.cuda()
|
||||
|
||||
# Forward, Backward and Optimize
|
||||
decoder.zero_grad()
|
||||
encoder.zero_grad()
|
||||
@ -80,5 +84,6 @@ def main():
|
||||
torch.save(encoder.state_dict(),
|
||||
os.path.join(config.model_path,
|
||||
'encoder-%d-%d.pkl' %(epoch+1, i+1)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user