Utilities

10import copy
11
12from torch.utils.data import Dataset, IterableDataset
13
14from labml_helpers.module import M, TypedModuleList

Clone Module

Make a nn.ModuleList with clones of a given module

17def clone_module_list(module: M, n: int) -> TypedModuleList[M]:
23    return TypedModuleList([copy.deepcopy(module) for _ in range(n)])

Cycle Data Loader

Infinite loader that recycles the data loader after each epoch

26def cycle_dataloader(data_loader):
33    while True:
34        for batch in data_loader:
35            yield batch

Map Style Dataset

This converts an IterableDataset to a map-style dataset so that we can shuffle the dataset.

This only works when the dataset size is small and can be held in memory.

38class MapStyleDataset(Dataset):
50    def __init__(self, dataset: IterableDataset):

Load the data to memory

52        self.data = [d for d in dataset]

Get a sample by index

54    def __getitem__(self, idx: int):
56        return self.data[idx]

Create an iterator

58    def __iter__(self):
60        return iter(self.data)

Size of the dataset

62    def __len__(self):
64        return len(self.data)