mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-05 16:36:44 +08:00
image caption generation added
This commit is contained in:
52
tutorials/09 - Image Captioning/sample.py
Normal file
52
tutorials/09 - Image Captioning/sample.py
Normal file
@ -0,0 +1,52 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import pickle
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
from model import EncoderCNN, DecoderRNN
|
||||
from vocab import Vocabulary
|
||||
from torch.autograd import Variable
|
||||
|
||||
# Hyper Parameters
|
||||
embed_size = 128
|
||||
hidden_size = 512
|
||||
num_layers = 1
|
||||
|
||||
# Load vocabulary
|
||||
with open('./data/vocab.pkl', 'rb') as f:
|
||||
vocab = pickle.load(f)
|
||||
|
||||
# Load an image array
|
||||
images = os.listdir('./data/val2014resized/')
|
||||
image_path = './data/val2014resized/' + images[12]
|
||||
with open(image_path, 'r+b') as f:
|
||||
img = np.asarray(Image.open(f))
|
||||
image = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0) / 255 - 0.5
|
||||
|
||||
# Load the trained models
|
||||
encoder = torch.load('./encoder.pkl')
|
||||
decoder = torch.load('./decoder.pkl')
|
||||
|
||||
# Encode the image
|
||||
feature = encoder(Variable(image).cuda())
|
||||
|
||||
# Set initial states
|
||||
state = (Variable(torch.zeros(num_layers, 1, hidden_size).cuda()),
|
||||
Variable(torch.zeros(num_layers, 1, hidden_size)).cuda())
|
||||
|
||||
# Decode the feature to caption
|
||||
ids = decoder.sample(feature, state)
|
||||
|
||||
words = []
|
||||
for id in ids:
|
||||
word = vocab.idx2word[id.data[0, 0]]
|
||||
words.append(word)
|
||||
if word == '<end>':
|
||||
break
|
||||
caption = ' '.join(words)
|
||||
|
||||
# Display the image and generated caption
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
print (caption)
|
Reference in New Issue
Block a user