mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-26 19:48:34 +08:00
image captioning completed'
This commit is contained in:
@ -1,3 +1,6 @@
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
class Config(object):
|
||||
"""Wrapper class for hyper-parameters."""
|
||||
def __init__(self):
|
||||
@ -8,6 +11,21 @@ class Config(object):
|
||||
self.word_count_threshold = 4
|
||||
self.num_threads = 2
|
||||
|
||||
# Image preprocessing in training phase
|
||||
self.train_transform = T.Compose([
|
||||
T.Scale(self.image_size),
|
||||
T.RandomCrop(self.crop_size),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.ToTensor(),
|
||||
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
# Image preprocessing in test phase
|
||||
self.test_transform = T.Compose([
|
||||
T.Scale(self.crop_size),
|
||||
T.CenterCrop(self.crop_size),
|
||||
T.ToTensor(),
|
||||
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
# Training
|
||||
self.num_epochs = 5
|
||||
self.batch_size = 64
|
||||
@ -23,4 +41,7 @@ class Config(object):
|
||||
# Path
|
||||
self.image_path = './data/'
|
||||
self.caption_path = './data/annotations/'
|
||||
self.vocab_path = './data/'
|
||||
self.vocab_path = './data/'
|
||||
self.model_path = './model/'
|
||||
self.trained_encoder = 'encoder-4-6000.pkl'
|
||||
self.trained_decoder = 'decoder-4-6000.pkl'
|
Reference in New Issue
Block a user