公共事业

10import copy
11
12from torch.utils.data import Dataset, IterableDataset
13
14from labml_helpers.module import M, TypedModuleList

克隆模块

nn.ModuleList 使用给定模块的克隆制作一个

17def clone_module_list(module: M, n: int) -> TypedModuleList[M]:
23    return TypedModuleList([copy.deepcopy(module) for _ in range(n)])

循环数据加载器

无限加载器,在每个纪元之后回收数据加载器

26def cycle_dataloader(data_loader):
34    while True:
35        for batch in data_loader:
36            yield batch

地图样式数据集

这会将转换IterableDataset 地图样式的数据集,以便我们可以随机排列数据集。

这仅在数据集大小较小且可以保存在内存中时才有效。

39class MapStyleDataset(Dataset):
52    def __init__(self, dataset: IterableDataset):

将数据加载到内存

54        self.data = [d for d in dataset]

按索引获取样本

56    def __getitem__(self, idx: int):
58        return self.data[idx]

创建迭代器

60    def __iter__(self):
62        return iter(self.data)

数据集的大小

64    def __len__(self):
66        return len(self.data)