Merge pull request #30 from keishinkickback/master

Bug fixes in /tutorials/09 - Image Captioning/train.py
This commit is contained in:
yunjey
2017-05-06 22:55:47 +09:00
committed by GitHub

View File

@@ -3,6 +3,7 @@ import torch
import torch.nn as nn
import numpy as np
import os
import pickle
from data_loader import get_loader
from build_vocab import Vocabulary
from model import EncoderCNN, DecoderRNN
@@ -24,7 +25,7 @@ def main(args):
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Load vocabulary wrapper.
with open(vocab_path, 'rb') as f:
with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f)
# Build data loader
@@ -115,4 +116,3 @@ if __name__ == '__main__':
parser.add_argument('--learning_rate', type=float, default=0.001)
args = parser.parse_args()
print(args)
main(args)