10from typing import List
11
12import torch.nn as nn
13
14from labml import lab
15from labml.configs import option
16from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
17from labml_helpers.module import Module
18from labml_nn.experiments.mnist import MNISTConfigsThis extends from CIFAR 10 dataset configurations from  labml_helpers
  and MNISTConfigs
.
21class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):Use CIFAR10 dataset by default
30    dataset_name: str = 'CIFAR10'33@option(CIFAR10Configs.train_dataset)
34def cifar10_train_augmented():38    from torchvision.datasets import CIFAR10
39    from torchvision.transforms import transforms
40    return CIFAR10(str(lab.get_data_path()),
41                   train=True,
42                   download=True,
43                   transform=transforms.Compose([Pad and crop
45                       transforms.RandomCrop(32, padding=4),Random horizontal flip
47                       transforms.RandomHorizontalFlip(),49                       transforms.ToTensor(),
50                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
51                   ]))54@option(CIFAR10Configs.valid_dataset)
55def cifar10_valid_no_augment():59    from torchvision.datasets import CIFAR10
60    from torchvision.transforms import transforms
61    return CIFAR10(str(lab.get_data_path()),
62                   train=False,
63                   download=True,
64                   transform=transforms.Compose([
65                       transforms.ToTensor(),
66                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
67                   ]))70class CIFAR10VGGModel(Module):Convolution and activation combined
75    def conv_block(self, in_channels, out_channels) -> nn.Module:79        return nn.Sequential(
80            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
81            nn.ReLU(inplace=True),
82        )84    def __init__(self, blocks: List[List[int]]):
85        super().__init__()5 pooling layers will produce a output of size . CIFAR 10 image size is
89        assert len(blocks) == 5
90        layers = []RGB channels
92        in_channels = 3Number of channels in each layer in each block
94        for block in blocks:Convolution, Normalization and Activation layers
96            for channels in block:
97                layers += self.conv_block(in_channels, channels)
98                in_channels = channelsMax pooling at end of each block
100            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]Create a sequential model with the layers
103        self.layers = nn.Sequential(*layers)Final logits layer
105        self.fc = nn.Linear(in_channels, 10)107    def forward(self, x):The VGG layers
109        x = self.layers(x)Reshape for classification layer
111        x = x.view(x.shape[0], -1)Final linear layer
113        return self.fc(x)