This is a PyTorch implementation of the paper Evidential Deep Learning to Quantify Classification Uncertainty.
Dampster-Shafer Theory of Evidence assigns belief masses a set of classes (unlike assigning a probability to a single class). Sum of the masses of all subsets is $1$. Individual class probabilities (plausibilities) can be derived from these masses.
Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying “I don’t know”.
If there are $K$ classes, we assign masses $b_k \ge 0$ to each of the classes and an overall uncertainty mass $u \ge 0$ to all classes.
Belief masses $b_k$ and $u$ can be computed from evidence $e_k \ge 0$, as $b_k = \frac{e_k}{S}$ and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$. Paper uses term evidence as a measure of the amount of support collected from data in favor of a sample to be classified into a certain class.
This corresponds to a Dirichlet distribution with parameters $\color{orange}{\alpha_k} = e_k + 1$, and $\color{orange}{\alpha_0} = S = \sum_{k=1}^K \color{orange}{\alpha_k}$ is known as the Dirichlet strength. Dirichlet distribution $D(\mathbf{p} \vert \color{orange}{\mathbf{\alpha}})$ is a distribution over categorical distribution; i.e. you can sample class probabilities from a Dirichlet distribution. The expected probability for class $k$ is $\hat{p}_k = \frac{\color{orange}{\alpha_k}}{S}$.
We get the model to output evidences for a given input $\mathbf{x}$. We use a function such as ReLU or a Softplus at the final layer to get $f(\mathbf{x} | \Theta) \ge 0$.
The paper proposes a few loss functions to train the model, which we have implemented below.
Here is the training code experiment.py to train a model on MNIST dataset.
54import torch
55
56from labml import tracker
57from labml_helpers.module import ModuleThe distribution $D(\mathbf{p} \vert \color{orange}{\mathbf{\alpha}})$ is a prior on the likelihood $Multi(\mathbf{y} \vert p)$, and the negative log marginal likelihood is calculated by integrating over class probabilities $\mathbf{p}$.
If target probabilities (one-hot targets) are $y_k$ for a given sample the loss is,
60class MaximumLikelihoodLoss(Module):evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]target is $\mathbf{y}$ with shape [batch_size, n_classes]84 def forward(self, evidence: torch.Tensor, target: torch.Tensor):$\color{orange}{\alpha_k} = e_k + 1$
90 alpha = evidence + 1.$S = \sum_{k=1}^K \color{orange}{\alpha_k}$
92 strength = alpha.sum(dim=-1)Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{orange}{\alpha_k} \bigg)$
95 loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)Mean loss over the batch
98 return loss.mean()Bayes risk is the overall maximum cost of making incorrect estimates. It takes a cost function that gives the cost of making an incorrect estimate and sums it over all possible outcomes based on probability distribution.
Here the cost function is cross-entropy loss, for one-hot coded $\mathbf{y}$
We integrate this cost over all $\mathbf{p}$
where $\psi(\cdot)$ is the $digamma$ function.
101class CrossEntropyBayesRisk(Module):evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]target is $\mathbf{y}$ with shape [batch_size, n_classes]130 def forward(self, evidence: torch.Tensor, target: torch.Tensor):$\color{orange}{\alpha_k} = e_k + 1$
136 alpha = evidence + 1.$S = \sum_{k=1}^K \color{orange}{\alpha_k}$
138 strength = alpha.sum(dim=-1)Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{orange}{\alpha_k} ) \bigg)$
141 loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)Mean loss over the batch
144 return loss.mean()Here the cost function is squared error,
We integrate this cost over all $\mathbf{p}$
Where is the expected probability when sampled from the Dirichlet distribution and where is the variance.
This gives,
This first part of the equation $\big(y_k -\mathbb{E}[p_k]\big)^2$ is the error term and the second part is the variance.
147class SquaredErrorBayesRisk(Module):evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]target is $\mathbf{y}$ with shape [batch_size, n_classes]191 def forward(self, evidence: torch.Tensor, target: torch.Tensor):$\color{orange}{\alpha_k} = e_k + 1$
197 alpha = evidence + 1.$S = \sum_{k=1}^K \color{orange}{\alpha_k}$
199 strength = alpha.sum(dim=-1)$\hat{p}_k = \frac{\color{orange}{\alpha_k}}{S}$
201 p = alpha / strength[:, None]Error $(y_k -\hat{p}_k)^2$
204 err = (target - p) ** 2Variance $\text{Var}(p_k) = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$
206 var = p * (1 - p) / (strength[:, None] + 1)Sum of them
209 loss = (err + var).sum(dim=-1)Mean loss over the batch
212 return loss.mean()This tries to shrink the total evidence to zero if the sample cannot be correctly classified.
First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{orange}{\alpha_k}$ the Dirichlet parameters after remove the correct evidence.
where $\Gamma(\cdot)$ is the gamma function, $\psi(\cdot)$ is the $digamma$ function and $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
215class KLDivergenceLoss(Module):evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]target is $\mathbf{y}$ with shape [batch_size, n_classes]238 def forward(self, evidence: torch.Tensor, target: torch.Tensor):$\color{orange}{\alpha_k} = e_k + 1$
244 alpha = evidence + 1.Number of classes
246 n_classes = evidence.shape[-1]Remove non-misleading evidence
249 alpha_tilde = target + (1 - target) * alpha$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
251 strength_tilde = alpha_tilde.sum(dim=-1)The first term
261 first = (torch.lgamma(alpha_tilde.sum(dim=-1))
262 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
263 - (torch.lgamma(alpha_tilde)).sum(dim=-1))The second term
268 second = (
269 (alpha_tilde - 1) *
270 (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
271 ).sum(dim=-1)Sum of the terms
274 loss = first + secondMean loss over the batch
277 return loss.mean()280class TrackStatistics(Module):287 def forward(self, evidence: torch.Tensor, target: torch.Tensor):Number of classes
289 n_classes = evidence.shape[-1]Predictions that correctly match with the target (greedy sampling based on highest probability)
291 match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))Track accuracy
293 tracker.add('accuracy.', match.sum() / match.shape[0])$\color{orange}{\alpha_k} = e_k + 1$
296 alpha = evidence + 1.$S = \sum_{k=1}^K \color{orange}{\alpha_k}$
298 strength = alpha.sum(dim=-1)$\hat{p}_k = \frac{\color{orange}{\alpha_k}}{S}$
301 expected_probability = alpha / strength[:, None]Expected probability of the selected (greedy highset probability) class
303 expected_probability, _ = expected_probability.max(dim=-1)Uncertainty mass $u = \frac{K}{S}$
306 uncertainty_mass = n_classes / strengthTrack $u$ for correctly predictions
309 tracker.add('u.succ.', uncertainty_mass.masked_select(match))Track $u$ for incorrect predictions
311 tracker.add('u.fail.', uncertainty_mass.masked_select(~match))Track $\hat{p}_k$ for correctly predictions
313 tracker.add('prob.succ.', expected_probability.masked_select(match))Track $\hat{p}_k$ for incorrect predictions
315 tracker.add('prob.fail.', expected_probability.masked_select(~match))