🚧 compressive transformer

This commit is contained in:
Varuna Jayasiri
2021-02-17 16:36:52 +05:30
parent 9a9d6c671d
commit ff8f80039a
3 changed files with 513 additions and 5 deletions

View File

@ -9,13 +9,11 @@ summary: A bunch of utility functions and classes
import copy
from torch import nn
from labml_helpers.module import Module
from labml_helpers.module import M, TypedModuleList
def clone_module_list(module: Module, n: int):
def clone_module_list(module: M, n: int) -> TypedModuleList[M]:
"""
## Make a `nn.ModuleList` with clones of a given layer
"""
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
return TypedModuleList([copy.deepcopy(module) for _ in range(n)])