10import copy
11from torch import nn
12from torch.utils.data import Dataset, IterableDataset
16def clone_module_list(module: nn.Module, n: int) -> nn.ModuleList:
22 return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
25def cycle_dataloader(data_loader):
33 while True:
34 for batch in data_loader:
35 yield batch
This converts an IterableDataset
to a map-style dataset so that we can shuffle the dataset.
This only works when the dataset size is small and can be held in memory.
38class MapStyleDataset(Dataset):
51 def __init__(self, dataset: IterableDataset):
Load the data to memory
53 self.data = [d for d in dataset]
Get a sample by index
55 def __getitem__(self, idx: int):
57 return self.data[idx]
Create an iterator
59 def __iter__(self):
61 return iter(self.data)
Size of the dataset
63 def __len__(self):
65 return len(self.data)