diff --git a/tutorials/09 - Image Captioning/model.py b/tutorials/09 - Image Captioning/model.py index 448c201..fcef106 100644 --- a/tutorials/09 - Image Captioning/model.py +++ b/tutorials/09 - Image Captioning/model.py @@ -10,9 +10,15 @@ class EncoderCNN(nn.Module): """Load pretrained ResNet-152 and replace top fc layer.""" super(EncoderCNN, self).__init__() self.resnet = models.resnet152(pretrained=True) - self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size) + # For efficient memory usage. for param in self.resnet.parameters(): param.requires_grad = False + self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size) + self.init_weights() + + def init_weights(self): + self.resnet.fc.weight.data.uniform_(-0.1, 0.1) + self.resnet.fc.bias.data.fill_(0) def forward(self, images): """Extract image feature vectors.""" @@ -30,6 +36,11 @@ class DecoderRNN(nn.Module): self.embed = nn.Embedding(vocab_size, embed_size) self.lstm = nn.LSTM(embed_size, hidden_size, num_layers) self.linear = nn.Linear(hidden_size, vocab_size) + + def init_weights(self): + self.embed.weight.data.uniform_(-0.1, 0.1) + self.linear.weigth.data.uniform_(-0.1, 0.1) + self.linear.bias.data.fill_(0) def forward(self, features, captions, lengths): """Decode image feature vectors and generate caption.""" diff --git a/tutorials/09 - Image Captioning/sample.py b/tutorials/09 - Image Captioning/sample.py index 2886da5..a444d08 100644 --- a/tutorials/09 - Image Captioning/sample.py +++ b/tutorials/09 - Image Captioning/sample.py @@ -1,6 +1,7 @@ import os import numpy as np import torch +import torchvision.transforms as T import pickle import matplotlib.pyplot as plt from PIL import Image @@ -8,6 +9,12 @@ from model import EncoderCNN, DecoderRNN from vocab import Vocabulary from torch.autograd import Variable +# Image processing +transform = T.Compose([ + T.CenterCrop(224), + T.ToTensor(), + T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) + # Hyper Parameters embed_size = 128 hidden_size = 512 @@ -18,11 +25,10 @@ with open('./data/vocab.pkl', 'rb') as f: vocab = pickle.load(f) # Load an image array -images = os.listdir('./data/val2014resized/') -image_path = './data/val2014resized/' + images[12] -with open(image_path, 'r+b') as f: - img = np.asarray(Image.open(f)) -image = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0) / 255 - 0.5 +images = os.listdir('./data/train2014resized/') +image_path = './data/train2014resized/' + images[12] +img = Image.open(image_path) +image = transform(img).unsqueeze(0) # Load the trained models encoder = torch.load('./encoder.pkl') diff --git a/tutorials/09 - Image Captioning/train.py b/tutorials/09 - Image Captioning/train.py index 7998616..d97976d 100644 --- a/tutorials/09 - Image Captioning/train.py +++ b/tutorials/09 - Image Captioning/train.py @@ -1,6 +1,6 @@ from data import get_loader from vocab import Vocabulary -from models import EncoderCNN, DecoderRNN +from model import EncoderCNN, DecoderRNN from torch.autograd import Variable from torch.nn.utils.rnn import pack_padded_sequence import torch @@ -10,10 +10,11 @@ import torchvision.transforms as T import pickle # Hyper Parameters -num_epochs = 5 -batch_size = 100 -embed_size = 128 +num_epochs = 1 +batch_size = 32 +embed_size = 256 hidden_size = 512 +crop_size = 224 num_layers = 1 learning_rate = 0.001 train_image_path = './data/train2014resized/' @@ -21,6 +22,7 @@ train_json_path = './data/annotations/captions_train2014.json' # Image Preprocessing transform = T.Compose([ + T.RandomCrop(crop_size), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) @@ -42,7 +44,8 @@ decoder.cuda() # Loss and Optimizer criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate) +params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters()) +optimizer = torch.optim.Adam(params, lr=learning_rate) # Train the Decoder for epoch in range(num_epochs): @@ -63,7 +66,7 @@ for epoch in range(num_epochs): if i % 100 == 0: print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' %(epoch, num_epochs, i, total_step, loss.data[0], np.exp(loss.data[0]))) - + # Save the Model torch.save(decoder, 'decoder.pkl') torch.save(encoder, 'encoder.pkl') \ No newline at end of file