This is a PyTorch implementation of the paper Evidential Deep Learning to Quantify Classification Uncertainty.
Dampster-Shafer Theory of Evidence assigns belief masses a set of classes (unlike assigning a probability to a single class). Sum of the masses of all subsets is . Individual class probabilities (plausibilities) can be derived from these masses.
Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying "I don't know".
If there are classes, we assign masses to each of the classes and an overall uncertainty mass to all classes.
Belief masses and can be computed from evidence , as and where . Paper uses term evidence as a measure of the amount of support collected from data in favor of a sample to be classified into a certain class.
This corresponds to a Dirichlet distribution with parameters , and is known as the Dirichlet strength. Dirichlet distribution is a distribution over categorical distribution; i.e. you can sample class probabilities from a Dirichlet distribution. The expected probability for class is .
We get the model to output evidences for a given input . We use a function such as ReLU or a Softplus at the final layer to get .
The paper proposes a few loss functions to train the model, which we have implemented below.
Here is the training code experiment.py
to train a model on MNIST dataset.
54import torch
55
56from labml import tracker
57from labml_helpers.module import ModuleThe distribution is a prior on the likelihood , and the negative log marginal likelihood is calculated by integrating over class probabilities .
If target probabilities (one-hot targets) are for a given sample the loss is,
60class MaximumLikelihoodLoss(Module):evidence
is with shape [batch_size, n_classes]
target
is with shape [batch_size, n_classes]
85 def forward(self, evidence: torch.Tensor, target: torch.Tensor):91 alpha = evidence + 1.93 strength = alpha.sum(dim=-1)Losses
96 loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)Mean loss over the batch
99 return loss.mean()Bayes risk is the overall maximum cost of making incorrect estimates. It takes a cost function that gives the cost of making an incorrect estimate and sums it over all possible outcomes based on probability distribution.
Here the cost function is cross-entropy loss, for one-hot coded
We integrate this cost over all
where is the function.
102class CrossEntropyBayesRisk(Module):evidence
is with shape [batch_size, n_classes]
target
is with shape [batch_size, n_classes]
132 def forward(self, evidence: torch.Tensor, target: torch.Tensor):138 alpha = evidence + 1.140 strength = alpha.sum(dim=-1)Losses
143 loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)Mean loss over the batch
146 return loss.mean()Here the cost function is squared error,
We integrate this cost over all
Where is the expected probability when sampled from the Dirichlet distribution and where is the variance.
This gives, begin{align} mathcal{L}(Theta) &= sum_{k=1}^K Big( y_k^2 -2 y_k mathbb{E}p_k + mathbb{E}p_k^2 Big) \ &= sum_{k=1}^K Big( y_k^2 -2 y_k mathbb{E}p_k + mathbb{E}p_k^2 + text{Var}(p_k) Big) \ &= sum_{k=1}^K Big( big( y_k -mathbb{E}p_k big)^2 + text{Var}(p_k) Big) \ &= sum_{k=1}^K Big( ( y_k -hat{p}_k)^2 + frac{hat{p}_k(1 - hat{p}_k)}{S + 1} Big) end{align}
This first part of the equation is the error term and the second part is the variance.
149class SquaredErrorBayesRisk(Module):evidence
is with shape [batch_size, n_classes]
target
is with shape [batch_size, n_classes]
194 def forward(self, evidence: torch.Tensor, target: torch.Tensor):200 alpha = evidence + 1.202 strength = alpha.sum(dim=-1)204 p = alpha / strength[:, None]Error
207 err = (target - p) ** 2Variance
209 var = p * (1 - p) / (strength[:, None] + 1)Sum of them
212 loss = (err + var).sum(dim=-1)Mean loss over the batch
215 return loss.mean()This tries to shrink the total evidence to zero if the sample cannot be correctly classified.
First we calculate the Dirichlet parameters after remove the correct evidence.
where is the gamma function, is the function and
218class KLDivergenceLoss(Module):evidence
is with shape [batch_size, n_classes]
target
is with shape [batch_size, n_classes]
242 def forward(self, evidence: torch.Tensor, target: torch.Tensor):248 alpha = evidence + 1.Number of classes
250 n_classes = evidence.shape[-1]Remove non-misleading evidence
253 alpha_tilde = target + (1 - target) * alpha255 strength_tilde = alpha_tilde.sum(dim=-1)The first term begin{align} &log Bigg( frac{Gamma Big( sum_{k=1}^K tilde{alpha}_k Big)} {Gamma(K) prod_{k=1}^K Gamma(tilde{alpha}_k)} Bigg) \ &= log Gamma Big( sum_{k=1}^K tilde{alpha}_k Big) - log Gamma(K) - sum_{k=1}^K log Gamma(tilde{alpha}_k) end{align}
265 first = (torch.lgamma(alpha_tilde.sum(dim=-1))
266 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
267 - (torch.lgamma(alpha_tilde)).sum(dim=-1))The second term
272 second = (
273 (alpha_tilde - 1) *
274 (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
275 ).sum(dim=-1)Sum of the terms
278 loss = first + secondMean loss over the batch
281 return loss.mean()284class TrackStatistics(Module):292 def forward(self, evidence: torch.Tensor, target: torch.Tensor):Number of classes
294 n_classes = evidence.shape[-1]Predictions that correctly match with the target (greedy sampling based on highest probability)
296 match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))Track accuracy
298 tracker.add('accuracy.', match.sum() / match.shape[0])301 alpha = evidence + 1.303 strength = alpha.sum(dim=-1)306 expected_probability = alpha / strength[:, None]Expected probability of the selected (greedy highset probability) class
308 expected_probability, _ = expected_probability.max(dim=-1)Uncertainty mass
311 uncertainty_mass = n_classes / strengthTrack for correctly predictions
314 tracker.add('u.succ.', uncertainty_mass.masked_select(match))Track for incorrect predictions
316 tracker.add('u.fail.', uncertainty_mass.masked_select(~match))Track for correctly predictions
318 tracker.add('prob.succ.', expected_probability.masked_select(match))Track for incorrect predictions
320 tracker.add('prob.fail.', expected_probability.masked_select(~match))