This is an annotated PyTorch experiment to train a transformer xl model.
11from typing import List
12
13import torch
14import torch.nn as nn
15from labml.logger import Text
16
17from labml import experiment, tracker, monit, logger
18from labml.configs import option
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer26class AutoregressiveModel(Module):31    def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
32        super().__init__()Token embedding module
34        self.src_embed = nn.Embedding(n_vocab, d_model)Transformer
36        self.transformer = transformerFinal layer
38        self.generator = nn.Linear(d_model, n_vocab)Masks
40        self.mask_x = None
41        self.mask_mem = None43    def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):Length of the memory
45        m_len = len(mem[0]) if mem else 0Create a subsequent mask for tokens
47        if self.mask_x is None or self.mask_x.shape[0] < len(x):
48            from labml_nn.transformers.utils import subsequent_mask
49            self.mask_x = subsequent_mask(len(x)).to(x.device)Create an all ones (full visibility) mask for memory
51        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
52            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)Concatenate the masks if there is memory
55        if m_len:
56            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)Use the subsequent mask otherwise
58        else:
59            mask = self.mask_x[:len(x), :len(x)]Token embeddings
62        x = self.src_embed(x)Run it through the transformer
64        res, mem = self.transformer(x, mem, mask)Generate logits of the next token
66        res = self.generator(res)68        return res, mem71class Configs(NLPAutoRegressionConfigs):78    model: AutoregressiveModelToken embedding size
81    d_model: int = 128Number of attention heads
83    heads: int = 4Dropout probability
85    dropout: float = 0.0Number of features in FFN hidden layer
87    d_ff: int = 256Number of transformer layers
89    n_layers: int = 6Number of memories to keep
91    mem_len: int = 128State module to maintain memories when switching between training and validation
93    memory = SimpleStateModule()95    def init(self):Set tracker configurations
97        tracker.set_scalar("accuracy.*", True)
98        tracker.set_scalar("loss.*", True)Add a hook to log module outputs
100        hook_model_outputs(self.mode, self.model, 'model')This will keep the accuracy metric stats and memories separate for training and validation.
102        self.state_modules = [self.accuracy, self.memory]Concatenate memories and remove old memories to keep a maximum of
mem_len memories.
104    def merge_memory(self, old_mem, new_mem):If it’s configured not to use memory
111        if self.mem_len == 0:
112            return []Concatenate with old memory
115        if old_mem:
116            mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
117        else:
118            mem = new_memTruncate old memories
121        if len(mem[0]) > self.mem_len:
122            mem = [m[-self.mem_len:] for m in mem]125        return mem127    def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
133        data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
136        if self.mode.is_train:
137            tracker.add_global_step(data.shape[0] * data.shape[1])Whether to capture model outputs
140        with self.mode.update(is_log_activations=batch_idx.is_last):Get memories
142            mem = self.memory.get()Run the model
144            output, new_mem = self.model(data, mem)Merge memory
146            mem = self.merge_memory(mem, new_mem)Update memories
148            self.memory.set(mem)Calculate and log cross entropy loss
151        loss = self.loss_func(output, target)
152        tracker.add("loss.", loss)Calculate and log accuracy
155        self.accuracy(output, target)
156        self.accuracy.track()Train the model
159        if self.mode.is_train:Calculate gradients
161            loss.backward()Clip gradients
163            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
165            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
167            if batch_idx.is_last:
168                tracker.add('model', self.model)Clear the gradients
170            self.optimizer.zero_grad()Save the tracked metrics
173        tracker.save()175    def sample(self):Starting prompt
181        prompt = self.promptCollect output for printing
183        log = [(prompt, Text.subtle)]memory
185        mem = []Sample 25 tokens
187        for i in monit.iterate('Sample', 25):Tokenize the prompt
189            data = self.text.text_to_i(prompt).unsqueeze(-1)Move to device
191            data = data.to(self.device)Get the model output
193            output, new_mem = self.model(data, mem)Get the model prediction (greedy)
195            output = output.argmax(dim=-1).squeeze(1)Add the prediction to prompt
197            prompt += self.prompt_separator + self.text.itos[output[-1]]Only feed the last character to model in next iteration, rest will go in as memories
199            prompt = prompt[-1:]Add the prediction for logging
201            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]Update memory
203            mem = self.merge_memory(mem, new_mem)Print the sampled output
206        logger.log(log)209@option(Configs.model)
210def autoregressive_model(c: Configs):214    from labml_nn.transformers.xl import RelativeMultiHeadAttention
215    from labml_nn.transformers.feed_forward import FeedForward
216    m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
217        TransformerXLLayer(d_model=c.d_model,
218                           self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
219                           feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
220                           dropout_prob=c.dropout), c.n_layers))
221    return m.to(c.device)224def main():Create experiment
229    experiment.create(name="transformer_xl", comment='')Create configs
231    conf = Configs()Load configurations
233    experiment.configs(conf,A dictionary of configurations to override
235                       {'tokenizer': 'character',
236                        'text': 'tiny_shakespeare',
237                        'optimizer.learning_rate': 1.,
238                        'optimizer.optimizer': 'Noam',
239                        'prompt': 'It is',
240                        'prompt_separator': '',
241
242                        'train_loader': 'sequential_train_loader',
243                        'valid_loader': 'sequential_valid_loader',
244
245                        'seq_len': 2,
246                        'mem_len': 32,
247                        'epochs': 128,
248                        'batch_size': 32,
249                        'inner_iterations': 25,
250                        })Set models for saving and loading
253    experiment.add_pytorch_models({'model': conf.model})Start the experiment
256    with experiment.start():TrainValidConfigs.run
258        conf.run()262if __name__ == '__main__':
263    main()