10from typing import List
11
12import torch.nn as nn
13
14from labml import lab
15from labml.configs import option
16from labml_nn.helpers.datasets import CIFAR10Configs as CIFAR10DatasetConfigs
17from labml_nn.experiments.mnist import MNISTConfigs20class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):Use CIFAR10 dataset by default
28    dataset_name: str = 'CIFAR10'31@option(CIFAR10Configs.train_dataset)
32def cifar10_train_augmented():36    from torchvision.datasets import CIFAR10
37    from torchvision.transforms import transforms
38    return CIFAR10(str(lab.get_data_path()),
39                   train=True,
40                   download=True,
41                   transform=transforms.Compose([Pad and crop
43                       transforms.RandomCrop(32, padding=4),Random horizontal flip
45                       transforms.RandomHorizontalFlip(),47                       transforms.ToTensor(),
48                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
49                   ]))52@option(CIFAR10Configs.valid_dataset)
53def cifar10_valid_no_augment():57    from torchvision.datasets import CIFAR10
58    from torchvision.transforms import transforms
59    return CIFAR10(str(lab.get_data_path()),
60                   train=False,
61                   download=True,
62                   transform=transforms.Compose([
63                       transforms.ToTensor(),
64                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
65                   ]))68class CIFAR10VGGModel(nn.Module):Convolution and activation combined
73    def conv_block(self, in_channels, out_channels) -> nn.Module:77        return nn.Sequential(
78            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
79            nn.ReLU(inplace=True),
80        )82    def __init__(self, blocks: List[List[int]]):
83        super().__init__()5 pooling layers will produce a output of size . CIFAR 10 image size is
87        assert len(blocks) == 5
88        layers = []RGB channels
90        in_channels = 3Number of channels in each layer in each block
92        for block in blocks:Convolution, Normalization and Activation layers
94            for channels in block:
95                layers += self.conv_block(in_channels, channels)
96                in_channels = channelsMax pooling at end of each block
98            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]Create a sequential model with the layers
101        self.layers = nn.Sequential(*layers)Final logits layer
103        self.fc = nn.Linear(in_channels, 10)105    def forward(self, x):The VGG layers
107        x = self.layers(x)Reshape for classification layer
109        x = x.view(x.shape[0], -1)Final linear layer
111        return self.fc(x)