This is a miniature PyTorch implementation of the paper Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. Our implementation only has a few million parameters and doesn’t do model parallel distributed training. It does single GPU training, but we implement the concept of switching as described in the paper.
The Switch Transformer uses different parameters for each token by switching among parameters based on the token. Therefore, only a fraction of parameters are chosen for each token. So you can have more parameters but less computational cost.
The switching happens at the Position-wise Feedforward network (FFN) of each transformer block. Position-wise feedforward network consists of two sequentially fully connected layers. In switch transformer we have multiple FFNs (multiple experts), and we chose which one to use based on a router. The output is a set of probabilities for picking a FFN, and we pick the one with the highest probability and only evaluate that. So essentially the computational cost is the same as having a single FFN. In our implementation this doesn’t parallelize well when you have many or large FFNs since it’s all happening on a single GPU. In a distributed setup you would have each FFN (each very large) on a different device.
The paper introduces another loss term to balance load among the experts (FFNs) and discusses dropping tokens when routing is not balanced.
Here’s the training code and a notebook for training a switch transformer on Tiny Shakespeare dataset.
40import torch
41from torch import nn
42
43from labml_helpers.module import Module
44from labml_nn.transformers.mha import MultiHeadAttention
45from labml_nn.transformers.feed_forward import FeedForward
46from labml_nn.utils import clone_module_list
49class SwitchFeedForward(Module):
capacity_factor
is the capacity of each expert as a factor relative to ideally balanced loaddrop_tokens
specifies whether to drop tokens if more tokens are routed to an expert than the capacityis_scale_prob
specifies whether to multiply the input to the FFN by the routing probabilityn_experts
is the number of expertsexpert
is the expert layer, a FFN moduled_model
is the number of features in a token embeddingd_ff
is the number of features in the hidden layer of the FFNdropout
is dropout probability in the FFN54 def __init__(self, *,
55 capacity_factor: float,
56 drop_tokens: bool,
57 is_scale_prob: bool,
58 n_experts: int,
59 expert: FeedForward,
60 d_model: int):
71 super().__init__()
72
73 self.capacity_factor = capacity_factor
74 self.is_scale_prob = is_scale_prob
75 self.n_experts = n_experts
76 self.drop_tokens = drop_tokens
make copies of the FFNs
79 self.experts = clone_module_list(expert, n_experts)
Routing layer and softmax
81 self.switch = nn.Linear(d_model, n_experts)
82 self.softmax = nn.Softmax(dim=-1)
x
is the input to the switching module with shape [seq_len, batch_size, d_model]
84 def forward(self, x: torch.Tensor):
Capture the shape to change shapes later
90 seq_len, batch_size, d_model = x.shape
Flatten the sequence and batch dimensions
92 x = x.view(-1, d_model)
Get routing probabilities for each of the tokens.
where $N$ is the number of experts n_experts
and
$h(\cdot)$ is the linear transformation of token embeddings.
98 route_prob = self.softmax(self.switch(x))
Get the maximum routing probabilities and the routes. We route to the expert with highest probability
102 route_prob_max, routes = torch.max(route_prob, dim=-1)
Scale the inputs to the experts by the routing probabilities
105 if self.is_scale_prob:
106 factor = route_prob_max
Don’t scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
108 else:
109 factor = route_prob_max / route_prob_max.detach()
Multiply by the scaling factor
111 x = x * factor.view(-1, 1)
Get indexes of tokens going to each expert
114 indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]
Initialize an empty tensor to store outputs
117 final_output = x.new_zeros(x.shape)
Capacity of each expert.
123 capacity = int(self.capacity_factor * len(x) / self.n_experts)
Number of tokens routed to each expert.
125 counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])
Initialize an empty list of dropped tokens
128 dropped = []
Only drop tokens if drop_tokens
is True
.
130 if self.drop_tokens:
Drop tokens in each of the experts
132 for i in range(self.n_experts):
Ignore if the expert is not over capacity
134 if len(indexes_list[i]) <= capacity:
135 continue
Shuffle indexes before dropping
137 indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]
Collect the tokens over capacity as dropped tokens
139 dropped.append(indexes_list[i][capacity:])
Keep only the tokens upto the capacity of the expert
141 indexes_list[i] = indexes_list[i][:capacity]
Get outputs of the expert FFNs
144 route_outputs = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)]
Assign to final output
147 for i in range(self.n_experts):
148 final_output[indexes_list[i], :] = route_outputs[i]
Pass through the dropped tokens
151 if dropped:
152 dropped = torch.cat(dropped)
153 final_output[dropped, :] = x[dropped, :]
Change the shape of the final output back to [seq_len, batch_size, d_model]
156 final_output = final_output.view(seq_len, batch_size, d_model)
Return * the final output * number of tokens routed to each expert * sum of probabilities for each expert * number of tokens dropped. These are used for the load balancing loss and logging
164 return final_output, counts, route_prob.sum(0), len(dropped)
This is the same as normal transformer block with handling extra outputs of switch feedforward module.
167class SwitchTransformerLayer(Module):
d_model
is the token embedding sizeattn
is the attention modulefeed_forward
is the feed forward module (which is the switching module in this case)dropout_prob
is the probability of dropping out after self attention and FFN174 def __init__(self, *,
175 d_model: int,
176 attn: MultiHeadAttention,
177 feed_forward: SwitchFeedForward,
178 dropout_prob: float):
185 super().__init__()
186 self.size = d_model
187 self.attn = attn
188 self.feed_forward = feed_forward
189 self.dropout = nn.Dropout(dropout_prob)
190 self.norm_self_attn = nn.LayerNorm([d_model])
191 self.norm_ff = nn.LayerNorm([d_model])
193 def forward(self, *,
194 x: torch.Tensor,
195 mask: torch.Tensor):
Normalize the vectors before doing self attention
197 z = self.norm_self_attn(x)
Run through self attention, i.e. keys and values are from self
199 self_attn = self.attn(query=z, key=z, value=z, mask=mask)
Add the self attention results
201 x = x + self.dropout(self_attn)
Normalize for feed-forward
204 z = self.norm_ff(x)
Pass through the switching feed-forward network
206 ff, counts, route_prob, n_dropped = self.feed_forward(z)
Add the feed-forward results back
208 x = x + self.dropout(ff)
209
210 return x, counts, route_prob, n_dropped
213class SwitchTransformer(Module):
218 def __init__(self, layer: SwitchTransformerLayer, n_layers: int):
219 super().__init__()
Make copies of the transformer layer
221 self.layers = clone_module_list(layer, n_layers)
Final normalization layer
223 self.norm = nn.LayerNorm([layer.size])
225 def forward(self, x: torch.Tensor, mask: torch.Tensor):
Run through each transformer layer
227 counts, route_prob, n_dropped = [], [], []
228 for layer in self.layers:
229 x, f, p, n_d = layer(x=x, mask=mask)
230 counts.append(f)
231 route_prob.append(p)
232 n_dropped.append(n_d)
Finally, normalize the vectors
234 x = self.norm(x)
236 return x, torch.stack(counts), torch.stack(route_prob), n_dropped