这个脚本在 CIFAR 10 数据集上训练一个 ConvMixer。
这不是试图重现论文的结果。本文使用了 PyTorch 图像模型(timm)中存在的图像增强功能进行训练。我们这样做并不是为了简单起见,这会导致我们的验证准确性下降。
20from labml import experiment
21from labml.configs import option
22from labml_nn.experiments.cifar10 import CIFAR10Configs25class Configs(CIFAR10Configs):补丁的大小,
34 patch_size: int = 2补丁嵌入中的通道数,
36 d_model: int = 256ConvMixer 层数或深度,
38 n_layers: int = 8深度卷积的内核大小,
40 kernel_size: int = 7任务中的类数
42 n_classes: int = 1045@option(Configs.model)
46def _conv_mixer(c: Configs):50 from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings创建混音器
53 return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
54 PatchEmbeddings(c.d_model, c.patch_size, 3),
55 ClassificationHead(c.d_model, c.n_classes)).to(c.device)58def main():创建实验
60 experiment.create(name='ConvMixer', comment='cifar10')创建配置
62 conf = Configs()装载配置
64 experiment.configs(conf, {优化器
66 'optimizer.optimizer': 'Adam',
67 'optimizer.learning_rate': 2.5e-4,训练周期和批次大小
70 'epochs': 150,
71 'train_batch_size': 64,简单的图像增强
74 'train_dataset': 'cifar10_train_augmented',不要扩充图像以进行验证
76 'valid_dataset': 'cifar10_valid_no_augment',
77 })设置保存/加载的模型
79 experiment.add_pytorch_models({'model': conf.model})开始实验并运行训练循环
81 with experiment.start():
82 conf.run()86if __name__ == '__main__':
87 main()