This trains a model based on Evidential Deep Learning to Quantify Classification Uncertainty on MNIST dataset.
14from typing import Any
15
16import torch.nn as nn
17import torch.utils.data
18
19from labml import tracker, experiment
20from labml.configs import option, calculate
21from labml_nn.helpers.schedule import Schedule, RelativePiecewise
22from labml_nn.helpers.trainer import BatchIndex
23from labml_nn.experiments.mnist import MNISTConfigs
24from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
25    CrossEntropyBayesRisk, SquaredErrorBayesRisk28class Model(nn.Module):33    def __init__(self, dropout: float):
34        super().__init__()First convolution layer
36        self.conv1 = nn.Conv2d(1, 20, kernel_size=5)ReLU activation
38        self.act1 = nn.ReLU()max-pooling
40        self.max_pool1 = nn.MaxPool2d(2, 2)Second convolution layer
42        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)ReLU activation
44        self.act2 = nn.ReLU()max-pooling
46        self.max_pool2 = nn.MaxPool2d(2, 2)First fully-connected layer that maps to features
48        self.fc1 = nn.Linear(50 * 4 * 4, 500)ReLU activation
50        self.act3 = nn.ReLU()Final fully connected layer to output evidence for classes. The ReLU or Softplus activation is applied to this outside the model to get the non-negative evidence
54        self.fc2 = nn.Linear(500, 10)Dropout for the hidden layer
56        self.dropout = nn.Dropout(p=dropout)x
 is the batch of MNIST images of shape [batch_size, 1, 28, 28]
58    def __call__(self, x: torch.Tensor):Apply first convolution and max pooling. The result has shape [batch_size, 20, 12, 12]
 
64        x = self.max_pool1(self.act1(self.conv1(x)))Apply second convolution and max pooling. The result has shape [batch_size, 50, 4, 4]
 
67        x = self.max_pool2(self.act2(self.conv2(x)))Flatten the tensor to shape [batch_size, 50 * 4 * 4]
 
69        x = x.view(x.shape[0], -1)Apply hidden layer
71        x = self.act3(self.fc1(x))Apply dropout
73        x = self.dropout(x)Apply final layer and return
75        return self.fc2(x)78class Configs(MNISTConfigs):86    kl_div_loss = KLDivergenceLoss()KL Divergence regularization coefficient schedule
88    kl_div_coef: ScheduleKL Divergence regularization coefficient schedule
90    kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]Stats module for tracking
92    stats = TrackStatistics()Dropout
94    dropout: float = 0.5Module to convert the model output to non-zero evidences
96    outputs_to_evidence: nn.Module98    def init(self):Set tracker configurations
103        tracker.set_scalar("loss.*", True)
104        tracker.set_scalar("accuracy.*", True)
105        tracker.set_histogram('u.*', True)
106        tracker.set_histogram('prob.*', False)
107        tracker.set_scalar('annealing_coef.*', False)
108        tracker.set_scalar('kl_div_loss.*', False)111        self.state_modules = []113    def step(self, batch: Any, batch_idx: BatchIndex):Training/Evaluation mode
119        self.model.train(self.mode.is_train)Move data to the device
122        data, target = batch[0].to(self.device), batch[1].to(self.device)One-hot coded targets
125        eye = torch.eye(10).to(torch.float).to(self.device)
126        target = eye[target]Update global step (number of samples processed) when in training mode
129        if self.mode.is_train:
130            tracker.add_global_step(len(data))Get model outputs
133        outputs = self.model(data)Get evidences
135        evidence = self.outputs_to_evidence(outputs)Calculate loss
138        loss = self.loss_func(evidence, target)Calculate KL Divergence regularization loss
140        kl_div_loss = self.kl_div_loss(evidence, target)
141        tracker.add("loss.", loss)
142        tracker.add("kl_div_loss.", kl_div_loss)KL Divergence loss coefficient
145        annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
146        tracker.add("annealing_coef.", annealing_coef)Total loss
149        loss = loss + annealing_coef * kl_div_lossTrack statistics
152        self.stats(evidence, target)Train the model
155        if self.mode.is_train:Calculate gradients
157            loss.backward()Take optimizer step
159            self.optimizer.step()Clear the gradients
161            self.optimizer.zero_grad()Save the tracked metrics
164        tracker.save()167@option(Configs.model)
168def mnist_model(c: Configs):172    return Model(c.dropout).to(c.device)175@option(Configs.kl_div_coef)
176def kl_div_coef(c: Configs):Create a relative piecewise schedule
182    return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))186calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())188calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())190calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())ReLU to calculate evidence
193calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())Softplus to calculate evidence
195calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())198def main():Create experiment
200    experiment.create(name='evidence_mnist')Create configurations
202    conf = Configs()Load configurations
204    experiment.configs(conf, {
205        'optimizer.optimizer': 'Adam',
206        'optimizer.learning_rate': 0.001,
207        'optimizer.weight_decay': 0.005,'loss_func': 'max_likelihood_loss', 'loss_func': 'cross_entropy_bayes_risk',
211        'loss_func': 'squared_error_bayes_risk',
212
213        'outputs_to_evidence': 'softplus',
214
215        'dropout': 0.5,
216    })Start the experiment and run the training loop
218    with experiment.start():
219        conf.run()223if __name__ == '__main__':
224    main()