ユーティリティ

10import copy
11
12from torch.utils.data import Dataset, IterableDataset
13
14from labml_helpers.module import M, TypedModuleList

クローン・モジュール

nn.ModuleList 与えられたモジュールのクローンを使って作る

17def clone_module_list(module: M, n: int) -> TypedModuleList[M]:
23    return TypedModuleList([copy.deepcopy(module) for _ in range(n)])

サイクルデータローダー

エポックごとにデータローダーをリサイクルする無限ローダー

26def cycle_dataloader(data_loader):
34    while True:
35        for batch in data_loader:
36            yield batch

マップスタイルデータセット

これにより、がマップスタイルのデータセットに変換され、IterableDataset データセットをシャッフルできるようになります。

これは、データセットのサイズが小さく、メモリに保持できる場合にのみ機能します。

39class MapStyleDataset(Dataset):
52    def __init__(self, dataset: IterableDataset):

データをメモリに読み込む

54        self.data = [d for d in dataset]

インデックスでサンプルを取得

56    def __getitem__(self, idx: int):
58        return self.data[idx]

イテレータの作成

60    def __iter__(self):
62        return iter(self.data)

データセットのサイズ

64    def __len__(self):
66        return len(self.data)