16import torch
17from torch import nn
18from torch.utils.data import DataLoader, RandomSampler
19
20from labml import monit, lab, tracker, experiment, logger
21from labml.logger import Text
22from labml_helpers.datasets.text import TextFileDataset
23from labml_nn.optimizers.noam import Noam
24from labml_nn.transformers.retro import model as retro
25from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
26from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder29class Sampler:device
 is the device of the model model
 is the Retro mode tds
 is the text dataset (used to get neighbor chunks) chunk_len
 is the length of a chunk36    def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):43        self.chunk_len = chunk_len
44        self.tds = tds
45        self.model = model
46        self.device = device49        self.index = RetroIndex()51    def retrieve_nearest_neighbours(self, chunk: str):Retrieve the offsets of the nearest neighbors
57        neighbor_offsets = self.index([chunk], None)Get the neighbors (with neighbor length equal to chunk_len * 2
) 
60        text = self.tds.train
61        neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]64        return neighbors66    def sample(self, prompt: str, sample_len: int):To store nearest neighbors as strings
72        neighbors_str = []Sampled text
75        sampled = ''Sample sample_len
 tokens 
78        for i in range(sample_len):We need to retrieve neighbors, if there are more sampled chunks than we have already retrieved for
81            while len(neighbors_str) < len(prompt) // self.chunk_len:Get the last chunk for which we haven't retrieved neighbors
83                off = len(neighbors_str) * self.chunk_len
84                chunk = prompt[off: off + self.chunk_len]Retrieve nearest neighbors
86                neighbors_str.append(self.retrieve_nearest_neighbours(chunk))Tokenize the input
89            src = self.tds.text_to_i(prompt)Tokenize the retrieved neighbors
91            neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])Move them to the same device as the model
94            src = src.to(self.device)
95            neighbors = neighbors.to(self.device)Get model output
98            res = self.model(src[None, :], neighbors[None, :, :, :])Greedily sample the last token
101            token = res[0, -1, :].argmax(dim=-1)Add the sampled token text to the prompt and sample text
104            prompt += self.tds.itos[token.item()]
105            sampled += self.tds.itos[token.item()]108        return sampled111class Trainer:device
 is the device of the model model
 is the Retro mode dataloader
 is the dataloader for the dataset with pre-retrieved neighbors optimizer
 is the optimizer116    def __init__(self, device: torch.device, model: retro.RetroModel,
117                 dataloader: DataLoader, optimizer: torch.optim.Optimizer):124        self.optimizer = optimizer
125        self.device = device
126        self.dataloader = dataloader
127        self.model = model
128        self.loss_func = nn.CrossEntropyLoss()130    def __call__(self):Iterate through training data
136        for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):Move data to the device
138            src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)Forward pass
141            res = self.model(src, neighbors)Calculate loss
143            loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))Clear the gradients
146            self.optimizer.zero_grad()Backward pass
148            loss.backward()Optimize the model
150            self.optimizer.step()Save training statistics and increment the global step counter
153            tracker.save({'loss.train': loss})
154            tracker.add_global_step(len(src))157def train():Create an experiment
163    experiment.create(name='retro_small')GPU device
166    device = torch.device('cuda:0')Load Tiny Shakespeare dataset
169    tds = TextFileDataset(
170        lab.get_data_path() / 'tiny_shakespeare.txt',
171        list,
172        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')Load Retro dataset
175    train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)Create dataloader
178    train_dl = DataLoader(train_dataset,
179                          batch_size=4,
180                          sampler=RandomSampler(train_dataset, replacement=True))Hyper-parameters
183    chunk_len = 16
184    d_model = 128
185    d_ff = 512
186    n_heads = 16
187    d_k = 16Create the nearest neighbor encoder
190    nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)Create the model
192    model = RetroModel(tds.n_tokens, d_model, 6,
193                       {3, 5},
194                       chunk_len, n_heads, d_k, d_ff,
195                       encoder=nearest_neighbor_encoder)Move the model to the device
197    model = model.to(device)Create the optimizer
199    optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)Create the Trainer
 
201    trainer = Trainer(device, model, train_dl, optimizer)Create the Sampler
 
203    sampler = Sampler(device, model, tds, chunk_len)205    prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''Set models for saving and loading
208    experiment.add_pytorch_models(model=model)Start the experiment
211    with experiment.start():Train for 32
 epochs 
213        for epoch in monit.loop(32):Train
215            trainer()Print a new line
217            tracker.new_line()Sample from the prompt
 
219            logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
220                        (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])Save models
222            experiment.save_checkpoint()226if __name__ == '__main__':
227    train()