diff --git a/tutorials/03-advanced/image_captioning/README.md b/tutorials/03-advanced/image_captioning/README.md index 7a658a2..dddcca2 100644 --- a/tutorials/03-advanced/image_captioning/README.md +++ b/tutorials/03-advanced/image_captioning/README.md @@ -56,4 +56,4 @@ $ python sample.py --image='png/example.png'
## Pretrained model -If you do not want to train the model from scratch, you can use a pretrained model. I have provided the pretrained model as a zip file. You can download the file [here](https://www.dropbox.com/s/bmo30z81a4v7m0r/pretrained_model.zip?dl=0) and extract it to `./models/` directory using `unzip pretrained_model.zip`. +If you do not want to train the model from scratch, you can use a pretrained model. I have provided the pretrained model as a zip file. You can download the pretrained model [here](https://www.dropbox.com/s/ne0ixz5d58ccbbz/pretrained_model.zip?dl=0) and vocabulary file [here](https://www.dropbox.com/s/26adb7y9m98uisa/vocap.zip?dl=0). Note that you should extract pretrained_model.zip to `./models/` and vocab.pkl to `./data/`. diff --git a/tutorials/03-advanced/image_captioning/build_vocab.py b/tutorials/03-advanced/image_captioning/build_vocab.py index 612920a..883fc69 100644 --- a/tutorials/03-advanced/image_captioning/build_vocab.py +++ b/tutorials/03-advanced/image_captioning/build_vocab.py @@ -59,7 +59,7 @@ def main(args): threshold=args.threshold) vocab_path = args.vocab_path with open(vocab_path, 'wb') as f: - pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) + pickle.dump(vocab, f) print("Total vocabulary size: %d" %len(vocab)) print("Saved the vocabulary wrapper to '%s'" %vocab_path) @@ -67,7 +67,7 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--caption_path', type=str, - default='./data/annotations/captions_train2014.json', + default='/usr/share/mscoco/annotations/captions_train2014.json', help='path for train annotation file') parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl', help='path for saving vocabulary wrapper') diff --git a/tutorials/03-advanced/image_captioning/sample.py b/tutorials/03-advanced/image_captioning/sample.py index acf6271..ce1a999 100644 --- a/tutorials/03-advanced/image_captioning/sample.py +++ b/tutorials/03-advanced/image_captioning/sample.py @@ -47,7 +47,7 @@ def main(args): encoder.load_state_dict(torch.load(args.encoder_path)) decoder.load_state_dict(torch.load(args.decoder_path)) - # Prepare Image + # Prepare Image image = load_image(args.image, transform) image_tensor = to_var(image, volatile=True) @@ -72,6 +72,7 @@ def main(args): # Print out image and generated caption. print (sentence) + image = Image.open(args.image) plt.imshow(np.asarray(image)) if __name__ == '__main__':