diff --git a/labml_nn/transformers/switch/experiment.py b/labml_nn/transformers/switch/experiment.py index d03a1c56..31e24d9e 100644 --- a/labml_nn/transformers/switch/experiment.py +++ b/labml_nn/transformers/switch/experiment.py @@ -114,7 +114,7 @@ class Configs(NLPAutoRegressionConfigs): # Total number of tokens processed, $T$, in the current batch $\mathscr{B}$ total = counts.sum(dim=-1, keepdims=True) # Fraction of tokens routed to each expert - # $$f_i = \frac{1}{T} \sum_{x \in \mathscr{B}} \unicode{x1D7D9} \{ \mathop{argmax} p(x), i \}$$ + # $$f_i = \frac{1}{T} \sum_{x \in \mathscr{B}} \mathbf{1} \{ \mathop{argmax} p(x), i \}$$ # $f_i$ is the count of tokens where the argmax of $p(x)$ is equal to $i$. route_frac = counts / total # Mean routing probability