Utilities

10import copy
11
12from torch import nn
13
14from labml_helpers.module import Module

Make a nn.ModuleList with clones of a given layer

17def clone_module_list(module: Module, n: int):
21    return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])