mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-08 19:05:33 +08:00
edit image captioning code
This commit is contained in:
@ -56,4 +56,4 @@ $ python sample.py --image='png/example.png'
|
|||||||
<br>
|
<br>
|
||||||
|
|
||||||
## Pretrained model
|
## Pretrained model
|
||||||
If you do not want to train the model from scratch, you can use a pretrained model. I have provided the pretrained model as a zip file. You can download the file [here](https://www.dropbox.com/s/bmo30z81a4v7m0r/pretrained_model.zip?dl=0) and extract it to `./models/` directory using `unzip pretrained_model.zip`.
|
If you do not want to train the model from scratch, you can use a pretrained model. I have provided the pretrained model as a zip file. You can download the pretrained model [here](https://www.dropbox.com/s/ne0ixz5d58ccbbz/pretrained_model.zip?dl=0) and vocabulary file [here](https://www.dropbox.com/s/26adb7y9m98uisa/vocap.zip?dl=0). Note that you should extract pretrained_model.zip to `./models/` and vocab.pkl to `./data/`.
|
||||||
|
@ -59,7 +59,7 @@ def main(args):
|
|||||||
threshold=args.threshold)
|
threshold=args.threshold)
|
||||||
vocab_path = args.vocab_path
|
vocab_path = args.vocab_path
|
||||||
with open(vocab_path, 'wb') as f:
|
with open(vocab_path, 'wb') as f:
|
||||||
pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL)
|
pickle.dump(vocab, f)
|
||||||
print("Total vocabulary size: %d" %len(vocab))
|
print("Total vocabulary size: %d" %len(vocab))
|
||||||
print("Saved the vocabulary wrapper to '%s'" %vocab_path)
|
print("Saved the vocabulary wrapper to '%s'" %vocab_path)
|
||||||
|
|
||||||
@ -67,7 +67,7 @@ def main(args):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--caption_path', type=str,
|
parser.add_argument('--caption_path', type=str,
|
||||||
default='./data/annotations/captions_train2014.json',
|
default='/usr/share/mscoco/annotations/captions_train2014.json',
|
||||||
help='path for train annotation file')
|
help='path for train annotation file')
|
||||||
parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl',
|
parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl',
|
||||||
help='path for saving vocabulary wrapper')
|
help='path for saving vocabulary wrapper')
|
||||||
|
@ -72,6 +72,7 @@ def main(args):
|
|||||||
|
|
||||||
# Print out image and generated caption.
|
# Print out image and generated caption.
|
||||||
print (sentence)
|
print (sentence)
|
||||||
|
image = Image.open(args.image)
|
||||||
plt.imshow(np.asarray(image))
|
plt.imshow(np.asarray(image))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Reference in New Issue
Block a user