This is an annotated PyTorch experiment to train a switch transformer.
14import torch
15import torch.nn as nn
16
17from labml import experiment, tracker
18from labml.configs import option
19from labml_nn.helpers.trainer import BatchIndex
20from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs23class AutoregressiveModel(nn.Module):28    def __init__(self, n_vocab: int, d_model: int, transformer: nn.Module):
29        super().__init__()Token embedding module
31        self.src_embed = nn.Embedding(n_vocab, d_model)Transformer
33        self.transformer = transformerFinal layer
35        self.generator = nn.Linear(d_model, n_vocab)
36        self.mask = None38    def forward(self, x: torch.Tensor):Initialize the subsequent mask
40        if self.mask is None or self.mask.size(0) != len(x):
41            from labml_nn.transformers.utils import subsequent_mask
42            self.mask = subsequent_mask(len(x)).to(x.device)Token embeddings
44        x = self.src_embed(x)Run it through the transformer
46        res, counts, route_prob, n_dropped, route_prob_max = self.transformer(x, self.mask)Generate logits of the next token
48        res = self.generator(res)50        return res, counts, route_prob, n_dropped, route_prob_maxThis extends NLPAutoRegressionConfigs
.
The default configs can and will be over-ridden when we start the experiment
53class Configs(NLPAutoRegressionConfigs):62    model: AutoregressiveModel
63    transformer: nn.ModuleToken embedding size
66    d_model: int = 128Number of attention heads
68    heads: int = 4Dropout probability
70    dropout: float = 0.0Number of features in FFN hidden layer
72    d_ff: int = 256Number of transformer layers
74    n_layers: int = 6Number of experts
76    n_experts: int = 4Load balancing coefficient
78    load_balancing_loss_ceof = 0.01Whether to scale the chosen expert outputs by the routing probability
80    is_scale_prob: bool = TrueWhether to drop tokens
82    drop_tokens: bool = FalseCapacity factor to determine capacity of each model
84    capacity_factor: float = 1.086    def init(self):
87        super().init()Initialize tracking indicators
89        tracker.set_scalar("lb_loss.*", False)
90        tracker.set_scalar("route.*", False)
91        tracker.set_scalar("dropped.*", False)93    def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
99        data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
102        if self.mode.is_train:
103            tracker.add_global_step(data.shape[0] * data.shape[1])Get model outputs.
106        output, counts, route_prob, n_dropped, route_prob_max = self.model(data)Calculate and cross entropy loss
109        cross_entropy_loss = self.loss_func(output, target)Total number of tokens processed, , in the current batch
111        total = counts.sum(dim=-1, keepdims=True)Fraction of tokens routed to each expert is the count of tokens where the argmax of is equal to .
115        route_frac = counts / totalMean routing probability
118        route_prob = route_prob / totalLoad balancing loss is the loss for a single layer and here we are taking the sum of losses across all layers.
123        load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()Track stats
126        tracker.add('dropped.', total.new_tensor(n_dropped) / total)
127        tracker.add('route.min.', route_frac.min())
128        tracker.add('route.max.', route_frac.max())
129        tracker.add('route.std.', route_frac.std())
130        tracker.add('route.max_prob.', route_prob_max)
131        tracker.add("loss.", cross_entropy_loss)
132        tracker.add("lb_loss.", load_balancing_loss)Combined loss. The load balancing loss is multiplied by a coefficient which is set to something small like .
137        loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_lossCalculate and log accuracy
140        self.accuracy(output, target)
141        self.accuracy.track()Train the model
144        if self.mode.is_train:Calculate gradients
146            loss.backward()Clip gradients
148            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
150            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
152            if batch_idx.is_last:
153                tracker.add('model', self.model)Clear the gradients
155            self.optimizer.zero_grad()Save the tracked metrics
158        tracker.save()161@option(Configs.model)
162def autoregressive_model(c: Configs):166    m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer)
167    return m.to(c.device)170@option(Configs.transformer)
171def switch_transformer(c: Configs):175    from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
176    from labml_nn.transformers import MultiHeadAttention
177    from labml_nn.transformers.feed_forward import FeedForward
178
179    return SwitchTransformer(
180        SwitchTransformerLayer(d_model=c.d_model,
181                               attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
182                               feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
183                                                              drop_tokens=c.drop_tokens,
184                                                              is_scale_prob=c.is_scale_prob,
185                                                              n_experts=c.n_experts,
186                                                              expert=FeedForward(c.d_model, c.d_ff, c.dropout),
187                                                              d_model=c.d_model),
188                               dropout_prob=c.dropout),
189        c.n_layers)192def main():Create experiment
197    experiment.create(name="switch_transformer", comment='')Create configs
199    conf = Configs()Load configurations
201    experiment.configs(conf,A dictionary of configurations to override
203                       {'tokenizer': 'character',
204                        'text': 'tiny_shakespeare',
205                        'optimizer.learning_rate': 1.,
206                        'optimizer.optimizer': 'Noam',
207                        'prompt': 'It is',
208                        'prompt_separator': '',
209
210                        'transformer': 'switch_transformer',
211                        'n_experts': 4,
212
213                        'drop_tokens': True,
214                        'capacity_factor': 1.2,
215
216                        'train_loader': 'shuffled_train_loader',
217                        'valid_loader': 'shuffled_valid_loader',
218
219                        'seq_len': 64,
220                        'epochs': 128,
221                        'batch_size': 32,
222                        'inner_iterations': 25,
223                        })Set models for saving and loading
226    experiment.add_pytorch_models({'model': conf.model})Start the experiment
229    with experiment.start():TrainValidConfigs.run
 
231        conf.run()235if __name__ == '__main__':
236    main()