This is a miniature PyTorch implementation of the paper Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. Our implementation only has a few million parameters and doesn’t do model parallel distributed training. It does single GPU training, but we implement the concept of switching as described in the paper.
The Switch Transformer uses different parameters for each token by switching among parameters based on the token. Therefore, only a fraction of parameters are chosen for each token. So you can have more parameters but less computational cost.
The switching happens at the Position-wise Feedforward network (FFN) of each transformer block. Position-wise feedforward network consists of two sequentially fully connected layers. In switch transformer we have multiple FFNs (multiple experts), and we chose which one to use based on a router. The output is a set of probabilities for picking a FFN, and we pick the one with the highest probability and only evaluate that. So essentially the computational cost is the same as having a single FFN. In our implementation this doesn’t parallelize well when you have many or large FFNs since it’s all happening on a single GPU. In a distributed setup you would have each FFN (each very large) on a different device.
The paper introduces another loss term to balance load among the experts (FFNs) and discusses dropping tokens when routing is not balanced.
Here’s the training code and a notebook for training a switch transformer on Tiny Shakespeare dataset.
40import torch
41from torch import nn
42
43from labml_helpers.module import Module
44from labml_nn.transformers.mha import MultiHeadAttention
45from labml_nn.transformers.feed_forward import FeedForward
46from labml_nn.utils import clone_module_list49class SwitchFeedForward(Module):capacity_factor is the capacity of each expert as a factor relative to ideally balanced loaddrop_tokens specifies whether to drop tokens if more tokens are routed to an expert than the capacityis_scale_prob specifies whether to multiply the input to the FFN by the routing probabilityn_experts is the number of expertsexpert is the expert layer, a FFN moduled_model is the number of features in a token embeddingd_ff is the number of features in the hidden layer of the FFNdropout is dropout probability in the FFN54    def __init__(self, *,
55                 capacity_factor: float,
56                 drop_tokens: bool,
57                 is_scale_prob: bool,
58                 n_experts: int,
59                 expert: FeedForward,
60                 d_model: int):71        super().__init__()
72
73        self.capacity_factor = capacity_factor
74        self.is_scale_prob = is_scale_prob
75        self.n_experts = n_experts
76        self.drop_tokens = drop_tokensmake copies of the FFNs
79        self.experts = clone_module_list(expert, n_experts)Routing layer and softmax
81        self.switch = nn.Linear(d_model, n_experts)
82        self.softmax = nn.Softmax(dim=-1)x is the input to the switching module with shape [seq_len, batch_size, d_model]84    def forward(self, x: torch.Tensor):Capture the shape to change shapes later
90        seq_len, batch_size, d_model = x.shapeFlatten the sequence and batch dimensions
92        x = x.view(-1, d_model)Get routing probabilities for each of the tokens.
where $N$ is the number of experts n_experts and
$h(\cdot)$ is the linear transformation of token embeddings.
98        route_prob = self.softmax(self.switch(x))Get the maximum routing probabilities and the routes. We route to the expert with highest probability
102        route_prob_max, routes = torch.max(route_prob, dim=-1)Scale the inputs to the experts by the routing probabilities
105        if self.is_scale_prob:
106            factor = route_prob_maxDon’t scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
108        else:
109            factor = route_prob_max / route_prob_max.detach()Multiply by the scaling factor
111        x = x * factor.view(-1, 1)Get indexes of tokens going to each expert
114        indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]Initialize an empty tensor to store outputs
117        final_output = x.new_zeros(x.shape)Capacity of each expert.
123        capacity = int(self.capacity_factor * len(x) / self.n_experts)Number of tokens routed to each expert.
125        counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])Initialize an empty list of dropped tokens
128        dropped = []Only drop tokens if drop_tokens is True.
130        if self.drop_tokens:Drop tokens in each of the experts
132            for i in range(self.n_experts):Ignore if the expert is not over capacity
134                if len(indexes_list[i]) <= capacity:
135                    continueShuffle indexes before dropping
137                indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]Collect the tokens over capacity as dropped tokens
139                dropped.append(indexes_list[i][capacity:])Keep only the tokens upto the capacity of the expert
141                indexes_list[i] = indexes_list[i][:capacity]Get outputs of the expert FFNs
144        route_outputs = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)]Assign to final output
147        for i in range(self.n_experts):
148            final_output[indexes_list[i], :] = route_outputs[i]Pass through the dropped tokens
151        if dropped:
152            dropped = torch.cat(dropped)
153            final_output[dropped, :] = x[dropped, :]Change the shape of the final output back to [seq_len, batch_size, d_model]
156        final_output = final_output.view(seq_len, batch_size, d_model)Return * the final output * number of tokens routed to each expert * sum of probabilities for each expert * number of tokens dropped. These are used for the load balancing loss and logging
164        return final_output, counts, route_prob.sum(0), len(dropped)This is the same as normal transformer block with handling extra outputs of switch feedforward module.
167class SwitchTransformerLayer(Module):d_model is the token embedding sizeattn is the attention modulefeed_forward is the feed forward module (which is the switching module in this case)dropout_prob is the probability of dropping out after self attention and FFN174    def __init__(self, *,
175                 d_model: int,
176                 attn: MultiHeadAttention,
177                 feed_forward: SwitchFeedForward,
178                 dropout_prob: float):185        super().__init__()
186        self.size = d_model
187        self.attn = attn
188        self.feed_forward = feed_forward
189        self.dropout = nn.Dropout(dropout_prob)
190        self.norm_self_attn = nn.LayerNorm([d_model])
191        self.norm_ff = nn.LayerNorm([d_model])193    def forward(self, *,
194                 x: torch.Tensor,
195                 mask: torch.Tensor):Normalize the vectors before doing self attention
197        z = self.norm_self_attn(x)Run through self attention, i.e. keys and values are from self
199        self_attn = self.attn(query=z, key=z, value=z, mask=mask)Add the self attention results
201        x = x + self.dropout(self_attn)Normalize for feed-forward
204        z = self.norm_ff(x)Pass through the switching feed-forward network
206        ff, counts, route_prob, n_dropped = self.feed_forward(z)Add the feed-forward results back
208        x = x + self.dropout(ff)
209
210        return x, counts, route_prob, n_dropped213class SwitchTransformer(Module):218    def __init__(self, layer: SwitchTransformerLayer, n_layers: int):
219        super().__init__()Make copies of the transformer layer
221        self.layers = clone_module_list(layer, n_layers)Final normalization layer
223        self.norm = nn.LayerNorm([layer.size])225    def forward(self, x: torch.Tensor, mask: torch.Tensor):Run through each transformer layer
227        counts, route_prob, n_dropped = [], [], []
228        for layer in self.layers:
229            x, f, p, n_d = layer(x=x, mask=mask)
230            counts.append(f)
231            route_prob.append(p)
232            n_dropped.append(n_d)Finally, normalize the vectors
234        x = self.norm(x)236        return x, torch.stack(counts), torch.stack(route_prob), n_dropped