Merge branch 'master' of github.com:lab-ml/nn

merge
This commit is contained in:
Varuna Jayasiri
2021-09-18 14:18:26 +05:30
7 changed files with 242 additions and 253 deletions

View File

@ -143,16 +143,13 @@ class SwitchFeedForward(Module):
dropped = torch.cat(dropped)
final_output[dropped, :] = x[dropped, :]
# Scale the outputs of the the experts by the routing probabilities
if self.is_scale_prob:
factor = route_prob_max
# Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
# (this is just something we experimented with)
# Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$
final_output = final_output * route_prob_max.view(-1, 1)
else:
factor = route_prob_max / route_prob_max.detach()
# Multiply by the scaling factor
final_output = final_output * factor.view(-1, 1)
# Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
# (this is something we experimented with).
final_output = final_output * (route_prob_max / route_prob_max.detach()).view(-1, 1)
# Change the shape of the final output back to `[seq_len, batch_size, d_model]`
final_output = final_output.view(seq_len, batch_size, d_model)