modified the code

This commit is contained in:
yunjey
2017-04-13 19:49:10 +09:00
parent fd445c96b9
commit 0e59313b4b
5 changed files with 134 additions and 88 deletions

View File

@ -1,48 +1,46 @@
from vocab import Vocabulary
from model import EncoderCNN, DecoderRNN
from configuration import Config
from PIL import Image
from torch.autograd import Variable
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import argparse
import pickle
import os
from torch.autograd import Variable
from torchvision import transforms
from build_vocab import Vocabulary
from model import EncoderCNN, DecoderRNN
from PIL import Image
def main(params):
# Configuration for hyper-parameters
config = Config()
def main(args):
# Image preprocessing
transform = transforms.Compose([
transforms.Scale(args.crop_size),
transforms.CenterCrop(args.crop_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Image Preprocessing
transform = config.test_transform
# Load vocabulary
with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
# Load vocabulary wrapper
with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f)
# Build Models
encoder = EncoderCNN(config.embed_size)
encoder = EncoderCNN(args.embed_size)
encoder.eval() # evaluation mode (BN uses moving mean/variance)
decoder = DecoderRNN(config.embed_size, config.hidden_size,
len(vocab), config.num_layers)
decoder = DecoderRNN(args.embed_size, args.hidden_size,
len(vocab), args.num_layers)
# Load the trained model parameters
encoder.load_state_dict(torch.load(os.path.join(config.model_path,
config.trained_encoder)))
decoder.load_state_dict(torch.load(os.path.join(config.model_path,
config.trained_decoder)))
encoder.load_state_dict(torch.load(args.encoder_path))
decoder.load_state_dict(torch.load(args.decoder_path))
# Prepare Image
image = Image.open(params['image'])
image = Image.open(args.image)
image_tensor = Variable(transform(image).unsqueeze(0))
# Set initial states
state = (Variable(torch.zeros(config.num_layers, 1, config.hidden_size)),
Variable(torch.zeros(config.num_layers, 1, config.hidden_size)))
state = (Variable(torch.zeros(args.num_layers, 1, args.hidden_size)),
Variable(torch.zeros(args.num_layers, 1, args.hidden_size)))
# If use gpu
if torch.cuda.is_available():
@ -71,7 +69,23 @@ def main(params):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, required=True, help='image for generating caption')
parser.add_argument('--image', type=str, required=True,
help='input image for generating caption')
parser.add_argument('--encoder_path', type=str, default='./models/encoder-5-3000.pkl',
help='path for trained encoder')
parser.add_argument('--decoder_path', type=str, default='./models/decoder-5-3000.pkl',
help='path for trained decoder')
parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl',
help='path for vocabulary wrapper')
parser.add_argument('--crop_size', type=int, default=224,
help='size for center cropping images')
# Model parameters (should be same as paramters in train.py)
parser.add_argument('--embed_size', type=int , default=256,
help='dimension of word embedding vectors')
parser.add_argument('--hidden_size', type=int , default=512,
help='dimension of lstm hidden states')
parser.add_argument('--num_layers', type=int , default=1 ,
help='number of layers in lstm')
args = parser.parse_args()
params = vars(args)
main(params)
main(args)