15import torch.nn as nn
16
17from labml import experiment, logger
18from labml.configs import option
19from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
20from labml_nn.normalization.batch_norm import BatchNorm23class Configs(CIFAR10Configs):30    pass33class SmallModel(CIFAR10VGGModel):创建卷积层和激活
40    def conv_block(self, in_channels, out_channels) -> nn.Module:44        return nn.Sequential(卷积层
46            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),批量标准化
48            BatchNorm(out_channels, track_running_stats=False),激活 ReLU
50            nn.ReLU(inplace=True),
51        )53    def __init__(self):使用给定的卷积大小(通道)创建模型
55        super().__init__([[32, 32], [64, 64], [128], [128], [128]])58@option(Configs.model)
59def _small_model(c: Configs):63    return SmallModel().to(c.device)66def main():创建实验
68    experiment.create(name='cifar10', comment='small model')创建配置
70    conf = Configs()装载配置
72    experiment.configs(conf, {
73        'optimizer.optimizer': 'Adam',
74        'optimizer.learning_rate': 2.5e-4,
75    })设置保存/加载的模型
77    experiment.add_pytorch_models({'model': conf.model})打印模型中参数的数量
79    logger.inspect(params=(sum(p.numel() for p in conf.model.parameters() if p.requires_grad)))开始实验并运行训练循环
81    with experiment.start():
82        conf.run()86if __name__ == '__main__':
87    main()