10from typing import List
11
12import torch.nn as nn
13
14from labml import lab
15from labml.configs import option
16from labml_nn.helpers.datasets import CIFAR10Configs as CIFAR10DatasetConfigs
17from labml_nn.experiments.mnist import MNISTConfigs
20class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):
Use CIFAR10 dataset by default
28 dataset_name: str = 'CIFAR10'
31@option(CIFAR10Configs.train_dataset)
32def cifar10_train_augmented():
36 from torchvision.datasets import CIFAR10
37 from torchvision.transforms import transforms
38 return CIFAR10(str(lab.get_data_path()),
39 train=True,
40 download=True,
41 transform=transforms.Compose([
Pad and crop
43 transforms.RandomCrop(32, padding=4),
Random horizontal flip
45 transforms.RandomHorizontalFlip(),
47 transforms.ToTensor(),
48 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
49 ]))
52@option(CIFAR10Configs.valid_dataset)
53def cifar10_valid_no_augment():
57 from torchvision.datasets import CIFAR10
58 from torchvision.transforms import transforms
59 return CIFAR10(str(lab.get_data_path()),
60 train=False,
61 download=True,
62 transform=transforms.Compose([
63 transforms.ToTensor(),
64 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
65 ]))
68class CIFAR10VGGModel(nn.Module):
Convolution and activation combined
73 def conv_block(self, in_channels, out_channels) -> nn.Module:
77 return nn.Sequential(
78 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
79 nn.ReLU(inplace=True),
80 )
82 def __init__(self, blocks: List[List[int]]):
83 super().__init__()
5 pooling layers will produce a output of size . CIFAR 10 image size is
87 assert len(blocks) == 5
88 layers = []
RGB channels
90 in_channels = 3
Number of channels in each layer in each block
92 for block in blocks:
Convolution, Normalization and Activation layers
94 for channels in block:
95 layers += self.conv_block(in_channels, channels)
96 in_channels = channels
Max pooling at end of each block
98 layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
Create a sequential model with the layers
101 self.layers = nn.Sequential(*layers)
Final logits layer
103 self.fc = nn.Linear(in_channels, 10)
105 def forward(self, x):
The VGG layers
107 x = self.layers(x)
Reshape for classification layer
109 x = x.view(x.shape[0], -1)
Final linear layer
111 return self.fc(x)