This trains a model based on Evidential Deep Learning to Quantify Classification Uncertainty on MNIST dataset.
14from typing import Any
15
16import torch.nn as nn
17import torch.utils.data
18
19from labml import tracker, experiment
20from labml.configs import option, calculate
21from labml_helpers.module import Module
22from labml_helpers.schedule import Schedule, RelativePiecewise
23from labml_helpers.train_valid import BatchIndex
24from labml_nn.experiments.mnist import MNISTConfigs
25from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
26    CrossEntropyBayesRisk, SquaredErrorBayesRisk29class Model(Module):34    def __init__(self, dropout: float):
35        super().__init__()First convolution layer
37        self.conv1 = nn.Conv2d(1, 20, kernel_size=5)ReLU activation
39        self.act1 = nn.ReLU()max-pooling
41        self.max_pool1 = nn.MaxPool2d(2, 2)Second convolution layer
43        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)ReLU activation
45        self.act2 = nn.ReLU()max-pooling
47        self.max_pool2 = nn.MaxPool2d(2, 2)First fully-connected layer that maps to features
49        self.fc1 = nn.Linear(50 * 4 * 4, 500)ReLU activation
51        self.act3 = nn.ReLU()Final fully connected layer to output evidence for classes. The ReLU or Softplus activation is applied to this outside the model to get the non-negative evidence
55        self.fc2 = nn.Linear(500, 10)Dropout for the hidden layer
57        self.dropout = nn.Dropout(p=dropout)x
 is the batch of MNIST images of shape [batch_size, 1, 28, 28]
59    def __call__(self, x: torch.Tensor):Apply first convolution and max pooling. The result has shape [batch_size, 20, 12, 12]
 
65        x = self.max_pool1(self.act1(self.conv1(x)))Apply second convolution and max pooling. The result has shape [batch_size, 50, 4, 4]
 
68        x = self.max_pool2(self.act2(self.conv2(x)))Flatten the tensor to shape [batch_size, 50 * 4 * 4]
 
70        x = x.view(x.shape[0], -1)Apply hidden layer
72        x = self.act3(self.fc1(x))Apply dropout
74        x = self.dropout(x)Apply final layer and return
76        return self.fc2(x)79class Configs(MNISTConfigs):87    kl_div_loss = KLDivergenceLoss()KL Divergence regularization coefficient schedule
89    kl_div_coef: ScheduleKL Divergence regularization coefficient schedule
91    kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]Stats module for tracking
93    stats = TrackStatistics()Dropout
95    dropout: float = 0.5Module to convert the model output to non-zero evidences
97    outputs_to_evidence: Module99    def init(self):Set tracker configurations
104        tracker.set_scalar("loss.*", True)
105        tracker.set_scalar("accuracy.*", True)
106        tracker.set_histogram('u.*', True)
107        tracker.set_histogram('prob.*', False)
108        tracker.set_scalar('annealing_coef.*', False)
109        tracker.set_scalar('kl_div_loss.*', False)112        self.state_modules = []114    def step(self, batch: Any, batch_idx: BatchIndex):Training/Evaluation mode
120        self.model.train(self.mode.is_train)Move data to the device
123        data, target = batch[0].to(self.device), batch[1].to(self.device)One-hot coded targets
126        eye = torch.eye(10).to(torch.float).to(self.device)
127        target = eye[target]Update global step (number of samples processed) when in training mode
130        if self.mode.is_train:
131            tracker.add_global_step(len(data))Get model outputs
134        outputs = self.model(data)Get evidences
136        evidence = self.outputs_to_evidence(outputs)Calculate loss
139        loss = self.loss_func(evidence, target)Calculate KL Divergence regularization loss
141        kl_div_loss = self.kl_div_loss(evidence, target)
142        tracker.add("loss.", loss)
143        tracker.add("kl_div_loss.", kl_div_loss)KL Divergence loss coefficient
146        annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
147        tracker.add("annealing_coef.", annealing_coef)Total loss
150        loss = loss + annealing_coef * kl_div_lossTrack statistics
153        self.stats(evidence, target)Train the model
156        if self.mode.is_train:Calculate gradients
158            loss.backward()Take optimizer step
160            self.optimizer.step()Clear the gradients
162            self.optimizer.zero_grad()Save the tracked metrics
165        tracker.save()168@option(Configs.model)
169def mnist_model(c: Configs):173    return Model(c.dropout).to(c.device)176@option(Configs.kl_div_coef)
177def kl_div_coef(c: Configs):Create a relative piecewise schedule
183    return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))187calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())189calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())191calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())ReLU to calculate evidence
194calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())Softplus to calculate evidence
196calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())199def main():Create experiment
201    experiment.create(name='evidence_mnist')Create configurations
203    conf = Configs()Load configurations
205    experiment.configs(conf, {
206        'optimizer.optimizer': 'Adam',
207        'optimizer.learning_rate': 0.001,
208        'optimizer.weight_decay': 0.005,'loss_func': 'max_likelihood_loss', 'loss_func': 'cross_entropy_bayes_risk',
212        'loss_func': 'squared_error_bayes_risk',
213
214        'outputs_to_evidence': 'softplus',
215
216        'dropout': 0.5,
217    })Start the experiment and run the training loop
219    with experiment.start():
220        conf.run()224if __name__ == '__main__':
225    main()