mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-27 12:03:31 +08:00
Bug fix
import pickle and add args. to vocab_path
This commit is contained in:
@ -3,6 +3,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
from data_loader import get_loader
|
from data_loader import get_loader
|
||||||
from build_vocab import Vocabulary
|
from build_vocab import Vocabulary
|
||||||
from model import EncoderCNN, DecoderRNN
|
from model import EncoderCNN, DecoderRNN
|
||||||
@ -24,7 +25,7 @@ def main(args):
|
|||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
# Load vocabulary wrapper.
|
# Load vocabulary wrapper.
|
||||||
with open(vocab_path, 'rb') as f:
|
with open(args.vocab_path, 'rb') as f:
|
||||||
vocab = pickle.load(f)
|
vocab = pickle.load(f)
|
||||||
|
|
||||||
# Build data loader
|
# Build data loader
|
||||||
@ -115,4 +116,3 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--learning_rate', type=float, default=0.001)
|
parser.add_argument('--learning_rate', type=float, default=0.001)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
main(args)
|
|
Reference in New Issue
Block a user