mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2026-03-13 09:11:37 +08:00
Merge pull request #120 from arisliang/master
minor refactoring and fix
This commit is contained in:
@@ -5,7 +5,7 @@ import torchvision.transforms as transforms
|
||||
|
||||
|
||||
# Hyper-parameters
|
||||
input_size = 784
|
||||
input_size = 28 * 28 # 784
|
||||
num_classes = 10
|
||||
num_epochs = 5
|
||||
batch_size = 100
|
||||
@@ -43,7 +43,7 @@ total_step = len(train_loader)
|
||||
for epoch in range(num_epochs):
|
||||
for i, (images, labels) in enumerate(train_loader):
|
||||
# Reshape images to (batch_size, input_size)
|
||||
images = images.reshape(-1, 28*28)
|
||||
images = images.reshape(-1, input_size)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(images)
|
||||
@@ -64,7 +64,7 @@ with torch.no_grad():
|
||||
correct = 0
|
||||
total = 0
|
||||
for images, labels in test_loader:
|
||||
images = images.reshape(-1, 28*28)
|
||||
images = images.reshape(-1, input_size)
|
||||
outputs = model(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
|
||||
@@ -16,6 +16,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Hyper-parameters
|
||||
num_epochs = 80
|
||||
batch_size = 100
|
||||
learning_rate = 0.001
|
||||
|
||||
# Image preprocessing modules
|
||||
@@ -37,11 +38,11 @@ test_dataset = torchvision.datasets.CIFAR10(root='../../data/',
|
||||
|
||||
# Data loader
|
||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
||||
batch_size=100,
|
||||
batch_size=batch_size,
|
||||
shuffle=True)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
||||
batch_size=100,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
# 3x3 convolution
|
||||
|
||||
@@ -85,6 +85,7 @@ for epoch in range(num_epochs):
|
||||
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
|
||||
|
||||
# Test the model
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
Reference in New Issue
Block a user