This commit is contained in:
Varuna Jayasiri
2021-09-17 12:06:41 +05:30
parent f87879e780
commit 40eb9cab4e
2 changed files with 86 additions and 96 deletions

View File

@ -143,7 +143,6 @@ class SwitchFeedForward(Module):
dropped = torch.cat(dropped)
final_output[dropped, :] = x[dropped, :]
# Scale the expert outputs by the routing probabilities
if self.is_scale_prob:
# Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$
final_output = final_output * route_prob_max.view(-1, 1)