10import copy
11
12from torch.utils.data import Dataset, IterableDataset
13
14from labml_helpers.module import M, TypedModuleList
17def clone_module_list(module: M, n: int) -> TypedModuleList[M]:
23 return TypedModuleList([copy.deepcopy(module) for _ in range(n)])
26def cycle_dataloader(data_loader):
34 while True:
35 for batch in data_loader:
36 yield batch
39class MapStyleDataset(Dataset):
52 def __init__(self, dataset: IterableDataset):
将数据加载到内存
54 self.data = [d for d in dataset]
按索引获取样本
56 def __getitem__(self, idx: int):
58 return self.data[idx]
创建迭代器
60 def __iter__(self):
62 return iter(self.data)
数据集的大小
64 def __len__(self):
66 return len(self.data)