This is a PyTorch implementation of the paper Deep Residual Learning for Image Recognition.
ResNets train layers as residual functions to overcome the degradation problem. The degradation problem is the accuracy of deep neural networks degrading when the number of layers becomes very high. The accuracy increases as the number of layers increase, then saturates, and then starts to degrade.
The paper argues that deeper models should perform at least as well as shallower models because the extra layers can just learn to perform an identity mapping.
If is the mapping that needs to be learned by a few layers, they train the residual function
instead. And the original function becomes .
In this case, learning identity mapping for is equivalent to learning to be , which is easier to learn.
In the parameterized form this can be written as,
and when the feature map sizes of and are different the paper suggests doing a linear projection, with learned weights .
Paper experimented with zero padding instead of linear projections and found linear projections to work better. Also when the feature map sizes match they found identity mapping to be better than linear projections.
should have more than one layer, otherwise the sum also won't have non-linearities and will be like a linear layer.
Here is the training code for training a ResNet on CIFAR-10.
57from typing import List, Optional
58
59import torch
60from torch import nn
61
62from labml_helpers.module import Module65class ShortcutProjection(Module):in_channels
 is the number of channels in  out_channels
 is the number of channels in  stride
 is the stride length in the convolution operation for . We do the same stride on the shortcut connection, to match the feature-map size.72    def __init__(self, in_channels: int, out_channels: int, stride: int):79        super().__init__()Convolution layer for linear projection
82        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)Paper suggests adding batch normalization after each convolution operation
84        self.bn = nn.BatchNorm2d(out_channels)86    def forward(self, x: torch.Tensor):Convolution and batch normalization
88        return self.bn(self.conv(x))This implements the residual block described in the paper. It has two convolution layers.
The first convolution layer maps from in_channels
 to out_channels
, where the out_channels
 is higher than in_channels
 when we reduce the feature map size with a stride length greater than .
The second convolution layer maps from out_channels
 to out_channels
 and always has a stride length of 1.
Both convolution layers are followed by batch normalization.
91class ResidualBlock(Module):in_channels
 is the number of channels in  out_channels
 is the number of output channels stride
 is the stride length in the convolution operation.112    def __init__(self, in_channels: int, out_channels: int, stride: int):118        super().__init__()First  convolution layer, this maps to out_channels
 
121        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)Batch normalization after the first convolution
123        self.bn1 = nn.BatchNorm2d(out_channels)First activation function (ReLU)
125        self.act1 = nn.ReLU()Second convolution layer
128        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)Batch normalization after the second convolution
130        self.bn2 = nn.BatchNorm2d(out_channels)Shortcut connection should be a projection if the stride length is not of if the number of channels change
134        if stride != 1 or in_channels != out_channels:Projection
136            self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
137        else:Identity
139            self.shortcut = nn.Identity()Second activation function (ReLU) (after adding the shortcut)
142        self.act2 = nn.ReLU()x
 is the input of shape [batch_size, in_channels, height, width]
144    def forward(self, x: torch.Tensor):Get the shortcut connection
149        shortcut = self.shortcut(x)First convolution and activation
151        x = self.act1(self.bn1(self.conv1(x)))Second convolution
153        x = self.bn2(self.conv2(x))Activation function after adding the shortcut
155        return self.act2(x + shortcut)This implements the bottleneck block described in the paper. It has , , and convolution layers.
The first convolution layer maps from in_channels
 to bottleneck_channels
 with a  convolution, where the bottleneck_channels
 is lower than in_channels
.
The second  convolution layer maps from bottleneck_channels
 to bottleneck_channels
. This can have a stride length greater than  when we want to compress the feature map size.
The third, final  convolution layer maps to out_channels
. out_channels
 is higher than in_channels
 if the stride length is greater than ; otherwise,  is equal to in_channels
.
bottleneck_channels
 is less than in_channels
 and the  convolution is performed on this shrunk space (hence the bottleneck). The two  convolution decreases and increases the number of channels.
158class BottleneckResidualBlock(Module):in_channels
 is the number of channels in  bottleneck_channels
 is the number of channels for the  convlution out_channels
 is the number of output channels stride
 is the stride length in the  convolution operation.186    def __init__(self, in_channels: int, bottleneck_channels: int, out_channels: int, stride: int):193        super().__init__()First  convolution layer, this maps to bottleneck_channels
 
196        self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, stride=1)Batch normalization after the first convolution
198        self.bn1 = nn.BatchNorm2d(bottleneck_channels)First activation function (ReLU)
200        self.act1 = nn.ReLU()Second convolution layer
203        self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride, padding=1)Batch normalization after the second convolution
205        self.bn2 = nn.BatchNorm2d(bottleneck_channels)Second activation function (ReLU)
207        self.act2 = nn.ReLU()Third  convolution layer, this maps to out_channels
. 
210        self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, stride=1)Batch normalization after the second convolution
212        self.bn3 = nn.BatchNorm2d(out_channels)Shortcut connection should be a projection if the stride length is not of if the number of channels change
216        if stride != 1 or in_channels != out_channels:Projection
218            self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
219        else:Identity
221            self.shortcut = nn.Identity()Second activation function (ReLU) (after adding the shortcut)
224        self.act3 = nn.ReLU()x
 is the input of shape [batch_size, in_channels, height, width]
226    def forward(self, x: torch.Tensor):Get the shortcut connection
231        shortcut = self.shortcut(x)First convolution and activation
233        x = self.act1(self.bn1(self.conv1(x)))Second convolution and activation
235        x = self.act2(self.bn2(self.conv2(x)))Third convolution
237        x = self.bn3(self.conv3(x))Activation function after adding the shortcut
239        return self.act3(x + shortcut)This is a the base of the resnet model without the final linear layer and softmax for classification.
The resnet is made of stacked residual blocks or bottleneck residual blocks. The feature map size is halved after a few blocks with a block of stride length . The number of channels is increased when the feature map size is reduced. Finally the feature map is average pooled to get a vector representation.
242class ResNetBase(Module):n_blocks
 is a list of of number of blocks for each feature map size. n_channels
 is the number of channels for each feature map size. bottlenecks
 is the number of channels the bottlenecks. If this is None
, residual blocks are used. img_channels
 is the number of channels in the input. first_kernel_size
 is the kernel size of the initial convolution layer256    def __init__(self, n_blocks: List[int], n_channels: List[int],
257                 bottlenecks: Optional[List[int]] = None,
258                 img_channels: int = 3, first_kernel_size: int = 7):267        super().__init__()Number of blocks and number of channels for each feature map size
270        assert len(n_blocks) == len(n_channels)If bottleneck residual blocks are used, the number of channels in bottlenecks should be provided for each feature map size
273        assert bottlenecks is None or len(bottlenecks) == len(n_channels)Initial convolution layer maps from img_channels
 to number of channels in the first residual block (n_channels[0]
) 
277        self.conv = nn.Conv2d(img_channels, n_channels[0],
278                              kernel_size=first_kernel_size, stride=2, padding=first_kernel_size // 2)Batch norm after initial convolution
280        self.bn = nn.BatchNorm2d(n_channels[0])List of blocks
283        blocks = []Number of channels from previous layer (or block)
285        prev_channels = n_channels[0]Loop through each feature map size
287        for i, channels in enumerate(n_channels):The first block for the new feature map size, will have a stride length of except fro the very first block
290            stride = 2 if len(blocks) == 0 else 1
291
292            if bottlenecks is None:residual blocks that maps from prev_channels
 to channels
 
294                blocks.append(ResidualBlock(prev_channels, channels, stride=stride))
295            else:bottleneck residual blocks that maps from prev_channels
 to channels
 
298                blocks.append(BottleneckResidualBlock(prev_channels, bottlenecks[i], channels,
299                                                      stride=stride))Change the number of channels
302            prev_channels = channelsAdd rest of the blocks - no change in feature map size or channels
304            for _ in range(n_blocks[i] - 1):
305                if bottlenecks is None:307                    blocks.append(ResidualBlock(channels, channels, stride=1))
308                else:310                    blocks.append(BottleneckResidualBlock(channels, bottlenecks[i], channels, stride=1))Stack the blocks
313        self.blocks = nn.Sequential(*blocks)x
 has shape [batch_size, img_channels, height, width]
315    def forward(self, x: torch.Tensor):Initial convolution and batch normalization
321        x = self.bn(self.conv(x))Residual (or bottleneck) blocks
323        x = self.blocks(x)Change x
 from shape [batch_size, channels, h, w]
 to [batch_size, channels, h * w]
 
325        x = x.view(x.shape[0], x.shape[1], -1)Global average pooling
327        return x.mean(dim=-1)