在 CIFAR 10 上训练 C onvMixer

这个脚本在 CIFAR 10 数据集上训练一个 ConvMixer。

这不是试图重现论文的结果。本文使用了 PyTorch 图像模型(timm)中存在的图像增强功能进行训练。我们这样做并不是为了简单起见,这会导致我们的验证准确性下降。

View Run

20from labml import experiment
21from labml.configs import option
22from labml_nn.experiments.cifar10 import CIFAR10Configs

配置

我们使用CIFAR10Configs 它来定义所有与数据集相关的配置、优化器和训练循环。

25class Configs(CIFAR10Configs):

补丁的大小,

34    patch_size: int = 2

补丁嵌入中的通道数,

36    d_model: int = 256

ConvMixer 层数或深度,

38    n_layers: int = 8

深度卷积的内核大小,

40    kernel_size: int = 7

任务中的类数

42    n_classes: int = 10

创建模型

45@option(Configs.model)
46def _conv_mixer(c: Configs):
50    from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings

创建混音器

53    return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
54                     PatchEmbeddings(c.d_model, c.patch_size, 3),
55                     ClassificationHead(c.d_model, c.n_classes)).to(c.device)
58def main():

创建实验

60    experiment.create(name='ConvMixer', comment='cifar10')

创建配置

62    conf = Configs()

装载配置

64    experiment.configs(conf, {

优化器

66        'optimizer.optimizer': 'Adam',
67        'optimizer.learning_rate': 2.5e-4,

训练周期和批次大小

70        'epochs': 150,
71        'train_batch_size': 64,

简单的图像增强

74        'train_dataset': 'cifar10_train_augmented',

不要扩充图像以进行验证

76        'valid_dataset': 'cifar10_valid_no_augment',
77    })

设置保存/加载的模型

79    experiment.add_pytorch_models({'model': conf.model})

开始实验并运行训练循环

81    with experiment.start():
82        conf.run()

86if __name__ == '__main__':
87    main()