mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-27 12:03:31 +08:00
Bug fix
import pickle and add args. to vocab_path
This commit is contained in:
@ -3,6 +3,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
from data_loader import get_loader
|
||||
from build_vocab import Vocabulary
|
||||
from model import EncoderCNN, DecoderRNN
|
||||
@ -24,7 +25,7 @@ def main(args):
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
# Load vocabulary wrapper.
|
||||
with open(vocab_path, 'rb') as f:
|
||||
with open(args.vocab_path, 'rb') as f:
|
||||
vocab = pickle.load(f)
|
||||
|
||||
# Build data loader
|
||||
@ -115,4 +116,3 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--learning_rate', type=float, default=0.001)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
main(args)
|
Reference in New Issue
Block a user