diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 98c691af..e4cf5fa1 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -603,14 +603,14 @@
This is an annotated PyTorch experiment to train a switch transformer.
+12import torch
-13import torch.nn as nn
-14
-15from labml import experiment, tracker
-16from labml.configs import option
-17from labml_helpers.module import Module
-18from labml_helpers.train_valid import BatchIndex
-19from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs15import torch
+16import torch.nn as nn
+17
+18from labml import experiment, tracker
+19from labml.configs import option
+20from labml_helpers.module import Module
+21from labml_helpers.train_valid import BatchIndex
+22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs22class AutoregressiveModel(Module):25class AutoregressiveModel(Module):27 def __init__(self, n_vocab: int, d_model: int, transformer: Module):
-28 super().__init__()30 def __init__(self, n_vocab: int, d_model: int, transformer: Module):
+31 super().__init__()Token embedding module
30 self.src_embed = nn.Embedding(n_vocab, d_model)33 self.src_embed = nn.Embedding(n_vocab, d_model)Transformer
32 self.transformer = transformer35 self.transformer = transformerFinal layer
34 self.generator = nn.Linear(d_model, n_vocab)
-35 self.mask = None37 self.generator = nn.Linear(d_model, n_vocab)
+38 self.mask = None37 def forward(self, x: torch.Tensor):40 def forward(self, x: torch.Tensor):Initialize the subsequent mask
39 if self.mask is None or self.mask.size(0) != len(x):
-40 from labml_nn.transformers.utils import subsequent_mask
-41 self.mask = subsequent_mask(len(x)).to(x.device)42 if self.mask is None or self.mask.size(0) != len(x):
+43 from labml_nn.transformers.utils import subsequent_mask
+44 self.mask = subsequent_mask(len(x)).to(x.device)Token embeddings
43 x = self.src_embed(x)46 x = self.src_embed(x)Run it through the transformer
45 res, counts, route_prob, n_dropped = self.transformer(x, self.mask)48 res, counts, route_prob, n_dropped = self.transformer(x, self.mask)Generate logits of the next token
47 res = self.generator(res)50 res = self.generator(res)49 return res, counts, route_prob, n_dropped52 return res, counts, route_prob, n_droppedThe default configs can and will be over-ridden when we start the experiment
52class Configs(NLPAutoRegressionConfigs):55class Configs(NLPAutoRegressionConfigs):61 model: AutoregressiveModel
-62 transformer: Module64 model: AutoregressiveModel
+65 transformer: ModuleToken embedding size
65 d_model: int = 12868 d_model: int = 128Number of attention heads
67 heads: int = 470 heads: int = 4Dropout probability
69 dropout: float = 0.072 dropout: float = 0.0Number of features in FFN hidden layer
71 d_ff: int = 25674 d_ff: int = 256Number of transformer layers
73 n_layers: int = 676 n_layers: int = 6Number of experts
75 n_experts: int = 478 n_experts: int = 4Load balancing coefficient
77 load_balancing_loss_ceof = 0.0180 load_balancing_loss_ceof = 0.01Whether to scale the chosen expert outputs by the routing probability
79 is_scale_prob: bool = True82 is_scale_prob: bool = TrueWhether to drop tokens
81 drop_tokens: bool = False84 drop_tokens: bool = FalseCapacity factor to determine capacity of each model
83 capacity_factor: float = 1.086 capacity_factor: float = 1.085 def init(self):
-86 super().init()88 def init(self):
+89 super().init()Initialize tracking indicators
88 tracker.set_scalar("lb_loss.*", False)
-89 tracker.set_scalar("route.*", False)
-90 tracker.set_scalar("dropped.*", False)91 tracker.set_scalar("lb_loss.*", False)
+92 tracker.set_scalar("route.*", False)
+93 tracker.set_scalar("dropped.*", False)92 def step(self, batch: any, batch_idx: BatchIndex):95 def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
98 data, target = batch[0].to(self.device), batch[1].to(self.device)101 data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
101 if self.mode.is_train:
-102 tracker.add_global_step(data.shape[0] * data.shape[1])104 if self.mode.is_train:
+105 tracker.add_global_step(data.shape[0] * data.shape[1])Whether to capture model outputs
105 with self.mode.update(is_log_activations=batch_idx.is_last):108 with self.mode.update(is_log_activations=batch_idx.is_last):Get model outputs.
107 output, counts, route_prob, n_dropped = self.model(data)110 output, counts, route_prob, n_dropped = self.model(data)Calculate and cross entropy loss
110 cross_entropy_loss = self.loss_func(output, target)113 cross_entropy_loss = self.loss_func(output, target)Total number of tokens processed, $T$, in the current batch $\mathscr{B}$
112 total = counts.sum(dim=-1, keepdims=True)115 total = counts.sum(dim=-1, keepdims=True)116 route_frac = counts / total119 route_frac = counts / total119 route_prob = route_prob / total122 route_prob = route_prob / total124 load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()127 load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()Track stats
127 tracker.add('dropped.', total.new_tensor(n_dropped) / total)
-128 tracker.add('route.min.', route_frac.min())
-129 tracker.add('route.max.', route_frac.max())
-130 tracker.add('route.std.', route_frac.std())
-131 tracker.add("loss.", cross_entropy_loss)
-132 tracker.add("lb_loss.", load_balancing_loss)130 tracker.add('dropped.', total.new_tensor(n_dropped) / total)
+131 tracker.add('route.min.', route_frac.min())
+132 tracker.add('route.max.', route_frac.max())
+133 tracker.add('route.std.', route_frac.std())
+134 tracker.add("loss.", cross_entropy_loss)
+135 tracker.add("lb_loss.", load_balancing_loss)137 loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_loss140 loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_lossCalculate and log accuracy
140 self.accuracy(output, target)
-141 self.accuracy.track()143 self.accuracy(output, target)
+144 self.accuracy.track()Train the model
144 if self.mode.is_train:147 if self.mode.is_train:Calculate gradients
146 loss.backward()149 loss.backward()Clip gradients
148 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)151 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
150 self.optimizer.step()153 self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
152 if batch_idx.is_last:
-153 tracker.add('model', self.model)155 if batch_idx.is_last:
+156 tracker.add('model', self.model)Clear the gradients
155 self.optimizer.zero_grad()158 self.optimizer.zero_grad()Save the tracked metrics
158 tracker.save()161 tracker.save()161@option(Configs.model)
-162def autoregressive_model(c: Configs):164@option(Configs.model)
+165def autoregressive_model(c: Configs):166 m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer)
-167 return m.to(c.device)169 m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer)
+170 return m.to(c.device)170@option(Configs.transformer)
-171def switch_transformer(c: Configs):173@option(Configs.transformer)
+174def switch_transformer(c: Configs):175 from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
-176 from labml_nn.transformers import MultiHeadAttention
-177 from labml_nn.transformers.feed_forward import FeedForward
-178
-179 return SwitchTransformer(
-180 SwitchTransformerLayer(d_model=c.d_model,
-181 attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
-182 feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
-183 drop_tokens=c.drop_tokens,
-184 is_scale_prob=c.is_scale_prob,
-185 n_experts=c.n_experts,
-186 expert=FeedForward(c.d_model, c.d_ff, c.dropout),
-187 d_model=c.d_model),
-188 dropout_prob=c.dropout),
-189 c.n_layers)178 from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
+179 from labml_nn.transformers import MultiHeadAttention
+180 from labml_nn.transformers.feed_forward import FeedForward
+181
+182 return SwitchTransformer(
+183 SwitchTransformerLayer(d_model=c.d_model,
+184 attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
+185 feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
+186 drop_tokens=c.drop_tokens,
+187 is_scale_prob=c.is_scale_prob,
+188 n_experts=c.n_experts,
+189 expert=FeedForward(c.d_model, c.d_ff, c.dropout),
+190 d_model=c.d_model),
+191 dropout_prob=c.dropout),
+192 c.n_layers)192def main():195def main():Create experiment
197 experiment.create(name="switch_transformer", comment='')200 experiment.create(name="switch_transformer", comment='')Create configs
199 conf = Configs()202 conf = Configs()Load configurations
201 experiment.configs(conf,204 experiment.configs(conf,A dictionary of configurations to override
203 {'tokenizer': 'character',
-204 'text': 'tiny_shakespeare',
-205 'optimizer.learning_rate': 1.,
-206 'optimizer.optimizer': 'Noam',
-207 'prompt': 'It is',
-208 'prompt_separator': '',
-209
-210 'transformer': 'switch_transformer',
-211 'is_scale_prob': False,
-212 'n_experts': 4,
-213
-214 'drop_tokens': True,
-215 'capacity_factor': 1.2,
-216
-217 'train_loader': 'shuffled_train_loader',
-218 'valid_loader': 'shuffled_valid_loader',
-219
-220 'seq_len': 64,
-221 'epochs': 128,
-222 'batch_size': 32,
-223 'inner_iterations': 25,
-224 })206 {'tokenizer': 'character',
+207 'text': 'tiny_shakespeare',
+208 'optimizer.learning_rate': 1.,
+209 'optimizer.optimizer': 'Noam',
+210 'prompt': 'It is',
+211 'prompt_separator': '',
+212
+213 'transformer': 'switch_transformer',
+214 'n_experts': 4,
+215
+216 'drop_tokens': True,
+217 'capacity_factor': 1.2,
+218
+219 'train_loader': 'shuffled_train_loader',
+220 'valid_loader': 'shuffled_valid_loader',
+221
+222 'seq_len': 64,
+223 'epochs': 128,
+224 'batch_size': 32,
+225 'inner_iterations': 25,
+226 })Set models for saving and loading
227 experiment.add_pytorch_models({'model': conf.model})229 experiment.add_pytorch_models({'model': conf.model})Start the experiment
230 with experiment.start():232 with experiment.start():TrainValidConfigs.run
232 conf.run()234 conf.run()236if __name__ == '__main__':
-237 main()238if __name__ == '__main__':
+239 main()40import torch
41from torch import nn
42
43from labml_helpers.module import Module
-44from labml_nn.transformers.mha import MultiHeadAttention
-45from labml_nn.transformers.feed_forward import FeedForward
+44from labml_nn.transformers.feed_forward import FeedForward
+45from labml_nn.transformers.mha import MultiHeadAttention
46from labml_nn.utils import clone_module_listScale the inputs to the experts by the routing probabilities
+Get indexes of tokens going to each expert
105 if self.is_scale_prob:
-106 factor = route_prob_max105 indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]Don’t scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
+Initialize an empty tensor to store outputs
108 else:
-109 factor = route_prob_max / route_prob_max.detach()108 final_output = x.new_zeros(x.shape)Multiply by the scaling factor
-111 x = x * factor.view(-1, 1)Get indexes of tokens going to each expert
-114 indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]Initialize an empty tensor to store outputs
-117 final_output = x.new_zeros(x.shape)Capacity of each expert.