13from labml import experiment
14from labml.configs import option
15from labml_nn.experiments.cifar10 import CIFAR10Configs
16from labml_nn.transformers import TransformerConfigsWe use CIFAR10Configs
which defines all the dataset related configurations, optimizer, and a training loop.
19class Configs(CIFAR10Configs):29 transformer: TransformerConfigsSize of a patch
32 patch_size: int = 4Size of the hidden layer in classification head
34 n_hidden_classification: int = 2048Number of classes in the task
36 n_classes: int = 10Create transformer configs
39@option(Configs.transformer)
40def _transformer():44 return TransformerConfigs()47@option(Configs.model)
48def _vit(c: Configs):52 from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
53 PatchEmbeddingsTransformer size from Transformer configurations
56 d_model = c.transformer.d_modelCreate a vision transformer
58 return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
59 PatchEmbeddings(d_model, c.patch_size, 3),
60 LearnedPositionalEmbeddings(d_model),
61 ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)64def main():Create experiment
66 experiment.create(name='ViT', comment='cifar10')Create configurations
68 conf = Configs()Load configurations
70 experiment.configs(conf, {Optimizer
72 'optimizer.optimizer': 'Adam',
73 'optimizer.learning_rate': 2.5e-4,Transformer embedding size
76 'transformer.d_model': 512,Training epochs and batch size
79 'epochs': 1000,
80 'train_batch_size': 64,Augment CIFAR 10 images for training
83 'train_dataset': 'cifar10_train_augmented',Do not augment CIFAR 10 images for validation
85 'valid_dataset': 'cifar10_valid_no_augment',
86 })Set model for saving/loading
88 experiment.add_pytorch_models({'model': conf.model})Start the experiment and run the training loop
90 with experiment.start():
91 conf.run()95if __name__ == '__main__':
96 main()