This commit is contained in:
Varuna Jayasiri
2021-01-20 10:19:22 +05:30
parent 5d174f3a7c
commit ec1cb8b27b

View File

@ -110,6 +110,7 @@ class SwitchFeedForward(Module):
for i in range(self.n_switches): for i in range(self.n_switches):
if len(indexes_list[i]) <= capacity: if len(indexes_list[i]) <= capacity:
continue continue
indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]
dropped.append(indexes_list[i][capacity:]) dropped.append(indexes_list[i][capacity:])
indexes_list[i] = indexes_list[i][:capacity] indexes_list[i] = indexes_list[i][:capacity]