This commit is contained in:
Varuna Jayasiri
2021-01-20 10:36:51 +05:30
parent 5d174f3a7c
commit 35e2bc6c96

View File

@ -81,27 +81,46 @@ class SwitchFeedForward(Module):
self.softmax = nn.Softmax(dim=-1)
def __call__(self, x: torch.Tensor):
seq_len, bs, d_model = x.shape
"""
* `x` is the input to the switching module with shape `[seq_len, batch_size, d_model]`
"""
# Capture the shape to change shapes later
seq_len, batch_size, d_model = x.shape
# Flatten the sequence and batch dimensions
x = x.view(-1, d_model)
# Get routing probabilities for each of the tokens.
# $$p_i(x) = \frac{e^{h(x)_i}}{\sum^N_j e^{h(x)_j}}$$
# where $N$ is the number of experts `n_experts` and
# $h(\cdot)$ is the linear transformation of token embeddings.
route_prob = self.softmax(self.switch(x))
# Get the maximum routing probabilities and the routes.
# We route to the expert with highest probability
route_prob_max, routes = torch.max(route_prob, dim=-1)
# Scale the inputs to the experts by the routing probabilities
if self.is_scale_prob:
factor = route_prob_max
# Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
else:
factor = route_prob_max / route_prob_max.detach()
# Multiply by the scaling factor
x = x * factor.view(-1, 1)
# Get indexes of vectors going to each route
# Get indexes of tokens going to each expert
indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_switches)]
# Tensor to store outputs
# Initialize an empty tensor to store outputs
final_output = x.new_zeros(x.shape)
# Capacity of a route
# Capacity of each expert.
# $$\mathrm{expert\;capacity} =
# \frac{\mathrm{tokens\;per\;batch}}{\mathrm{number\;of\;experts}}
# \times \mathrm{capacity\;factor}$$
capacity = int(self.capacity_factor * len(x) / self.n_switches)
# Number of tokens going to each route
# Number of tokens routed to each expert
counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_switches)])
# Drop tokens
@ -125,7 +144,7 @@ class SwitchFeedForward(Module):
final_output[dropped, :] = x[dropped, :]
# Change the shape of the final output
final_output = final_output.view(seq_len, bs, d_model)
final_output = final_output.view(seq_len, batch_size, d_model)
return final_output, counts, route_prob.sum(0), len(dropped)