RETRO training

This is the training code for RETRO.

View Run

16import torch
17from torch import nn
18from torch.utils.data import DataLoader, RandomSampler
19
20from labml import monit, lab, tracker, experiment, logger
21from labml.logger import Text
22from labml_helpers.datasets.text import TextFileDataset
23from labml_nn.optimizers.noam import Noam
24from labml_nn.transformers.retro import model as retro
25from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
26from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder

Sampler

This class greedily samples from a model.

29class Sampler:
  • device is the device of the model
  • model is the Retro mode
  • tds is the text dataset (used to get neighbor chunks)
  • chunk_len is the length of a chunk
36    def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):
43        self.chunk_len = chunk_len
44        self.tds = tds
45        self.model = model
46        self.device = device
49        self.index = RetroIndex()

Retrieve nearest neighbors of a given chunk

51    def retrieve_nearest_neighbours(self, chunk: str):

Retrieve the offsets of the nearest neighbors

57        neighbor_offsets = self.index([chunk], None)

Get the neighbors (with neighbor length equal to chunk_len * 2 )

60        text = self.tds.train
61        neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]

64        return neighbors

Sample text from the given prompt

66    def sample(self, prompt: str, sample_len: int):

To store nearest neighbors as strings

72        neighbors_str = []

Sampled text

75        sampled = ''

Sample sample_len tokens

78        for i in range(sample_len):

We need to retrieve neighbors, if there are more sampled chunks than we have already retrieved for

81            while len(neighbors_str) < len(prompt) // self.chunk_len:

Get the last chunk for which we haven't retrieved neighbors

83                off = len(neighbors_str) * self.chunk_len
84                chunk = prompt[off: off + self.chunk_len]

Retrieve nearest neighbors

86                neighbors_str.append(self.retrieve_nearest_neighbours(chunk))

Tokenize the input

89            src = self.tds.text_to_i(prompt)

Tokenize the retrieved neighbors

91            neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])

Move them to the same device as the model

94            src = src.to(self.device)
95            neighbors = neighbors.to(self.device)

Get model output

98            res = self.model(src[None, :], neighbors[None, :, :, :])

Greedily sample the last token

101            token = res[0, -1, :].argmax(dim=-1)

Add the sampled token text to the prompt and sample text

104            prompt += self.tds.itos[token.item()]
105            sampled += self.tds.itos[token.item()]

108        return sampled

Retro trainer

111class Trainer:
116    def __init__(self, device: torch.device, model: retro.RetroModel,
117                 dataloader: DataLoader, optimizer: torch.optim.Optimizer):
124        self.optimizer = optimizer
125        self.device = device
126        self.dataloader = dataloader
127        self.model = model
128        self.loss_func = nn.CrossEntropyLoss()

Train the model for an epoch

130    def __call__(self):

Iterate through training data

136        for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):

Move data to the device

138            src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)

Forward pass

141            res = self.model(src, neighbors)

Calculate loss

143            loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))

Clear the gradients

146            self.optimizer.zero_grad()

Backward pass

148            loss.backward()

Optimize the model

150            self.optimizer.step()

Save training statistics and increment the global step counter

153            tracker.save({'loss.train': loss})
154            tracker.add_global_step(len(src))

Create and train a small model

157def train():

Create an experiment

163    experiment.create(name='retro_small')

GPU device

166    device = torch.device('cuda:0')

Load Tiny Shakespeare dataset

169    tds = TextFileDataset(
170        lab.get_data_path() / 'tiny_shakespeare.txt',
171        list,
172        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
175    train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)

Create dataloader

178    train_dl = DataLoader(train_dataset,
179                          batch_size=4,
180                          sampler=RandomSampler(train_dataset, replacement=True))

Hyper-parameters

183    chunk_len = 16
184    d_model = 128
185    d_ff = 512
186    n_heads = 16
187    d_k = 16

Create the nearest neighbor encoder

190    nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)

Create the model

192    model = RetroModel(tds.n_tokens, d_model, 6,
193                       {3, 5},
194                       chunk_len, n_heads, d_k, d_ff,
195                       encoder=nearest_neighbor_encoder)

Move the model to the device

197    model = model.to(device)

Create the optimizer

199    optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)

Create the Trainer

201    trainer = Trainer(device, model, train_dl, optimizer)

Create the Sampler

203    sampler = Sampler(device, model, tds, chunk_len)

205    prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''

Set models for saving and loading

208    experiment.add_pytorch_models(model=model)

Start the experiment

211    with experiment.start():

Train for 32 epochs

213        for epoch in monit.loop(32):

Train

215            trainer()

Print a new line

217            tracker.new_line()

Sample from the prompt

219            logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
220                        (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])

Save models

222            experiment.save_checkpoint()

226if __name__ == '__main__':
227    train()