image captioning completed'

This commit is contained in:
yunjey
2017-03-21 20:01:47 +09:00
parent ba7d5467be
commit 6f5fda14f0
7 changed files with 297 additions and 68 deletions

View File

@ -6,7 +6,6 @@ from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence
import torch
import torch.nn as nn
import torchvision.transforms as T
import numpy as np
import pickle
import os
@ -16,14 +15,13 @@ def main():
# Configuration for hyper-parameters
config = Config()
# Create model directory
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
# Image preprocessing
transform = T.Compose([
T.Scale(config.image_size), # no resize
T.RandomCrop(config.crop_size),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform = config.train_transform
# Load vocabulary wrapper
with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
vocab = pickle.load(f)
@ -40,22 +38,28 @@ def main():
encoder = EncoderCNN(config.embed_size)
decoder = DecoderRNN(config.embed_size, config.hidden_size,
len(vocab), config.num_layers)
encoder.cuda()
decoder.cuda()
if torch.cuda.is_available()
encoder.cuda()
decoder.cuda()
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters())
optimizer = torch.optim.Adam(params, lr=config.learning_rate)
# Train the Models
for epoch in range(config.num_epochs):
for i, (images, captions, lengths) in enumerate(train_loader):
# Set mini-batch dataset
images = Variable(images).cuda()
captions = Variable(captions).cuda()
images = Variable(images)
captions = Variable(captions)
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
if torch.cuda.is_available():
images = images.cuda()
captions = captions.cuda()
# Forward, Backward and Optimize
decoder.zero_grad()
@ -80,5 +84,6 @@ def main():
torch.save(encoder.state_dict(),
os.path.join(config.model_path,
'encoder-%d-%d.pkl' %(epoch+1, i+1)))
if __name__ == '__main__':
main()