mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-26 19:48:34 +08:00
image captioning completed'
This commit is contained in:
@ -6,7 +6,6 @@ from torch.autograd import Variable
|
||||
from torch.nn.utils.rnn import pack_padded_sequence
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
@ -16,14 +15,13 @@ def main():
|
||||
# Configuration for hyper-parameters
|
||||
config = Config()
|
||||
|
||||
# Create model directory
|
||||
if not os.path.exists(config.model_path):
|
||||
os.makedirs(config.model_path)
|
||||
|
||||
# Image preprocessing
|
||||
transform = T.Compose([
|
||||
T.Scale(config.image_size), # no resize
|
||||
T.RandomCrop(config.crop_size),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.ToTensor(),
|
||||
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
transform = config.train_transform
|
||||
|
||||
# Load vocabulary wrapper
|
||||
with open(os.path.join(config.vocab_path, 'vocab.pkl'), 'rb') as f:
|
||||
vocab = pickle.load(f)
|
||||
@ -40,22 +38,28 @@ def main():
|
||||
encoder = EncoderCNN(config.embed_size)
|
||||
decoder = DecoderRNN(config.embed_size, config.hidden_size,
|
||||
len(vocab), config.num_layers)
|
||||
encoder.cuda()
|
||||
decoder.cuda()
|
||||
|
||||
if torch.cuda.is_available()
|
||||
encoder.cuda()
|
||||
decoder.cuda()
|
||||
|
||||
# Loss and Optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters())
|
||||
optimizer = torch.optim.Adam(params, lr=config.learning_rate)
|
||||
|
||||
|
||||
# Train the Models
|
||||
for epoch in range(config.num_epochs):
|
||||
for i, (images, captions, lengths) in enumerate(train_loader):
|
||||
|
||||
# Set mini-batch dataset
|
||||
images = Variable(images).cuda()
|
||||
captions = Variable(captions).cuda()
|
||||
images = Variable(images)
|
||||
captions = Variable(captions)
|
||||
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
images = images.cuda()
|
||||
captions = captions.cuda()
|
||||
|
||||
# Forward, Backward and Optimize
|
||||
decoder.zero_grad()
|
||||
@ -80,5 +84,6 @@ def main():
|
||||
torch.save(encoder.state_dict(),
|
||||
os.path.join(config.model_path,
|
||||
'encoder-%d-%d.pkl' %(epoch+1, i+1)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user