mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2026-03-13 09:11:37 +08:00
minor refactoring for batch size in deep residual network
This commit is contained in:
@@ -16,6 +16,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Hyper-parameters
|
||||
num_epochs = 80
|
||||
batch_size = 100
|
||||
learning_rate = 0.001
|
||||
|
||||
# Image preprocessing modules
|
||||
@@ -37,11 +38,11 @@ test_dataset = torchvision.datasets.CIFAR10(root='../../data/',
|
||||
|
||||
# Data loader
|
||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
||||
batch_size=100,
|
||||
batch_size=batch_size,
|
||||
shuffle=True)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
||||
batch_size=100,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
# 3x3 convolution
|
||||
|
||||
Reference in New Issue
Block a user