This is a PyTorch implementation/tutorial of the paper Distilling the Knowledge in a Neural Network.
It's a way of training a small network using the knowledge in a trained larger network; i.e. distilling the knowledge from the large network.
A large model with regularization or an ensemble of models (using dropout) generalizes better than a small model when trained directly on the data and labels. However, a small model can be trained to generalize better with help of a large model. Smaller models are better in production: faster, less compute, less memory.
The output probabilities of a trained model give more information than the labels because it assigns non-zero probabilities to incorrect classes as well. These probabilities tell us that a sample has a chance of belonging to certain classes. For instance, when classifying digits, when given an image of digit 7, a generalized model will give a high probability to 7 and a small but non-zero probability to 2, while assigning almost zero probability to other digits. Distillation uses this information to train a small model better.
The probabilities are usually computed with a softmax operation,
where is the probability for class and is the logit.
We train the small model to minimize the Cross entropy or KL Divergence between its output probability distribution and the large network's output probability distribution (soft targets).
One of the problems here is that the probabilities assigned to incorrect classes by the large network are often very small and don't contribute to the loss. So they soften the probabilities by applying a temperature ,
where higher values for will produce softer probabilities.
Paper suggests adding a second loss term for predicting the actual labels when training the small model. We calculate the composite loss as the weighted sum of the two loss terms: soft targets and actual labels.
The dataset for distillation is called the transfer set, and the paper suggests using the same training data.
We train on CIFAR-10 dataset. We train a large model that has parameters with dropout and it gives an accuracy of 85% on the validation set. A small model with parameters gives an accuracy of 80%.
We then train the small model with distillation from the large model, and it gives an accuracy of 82%; a 2% increase in the accuracy.
72import torch
73import torch.nn.functional
74from torch import nn
75
76from labml import experiment, tracker
77from labml.configs import option
78from labml_helpers.train_valid import BatchIndex
79from labml_nn.distillation.large import LargeModel
80from labml_nn.distillation.small import SmallModel
81from labml_nn.experiments.cifar10 import CIFAR10ConfigsThis extends from CIFAR10Configs
 which defines all the dataset related configurations, optimizer, and a training loop.
84class Configs(CIFAR10Configs):The small model
92    model: SmallModelThe large model
94    large: LargeModelKL Divergence loss for soft targets
96    kl_div_loss = nn.KLDivLoss(log_target=True)Cross entropy loss for true label loss
98    loss_func = nn.CrossEntropyLoss()Temperature,
100    temperature: float = 5.Weight for soft targets loss.
The gradients produced by soft targets get scaled by . To compensate for this the paper suggests scaling the soft targets loss by a factor of
106    soft_targets_weight: float = 100.Weight for true label cross entropy loss
108    label_loss_weight: float = 0.5110    def step(self, batch: any, batch_idx: BatchIndex):Training/Evaluation mode for the small model
118        self.model.train(self.mode.is_train)Large model in evaluation mode
120        self.large.eval()Move data to the device
123        data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of samples processed) when in training mode
126        if self.mode.is_train:
127            tracker.add_global_step(len(data))Get the output logits, , from the large model
130        with torch.no_grad():
131            large_logits = self.large(data)Get the output logits, , from the small model
134        output = self.model(data)Soft targets
138        soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)Temperature adjusted probabilities of the small model
141        soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)Calculate the soft targets loss
144        soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)Calculate the true label loss
146        label_loss = self.loss_func(output, target)Weighted sum of the two losses
148        loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_lossLog the losses
150        tracker.add({"loss.kl_div.": soft_targets_loss,
151                     "loss.nll": label_loss,
152                     "loss.": loss})Calculate and log accuracy
155        self.accuracy(output, target)
156        self.accuracy.track()Train the model
159        if self.mode.is_train:Calculate gradients
161            loss.backward()Take optimizer step
163            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
165            if batch_idx.is_last:
166                tracker.add('model', self.model)Clear the gradients
168            self.optimizer.zero_grad()Save the tracked metrics
171        tracker.save()174@option(Configs.large)
175def _large_model(c: Configs):179    return LargeModel().to(c.device)182@option(Configs.model)
183def _small_student_model(c: Configs):187    return SmallModel().to(c.device)190def get_saved_model(run_uuid: str, checkpoint: int):195    from labml_nn.distillation.large import Configs as LargeConfigsIn evaluation mode (no recording)
198    experiment.evaluate()Initialize configs of the large model training experiment
200    conf = LargeConfigs()Load saved configs
202    experiment.configs(conf, experiment.load_configs(run_uuid))Set models for saving/loading
204    experiment.add_pytorch_models({'model': conf.model})Set which run and checkpoint to load
206    experiment.load(run_uuid, checkpoint)Start the experiment - this will load the model, and prepare everything
208    experiment.start()Return the model
211    return conf.modelTrain a small model with distillation
214def main(run_uuid: str, checkpoint: int):Load saved model
219    large_model = get_saved_model(run_uuid, checkpoint)Create experiment
221    experiment.create(name='distillation', comment='cifar10')Create configurations
223    conf = Configs()Set the loaded large model
225    conf.large = large_modelLoad configurations
227    experiment.configs(conf, {
228        'optimizer.optimizer': 'Adam',
229        'optimizer.learning_rate': 2.5e-4,
230        'model': '_small_student_model',
231    })Set model for saving/loading
233    experiment.add_pytorch_models({'model': conf.model})Start experiment from scratch
235    experiment.load(None, None)Start the experiment and run the training loop
237    with experiment.start():
238        conf.run()242if __name__ == '__main__':
243    main('d46cd53edaec11eb93c38d6538aee7d6', 1_000_000)