diff --git a/labml_nn/transformers/switch/__init__.py b/labml_nn/transformers/switch/__init__.py index c1d37cd6..cdeccbae 100644 --- a/labml_nn/transformers/switch/__init__.py +++ b/labml_nn/transformers/switch/__init__.py @@ -81,27 +81,46 @@ class SwitchFeedForward(Module): self.softmax = nn.Softmax(dim=-1) def __call__(self, x: torch.Tensor): - seq_len, bs, d_model = x.shape + """ + * `x` is the input to the switching module with shape `[seq_len, batch_size, d_model]` + """ + + # Capture the shape to change shapes later + seq_len, batch_size, d_model = x.shape + # Flatten the sequence and batch dimensions x = x.view(-1, d_model) + # Get routing probabilities for each of the tokens. + # $$p_i(x) = \frac{e^{h(x)_i}}{\sum^N_j e^{h(x)_j}}$$ + # where $N$ is the number of experts `n_experts` and + # $h(\cdot)$ is the linear transformation of token embeddings. route_prob = self.softmax(self.switch(x)) + + # Get the maximum routing probabilities and the routes. + # We route to the expert with highest probability route_prob_max, routes = torch.max(route_prob, dim=-1) + # Scale the inputs to the experts by the routing probabilities if self.is_scale_prob: factor = route_prob_max + # Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow else: factor = route_prob_max / route_prob_max.detach() + # Multiply by the scaling factor x = x * factor.view(-1, 1) - # Get indexes of vectors going to each route + # Get indexes of tokens going to each expert indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_switches)] - # Tensor to store outputs + # Initialize an empty tensor to store outputs final_output = x.new_zeros(x.shape) - # Capacity of a route + # Capacity of each expert. + # $$\mathrm{expert\;capacity} = + # \frac{\mathrm{tokens\;per\;batch}}{\mathrm{number\;of\;experts}} + # \times \mathrm{capacity\;factor}$$ capacity = int(self.capacity_factor * len(x) / self.n_switches) - # Number of tokens going to each route + # Number of tokens routed to each expert counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_switches)]) # Drop tokens @@ -125,7 +144,7 @@ class SwitchFeedForward(Module): final_output[dropped, :] = x[dropped, :] # Change the shape of the final output - final_output = final_output.view(seq_len, bs, d_model) + final_output = final_output.view(seq_len, batch_size, d_model) return final_output, counts, route_prob.sum(0), len(dropped)