Files
2023-04-02 12:10:18 +05:30

86 lines
2.4 KiB
Python

"""
---
title: Train ConvMixer on CIFAR 10
summary: >
Train ConvMixer on CIFAR 10
---
# Train a [ConvMixer](index.html) on CIFAR 10
This script trains a ConvMixer on CIFAR 10 dataset.
This is not an attempt to reproduce the results of the paper.
The paper uses image augmentations
present in [PyTorch Image Models (timm)](https://github.com/rwightman/pytorch-image-models)
for training. We haven't done this for simplicity - which causes our validation accuracy to drop.
"""
from labml import experiment
from labml.configs import option
from labml_nn.experiments.cifar10 import CIFAR10Configs
class Configs(CIFAR10Configs):
"""
## Configurations
We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
dataset related configurations, optimizer, and a training loop.
"""
# Size of a patch, $p$
patch_size: int = 2
# Number of channels in patch embeddings, $h$
d_model: int = 256
# Number of [ConvMixer layers](#ConvMixerLayer) or depth, $d$
n_layers: int = 8
# Kernel size of the depth-wise convolution, $k$
kernel_size: int = 7
# Number of classes in the task
n_classes: int = 10
@option(Configs.model)
def _conv_mixer(c: Configs):
"""
### Create model
"""
from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings
# Create ConvMixer
return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
PatchEmbeddings(c.d_model, c.patch_size, 3),
ClassificationHead(c.d_model, c.n_classes)).to(c.device)
def main():
# Create experiment
experiment.create(name='ConvMixer', comment='cifar10')
# Create configurations
conf = Configs()
# Load configurations
experiment.configs(conf, {
# Optimizer
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 2.5e-4,
# Training epochs and batch size
'epochs': 150,
'train_batch_size': 64,
# Simple image augmentations
'train_dataset': 'cifar10_train_augmented',
# Do not augment images for validation
'valid_dataset': 'cifar10_valid_no_augment',
})
# Set model for saving/loading
experiment.add_pytorch_models({'model': conf.model})
# Start the experiment and run the training loop
with experiment.start():
conf.run()
#
if __name__ == '__main__':
main()