diff --git a/tutorials/09 - Image Captioning/train.py b/tutorials/09 - Image Captioning/train.py index 05578f5..fd8855d 100644 --- a/tutorials/09 - Image Captioning/train.py +++ b/tutorials/09 - Image Captioning/train.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import numpy as np import os +import pickle from data_loader import get_loader from build_vocab import Vocabulary from model import EncoderCNN, DecoderRNN @@ -24,7 +25,7 @@ def main(args): transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # Load vocabulary wrapper. - with open(vocab_path, 'rb') as f: + with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Build data loader @@ -115,4 +116,3 @@ if __name__ == '__main__': parser.add_argument('--learning_rate', type=float, default=0.001) args = parser.parse_args() print(args) - main(args) \ No newline at end of file