10from pathlib import PurePath, Path
11from typing import Optional, List
12
13import torch
14import torch.utils.data
15from labml import lab
16from labml import monit
17from labml.logger import inspect
18from labml.utils.download import download_file
19
20from labml_nn.neox.tokenizer import get_tokenizerpath
  is the location of the text file url
  is the URL to download the file from filter_subset
  is the number of characters to filter.  Use this during testing when trying large datasets Returns the text content
23def load_text(path: PurePath, url: Optional[str] = None, *, filter_subset: Optional[int] = None):34    path = Path(path)Download if it doesn't exist
37    if not path.exists():
38        if not url:
39            raise FileNotFoundError(str(path))
40        else:
41            download_file(url, path)
42
43    with monit.section("Load data"):Load data
45        with open(str(path), 'r') as f:
46            text = f.read()Filter
48        if filter_subset:
49            text = text[:filter_subset]52    return text55class NeoXDataset(torch.utils.data.Dataset):tokens
  is the list of token ids seq_len
  is the sequence length of a single training sample62    def __init__(self, tokens: List[int], seq_len: int):68        self.seq_len = seq_lenNumber of samples
70        n_samples = len(tokens) // seq_len
71        self.n_samples = n_samplesTruncate
73        tokens = tokens[:n_samples * seq_len + 1]Create a PyTorch tensor
75        self.tokens = torch.tensor(tokens)77    def __len__(self):
78        return self.n_samples80    def __getitem__(self, idx: int):87        offset = idx * self.seq_len
88        return self.tokens[offset:offset + self.seq_len], self.tokens[offset + 1:offset + 1 + self.seq_len]
89
90
91DATASETS = {
92    'tiny_shakespeare': {
93        'file': 'tiny_shakespeare.txt',
94        'url': 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
95    }
96}seq_len
  is the sequence length of a single training sample dataset_name
  is the name of the dataset Returns the dataset
99def get_training_data(seq_len: int = 32, dataset_name: str = 'tiny_shakespeare', truncate: int = -1):108    ds = DATASETS[dataset_name]Load the content
110    text = load_text(lab.get_data_path() / ds['file'], ds['url'])Tokenize
112    tokenizer = get_tokenizer()
113    tokens = tokenizer.encode_batch([text])[0]
114
115    if truncate > 0:
116        token_ids = tokens.ids[:truncate * seq_len]
117    else:
118        token_ids = tokens.ids121    return NeoXDataset(token_ids, seq_len)124def _test():
125    dataset = get_training_data()
126
127    inspect(tokens=len(dataset.tokens))131if __name__ == '__main__':
132    _test()