mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 02:41:38 +08:00
notes
This commit is contained in:
@ -81,27 +81,46 @@ class SwitchFeedForward(Module):
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
seq_len, bs, d_model = x.shape
|
||||
"""
|
||||
* `x` is the input to the switching module with shape `[seq_len, batch_size, d_model]`
|
||||
"""
|
||||
|
||||
# Capture the shape to change shapes later
|
||||
seq_len, batch_size, d_model = x.shape
|
||||
# Flatten the sequence and batch dimensions
|
||||
x = x.view(-1, d_model)
|
||||
|
||||
# Get routing probabilities for each of the tokens.
|
||||
# $$p_i(x) = \frac{e^{h(x)_i}}{\sum^N_j e^{h(x)_j}}$$
|
||||
# where $N$ is the number of experts `n_experts` and
|
||||
# $h(\cdot)$ is the linear transformation of token embeddings.
|
||||
route_prob = self.softmax(self.switch(x))
|
||||
|
||||
# Get the maximum routing probabilities and the routes.
|
||||
# We route to the expert with highest probability
|
||||
route_prob_max, routes = torch.max(route_prob, dim=-1)
|
||||
|
||||
# Scale the inputs to the experts by the routing probabilities
|
||||
if self.is_scale_prob:
|
||||
factor = route_prob_max
|
||||
# Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
|
||||
else:
|
||||
factor = route_prob_max / route_prob_max.detach()
|
||||
# Multiply by the scaling factor
|
||||
x = x * factor.view(-1, 1)
|
||||
|
||||
# Get indexes of vectors going to each route
|
||||
# Get indexes of tokens going to each expert
|
||||
indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_switches)]
|
||||
|
||||
# Tensor to store outputs
|
||||
# Initialize an empty tensor to store outputs
|
||||
final_output = x.new_zeros(x.shape)
|
||||
|
||||
# Capacity of a route
|
||||
# Capacity of each expert.
|
||||
# $$\mathrm{expert\;capacity} =
|
||||
# \frac{\mathrm{tokens\;per\;batch}}{\mathrm{number\;of\;experts}}
|
||||
# \times \mathrm{capacity\;factor}$$
|
||||
capacity = int(self.capacity_factor * len(x) / self.n_switches)
|
||||
# Number of tokens going to each route
|
||||
# Number of tokens routed to each expert
|
||||
counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_switches)])
|
||||
|
||||
# Drop tokens
|
||||
@ -125,7 +144,7 @@ class SwitchFeedForward(Module):
|
||||
final_output[dropped, :] = x[dropped, :]
|
||||
|
||||
# Change the shape of the final output
|
||||
final_output = final_output.view(seq_len, bs, d_model)
|
||||
final_output = final_output.view(seq_len, batch_size, d_model)
|
||||
|
||||
return final_output, counts, route_prob.sum(0), len(dropped)
|
||||
|
||||
|
Reference in New Issue
Block a user