mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-24 01:48:26 +08:00
142 lines
4.8 KiB
Python
142 lines
4.8 KiB
Python
# Implementation of https://arxiv.org/pdf/1512.03385.pdf
|
|
# See section 4.2 for model architecture on CIFAR-10
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.datasets as dsets
|
|
import torchvision.transforms as transforms
|
|
from torch.autograd import Variable
|
|
|
|
# Image Preprocessing
|
|
transform = transforms.Compose([
|
|
transforms.Scale(40),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.RandomCrop(32),
|
|
transforms.ToTensor()])
|
|
|
|
# CIFAR-10 Dataset
|
|
train_dataset = dsets.CIFAR10(root='./data/',
|
|
train=True,
|
|
transform=transform,
|
|
download=True)
|
|
|
|
test_dataset = dsets.CIFAR10(root='./data/',
|
|
train=False,
|
|
transform=transforms.ToTensor())
|
|
|
|
# Data Loader (Input Pipeline)
|
|
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
|
batch_size=100,
|
|
shuffle=True)
|
|
|
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
|
batch_size=100,
|
|
shuffle=False)
|
|
|
|
# 3x3 Convolution
|
|
def conv3x3(in_channels, out_channels, stride=1):
|
|
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
|
|
stride=stride, padding=1, bias=False)
|
|
|
|
# Residual Block
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
|
|
super(ResidualBlock, self).__init__()
|
|
self.conv1 = conv3x3(in_channels, out_channels, stride)
|
|
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.conv2 = conv3x3(out_channels, out_channels)
|
|
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
self.downsample = downsample
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
if self.downsample:
|
|
residual = self.downsample(x)
|
|
out += residual
|
|
out = self.relu(out)
|
|
return out
|
|
|
|
# ResNet Module
|
|
class ResNet(nn.Module):
|
|
def __init__(self, block, layers, num_classes=10):
|
|
super(ResNet, self).__init__()
|
|
self.in_channels = 16
|
|
self.conv = conv3x3(3, 16)
|
|
self.bn = nn.BatchNorm2d(16)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.layer1 = self.make_layer(block, 16, layers[0])
|
|
self.layer2 = self.make_layer(block, 32, layers[0], 2)
|
|
self.layer3 = self.make_layer(block, 64, layers[1], 2)
|
|
self.avg_pool = nn.AvgPool2d(8)
|
|
self.fc = nn.Linear(64, num_classes)
|
|
|
|
def make_layer(self, block, out_channels, blocks, stride=1):
|
|
downsample = None
|
|
if (stride != 1) or (self.in_channels != out_channels):
|
|
downsample = nn.Sequential(
|
|
conv3x3(self.in_channels, out_channels, stride=stride),
|
|
nn.BatchNorm2d(out_channels))
|
|
layers = []
|
|
layers.append(block(self.in_channels, out_channels, stride, downsample))
|
|
self.in_channels = out_channels
|
|
for i in range(1, blocks):
|
|
layers.append(block(out_channels, out_channels))
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
out = self.bn(out)
|
|
out = self.relu(out)
|
|
out = self.layer1(out)
|
|
out = self.layer2(out)
|
|
out = self.layer3(out)
|
|
out = self.avg_pool(out)
|
|
out = out.view(out.size(0), -1)
|
|
out = self.fc(out)
|
|
return out
|
|
|
|
resnet = ResNet(ResidualBlock, [3, 3, 3])
|
|
resnet.cuda()
|
|
|
|
# Loss and Optimizer
|
|
criterion = nn.CrossEntropyLoss()
|
|
lr = 0.001
|
|
optimizer = torch.optim.Adam(resnet.parameters(), lr=lr)
|
|
|
|
# Training
|
|
for epoch in range(40):
|
|
for i, (images, labels) in enumerate(train_loader):
|
|
images = Variable(images.cuda())
|
|
labels = Variable(labels.cuda())
|
|
|
|
# Forward + Backward + Optimize
|
|
optimizer.zero_grad()
|
|
outputs = resnet(images)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if (i+1) % 100 == 0:
|
|
print ("Epoch [%d/%d], Iter [%d/%d] Loss: %.4f" %(epoch+1, 40, i+1, 500, loss.data[0]))
|
|
|
|
# Decaying Learning Rate
|
|
if (epoch+1) % 20 == 0:
|
|
lr /= 3
|
|
optimizer = torch.optim.Adam(resnet.parameters(), lr=lr)
|
|
|
|
# Test
|
|
correct = 0
|
|
total = 0
|
|
for images, labels in test_loader:
|
|
images = Variable(images.cuda())
|
|
outputs = resnet(images)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted.cpu() == labels).sum()
|
|
|
|
print('Accuracy of the model on the test images: %d %%' % (100 * correct / total)) |