Train a ConvMixer on CIFAR 10

This script trains a ConvMixer on CIFAR 10 dataset.

This is not an attempt to reproduce the results of the paper. The paper uses image augmentations present in PyTorch Image Models (timm) for training. We haven't done this for simplicity - which causes our validation accuracy to drop.

View Run

20from labml import experiment
21from labml.configs import option
22from labml_nn.experiments.cifar10 import CIFAR10Configs

Configurations

We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

25class Configs(CIFAR10Configs):

Size of a patch,

34    patch_size: int = 2

Number of channels in patch embeddings,

36    d_model: int = 256

Number of ConvMixer layers or depth,

38    n_layers: int = 8

Kernel size of the depth-wise convolution,

40    kernel_size: int = 7

Number of classes in the task

42    n_classes: int = 10

Create model

45@option(Configs.model)
46def _conv_mixer(c: Configs):
50    from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings

Create ConvMixer

53    return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
54                     PatchEmbeddings(c.d_model, c.patch_size, 3),
55                     ClassificationHead(c.d_model, c.n_classes)).to(c.device)
58def main():

Create experiment

60    experiment.create(name='ConvMixer', comment='cifar10')

Create configurations

62    conf = Configs()

Load configurations

64    experiment.configs(conf, {

Optimizer

66        'optimizer.optimizer': 'Adam',
67        'optimizer.learning_rate': 2.5e-4,

Training epochs and batch size

70        'epochs': 150,
71        'train_batch_size': 64,

Simple image augmentations

74        'train_dataset': 'cifar10_train_augmented',

Do not augment images for validation

76        'valid_dataset': 'cifar10_valid_no_augment',
77    })

Set model for saving/loading

79    experiment.add_pytorch_models({'model': conf.model})

Start the experiment and run the training loop

81    with experiment.start():
82        conf.run()

86if __name__ == '__main__':
87    main()