10import copy
11
12from torch.utils.data import Dataset, IterableDataset
13
14from labml_helpers.module import M, TypedModuleList17def clone_module_list(module: M, n: int) -> TypedModuleList[M]:23 return TypedModuleList([copy.deepcopy(module) for _ in range(n)])26def cycle_dataloader(data_loader):33 while True:
34 for batch in data_loader:
35 yield batchThis converts an IterableDataset
to a map-style dataset
so that we can shuffle the dataset.
This only works when the dataset size is small and can be held in memory.
38class MapStyleDataset(Dataset):50 def __init__(self, dataset: IterableDataset):Load the data to memory
52 self.data = [d for d in dataset]Get a sample by index
54 def __getitem__(self, idx: int):56 return self.data[idx]Create an iterator
58 def __iter__(self):60 return iter(self.data)Size of the dataset
62 def __len__(self):64 return len(self.data)