这是论文《在神经网络中提炼知识》的 PyTorch 实现/教程。
这是一种使用经过训练的大型网络中的知识来训练小型网络的方法;即从大型网络中提取知识。
当@@直接在数据和标签上训练时,具有正则化或模型集合(使用 dropout)的大型模型比小型模型的概化效果更好。但是,在大型模型的帮助下,可以训练一个小模型以更好地进行概括。较小的模型在生产环境中会更好:更快、更少的计算、更少的内存。
训练模型的输出概率比标签提供的信息更多,因为它也将非零概率分配给不正确的类。这些概率告诉我们,样本有可能属于某些类别。例如,在对数字进行分类时,当给出数字 7 的图像时,广义模型将给出 7 的概率很高,为 2 提供一个小但非零的概率,同时为其他数字分配几乎为零的概率。蒸馏利用这些信息来更好地训练小型模型。
概率通常使用 softmax 运算进行计算,
其中是上课的概率,是 logit。
我们对小模型进行训练,以最大限度地减少其输出概率分布与大型网络的输出概率分布(软目标)之间的交叉熵或 KL 差异。
这里的问题之一是,大型网络分配给错误类别的概率通常很小,不会造成损失。所以他们通过施加温度来软化概率,
其中,的值越高,概率越低。
论文建议在训练小模型时添加第二个损失项来预测实际标签。我们将综合损失计算为两个损失项的加权总和:软目标和实际标签。
蒸馏的数据集称为转移集,本文建议使用相同的训练数据。
我们在 CIFAR-10 数据集上进行训练。我们训练了一个大型模型,该模型的参数带有 dropout,它在验证集上给出了 85% 的精度。具有参数的小型模型的精度为80%。
然后,我们使用大型模型的蒸馏来训练小模型,其精度为82%;精度提高了2%。
74import torch
75import torch.nn.functional
76from torch import nn
77
78from labml import experiment, tracker
79from labml.configs import option
80from labml_helpers.train_valid import BatchIndex
81from labml_nn.distillation.large import LargeModel
82from labml_nn.distillation.small import SmallModel
83from labml_nn.experiments.cifar10 import CIFAR10Configs86class Configs(CIFAR10Configs):小模型
94 model: SmallModel大型模型
96 large: LargeModel软目标的 KL 分散损失
98 kl_div_loss = nn.KLDivLoss(log_target=True)真实标签丢失的交叉熵损失
100 loss_func = nn.CrossEntropyLoss()温度,
102 temperature: float = 5.108 soft_targets_weight: float = 100.真实标签交叉熵损失的权重
110 label_loss_weight: float = 0.5112 def step(self, batch: any, batch_idx: BatchIndex):小模型的训练/评估模式
120 self.model.train(self.mode.is_train)评估模式中的大型模型
122 self.large.eval()将数据移动到设备
125 data, target = batch[0].to(self.device), batch[1].to(self.device)在训练模式下更新全局步长(处理的样本数)
128 if self.mode.is_train:
129 tracker.add_global_step(len(data))从大型模型中获取输出 logit、
132 with torch.no_grad():
133 large_logits = self.large(data)从小型模型中获取输出 logits、
136 output = self.model(data)软目标
140 soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)小模型的温度调整概率
143 soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)计算软目标损失
146 soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)计算真实标签丢失
148 label_loss = self.loss_func(output, target)两次亏损的加权总和
150 loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss记录损失
152 tracker.add({"loss.kl_div.": soft_targets_loss,
153 "loss.nll": label_loss,
154 "loss.": loss})计算和记录精度
157 self.accuracy(output, target)
158 self.accuracy.track()训练模型
161 if self.mode.is_train:计算梯度
163 loss.backward()采取优化器步骤
165 self.optimizer.step()记录每个纪元最后一批的模型参数和梯度
167 if batch_idx.is_last:
168 tracker.add('model', self.model)清除渐变
170 self.optimizer.zero_grad()保存跟踪的指标
173 tracker.save()176@option(Configs.large)
177def _large_model(c: Configs):181 return LargeModel().to(c.device)184@option(Configs.model)
185def _small_student_model(c: Configs):189 return SmallModel().to(c.device)197 from labml_nn.distillation.large import Configs as LargeConfigs在评估模式下(无录音)
200 experiment.evaluate()初始化大型模型训练实验的配置
202 conf = LargeConfigs()加载保存的配置
204 experiment.configs(conf, experiment.load_configs(run_uuid))设置用于保存/加载的模型
206 experiment.add_pytorch_models({'model': conf.model})设置要加载的运行和检查点
208 experiment.load(run_uuid, checkpoint)开始实验-这将加载模型,并准备所有内容
210 experiment.start()返回模型
213 return conf.model使用蒸馏训练小型模型
216def main(run_uuid: str, checkpoint: int):加载已保存的模型
221 large_model = get_saved_model(run_uuid, checkpoint)创建实验
223 experiment.create(name='distillation', comment='cifar10')创建配置
225 conf = Configs()设置加载的大型模型
227 conf.large = large_model装载配置
229 experiment.configs(conf, {
230 'optimizer.optimizer': 'Adam',
231 'optimizer.learning_rate': 2.5e-4,
232 'model': '_small_student_model',
233 })设置保存/加载的模型
235 experiment.add_pytorch_models({'model': conf.model})从头开始实验
237 experiment.load(None, None)开始实验并运行训练循环
239 with experiment.start():
240 conf.run()244if __name__ == '__main__':
245 main('d46cd53edaec11eb93c38d6538aee7d6', 1_000_000)