这个脚本在 CIFAR 10 数据集上训练一个 ConvMixer。
这不是试图重现论文的结果。本文使用了 PyTorch 图像模型(timm)中存在的图像增强功能进行训练。我们这样做并不是为了简单起见,这会导致我们的验证准确性下降。
20from labml import experiment
21from labml.configs import option
22from labml_nn.experiments.cifar10 import CIFAR10Configs25class Configs(CIFAR10Configs):补丁的大小,
34    patch_size: int = 2补丁嵌入中的通道数,
36    d_model: int = 256ConvMixer 层数或深度,
38    n_layers: int = 8深度卷积的内核大小,
40    kernel_size: int = 7任务中的类数
42    n_classes: int = 1045@option(Configs.model)
46def _conv_mixer(c: Configs):50    from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings创建混音器
53    return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
54                     PatchEmbeddings(c.d_model, c.patch_size, 3),
55                     ClassificationHead(c.d_model, c.n_classes)).to(c.device)58def main():创建实验
60    experiment.create(name='ConvMixer', comment='cifar10')创建配置
62    conf = Configs()装载配置
64    experiment.configs(conf, {优化器
66        'optimizer.optimizer': 'Adam',
67        'optimizer.learning_rate': 2.5e-4,训练周期和批次大小
70        'epochs': 150,
71        'train_batch_size': 64,简单的图像增强
74        'train_dataset': 'cifar10_train_augmented',不要扩充图像以进行验证
76        'valid_dataset': 'cifar10_valid_no_augment',
77    })设置保存/加载的模型
79    experiment.add_pytorch_models({'model': conf.model})开始实验并运行训练循环
81    with experiment.start():
82        conf.run()86if __name__ == '__main__':
87    main()