用于量化分类不确定性的证据性深度学习

这是 PyTorch 对《量化分类不确定性的证据深度学习》一文的实现。

Dampster-Shafer 证据理论为信仰群众分配了一组类别(与将概率分配给单个类别不同)。所有子集的质量之和为。单个类别的概率(合理性)可以从这些质量中得出。

将@@

质量分配给所有类的集合意味着它可以是任何一个类;即说 “我不知道”。

如果有类,我们会为每个类分配质量,为所有类分配总体不确定性质量

信仰群众可以根据证据计算,随可见。Paper 使用术语证据来衡量从数据中收集的支持量,以支持将样本分为特定类别。

这与带有参数的狄利克雷分布相对应称为狄利克雷强度。狄利克雷分布是分类分布之上的分布;也就是说,你可以从狄利克雷分布中对类概率进行采样。上课的预期概率

我们得到模型来输出给定输入的证据。我们在最后一层使用诸如 RelUSoftplus 之类的函数来获取

本文提出了一些损失函数来训练模型,我们在下面实现了这些函数。

以下是在 MNIST 数据集上训练模型的训练代码experiment.py

View Run

54import torch
55
56from labml import tracker
57from labml_helpers.module import Module

类型 II 最大似然损失

分布是似然的先验,负对数边际似然是通过积分类概率来计算的

如果目标概率(一热目标)是针对给定样本的,则损失为,

60class MaximumLikelihoodLoss(Module):
  • evidence有形状的[batch_size, n_classes]
  • target有形状的[batch_size, n_classes]
  • 85    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

    91        alpha = evidence + 1.

    93        strength = alpha.sum(dim=-1)

    亏损

    96        loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)

    整个批次的平均损失

    99        return loss.mean()

    交叉熵损失的贝叶斯风险

    贝叶斯风险是做出错误估算的总体最大成本。它采用一个成本函数,该函数给出了做出错误估计的成本,并根据概率分布将其与所有可能的结果相加。

    这里的代价函数是交叉熵损失,用于一次热编码

    我们整合了这个成本

    函数在哪里。

    102class CrossEntropyBayesRisk(Module):
    • evidence有形状的[batch_size, n_classes]
  • target有形状的[batch_size, n_classes]
  • 132    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

    138        alpha = evidence + 1.

    140        strength = alpha.sum(dim=-1)

    亏损

    143        loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)

    整个批次的平均损失

    146        return loss.mean()

    误差损失平方时的贝叶斯风险

    这里的成本函数是平方误差,

    我们整合了这个成本

    从狄利克雷分布采样时的预期概率在哪里,方差哪里。

    这给了,

    方程的第一部分是误差项,第二部分是方差。

    149class SquaredErrorBayesRisk(Module):
    • evidence有形状的[batch_size, n_classes]
  • target有形状的[batch_size, n_classes]
  • 195    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

    201        alpha = evidence + 1.

    203        strength = alpha.sum(dim=-1)

    205        p = alpha / strength[:, None]

    错误

    208        err = (target - p) ** 2

    方差

    210        var = p * (1 - p) / (strength[:, None] + 1)

    它们的总和

    213        loss = (err + var).sum(dim=-1)

    整个批次的平均损失

    216        return loss.mean()

    KL 背离正则化损失

    如果样本无法正确分类,这会试图将总证据缩小为零。

    首先,我们在移除正确的证据后计算狄利克雷参数。

    其中是 gamma 函数,函数和

    219class KLDivergenceLoss(Module):
    • evidence有形状的[batch_size, n_classes]
  • target有形状的[batch_size, n_classes]
  • 243    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

    249        alpha = evidence + 1.

    班级数

    251        n_classes = evidence.shape[-1]

    移除非误导性证据

    254        alpha_tilde = target + (1 - target) * alpha

    256        strength_tilde = alpha_tilde.sum(dim=-1)

    第一学期

    267        first = (torch.lgamma(alpha_tilde.sum(dim=-1))
    268                 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
    269                 - (torch.lgamma(alpha_tilde)).sum(dim=-1))

    第二学期

    274        second = (
    275                (alpha_tilde - 1) *
    276                (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
    277        ).sum(dim=-1)

    条款总和

    280        loss = first + second

    整个批次的平均损失

    283        return loss.mean()

    追踪统计数据

    该模块计算统计数据并使用 labml 对其进行跟踪tracker

    286class TrackStatistics(Module):
    294    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

    班级数

    296        n_classes = evidence.shape[-1]

    与目标正确匹配的预测(基于最高概率的贪婪抽样)

    298        match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))

    轨道精度

    300        tracker.add('accuracy.', match.sum() / match.shape[0])

    303        alpha = evidence + 1.

    305        strength = alpha.sum(dim=-1)

    308        expected_probability = alpha / strength[:, None]

    所选(贪婪的最高概率)类别的预期概率

    310        expected_probability, _ = expected_probability.max(dim=-1)

    不确定性质量

    313        uncertainty_mass = n_classes / strength

    追踪正确的预测

    316        tracker.add('u.succ.', uncertainty_mass.masked_select(match))

    追踪错误的预测

    318        tracker.add('u.fail.', uncertainty_mass.masked_select(~match))

    追踪正确的预测

    320        tracker.add('prob.succ.', expected_probability.masked_select(match))

    追踪错误的预测

    322        tracker.add('prob.fail.', expected_probability.masked_select(~match))