This script trains a ConvMixer on CIFAR 10 dataset.
This is not an attempt to reproduce the results of the paper. The paper uses image augmentations present in PyTorch Image Models (timm) for training. We haven't done this for simplicity - which causes our validation accuracy to drop.
20from labml import experiment
21from labml.configs import option
22from labml_nn.experiments.cifar10 import CIFAR10ConfigsWe use CIFAR10Configs
which defines all the dataset related configurations, optimizer, and a training loop.
25class Configs(CIFAR10Configs):Size of a patch,
34 patch_size: int = 2Number of channels in patch embeddings,
36 d_model: int = 256Number of ConvMixer layers or depth,
38 n_layers: int = 8Kernel size of the depth-wise convolution,
40 kernel_size: int = 7Number of classes in the task
42 n_classes: int = 1045@option(Configs.model)
46def _conv_mixer(c: Configs):50 from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddingsCreate ConvMixer
53 return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
54 PatchEmbeddings(c.d_model, c.patch_size, 3),
55 ClassificationHead(c.d_model, c.n_classes)).to(c.device)58def main():Create experiment
60 experiment.create(name='ConvMixer', comment='cifar10')Create configurations
62 conf = Configs()Load configurations
64 experiment.configs(conf, {Optimizer
66 'optimizer.optimizer': 'Adam',
67 'optimizer.learning_rate': 2.5e-4,Training epochs and batch size
70 'epochs': 150,
71 'train_batch_size': 64,Simple image augmentations
74 'train_dataset': 'cifar10_train_augmented',Do not augment images for validation
76 'valid_dataset': 'cifar10_valid_no_augment',
77 })Set model for saving/loading
79 experiment.add_pytorch_models({'model': conf.model})Start the experiment and run the training loop
81 with experiment.start():
82 conf.run()86if __name__ == '__main__':
87 main()