This is an annotated PyTorch experiment to train a transformer xl model.
11from typing import List
12
13import torch
14import torch.nn as nn
15from labml import experiment, tracker, monit, logger
16from labml.configs import option
17from labml.logger import Text
18from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
19from labml_nn.helpers.metrics import SimpleStateModule
20from labml_nn.helpers.trainer import BatchIndex
21from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer24class AutoregressiveModel(nn.Module):29    def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
30        super().__init__()Token embedding module
32        self.src_embed = nn.Embedding(n_vocab, d_model)Transformer
34        self.transformer = transformerFinal layer
36        self.generator = nn.Linear(d_model, n_vocab)Masks
38        self.mask_x = None
39        self.mask_mem = None41    def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):Length of the memory
43        m_len = len(mem[0]) if mem else 0Create a subsequent mask for tokens
45        if self.mask_x is None or self.mask_x.shape[0] < len(x):
46            from labml_nn.transformers.utils import subsequent_mask
47            self.mask_x = subsequent_mask(len(x)).to(x.device)Create an all ones (full visibility) mask for memory
49        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
50            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)Concatenate the masks if there is memory
53        if m_len:
54            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)Use the subsequent mask otherwise
56        else:
57            mask = self.mask_x[:len(x), :len(x)]Token embeddings
60        x = self.src_embed(x)Run it through the transformer
62        res, mem = self.transformer(x, mem, mask)Generate logits of the next token
64        res = self.generator(res)66        return res, mem69class Configs(NLPAutoRegressionConfigs):76    model: AutoregressiveModelToken embedding size
79    d_model: int = 128Number of attention heads
81    heads: int = 4Dropout probability
83    dropout: float = 0.0Number of features in FFN hidden layer
85    d_ff: int = 256Number of transformer layers
87    n_layers: int = 6Number of memories to keep
89    mem_len: int = 128State module to maintain memories when switching between training and validation
91    memory = SimpleStateModule()93    def init(self):Set tracker configurations
95        tracker.set_scalar("accuracy.*", True)
96        tracker.set_scalar("loss.*", True)This will keep the accuracy metric stats and memories separate for training and validation.
98        self.state_modules = [self.accuracy, self.memory] Concatenate memories and remove old memories to keep a maximum of mem_len
 memories.
100    def merge_memory(self, old_mem, new_mem):If it's configured not to use memory
107        if self.mem_len == 0:
108            return []Concatenate with old memory
111        if old_mem:
112            mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
113        else:
114            mem = new_memTruncate old memories
117        if len(mem[0]) > self.mem_len:
118            mem = [m[-self.mem_len:] for m in mem]121        return mem123    def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
129        data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
132        if self.mode.is_train:
133            tracker.add_global_step(data.shape[0] * data.shape[1])Get memories
136        mem = self.memory.get()Run the model
138        output, new_mem = self.model(data, mem)Merge memory
140        mem = self.merge_memory(mem, new_mem)Update memories
142        self.memory.set(mem)Calculate and log cross entropy loss
145        loss = self.loss_func(output, target)
146        tracker.add("loss.", loss)Calculate and log accuracy
149        self.accuracy(output, target)
150        self.accuracy.track()Train the model
153        if self.mode.is_train:Calculate gradients
155            loss.backward()Clip gradients
157            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
159            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
161            if batch_idx.is_last:
162                tracker.add('model', self.model)Clear the gradients
164            self.optimizer.zero_grad()Save the tracked metrics
167        tracker.save()169    def sample(self):Starting prompt
175        prompt = self.promptCollect output for printing
177        log = [(prompt, Text.subtle)]memory
179        mem = []Sample 25 tokens
181        for i in monit.iterate('Sample', 25):Tokenize the prompt
183            data = self.text.text_to_i(prompt).unsqueeze(-1)Move to device
185            data = data.to(self.device)Get the model output
187            output, new_mem = self.model(data, mem)Get the model prediction (greedy)
189            output = output.argmax(dim=-1).squeeze(1)Add the prediction to prompt
191            prompt += self.prompt_separator + self.text.itos[output[-1]]Only feed the last character to model in next iteration, rest will go in as memories
193            prompt = prompt[-1:]Add the prediction for logging
195            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]Update memory
197            mem = self.merge_memory(mem, new_mem)Print the sampled output
200        logger.log(log)203@option(Configs.model)
204def autoregressive_model(c: Configs):208    from labml_nn.transformers.xl import RelativeMultiHeadAttention
209    from labml_nn.transformers.feed_forward import FeedForward
210    m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
211        TransformerXLLayer(d_model=c.d_model,
212                           self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
213                           feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
214                           dropout_prob=c.dropout), c.n_layers))
215    return m.to(c.device)218def main():Create experiment
223    experiment.create(name="transformer_xl", comment='')Create configs
225    conf = Configs()Load configurations
227    experiment.configs(conf,A dictionary of configurations to override
229                       {'tokenizer': 'character',
230                        'text': 'tiny_shakespeare',
231                        'optimizer.learning_rate': 1.,
232                        'optimizer.optimizer': 'Noam',
233                        'prompt': 'It is',
234                        'prompt_separator': '',
235
236                        'train_loader': 'sequential_train_loader',
237                        'valid_loader': 'sequential_valid_loader',
238
239                        'seq_len': 2,
240                        'mem_len': 32,
241                        'epochs': 128,
242                        'batch_size': 32,
243                        'inner_iterations': 25,
244                        })Set models for saving and loading
247    experiment.add_pytorch_models({'model': conf.model})Start the experiment
250    with experiment.start():TrainValidConfigs.run
 
252        conf.run()256if __name__ == '__main__':
257    main()