10import copy
11
12from torch.utils.data import Dataset, IterableDataset
13
14from labml_helpers.module import M, TypedModuleList17def clone_module_list(module: M, n: int) -> TypedModuleList[M]:23    return TypedModuleList([copy.deepcopy(module) for _ in range(n)])26def cycle_dataloader(data_loader):34    while True:
35        for batch in data_loader:
36            yield batch39class MapStyleDataset(Dataset):52    def __init__(self, dataset: IterableDataset):将数据加载到内存
54        self.data = [d for d in dataset]按索引获取样本
56    def __getitem__(self, idx: int):58        return self.data[idx]创建迭代器
60    def __iter__(self):62        return iter(self.data)数据集的大小
64    def __len__(self):66        return len(self.data)