Train a Vision Transformer (ViT) on CIFAR 10

View Run

13from labml import experiment
14from labml.configs import option
15from labml_nn.experiments.cifar10 import CIFAR10Configs
16from labml_nn.transformers import TransformerConfigs

Configurations

We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

19class Configs(CIFAR10Configs):
29    transformer: TransformerConfigs

Size of a patch

32    patch_size: int = 4

Size of the hidden layer in classification head

34    n_hidden_classification: int = 2048

Number of classes in the task

36    n_classes: int = 10

Create transformer configs

39@option(Configs.transformer)
40def _transformer():
44    return TransformerConfigs()

Create model

47@option(Configs.model)
48def _vit(c: Configs):
52    from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
53        PatchEmbeddings

Transformer size from Transformer configurations

56    d_model = c.transformer.d_model

Create a vision transformer

58    return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
59                             PatchEmbeddings(d_model, c.patch_size, 3),
60                             LearnedPositionalEmbeddings(d_model),
61                             ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)
64def main():

Create experiment

66    experiment.create(name='ViT', comment='cifar10')

Create configurations

68    conf = Configs()

Load configurations

70    experiment.configs(conf, {

Optimizer

72        'optimizer.optimizer': 'Adam',
73        'optimizer.learning_rate': 2.5e-4,

Transformer embedding size

76        'transformer.d_model': 512,

Training epochs and batch size

79        'epochs': 1000,
80        'train_batch_size': 64,

Augment CIFAR 10 images for training

83        'train_dataset': 'cifar10_train_augmented',

Do not augment CIFAR 10 images for validation

85        'valid_dataset': 'cifar10_valid_no_augment',
86    })

Set model for saving/loading

88    experiment.add_pytorch_models({'model': conf.model})

Start the experiment and run the training loop

90    with experiment.start():
91        conf.run()

95if __name__ == '__main__':
96    main()