mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-06 09:23:21 +08:00
image caption generation added
This commit is contained in:
52
tutorials/09 - Image Captioning/sample.py
Normal file
52
tutorials/09 - Image Captioning/sample.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import pickle
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
from model import EncoderCNN, DecoderRNN
|
||||||
|
from vocab import Vocabulary
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
# Hyper Parameters
|
||||||
|
embed_size = 128
|
||||||
|
hidden_size = 512
|
||||||
|
num_layers = 1
|
||||||
|
|
||||||
|
# Load vocabulary
|
||||||
|
with open('./data/vocab.pkl', 'rb') as f:
|
||||||
|
vocab = pickle.load(f)
|
||||||
|
|
||||||
|
# Load an image array
|
||||||
|
images = os.listdir('./data/val2014resized/')
|
||||||
|
image_path = './data/val2014resized/' + images[12]
|
||||||
|
with open(image_path, 'r+b') as f:
|
||||||
|
img = np.asarray(Image.open(f))
|
||||||
|
image = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0) / 255 - 0.5
|
||||||
|
|
||||||
|
# Load the trained models
|
||||||
|
encoder = torch.load('./encoder.pkl')
|
||||||
|
decoder = torch.load('./decoder.pkl')
|
||||||
|
|
||||||
|
# Encode the image
|
||||||
|
feature = encoder(Variable(image).cuda())
|
||||||
|
|
||||||
|
# Set initial states
|
||||||
|
state = (Variable(torch.zeros(num_layers, 1, hidden_size).cuda()),
|
||||||
|
Variable(torch.zeros(num_layers, 1, hidden_size)).cuda())
|
||||||
|
|
||||||
|
# Decode the feature to caption
|
||||||
|
ids = decoder.sample(feature, state)
|
||||||
|
|
||||||
|
words = []
|
||||||
|
for id in ids:
|
||||||
|
word = vocab.idx2word[id.data[0, 0]]
|
||||||
|
words.append(word)
|
||||||
|
if word == '<end>':
|
||||||
|
break
|
||||||
|
caption = ' '.join(words)
|
||||||
|
|
||||||
|
# Display the image and generated caption
|
||||||
|
plt.imshow(img)
|
||||||
|
plt.show()
|
||||||
|
print (caption)
|
Reference in New Issue
Block a user