import pickle and add args. to vocab_path
This commit is contained in:
keishinkickback
2017-05-05 16:46:15 -04:00
parent 25ce8300ce
commit b2c7bd7bfc

View File

@ -3,6 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import os import os
import pickle
from data_loader import get_loader from data_loader import get_loader
from build_vocab import Vocabulary from build_vocab import Vocabulary
from model import EncoderCNN, DecoderRNN from model import EncoderCNN, DecoderRNN
@ -24,7 +25,7 @@ def main(args):
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Load vocabulary wrapper. # Load vocabulary wrapper.
with open(vocab_path, 'rb') as f: with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f) vocab = pickle.load(f)
# Build data loader # Build data loader
@ -115,4 +116,3 @@ if __name__ == '__main__':
parser.add_argument('--learning_rate', type=float, default=0.001) parser.add_argument('--learning_rate', type=float, default=0.001)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args)