mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-27 20:13:33 +08:00

The sample.py cannot be used to generate caption for gray image.In the data_loader.py, all of the images have been converted to 'RGB' format. But when generating captions using sample.py for single image, this conversion is missed. So for gray image, it will have such RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 1, 224, 224] to have 3 channels, but got 1 channels instead.
82 lines
3.0 KiB
Python
82 lines
3.0 KiB
Python
import torch
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import argparse
|
|
import pickle
|
|
import os
|
|
from torchvision import transforms
|
|
from build_vocab import Vocabulary
|
|
from model import EncoderCNN, DecoderRNN
|
|
from PIL import Image
|
|
|
|
|
|
# Device configuration
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
def load_image(image_path, transform=None):
|
|
image = Image.open(image_path).convert('RGB')
|
|
image = image.resize([224, 224], Image.LANCZOS)
|
|
|
|
if transform is not None:
|
|
image = transform(image).unsqueeze(0)
|
|
|
|
return image
|
|
|
|
def main(args):
|
|
# Image preprocessing
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.485, 0.456, 0.406),
|
|
(0.229, 0.224, 0.225))])
|
|
|
|
# Load vocabulary wrapper
|
|
with open(args.vocab_path, 'rb') as f:
|
|
vocab = pickle.load(f)
|
|
|
|
# Build models
|
|
encoder = EncoderCNN(args.embed_size).eval() # eval mode (batchnorm uses moving mean/variance)
|
|
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers)
|
|
encoder = encoder.to(device)
|
|
decoder = decoder.to(device)
|
|
|
|
# Load the trained model parameters
|
|
encoder.load_state_dict(torch.load(args.encoder_path))
|
|
decoder.load_state_dict(torch.load(args.decoder_path))
|
|
|
|
# Prepare an image
|
|
image = load_image(args.image, transform)
|
|
image_tensor = image.to(device)
|
|
|
|
# Generate an caption from the image
|
|
feature = encoder(image_tensor)
|
|
sampled_ids = decoder.sample(feature)
|
|
sampled_ids = sampled_ids[0].cpu().numpy() # (1, max_seq_length) -> (max_seq_length)
|
|
|
|
# Convert word_ids to words
|
|
sampled_caption = []
|
|
for word_id in sampled_ids:
|
|
word = vocab.idx2word[word_id]
|
|
sampled_caption.append(word)
|
|
if word == '<end>':
|
|
break
|
|
sentence = ' '.join(sampled_caption)
|
|
|
|
# Print out the image and the generated caption
|
|
print (sentence)
|
|
image = Image.open(args.image)
|
|
plt.imshow(np.asarray(image))
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--image', type=str, required=True, help='input image for generating caption')
|
|
parser.add_argument('--encoder_path', type=str, default='models/encoder-2-1000.ckpt', help='path for trained encoder')
|
|
parser.add_argument('--decoder_path', type=str, default='models/decoder-2-1000.ckpt', help='path for trained decoder')
|
|
parser.add_argument('--vocab_path', type=str, default='data/vocab.pkl', help='path for vocabulary wrapper')
|
|
|
|
# Model parameters (should be same as paramters in train.py)
|
|
parser.add_argument('--embed_size', type=int , default=256, help='dimension of word embedding vectors')
|
|
parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states')
|
|
parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm')
|
|
args = parser.parse_args()
|
|
main(args)
|