这是 PyTorch 对《量化分类不确定性的证据深度学习》一文的实现。
Dampster-Shafer 证据理论为信仰群众分配了一组类别(与将概率分配给单个类别不同)。所有子集的质量之和为。单个类别的概率(合理性)可以从这些质量中得出。
将@@质量分配给所有类的集合意味着它可以是任何一个类;即说 “我不知道”。
如果有类,我们会为每个类分配质量,为所有类分配总体不确定性质量。
信仰群众可以根据证据计算,随处可见。Paper 使用术语证据来衡量从数据中收集的支持量,以支持将样本分为特定类别。
这与带有参数的狄利克雷分布相对应,称为狄利克雷强度。狄利克雷分布是分类分布之上的分布;也就是说,你可以从狄利克雷分布中对类概率进行采样。上课的预期概率为。
我们得到模型来输出给定输入的证据。我们在最后一层使用诸如 RelU 或 Softplus 之类的函数来获取。
本文提出了一些损失函数来训练模型,我们在下面实现了这些函数。
以下是在 MNIST 数据集上训练模型的训练代码experiment.py
。
54import torch
55
56from labml import tracker
57from labml_helpers.module import Module60class MaximumLikelihoodLoss(Module):85    def forward(self, evidence: torch.Tensor, target: torch.Tensor):91        alpha = evidence + 1.93        strength = alpha.sum(dim=-1)亏损
96        loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)整个批次的平均损失
99        return loss.mean()贝叶斯风险是做出错误估算的总体最大成本。它采用一个成本函数,该函数给出了做出错误估计的成本,并根据概率分布将其与所有可能的结果相加。
这里的代价函数是交叉熵损失,用于一次热编码
我们整合了这个成本
函数在哪里。
102class CrossEntropyBayesRisk(Module):132    def forward(self, evidence: torch.Tensor, target: torch.Tensor):138        alpha = evidence + 1.140        strength = alpha.sum(dim=-1)亏损
143        loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)整个批次的平均损失
146        return loss.mean()149class SquaredErrorBayesRisk(Module):195    def forward(self, evidence: torch.Tensor, target: torch.Tensor):201        alpha = evidence + 1.203        strength = alpha.sum(dim=-1)205        p = alpha / strength[:, None]错误
208        err = (target - p) ** 2方差
210        var = p * (1 - p) / (strength[:, None] + 1)它们的总和
213        loss = (err + var).sum(dim=-1)整个批次的平均损失
216        return loss.mean()219class KLDivergenceLoss(Module):243    def forward(self, evidence: torch.Tensor, target: torch.Tensor):249        alpha = evidence + 1.班级数
251        n_classes = evidence.shape[-1]移除非误导性证据
254        alpha_tilde = target + (1 - target) * alpha256        strength_tilde = alpha_tilde.sum(dim=-1)267        first = (torch.lgamma(alpha_tilde.sum(dim=-1))
268                 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
269                 - (torch.lgamma(alpha_tilde)).sum(dim=-1))第二学期
274        second = (
275                (alpha_tilde - 1) *
276                (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
277        ).sum(dim=-1)条款总和
280        loss = first + second整个批次的平均损失
283        return loss.mean()294    def forward(self, evidence: torch.Tensor, target: torch.Tensor):班级数
296        n_classes = evidence.shape[-1]与目标正确匹配的预测(基于最高概率的贪婪抽样)
298        match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))轨道精度
300        tracker.add('accuracy.', match.sum() / match.shape[0])303        alpha = evidence + 1.305        strength = alpha.sum(dim=-1)308        expected_probability = alpha / strength[:, None]所选(贪婪的最高概率)类别的预期概率
310        expected_probability, _ = expected_probability.max(dim=-1)不确定性质量
313        uncertainty_mass = n_classes / strength追踪正确的预测
316        tracker.add('u.succ.', uncertainty_mass.masked_select(match))追踪错误的预测
318        tracker.add('u.fail.', uncertainty_mass.masked_select(~match))追踪正确的预测
320        tracker.add('prob.succ.', expected_probability.masked_select(match))追踪错误的预测
322        tracker.add('prob.fail.', expected_probability.masked_select(~match))