This is an annotated PyTorch experiment to train a switch transformer.
12import torch
13import torch.nn as nn
14
15from labml import experiment, tracker
16from labml.configs import option
17from labml_helpers.module import Module
18from labml_helpers.train_valid import BatchIndex
19from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs22class AutoregressiveModel(Module):27 def __init__(self, n_vocab: int, d_model: int, transformer: Module):
28 super().__init__()Token embedding module
30 self.src_embed = nn.Embedding(n_vocab, d_model)Transformer
32 self.transformer = transformerFinal layer
34 self.generator = nn.Linear(d_model, n_vocab)
35 self.mask = None37 def forward(self, x: torch.Tensor):Initialize the subsequent mask
39 if self.mask is None or self.mask.size(0) != len(x):
40 from labml_nn.transformers.utils import subsequent_mask
41 self.mask = subsequent_mask(len(x)).to(x.device)Token embeddings
43 x = self.src_embed(x)Run it through the transformer
45 res, counts, route_prob, n_dropped = self.transformer(x, self.mask)Generate logits of the next token
47 res = self.generator(res)49 return res, counts, route_prob, n_dropped52class Configs(NLPAutoRegressionConfigs):59 model: AutoregressiveModel
60 transformer: ModuleToken embedding size
63 d_model: int = 128Number of attention heads
65 heads: int = 4Dropout probability
67 dropout: float = 0.0Number of features in FFN hidden layer
69 d_ff: int = 256Number of transformer layers
71 n_layers: int = 6Number of experts
73 n_experts: int = 4Load balancing coefficient
75 load_balancing_loss_ceof = 0.01Whether to scale the chosen expert outputs by the routing probability
77 is_scale_prob: bool = TrueWhether to drop tokens
79 drop_tokens: bool = FalseCapacity factor to determine capacity of each model
81 capacity_factor: float = 1.083 def init(self):
84 super().init()Initialize tracking indicators
86 tracker.set_scalar("lb_loss.*", False)
87 tracker.set_scalar("route.*", False)
88 tracker.set_scalar("dropped.*", False)90 def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
96 data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
99 if self.mode.is_train:
100 tracker.add_global_step(data.shape[0] * data.shape[1])Whether to capture model outputs
103 with self.mode.update(is_log_activations=batch_idx.is_last):Get model outputs.
105 output, counts, route_prob, n_dropped = self.model(data)Calculate and cross entropy loss
108 cross_entropy_loss = self.loss_func(output, target)Total number of tokens processed, $T$, in the current batch $\mathscr{B}$
110 total = counts.sum(dim=-1, keepdims=True)Fraction of tokens routed to each expert $f_i$ is the count of tokens where the argmax of $p(x)$ is equal to $i$.
114 route_frac = counts / totalMean routing probability
117 route_prob = route_prob / totalLoad balancing loss
120 load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()Track stats
123 tracker.add('dropped.', total.new_tensor(n_dropped) / total)
124 tracker.add('route.min.', route_frac.min())
125 tracker.add('route.max.', route_frac.max())
126 tracker.add('route.std.', route_frac.std())
127 tracker.add("loss.", cross_entropy_loss)
128 tracker.add("lb_loss.", load_balancing_loss)Combined loss. The load balancing loss is multiplied by a coefficient $\alpha$ which is set to something small like $\alpha = 0.01$.
133 loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_lossCalculate and log accuracy
136 self.accuracy(output, target)
137 self.accuracy.track()Train the model
140 if self.mode.is_train:Calculate gradients
142 loss.backward()Clip gradients
144 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
146 self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
148 if batch_idx.is_last:
149 tracker.add('model', self.model)Clear the gradients
151 self.optimizer.zero_grad()Save the tracked metrics
154 tracker.save()157@option(Configs.model)
158def autoregressive_model(c: Configs):162 m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer)
163 return m.to(c.device)166@option(Configs.transformer)
167def switch_transformer(c: Configs):171 from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
172 from labml_nn.transformers import MultiHeadAttention
173 from labml_nn.transformers.feed_forward import FeedForward
174
175 return SwitchTransformer(
176 SwitchTransformerLayer(d_model=c.d_model,
177 attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
178 feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
179 drop_tokens=c.drop_tokens,
180 is_scale_prob=c.is_scale_prob,
181 n_experts=c.n_experts,
182 expert=FeedForward(c.d_model, c.d_ff, c.dropout),
183 d_model=c.d_model),
184 dropout_prob=c.dropout),
185 c.n_layers)188def main():Create experiment
193 experiment.create(name="switch_transformer", comment='')Create configs
195 conf = Configs()Load configurations
197 experiment.configs(conf,A dictionary of configurations to override
199 {'tokenizer': 'character',
200 'text': 'tiny_shakespeare',
201 'optimizer.learning_rate': 1.,
202 'optimizer.optimizer': 'Noam',
203 'prompt': 'It is',
204 'prompt_separator': '',
205
206 'transformer': 'switch_transformer',
207 'is_scale_prob': False,
208 'n_experts': 4,
209
210 'drop_tokens': True,
211 'capacity_factor': 1.2,
212
213 'train_loader': 'shuffled_train_loader',
214 'valid_loader': 'shuffled_valid_loader',
215
216 'seq_len': 64,
217 'epochs': 128,
218 'batch_size': 32,
219 'inner_iterations': 25,
220 })Set models for saving and loading
223 experiment.add_pytorch_models({'model': conf.model})Start the experiment
226 with experiment.start():TrainValidConfigs.run
228 conf.run()232if __name__ == '__main__':
233 main()