Switch Transformer Experiment

This is an annotated PyTorch experiment to train a switch transformer.

Open In Colab View Run

15import torch
16import torch.nn as nn
17
18from labml import experiment, tracker
19from labml.configs import option
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs

Auto regressive model

25class AutoregressiveModel(Module):
30    def __init__(self, n_vocab: int, d_model: int, transformer: Module):
31        super().__init__()

Token embedding module

33        self.src_embed = nn.Embedding(n_vocab, d_model)

Transformer

35        self.transformer = transformer

Final layer

37        self.generator = nn.Linear(d_model, n_vocab)
38        self.mask = None
40    def forward(self, x: torch.Tensor):

Initialize the subsequent mask

42        if self.mask is None or self.mask.size(0) != len(x):
43            from labml_nn.transformers.utils import subsequent_mask
44            self.mask = subsequent_mask(len(x)).to(x.device)

Token embeddings

46        x = self.src_embed(x)

Run it through the transformer

48        res, counts, route_prob, n_dropped, route_prob_max = self.transformer(x, self.mask)

Generate logits of the next token

50        res = self.generator(res)
52        return res, counts, route_prob, n_dropped, route_prob_max

Configurations

This extends NLPAutoRegressionConfigs.

The default configs can and will be over-ridden when we start the experiment

55class Configs(NLPAutoRegressionConfigs):
64    model: AutoregressiveModel
65    transformer: Module

Token embedding size

68    d_model: int = 128

Number of attention heads

70    heads: int = 4

Dropout probability

72    dropout: float = 0.0

Number of features in FFN hidden layer

74    d_ff: int = 256

Number of transformer layers

76    n_layers: int = 6

Number of experts

78    n_experts: int = 4

Load balancing coefficient

80    load_balancing_loss_ceof = 0.01

Whether to scale the chosen expert outputs by the routing probability

82    is_scale_prob: bool = True

Whether to drop tokens

84    drop_tokens: bool = False

Capacity factor to determine capacity of each model

86    capacity_factor: float = 1.0
88    def init(self):
89        super().init()

Initialize tracking indicators

91        tracker.set_scalar("lb_loss.*", False)
92        tracker.set_scalar("route.*", False)
93        tracker.set_scalar("dropped.*", False)

Training or validation step

95    def step(self, batch: any, batch_idx: BatchIndex):

Move data to the device

101        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

104        if self.mode.is_train:
105            tracker.add_global_step(data.shape[0] * data.shape[1])

Whether to capture model outputs

108        with self.mode.update(is_log_activations=batch_idx.is_last):

Get model outputs.

110            output, counts, route_prob, n_dropped, route_prob_max = self.model(data)

Calculate and cross entropy loss

113        cross_entropy_loss = self.loss_func(output, target)

Total number of tokens processed, $T$, in the current batch $\mathscr{B}$

115        total = counts.sum(dim=-1, keepdims=True)

Fraction of tokens routed to each expert $f_i$ is the count of tokens where the argmax of $p(x)$ is equal to $i$.

119        route_frac = counts / total

Mean routing probability

122        route_prob = route_prob / total

Load balancing loss $\mathscr{L}$ is the loss for a single layer and here we are taking the sum of losses across all layers.

127        load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()

Track stats

130        tracker.add('dropped.', total.new_tensor(n_dropped) / total)
131        tracker.add('route.min.', route_frac.min())
132        tracker.add('route.max.', route_frac.max())
133        tracker.add('route.std.', route_frac.std())
134        tracker.add('route.max_prob.', route_prob_max)
135        tracker.add("loss.", cross_entropy_loss)
136        tracker.add("lb_loss.", load_balancing_loss)

Combined loss. The load balancing loss is multiplied by a coefficient $\alpha$ which is set to something small like $\alpha = 0.01$.

141        loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_loss

Calculate and log accuracy

144        self.accuracy(output, target)
145        self.accuracy.track()

Train the model

148        if self.mode.is_train:

Calculate gradients

150            loss.backward()

Clip gradients

152            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

154            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

156            if batch_idx.is_last:
157                tracker.add('model', self.model)

Clear the gradients

159            self.optimizer.zero_grad()

Save the tracked metrics

162        tracker.save()

Initialize the auto-regressive model

165@option(Configs.model)
166def autoregressive_model(c: Configs):
170    m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer)
171    return m.to(c.device)

Initialize the switch transformer

174@option(Configs.transformer)
175def switch_transformer(c: Configs):
179    from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
180    from labml_nn.transformers import MultiHeadAttention
181    from labml_nn.transformers.feed_forward import FeedForward
182
183    return SwitchTransformer(
184        SwitchTransformerLayer(d_model=c.d_model,
185                               attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
186                               feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
187                                                              drop_tokens=c.drop_tokens,
188                                                              is_scale_prob=c.is_scale_prob,
189                                                              n_experts=c.n_experts,
190                                                              expert=FeedForward(c.d_model, c.d_ff, c.dropout),
191                                                              d_model=c.d_model),
192                               dropout_prob=c.dropout),
193        c.n_layers)

Run the experiment

196def main():

Create experiment

201    experiment.create(name="switch_transformer", comment='')

Create configs

203    conf = Configs()

Load configurations

205    experiment.configs(conf,

A dictionary of configurations to override

207                       {'tokenizer': 'character',
208                        'text': 'tiny_shakespeare',
209                        'optimizer.learning_rate': 1.,
210                        'optimizer.optimizer': 'Noam',
211                        'prompt': 'It is',
212                        'prompt_separator': '',
213
214                        'transformer': 'switch_transformer',
215                        'n_experts': 4,
216
217                        'drop_tokens': True,
218                        'capacity_factor': 1.2,
219
220                        'train_loader': 'shuffled_train_loader',
221                        'valid_loader': 'shuffled_valid_loader',
222
223                        'seq_len': 64,
224                        'epochs': 128,
225                        'batch_size': 32,
226                        'inner_iterations': 25,
227                        })

Set models for saving and loading

230    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

233    with experiment.start():

TrainValidConfigs.run

235        conf.run()
239if __name__ == '__main__':
240    main()