mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 17:57:14 +08:00
Evidential Deep Learning to Quantify Classification Uncertainty (#85)
This commit is contained in:
@ -94,6 +94,10 @@ Solving games with incomplete information such as poker with CFR.
|
||||
|
||||
* [PonderNet](adaptive_computation/ponder_net/index.html)
|
||||
|
||||
#### ✨ [Uncertainty](uncertainty/index.html)
|
||||
|
||||
* [Evidential Deep Learning to Quantify Classification Uncertainty](uncertainty/evidence/index.html)
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
@ -102,12 +106,12 @@ pip install labml-nn
|
||||
|
||||
### Citing LabML
|
||||
|
||||
If you use LabML for academic research, please cite the library using the following BibTeX entry.
|
||||
If you use this for academic research, please cite it using the following BibTeX entry.
|
||||
|
||||
```bibtex
|
||||
@misc{labml,
|
||||
author = {Varuna Jayasiri, Nipun Wijerathne},
|
||||
title = {LabML: A library to organize machine learning experiments},
|
||||
title = {labml.ai Annotated Paper Implementations},
|
||||
year = {2020},
|
||||
url = {https://nn.labml.ai/},
|
||||
}
|
||||
|
||||
@ -72,7 +72,10 @@ def main():
|
||||
# Create configurations
|
||||
conf = MNISTConfigs()
|
||||
# Load configurations
|
||||
experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
|
||||
experiment.configs(conf, {
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 0.001,
|
||||
})
|
||||
# Start the experiment and run the training loop
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
||||
13
labml_nn/uncertainty/__init__.py
Normal file
13
labml_nn/uncertainty/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""
|
||||
---
|
||||
title: Neural Networks with Uncertainty Estimation
|
||||
summary: >
|
||||
A set of PyTorch implementations/tutorials related to uncertainty estimation
|
||||
---
|
||||
|
||||
# Neural Networks with Uncertainty Estimation
|
||||
|
||||
These are neural network architectures that estimate the uncertainty of the predictions.
|
||||
|
||||
* [Evidential Deep Learning to Quantify Classification Uncertainty](evidence/index.html)
|
||||
"""
|
||||
315
labml_nn/uncertainty/evidence/__init__.py
Normal file
315
labml_nn/uncertainty/evidence/__init__.py
Normal file
@ -0,0 +1,315 @@
|
||||
"""
|
||||
---
|
||||
title: "Evidential Deep Learning to Quantify Classification Uncertainty"
|
||||
summary: >
|
||||
A PyTorch implementation/tutorial of the paper Evidential Deep Learning to Quantify Classification
|
||||
Uncertainty.
|
||||
---
|
||||
|
||||
# Evidential Deep Learning to Quantify Classification Uncertainty
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of the paper
|
||||
[Evidential Deep Learning to Quantify Classification Uncertainty](https://papers.labml.ai/paper/1806.01768).
|
||||
|
||||
[Dampster-Shafer Theory of Evidence](https://en.wikipedia.org/wiki/Dempster%E2%80%93Shafer_theory)
|
||||
assigns belief masses a set of classes (unlike assigning a probability to a single class).
|
||||
Sum of the masses of all subsets is $1$.
|
||||
Individual class probabilities (plausibilities) can be derived from these masses.
|
||||
|
||||
Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying "I don't know".
|
||||
|
||||
If there are $K$ classes, we assign masses $b_k \ge 0$ to each of the classes and
|
||||
an overall uncertainty mass $u \ge 0$ to all classes.
|
||||
|
||||
$$u + \sum_{k=1}^K b_k = 1$$
|
||||
|
||||
Belief masses $b_k$ and $u$ can be computed from evidence $e_k \ge 0$, as $b_k = \frac{e_k}{S}$
|
||||
and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$.
|
||||
Paper uses term evidence as a measure of the amount of support
|
||||
collected from data in favor of a sample to be classified into a certain class.
|
||||
|
||||
This corresponds to a [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution)
|
||||
with parameters $\color{cyan}{\alpha_k} = e_k + 1$, and
|
||||
$\color{cyan}{\alpha_0} = S = \sum_{k=1}^K \color{cyan}{\alpha_k}$ is known as the Dirichlet strength.
|
||||
Dirichlet distribution $D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}})$
|
||||
is a distribution over categorical distribution; i.e. you can sample class probabilities
|
||||
from a Dirichlet distribution.
|
||||
The expected probability for class $k$ is $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$.
|
||||
|
||||
We get the model to output evidences
|
||||
$$\mathbf{e} = \color{cyan}{\mathbf{\alpha}} - 1 = f(\mathbf{x} | \Theta)$$
|
||||
for a given input $\mathbf{x}$.
|
||||
We use a function such as
|
||||
[ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) or a
|
||||
[Softplus](https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html)
|
||||
at the final layer to get $f(\mathbf{x} | \Theta) \ge 0$.
|
||||
|
||||
The paper proposes a few loss functions to train the model, which we have implemented below.
|
||||
|
||||
Here is the [training code `experiment.py`](experiment.html) to train a model on MNIST dataset.
|
||||
|
||||
[](https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106)
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from labml import tracker
|
||||
from labml_helpers.module import Module
|
||||
|
||||
|
||||
class MaximumLikelihoodLoss(Module):
|
||||
"""
|
||||
<a id="MaximumLikelihoodLoss"></a>
|
||||
## Type II Maximum Likelihood Loss
|
||||
|
||||
The distribution D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}}) is a prior on the likelihood
|
||||
$Multi(\mathbf{y} \vert p)$,
|
||||
and the negative log marginal likelihood is calculated by integrating over class probabilities
|
||||
$\mathbf{p}$.
|
||||
|
||||
If target probabilities (one-hot targets) are $y_k$ for a given sample the loss is,
|
||||
|
||||
\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= -\log \Bigg(
|
||||
\int
|
||||
\prod_{k=1}^K p_k^{y_k}
|
||||
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
|
||||
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
|
||||
d\mathbf{p}
|
||||
\Bigg ) \\
|
||||
&= \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)
|
||||
\end{align}
|
||||
"""
|
||||
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
|
||||
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
|
||||
"""
|
||||
# $\color{cyan}{\alpha_k} = e_k + 1$
|
||||
alpha = evidence + 1.
|
||||
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
|
||||
strength = alpha.sum(dim=-1)
|
||||
|
||||
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)$
|
||||
loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)
|
||||
|
||||
# Mean loss over the batch
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class CrossEntropyBayesRisk(Module):
|
||||
"""
|
||||
<a id="CrossEntropyBayesRisk"></a>
|
||||
## Bayes Risk with Cross Entropy Loss
|
||||
|
||||
Bayes risk is the overall maximum cost of making incorrect estimates.
|
||||
It takes a cost function that gives the cost of making an incorrect estimate
|
||||
and sums it over all possible outcomes based on probability distribution.
|
||||
|
||||
Here the cost function is cross-entropy loss, for one-hot coded $\mathbf{y}$
|
||||
$$\sum_{k=1}^K -y_k \log p_k$$
|
||||
|
||||
We integrate this cost over all $\mathbf{p}$
|
||||
|
||||
\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= -\log \Bigg(
|
||||
\int
|
||||
\Big[ \sum_{k=1}^K -y_k \log p_k \Big]
|
||||
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
|
||||
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
|
||||
d\mathbf{p}
|
||||
\Bigg ) \\
|
||||
&= \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)
|
||||
\end{align}
|
||||
|
||||
where $\psi(\cdot)$ is the $digamma$ function.
|
||||
"""
|
||||
|
||||
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
|
||||
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
|
||||
"""
|
||||
# $\color{cyan}{\alpha_k} = e_k + 1$
|
||||
alpha = evidence + 1.
|
||||
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
|
||||
strength = alpha.sum(dim=-1)
|
||||
|
||||
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)$
|
||||
loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)
|
||||
|
||||
# Mean loss over the batch
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class SquaredErrorBayesRisk(Module):
|
||||
"""
|
||||
<a id="SquaredErrorBayesRisk"></a>
|
||||
## Bayes Risk with Squared Error Loss
|
||||
|
||||
Here the cost function is squared error,
|
||||
$$\sum_{k=1}^K (y_k - p_k)^2 = \Vert \mathbf{y} - \mathbf{p} \Vert_2^2$$
|
||||
|
||||
We integrate this cost over all $\mathbf{p}$
|
||||
|
||||
\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= -\log \Bigg(
|
||||
\int
|
||||
\Big[ \sum_{k=1}^K (y_k - p_k)^2 \Big]
|
||||
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
|
||||
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
|
||||
d\mathbf{p}
|
||||
\Bigg ) \\
|
||||
&= \sum_{k=1}^K \mathbb{E} \Big[ y_k^2 -2 y_k p_k + p_k^2 \Big] \\
|
||||
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big)
|
||||
\end{align}
|
||||
|
||||
Where $$\mathbb{E}[p_k] = \hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$$
|
||||
is the expected probability when sampled from the Dirichlet distribution
|
||||
and $$\mathbb{E}[p_k^2] = \mathbb{E}[p_k]^2 + \text{Var}(p_k)$$
|
||||
where
|
||||
$$\text{Var}(p_k) = \frac{\color{cyan}{\alpha_k}(S - \color{cyan}{\alpha_k})}{S^2 (S + 1)}
|
||||
= \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$$
|
||||
is the variance.
|
||||
|
||||
This gives,
|
||||
\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big) \\
|
||||
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k]^2 + \text{Var}(p_k) \Big) \\
|
||||
&= \sum_{k=1}^K \Big( \big( y_k -\mathbb{E}[p_k] \big)^2 + \text{Var}(p_k) \Big) \\
|
||||
&= \sum_{k=1}^K \Big( ( y_k -\hat{p}_k)^2 + \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1} \Big)
|
||||
\end{align}
|
||||
|
||||
This first part of the equation $\big(y_k -\mathbb{E}[p_k]\big)^2$ is the error term and
|
||||
the second part is the variance.
|
||||
"""
|
||||
|
||||
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
|
||||
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
|
||||
"""
|
||||
# $\color{cyan}{\alpha_k} = e_k + 1$
|
||||
alpha = evidence + 1.
|
||||
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
|
||||
strength = alpha.sum(dim=-1)
|
||||
# $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
|
||||
p = alpha / strength[:, None]
|
||||
|
||||
# Error $(y_k -\hat{p}_k)^2$
|
||||
err = (target - p) ** 2
|
||||
# Variance $\text{Var}(p_k) = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$
|
||||
var = p * (1 - p) / (strength[:, None] + 1)
|
||||
|
||||
# Sum of them
|
||||
loss = (err + var).sum(dim=-1)
|
||||
|
||||
# Mean loss over the batch
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class KLDivergenceLoss(Module):
|
||||
"""
|
||||
<a id="KLDivergenceLoss"></a>
|
||||
## KL Divergence Regularization Loss
|
||||
|
||||
This tries to shrink the total evidence to zero if the sample cannot be correctly classified.
|
||||
|
||||
First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$ the
|
||||
Dirichlet parameters after remove the correct evidence.
|
||||
|
||||
\begin{align}
|
||||
&KL \Big[ D(\mathbf{p} \vert \mathbf{\tilde{\alpha}}) \Big \Vert
|
||||
D(\mathbf{p} \vert <1, \dots, 1>\Big] \\
|
||||
&= \log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
|
||||
{\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg)
|
||||
+ \sum_{k=1}^K (\tilde{\alpha}_k - 1)
|
||||
\Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]
|
||||
\end{align}
|
||||
|
||||
where $\Gamma(\cdot)$ is the gamma function,
|
||||
$\psi(\cdot)$ is the $digamma$ function and
|
||||
$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
|
||||
"""
|
||||
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
|
||||
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
|
||||
"""
|
||||
# $\color{cyan}{\alpha_k} = e_k + 1$
|
||||
alpha = evidence + 1.
|
||||
# Number of classes
|
||||
n_classes = evidence.shape[-1]
|
||||
# Remove non-misleading evidence
|
||||
# $$\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$$
|
||||
alpha_tilde = target + (1 - target) * alpha
|
||||
# $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
|
||||
strength_tilde = alpha_tilde.sum(dim=-1)
|
||||
|
||||
# The first term
|
||||
# \begin{align}
|
||||
# &\log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
|
||||
# {\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg) \\
|
||||
# &= \log \Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)
|
||||
# - \log \Gamma(K)
|
||||
# - \sum_{k=1}^K \log \Gamma(\tilde{\alpha}_k)
|
||||
# \end{align}
|
||||
first = (torch.lgamma(alpha_tilde.sum(dim=-1))
|
||||
- torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
|
||||
- (torch.lgamma(alpha_tilde)).sum(dim=-1))
|
||||
|
||||
# The second term
|
||||
# $$\sum_{k=1}^K (\tilde{\alpha}_k - 1)
|
||||
# \Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]$$
|
||||
second = (
|
||||
(alpha_tilde - 1) *
|
||||
(torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
|
||||
).sum(dim=-1)
|
||||
|
||||
# Sum of the terms
|
||||
loss = first + second
|
||||
|
||||
# Mean loss over the batch
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class TrackStatistics(Module):
|
||||
"""
|
||||
<a id="TrackStatistics"></a>
|
||||
### Track statistics
|
||||
|
||||
This module computes statistics and tracks them with [labml `tracker`](https://docs.labml.ai/api/tracker.html).
|
||||
"""
|
||||
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
|
||||
# Number of classes
|
||||
n_classes = evidence.shape[-1]
|
||||
# Predictions that correctly match with the target (greedy sampling based on highest probability)
|
||||
match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))
|
||||
# Track accuracy
|
||||
tracker.add('accuracy.', match.sum() / match.shape[0])
|
||||
|
||||
# $\color{cyan}{\alpha_k} = e_k + 1$
|
||||
alpha = evidence + 1.
|
||||
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
|
||||
strength = alpha.sum(dim=-1)
|
||||
|
||||
# $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
|
||||
expected_probability = alpha / strength[:, None]
|
||||
# Expected probability of the selected (greedy highset probability) class
|
||||
expected_probability, _ = expected_probability.max(dim=-1)
|
||||
|
||||
# Uncertainty mass $u = \frac{K}{S}$
|
||||
uncertainty_mass = n_classes / strength
|
||||
|
||||
# Track $u$ for correctly predictions
|
||||
tracker.add('u.succ.', uncertainty_mass.masked_select(match))
|
||||
# Track $u$ for incorrect predictions
|
||||
tracker.add('u.fail.', uncertainty_mass.masked_select(~match))
|
||||
# Track $\hat{p}_k$ for correctly predictions
|
||||
tracker.add('prob.succ.', expected_probability.masked_select(match))
|
||||
# Track $\hat{p}_k$ for incorrect predictions
|
||||
tracker.add('prob.fail.', expected_probability.masked_select(~match))
|
||||
225
labml_nn/uncertainty/evidence/experiment.py
Normal file
225
labml_nn/uncertainty/evidence/experiment.py
Normal file
@ -0,0 +1,225 @@
|
||||
"""
|
||||
---
|
||||
title: "Evidential Deep Learning to Quantify Classification Uncertainty Experiment"
|
||||
summary: >
|
||||
This trains is EDL model on MNIST
|
||||
---
|
||||
|
||||
# [Evidential Deep Learning to Quantify Classification Uncertainty](index.html) Experiment
|
||||
|
||||
This trains a model based on [Evidential Deep Learning to Quantify Classification Uncertainty](index.html)
|
||||
on MNIST dataset.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
|
||||
from labml import tracker, experiment
|
||||
from labml.configs import option, calculate
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.schedule import Schedule, RelativePiecewise
|
||||
from labml_helpers.train_valid import BatchIndex
|
||||
from labml_nn.experiments.mnist import MNISTConfigs
|
||||
from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
|
||||
CrossEntropyBayesRisk, SquaredErrorBayesRisk
|
||||
|
||||
|
||||
class Model(Module):
|
||||
"""
|
||||
## LeNet based model fro MNIST classification
|
||||
"""
|
||||
|
||||
def __init__(self, dropout: float):
|
||||
super().__init__()
|
||||
# First $5x5$ convolution layer
|
||||
self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
|
||||
# ReLU activation
|
||||
self.act1 = nn.ReLU()
|
||||
# $2x2$ max-pooling
|
||||
self.max_pool1 = nn.MaxPool2d(2, 2)
|
||||
# Second $5x5$ convolution layer
|
||||
self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
|
||||
# ReLU activation
|
||||
self.act2 = nn.ReLU()
|
||||
# $2x2$ max-pooling
|
||||
self.max_pool2 = nn.MaxPool2d(2, 2)
|
||||
# First fully-connected layer that maps to $500$ features
|
||||
self.fc1 = nn.Linear(50 * 4 * 4, 500)
|
||||
# ReLU activation
|
||||
self.act3 = nn.ReLU()
|
||||
# Final fully connected layer to output evidence for $10$ classes.
|
||||
# The ReLU or Softplus activation is applied to this outside the model to get the
|
||||
# non-negative evidence
|
||||
self.fc2 = nn.Linear(500, 10)
|
||||
# Dropout for the hidden layer
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the batch of MNIST images of shape `[batch_size, 1, 28, 28]`
|
||||
"""
|
||||
# Apply first convolution and max pooling.
|
||||
# The result has shape `[batch_size, 20, 12, 12]`
|
||||
x = self.max_pool1(self.act1(self.conv1(x)))
|
||||
# Apply second convolution and max pooling.
|
||||
# The result has shape `[batch_size, 50, 4, 4]`
|
||||
x = self.max_pool2(self.act2(self.conv2(x)))
|
||||
# Flatten the tensor to shape `[batch_size, 50 * 4 * 4]`
|
||||
x = x.view(x.shape[0], -1)
|
||||
# Apply hidden layer
|
||||
x = self.act3(self.fc1(x))
|
||||
# Apply dropout
|
||||
x = self.dropout(x)
|
||||
# Apply final layer and return
|
||||
return self.fc2(x)
|
||||
|
||||
|
||||
class Configs(MNISTConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
We use [`MNISTConfigs`](../../experiments/mnist.html#MNISTConfigs) configurations.
|
||||
"""
|
||||
|
||||
# [KL Divergence regularization](index.html#KLDivergenceLoss)
|
||||
kl_div_loss = KLDivergenceLoss()
|
||||
# KL Divergence regularization coefficient schedule
|
||||
kl_div_coef: Schedule
|
||||
# KL Divergence regularization coefficient schedule
|
||||
kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]
|
||||
# [Stats module](index.html#TrackStatistics) for tracking
|
||||
stats = TrackStatistics()
|
||||
# Dropout
|
||||
dropout: float = 0.5
|
||||
# Module to convert the model output to non-zero evidences
|
||||
outputs_to_evidence: Module
|
||||
|
||||
def init(self):
|
||||
"""
|
||||
### Initialization
|
||||
"""
|
||||
# Set tracker configurations
|
||||
tracker.set_scalar("loss.*", True)
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
tracker.set_histogram('u.*', True)
|
||||
tracker.set_histogram('prob.*', False)
|
||||
tracker.set_scalar('annealing_coef.*', False)
|
||||
tracker.set_scalar('kl_div_loss.*', False)
|
||||
|
||||
#
|
||||
self.state_modules = []
|
||||
|
||||
def step(self, batch: Any, batch_idx: BatchIndex):
|
||||
"""
|
||||
### Training or validation step
|
||||
"""
|
||||
|
||||
# Training/Evaluation mode
|
||||
self.model.train(self.mode.is_train)
|
||||
|
||||
# Move data to the device
|
||||
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||
|
||||
# One-hot coded targets
|
||||
eye = torch.eye(10).to(torch.float).to(self.device)
|
||||
target = eye[target]
|
||||
|
||||
# Update global step (number of samples processed) when in training mode
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
# Get model outputs
|
||||
outputs = self.model(data)
|
||||
# Get evidences $e_k \ge 0$
|
||||
evidence = self.outputs_to_evidence(outputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = self.loss_func(evidence, target)
|
||||
# Calculate KL Divergence regularization loss
|
||||
kl_div_loss = self.kl_div_loss(evidence, target)
|
||||
tracker.add("loss.", loss)
|
||||
tracker.add("kl_div_loss.", kl_div_loss)
|
||||
|
||||
# KL Divergence loss coefficient $\lambda_t$
|
||||
annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
|
||||
tracker.add("annealing_coef.", annealing_coef)
|
||||
|
||||
# Total loss
|
||||
loss = loss + annealing_coef * kl_div_loss
|
||||
|
||||
# Track statistics
|
||||
self.stats(evidence, target)
|
||||
|
||||
# Train the model
|
||||
if self.mode.is_train:
|
||||
# Calculate gradients
|
||||
loss.backward()
|
||||
# Take optimizer step
|
||||
self.optimizer.step()
|
||||
# Clear the gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Save the tracked metrics
|
||||
tracker.save()
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
def mnist_model(c: Configs):
|
||||
"""
|
||||
### Create model
|
||||
"""
|
||||
return Model(c.dropout).to(c.device)
|
||||
|
||||
|
||||
@option(Configs.kl_div_coef)
|
||||
def kl_div_coef(c: Configs):
|
||||
"""
|
||||
### KL Divergence Loss Coefficient Schedule
|
||||
"""
|
||||
|
||||
# Create a [relative piecewise schedule](https://docs.labml.ai/api/helpers.html#labml_helpers.schedule.Piecewise)
|
||||
return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
|
||||
|
||||
|
||||
# [Maximum Likelihood Loss](index.html#MaximumLikelihoodLoss)
|
||||
calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())
|
||||
# [Cross Entropy Bayes Risk](index.html#CrossEntropyBayesRisk)
|
||||
calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
|
||||
# [Squared Error Bayes Risk](index.html#SquaredErrorBayesRisk)
|
||||
calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())
|
||||
|
||||
# ReLU to calculate evidence
|
||||
calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())
|
||||
# Softplus to calculate evidence
|
||||
calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name='evidence_mnist')
|
||||
# Create configurations
|
||||
conf = Configs()
|
||||
# Load configurations
|
||||
experiment.configs(conf, {
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 0.001,
|
||||
'optimizer.weight_decay': 0.005,
|
||||
|
||||
# 'loss_func': 'max_likelihood_loss',
|
||||
# 'loss_func': 'cross_entropy_bayes_risk',
|
||||
'loss_func': 'squared_error_bayes_risk',
|
||||
|
||||
'outputs_to_evidence': 'softplus',
|
||||
|
||||
'dropout': 0.5,
|
||||
})
|
||||
# Start the experiment and run the training loop
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
8
labml_nn/uncertainty/evidence/readme.md
Normal file
8
labml_nn/uncertainty/evidence/readme.md
Normal file
@ -0,0 +1,8 @@
|
||||
# [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of the paper
|
||||
[Evidential Deep Learning to Quantify Classification Uncertainty](https://papers.labml.ai/paper/1806.01768).
|
||||
|
||||
Here is the [training code `experiment.py`](https://nn.labml.ai/uncertainty/evidence/experiment.html) to train a model on MNIST dataset.
|
||||
|
||||
[](https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106)
|
||||
5
labml_nn/uncertainty/readme.md
Normal file
5
labml_nn/uncertainty/readme.md
Normal file
@ -0,0 +1,5 @@
|
||||
# [Neural Networks with Uncertainty Estimation](https://nn.labml.ai/uncertainty/index.html)
|
||||
|
||||
These are neural network architectures that estimate the uncertainty of the predictions.
|
||||
|
||||
* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
|
||||
Reference in New Issue
Block a user