10import copy
11from torch import nn
12from torch.utils.data import Dataset, IterableDataset16def clone_module_list(module: nn.Module, n: int) -> nn.ModuleList:22    return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])25def cycle_dataloader(data_loader):33    while True:
34        for batch in data_loader:
35            yield batchThis converts an IterableDataset
 to a map-style dataset so that we can shuffle the dataset.
This only works when the dataset size is small and can be held in memory.
38class MapStyleDataset(Dataset):51    def __init__(self, dataset: IterableDataset):Load the data to memory
53        self.data = [d for d in dataset]Get a sample by index
55    def __getitem__(self, idx: int):57        return self.data[idx]Create an iterator
59    def __iter__(self):61        return iter(self.data)Size of the dataset
63    def __len__(self):65        return len(self.data)