此脚本在 CIFAR 10 数据集上训练 ConvMixer。
这并不是试图重现论文的结果。本文使用 PyTorch 图像模型 (timm) 中存在的图像增强进行训练。为了简单起见,我们没有这样做——这会导致我们的验证精度下降。
18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.cifar10 import CIFAR10Configs23class Configs(CIFAR10Configs):补丁的大小,
32 patch_size: int = 2补丁嵌入中的通道数,
34 d_model: int = 256ConvMixer 层数或深度,
36 n_layers: int = 8深度卷积的内核大小,
38 kernel_size: int = 7任务中的类数
40 n_classes: int = 1043@option(Configs.model)
44def _conv_mixer(c: Configs):48 from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings创建混音器
51 return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
52 PatchEmbeddings(c.d_model, c.patch_size, 3),
53 ClassificationHead(c.d_model, c.n_classes)).to(c.device)56def main():创建实验
58 experiment.create(name='ConvMixer', comment='cifar10')创建配置
60 conf = Configs()装载配置
62 experiment.configs(conf, {优化器
64 'optimizer.optimizer': 'Adam',
65 'optimizer.learning_rate': 2.5e-4,训练周期和批次大小
68 'epochs': 150,
69 'train_batch_size': 64,简单的图像增强
72 'train_dataset': 'cifar10_train_augmented',不要扩充图像以进行验证
74 'valid_dataset': 'cifar10_valid_no_augment',
75 })设置保存/加载的模型
77 experiment.add_pytorch_models({'model': conf.model})开始实验并运行训练循环
79 with experiment.start():
80 conf.run()84if __name__ == '__main__':
85 main()