diff --git a/labml_nn/transformers/switch/__init__.py b/labml_nn/transformers/switch/__init__.py index c1d37cd6..c98bd228 100644 --- a/labml_nn/transformers/switch/__init__.py +++ b/labml_nn/transformers/switch/__init__.py @@ -110,6 +110,7 @@ class SwitchFeedForward(Module): for i in range(self.n_switches): if len(indexes_list[i]) <= capacity: continue + indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))] dropped.append(indexes_list[i][capacity:]) indexes_list[i] = indexes_list[i][:capacity]