mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-26 19:48:34 +08:00
modified the code
This commit is contained in:
@ -1,48 +1,46 @@
|
||||
from vocab import Vocabulary
|
||||
from model import EncoderCNN, DecoderRNN
|
||||
from configuration import Config
|
||||
from PIL import Image
|
||||
from torch.autograd import Variable
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import argparse
|
||||
import pickle
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
from torchvision import transforms
|
||||
from build_vocab import Vocabulary
|
||||
from model import EncoderCNN, DecoderRNN
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def main(params):
|
||||
# Configuration for hyper-parameters
|
||||
config = Config()
|
||||
def main(args):
|
||||
# Image preprocessing
|
||||
transform = transforms.Compose([
|
||||
transforms.Scale(args.crop_size),
|
||||
transforms.CenterCrop(args.crop_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
# Image Preprocessing
|
||||
transform = config.test_transform
|
||||
|
||||
# Load vocabulary
|
||||
with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
|
||||
# Load vocabulary wrapper
|
||||
with open(args.vocab_path, 'rb') as f:
|
||||
vocab = pickle.load(f)
|
||||
|
||||
# Build Models
|
||||
encoder = EncoderCNN(config.embed_size)
|
||||
encoder = EncoderCNN(args.embed_size)
|
||||
encoder.eval() # evaluation mode (BN uses moving mean/variance)
|
||||
decoder = DecoderRNN(config.embed_size, config.hidden_size,
|
||||
len(vocab), config.num_layers)
|
||||
decoder = DecoderRNN(args.embed_size, args.hidden_size,
|
||||
len(vocab), args.num_layers)
|
||||
|
||||
|
||||
# Load the trained model parameters
|
||||
encoder.load_state_dict(torch.load(os.path.join(config.model_path,
|
||||
config.trained_encoder)))
|
||||
decoder.load_state_dict(torch.load(os.path.join(config.model_path,
|
||||
config.trained_decoder)))
|
||||
encoder.load_state_dict(torch.load(args.encoder_path))
|
||||
decoder.load_state_dict(torch.load(args.decoder_path))
|
||||
|
||||
# Prepare Image
|
||||
image = Image.open(params['image'])
|
||||
image = Image.open(args.image)
|
||||
image_tensor = Variable(transform(image).unsqueeze(0))
|
||||
|
||||
# Set initial states
|
||||
state = (Variable(torch.zeros(config.num_layers, 1, config.hidden_size)),
|
||||
Variable(torch.zeros(config.num_layers, 1, config.hidden_size)))
|
||||
state = (Variable(torch.zeros(args.num_layers, 1, args.hidden_size)),
|
||||
Variable(torch.zeros(args.num_layers, 1, args.hidden_size)))
|
||||
|
||||
# If use gpu
|
||||
if torch.cuda.is_available():
|
||||
@ -71,7 +69,23 @@ def main(params):
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--image', type=str, required=True, help='image for generating caption')
|
||||
parser.add_argument('--image', type=str, required=True,
|
||||
help='input image for generating caption')
|
||||
parser.add_argument('--encoder_path', type=str, default='./models/encoder-5-3000.pkl',
|
||||
help='path for trained encoder')
|
||||
parser.add_argument('--decoder_path', type=str, default='./models/decoder-5-3000.pkl',
|
||||
help='path for trained decoder')
|
||||
parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl',
|
||||
help='path for vocabulary wrapper')
|
||||
parser.add_argument('--crop_size', type=int, default=224,
|
||||
help='size for center cropping images')
|
||||
|
||||
# Model parameters (should be same as paramters in train.py)
|
||||
parser.add_argument('--embed_size', type=int , default=256,
|
||||
help='dimension of word embedding vectors')
|
||||
parser.add_argument('--hidden_size', type=int , default=512,
|
||||
help='dimension of lstm hidden states')
|
||||
parser.add_argument('--num_layers', type=int , default=1 ,
|
||||
help='number of layers in lstm')
|
||||
args = parser.parse_args()
|
||||
params = vars(args)
|
||||
main(params)
|
||||
main(args)
|
Reference in New Issue
Block a user