11from labml import experiment
12from labml.configs import option
13from labml_nn.experiments.cifar10 import CIFAR10Configs
14from labml_nn.transformers import TransformerConfigsWe use CIFAR10Configs
 which defines all the dataset related configurations, optimizer, and a training loop.
17class Configs(CIFAR10Configs):27    transformer: TransformerConfigsSize of a patch
30    patch_size: int = 4Size of the hidden layer in classification head
32    n_hidden_classification: int = 2048Number of classes in the task
34    n_classes: int = 10Create transformer configs
37@option(Configs.transformer)
38def _transformer():42    return TransformerConfigs()45@option(Configs.model)
46def _vit(c: Configs):50    from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
51        PatchEmbeddingsTransformer size from Transformer configurations
54    d_model = c.transformer.d_modelCreate a vision transformer
56    return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
57                             PatchEmbeddings(d_model, c.patch_size, 3),
58                             LearnedPositionalEmbeddings(d_model),
59                             ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)62def main():Create experiment
64    experiment.create(name='ViT', comment='cifar10')Create configurations
66    conf = Configs()Load configurations
68    experiment.configs(conf, {Optimizer
70        'optimizer.optimizer': 'Adam',
71        'optimizer.learning_rate': 2.5e-4,Transformer embedding size
74        'transformer.d_model': 512,Training epochs and batch size
77        'epochs': 32,
78        'train_batch_size': 64,Augment CIFAR 10 images for training
81        'train_dataset': 'cifar10_train_augmented',Do not augment CIFAR 10 images for validation
83        'valid_dataset': 'cifar10_valid_no_augment',
84    })Set model for saving/loading
86    experiment.add_pytorch_models({'model': conf.model})Start the experiment and run the training loop
88    with experiment.start():
89        conf.run()93if __name__ == '__main__':
94    main()