This is an annotated PyTorch experiment to train a switch transformer.
15import torch
16import torch.nn as nn
17
18from labml import experiment, tracker
19from labml.configs import option
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs25class AutoregressiveModel(Module):30 def __init__(self, n_vocab: int, d_model: int, transformer: Module):
31 super().__init__()Token embedding module
33 self.src_embed = nn.Embedding(n_vocab, d_model)Transformer
35 self.transformer = transformerFinal layer
37 self.generator = nn.Linear(d_model, n_vocab)
38 self.mask = None40 def forward(self, x: torch.Tensor):Initialize the subsequent mask
42 if self.mask is None or self.mask.size(0) != len(x):
43 from labml_nn.transformers.utils import subsequent_mask
44 self.mask = subsequent_mask(len(x)).to(x.device)Token embeddings
46 x = self.src_embed(x)Run it through the transformer
48 res, counts, route_prob, n_dropped, route_prob_max = self.transformer(x, self.mask)Generate logits of the next token
50 res = self.generator(res)52 return res, counts, route_prob, n_dropped, route_prob_maxThis extends NLPAutoRegressionConfigs
.
The default configs can and will be over-ridden when we start the experiment
55class Configs(NLPAutoRegressionConfigs):64 model: AutoregressiveModel
65 transformer: ModuleToken embedding size
68 d_model: int = 128Number of attention heads
70 heads: int = 4Dropout probability
72 dropout: float = 0.0Number of features in FFN hidden layer
74 d_ff: int = 256Number of transformer layers
76 n_layers: int = 6Number of experts
78 n_experts: int = 4Load balancing coefficient
80 load_balancing_loss_ceof = 0.01Whether to scale the chosen expert outputs by the routing probability
82 is_scale_prob: bool = TrueWhether to drop tokens
84 drop_tokens: bool = FalseCapacity factor to determine capacity of each model
86 capacity_factor: float = 1.088 def init(self):
89 super().init()Initialize tracking indicators
91 tracker.set_scalar("lb_loss.*", False)
92 tracker.set_scalar("route.*", False)
93 tracker.set_scalar("dropped.*", False)95 def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
101 data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
104 if self.mode.is_train:
105 tracker.add_global_step(data.shape[0] * data.shape[1])Whether to capture model outputs
108 with self.mode.update(is_log_activations=batch_idx.is_last):Get model outputs.
110 output, counts, route_prob, n_dropped, route_prob_max = self.model(data)Calculate and cross entropy loss
113 cross_entropy_loss = self.loss_func(output, target)Total number of tokens processed, , in the current batch
115 total = counts.sum(dim=-1, keepdims=True)Fraction of tokens routed to each expert is the count of tokens where the argmax of is equal to .
119 route_frac = counts / totalMean routing probability
122 route_prob = route_prob / totalLoad balancing loss is the loss for a single layer and here we are taking the sum of losses across all layers.
127 load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()Track stats
130 tracker.add('dropped.', total.new_tensor(n_dropped) / total)
131 tracker.add('route.min.', route_frac.min())
132 tracker.add('route.max.', route_frac.max())
133 tracker.add('route.std.', route_frac.std())
134 tracker.add('route.max_prob.', route_prob_max)
135 tracker.add("loss.", cross_entropy_loss)
136 tracker.add("lb_loss.", load_balancing_loss)Combined loss. The load balancing loss is multiplied by a coefficient which is set to something small like .
141 loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_lossCalculate and log accuracy
144 self.accuracy(output, target)
145 self.accuracy.track()Train the model
148 if self.mode.is_train:Calculate gradients
150 loss.backward()Clip gradients
152 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
154 self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
156 if batch_idx.is_last:
157 tracker.add('model', self.model)Clear the gradients
159 self.optimizer.zero_grad()Save the tracked metrics
162 tracker.save()165@option(Configs.model)
166def autoregressive_model(c: Configs):170 m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer)
171 return m.to(c.device)174@option(Configs.transformer)
175def switch_transformer(c: Configs):179 from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
180 from labml_nn.transformers import MultiHeadAttention
181 from labml_nn.transformers.feed_forward import FeedForward
182
183 return SwitchTransformer(
184 SwitchTransformerLayer(d_model=c.d_model,
185 attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
186 feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
187 drop_tokens=c.drop_tokens,
188 is_scale_prob=c.is_scale_prob,
189 n_experts=c.n_experts,
190 expert=FeedForward(c.d_model, c.d_ff, c.dropout),
191 d_model=c.d_model),
192 dropout_prob=c.dropout),
193 c.n_layers)196def main():Create experiment
201 experiment.create(name="switch_transformer", comment='')Create configs
203 conf = Configs()Load configurations
205 experiment.configs(conf,A dictionary of configurations to override
207 {'tokenizer': 'character',
208 'text': 'tiny_shakespeare',
209 'optimizer.learning_rate': 1.,
210 'optimizer.optimizer': 'Noam',
211 'prompt': 'It is',
212 'prompt_separator': '',
213
214 'transformer': 'switch_transformer',
215 'n_experts': 4,
216
217 'drop_tokens': True,
218 'capacity_factor': 1.2,
219
220 'train_loader': 'shuffled_train_loader',
221 'valid_loader': 'shuffled_valid_loader',
222
223 'seq_len': 64,
224 'epochs': 128,
225 'batch_size': 32,
226 'inner_iterations': 25,
227 })Set models for saving and loading
230 experiment.add_pytorch_models({'model': conf.model})Start the experiment
233 with experiment.start():TrainValidConfigs.run
235 conf.run()239if __name__ == '__main__':
240 main()