在神经网络中提炼知识

这是论文《在神经网络中提炼知识》的 PyTorch 实现/教程。

这是一种使用经过训练的大型网络中的知识来训练小型网络的方法;即从大型网络中提取知识。

当@@

直接在数据和标签上训练时,具有正则化或模型集合(使用 dropout)的大型模型比小型模型的概化效果更好。但是,在大型模型的帮助下,可以训练一个小模型以更好地进行概括。较小的模型在生产环境中会更好:更快、更少的计算、更少的内存。

训练模型的输出概率比标签提供的信息更多,因为它也将非零概率分配给不正确的类。这些概率告诉我们,样本有可能属于某些类别。例如,在对数字进行分类时,当给出数字 7 的图像时,广义模型将给出 7 的概率很高,为 2 提供一个小但非零的概率,同时为其他数字分配几乎为零的概率。蒸馏利用这些信息来更好地训练小型模型。

软目标

概率通常使用 softmax 运算进行计算,

其中是上课的概率是 logit。

我们对小模型进行训练,以最大限度地减少其输出概率分布与大型网络的输出概率分布(软目标)之间的交叉熵或 KL 差异。

这里的问题之一是,大型网络分配给错误类别的概率通常很小,不会造成损失。所以他们通过施加温度来软化概率

其中,的值越高,概率越低。

训练

论文建议在训练小模型时添加第二个损失项来预测实际标签。我们将综合损失计算为两个损失项的加权总和:软目标和实际标签。

蒸馏的数据集称为转移集,本文建议使用相同的训练数据。

我们的实验

我们在 CIFAR-10 数据集上进行训练。我们训练了一个大型模型,该模型参数带有 dropout,它在验证集上给出了 85% 的精度。具有参数的小型模型的精度为80%。

然后,我们使用大型模型的蒸馏来训练小模型,其精度为82%;精度提高了2%。

View Run

74import torch
75import torch.nn.functional
76from torch import nn
77
78from labml import experiment, tracker
79from labml.configs import option
80from labml_helpers.train_valid import BatchIndex
81from labml_nn.distillation.large import LargeModel
82from labml_nn.distillation.small import SmallModel
83from labml_nn.experiments.cifar10 import CIFAR10Configs

配置

从此扩展定义CIFAR10Configs 了所有与数据集相关的配置、优化器和训练循环。

86class Configs(CIFAR10Configs):

小模型

94    model: SmallModel

大型模型

96    large: LargeModel

软目标的 KL 分散损失

98    kl_div_loss = nn.KLDivLoss(log_target=True)

真实标签丢失的交叉熵损失

100    loss_func = nn.CrossEntropyLoss()

温度,

102    temperature: float = 5.

软目标损失的权重。

软目标产生的梯度会被缩放。为了弥补这一点,本文建议将软目标的损失缩小一倍

108    soft_targets_weight: float = 100.

真实标签交叉熵损失的权重

110    label_loss_weight: float = 0.5

培训/验证步骤

我们定义了一个定制的训练/验证步骤,包括蒸馏

112    def step(self, batch: any, batch_idx: BatchIndex):

小模型的训练/评估模式

120        self.model.train(self.mode.is_train)

评估模式中的大型模型

122        self.large.eval()

将数据移动到设备

125        data, target = batch[0].to(self.device), batch[1].to(self.device)

在训练模式下更新全局步长(处理的样本数)

128        if self.mode.is_train:
129            tracker.add_global_step(len(data))

从大型模型中获取输出 logit

132        with torch.no_grad():
133            large_logits = self.large(data)

从小型模型中获取输出 logits

136        output = self.model(data)

软目标

140        soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)

小模型的温度调整概率

143        soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)

计算软目标损失

146        soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)

计算真实标签丢失

148        label_loss = self.loss_func(output, target)

两次亏损的加权总和

150        loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss

记录损失

152        tracker.add({"loss.kl_div.": soft_targets_loss,
153                     "loss.nll": label_loss,
154                     "loss.": loss})

计算和记录精度

157        self.accuracy(output, target)
158        self.accuracy.track()

训练模型

161        if self.mode.is_train:

计算梯度

163            loss.backward()

采取优化器步骤

165            self.optimizer.step()

记录每个纪元最后一批的模型参数和梯度

167            if batch_idx.is_last:
168                tracker.add('model', self.model)

清除渐变

170            self.optimizer.zero_grad()

保存跟踪的指标

173        tracker.save()

创建大型模型

176@option(Configs.large)
177def _large_model(c: Configs):
181    return LargeModel().to(c.device)

创建小模型

184@option(Configs.model)
185def _small_student_model(c: Configs):
189    return SmallModel().to(c.device)
192def get_saved_model(run_uuid: str, checkpoint: int):
197    from labml_nn.distillation.large import Configs as LargeConfigs

在评估模式下(无录音)

200    experiment.evaluate()

初始化大型模型训练实验的配置

202    conf = LargeConfigs()

加载保存的配置

204    experiment.configs(conf, experiment.load_configs(run_uuid))

设置用于保存/加载的模型

206    experiment.add_pytorch_models({'model': conf.model})

设置要加载的运行和检查点

208    experiment.load(run_uuid, checkpoint)

开始实验-这将加载模型,并准备所有内容

210    experiment.start()

返回模型

213    return conf.model

使用蒸馏训练小型模型

216def main(run_uuid: str, checkpoint: int):

加载已保存的模型

221    large_model = get_saved_model(run_uuid, checkpoint)

创建实验

223    experiment.create(name='distillation', comment='cifar10')

创建配置

225    conf = Configs()

设置加载的大型模型

227    conf.large = large_model

装载配置

229    experiment.configs(conf, {
230        'optimizer.optimizer': 'Adam',
231        'optimizer.learning_rate': 2.5e-4,
232        'model': '_small_student_model',
233    })

设置保存/加载的模型

235    experiment.add_pytorch_models({'model': conf.model})

从头开始实验

237    experiment.load(None, None)

开始实验并运行训练循环

239    with experiment.start():
240        conf.run()

244if __name__ == '__main__':
245    main('d46cd53edaec11eb93c38d6538aee7d6', 1_000_000)