CIFAR10 Experiment

10from typing import List
11
12import torch.nn as nn
13
14from labml import lab
15from labml.configs import option
16from labml_nn.helpers.datasets import CIFAR10Configs as CIFAR10DatasetConfigs
17from labml_nn.experiments.mnist import MNISTConfigs

Configurations

This extends from CIFAR 10 dataset configurations and MNISTConfigs .

20class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):

Use CIFAR10 dataset by default

28    dataset_name: str = 'CIFAR10'

Augmented CIFAR 10 train dataset

31@option(CIFAR10Configs.train_dataset)
32def cifar10_train_augmented():
36    from torchvision.datasets import CIFAR10
37    from torchvision.transforms import transforms
38    return CIFAR10(str(lab.get_data_path()),
39                   train=True,
40                   download=True,
41                   transform=transforms.Compose([

Pad and crop

43                       transforms.RandomCrop(32, padding=4),

Random horizontal flip

45                       transforms.RandomHorizontalFlip(),

47                       transforms.ToTensor(),
48                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
49                   ]))

Non-augmented CIFAR 10 validation dataset

52@option(CIFAR10Configs.valid_dataset)
53def cifar10_valid_no_augment():
57    from torchvision.datasets import CIFAR10
58    from torchvision.transforms import transforms
59    return CIFAR10(str(lab.get_data_path()),
60                   train=False,
61                   download=True,
62                   transform=transforms.Compose([
63                       transforms.ToTensor(),
64                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
65                   ]))

VGG model for CIFAR-10 classification

68class CIFAR10VGGModel(nn.Module):

Convolution and activation combined

73    def conv_block(self, in_channels, out_channels) -> nn.Module:
77        return nn.Sequential(
78            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
79            nn.ReLU(inplace=True),
80        )
82    def __init__(self, blocks: List[List[int]]):
83        super().__init__()

5 pooling layers will produce a output of size . CIFAR 10 image size is

87        assert len(blocks) == 5
88        layers = []

RGB channels

90        in_channels = 3

Number of channels in each layer in each block

92        for block in blocks:

Convolution, Normalization and Activation layers

94            for channels in block:
95                layers += self.conv_block(in_channels, channels)
96                in_channels = channels

Max pooling at end of each block

98            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]

Create a sequential model with the layers

101        self.layers = nn.Sequential(*layers)

Final logits layer

103        self.fc = nn.Linear(in_channels, 10)
105    def forward(self, x):

The VGG layers

107        x = self.layers(x)

Reshape for classification layer

109        x = x.view(x.shape[0], -1)

Final linear layer

111        return self.fc(x)