This is an annotated PyTorch experiment to train a compressive transformer model.
11from typing import List, Tuple, NamedTuple
12
13import torch
14import torch.nn as nn
15
16from labml import experiment, tracker, monit, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
24    CompressiveTransformerLayer, Conv1dCompression27class CompressedMemory(NamedTuple):
28    mem: List[torch.Tensor]
29    c_mem: List[torch.Tensor]32class AutoregressiveModel(Module):37    def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer):
38        super().__init__()Token embedding module
40        self.src_embed = nn.Embedding(n_vocab, d_model)Transformer
42        self.transformer = transformerFinal layer
44        self.generator = nn.Linear(d_model, n_vocab)Masks
46        self.mask_x = None
47        self.mask_mem = None49    def forward(self, x: torch.Tensor, mem: CompressedMemory):Get memory and compressed memory
51        if mem is not None:
52            mem, c_mem = mem.mem, mem.c_mem
53        else:
54            mem = []
55            c_mem = []Total length of the memory and compressed memory (for masks)
58        m_len = len(mem[0]) if mem else 0
59        if c_mem:
60            m_len += len(c_mem[0])Create a subsequent mask for tokens
63        if self.mask_x is None or self.mask_x.shape[0] < len(x):
64            from labml_nn.transformers.utils import subsequent_mask
65            self.mask_x = subsequent_mask(len(x)).to(x.device)Create an all ones (full visibility) mask for memory
67        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
68            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)Concatenate the masks if there is memory
71        if m_len:
72            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)Use only the subsequent mask otherwise
74        else:
75            mask = self.mask_x[:len(x), :len(x)]Token embeddings
78        x = self.src_embed(x)Run it through the transformer
80        res, mem = self.transformer(x, mem, c_mem, mask)Generate logits of the next token
82        res = self.generator(res)84        return res, memThe default configurations can and will be overridden when we start the experiment.
87class Configs(NLPAutoRegressionConfigs):94    model: AutoregressiveModelToken embedding size
97    d_model: int = 128Number of attention heads
99    heads: int = 4Dropout probability
101    dropout: float = 0.0Number of features in FFN hidden layer
103    d_ff: int = 256Number of transformer layers
105    n_layers: int = 6Number of memories to keep
107    mem_len: int = 8State module to maintain memories when switching between training and validation
109    memory = SimpleStateModule()Attention Reconstruction Loss
111    attention_reconstruction_loss: AttentionReconstructionLossCompression rate
113    compression_rate: int = 4Compressed memory length
115    c_mem_len: int = 128117    def init(self):Set tracker configurations
119        tracker.set_scalar("accuracy.*", True)
120        tracker.set_scalar("loss.*", True)Do not print the attention reconstruction loss in the terminal
122        tracker.set_scalar("ar_loss.*", False)Add a hook to log module outputs
124        hook_model_outputs(self.mode, self.model, 'model')This will keep the accuracy metric stats and memories separate for training and validation.
126        self.state_modules = [self.accuracy, self.memory]Concatenate new memories and compress the oldest memories.
128    @torch.no_grad()
129    def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
130            -> Tuple[CompressedMemory, List[torch.Tensor]]:If the configurations specify not to use memory
136        if self.mem_len == 0 and self.c_mem_len == 0:
137            return CompressedMemory([], []), []Get memory and compressed memory
140        if mem is not None:
141            mem, c_mem = mem.mem, mem.c_mem
142        else:
143            mem, c_mem = [], []Concatenate new memories with old memory
146        if mem:
147            mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
148        else:
149            mem = new_memCompress the oldest memories if there are more memories than mem_len
 
152        if len(mem[0]) > self.mem_len:Calculate the number of compressed memories to make , where  is the number of memories we have and  is the maximum number of memories we maintain (mem_len
). 
156            n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rateNumber of memories to compress
158            n_old = n_c_mem * self.compression_rateA list to keep memories that need to be compressed for each layer.
160            mem_to_compress = []A list to keep the memories that do not get compressed for each layer.
162            uncompressed_mem = []Iterate through memories of each layer.
164            for m in mem:Split the memories at
166                cm, m = torch.split(m, [n_old, len(m) - n_old])Collect memories to compress
168                mem_to_compress.append(cm)Collect remaining memories
170                uncompressed_mem.append(m)Update the memories
172            mem = uncompressed_memCompress the memories
175            new_c_mem = []
176            for i, layer in enumerate(self.model.transformer.layers):
177                new_c_mem.append(layer.compress(mem_to_compress[i]))Concatenate newly compressed memories with old compressed memories
180            if c_mem:
181                c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]If there are no old compressed memories
183            else:
184                c_mem = new_c_memTruncate old memories
187            if len(c_mem[0]) > self.c_mem_len:
188                c_mem = [m[-self.c_mem_len:] for m in c_mem]No memories are compressed if the number of memories is less than mem_len
 
190        else:
191            mem_to_compress = []Return memories and the memories that were compressed. Memories that were compressed are needed for the reconstruction loss computation.
195        return CompressedMemory(mem, c_mem), mem_to_compress197    def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
203        data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
206        if self.mode.is_train:
207            tracker.add_global_step(data.shape[0] * data.shape[1])Whether to capture model outputs
210        with self.mode.update(is_log_activations=batch_idx.is_last):Get memories
212            mem = self.memory.get()Run the model
214            output, new_mem = self.model(data, mem)Merge and compress memory
216            mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)Update memories
218            self.memory.set(mem)Calculate and log cross entropy loss
221        loss = self.loss_func(output, target)
222        tracker.add("loss.", loss)Calculate attention reconstruction loss if memories were compressed in this step
225        if mem_to_compress:Get attention reconstruction loss
227            ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)Track attention reconstruction loss
229            tracker.add("ar_loss.", ar_loss)Add attention reconstruction loss to loss
231            loss = loss + ar_lossCalculate and log accuracy
234        self.accuracy(output, target)
235        self.accuracy.track()Train the model
238        if self.mode.is_train:Calculate gradients
240            loss.backward()Clip gradients
242            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
244            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
246            if batch_idx.is_last:
247                tracker.add('model', self.model)Clear the gradients
249            self.optimizer.zero_grad()Save the tracked metrics
252        tracker.save()254    def sample(self):Starting prompt
260        prompt = self.promptCollect output for printing
262        log = [(prompt, Text.subtle)]memory
264        mem = CompressedMemory([], [])Sample 25 tokens
266        for i in monit.iterate('Sample', 25):Tokenize the prompt
268            data = self.text.text_to_i(prompt).unsqueeze(-1)Move to device
270            data = data.to(self.device)Get the model output
272            output, new_mem = self.model(data, mem)Get the model prediction (greedy)
274            output = output.argmax(dim=-1).squeeze(1)Add the prediction to prompt
276            prompt += self.prompt_separator + self.text.itos[output[-1]]Only feed the last character to model in next iteration, rest will go in as memories
278            prompt = prompt[-1:]Add the prediction for logging
280            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]Update and compress memory
282            mem, _ = self.merge_compress_memory(mem, new_mem)Print the sampled output
285        logger.log(log)288@option(Configs.model)
289def autoregressive_model(c: Configs):293    from labml_nn.transformers.xl import RelativeMultiHeadAttention
294    from labml_nn.transformers.feed_forward import FeedForward
295    m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer(
296        CompressiveTransformerLayer(d_model=c.d_model,
297                                    self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
298                                    feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
299                                    dropout_prob=c.dropout,
300                                    compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
301    return m.to(c.device)304@option(Configs.attention_reconstruction_loss)
305def attention_reconstruction_loss(c: Configs):309    return AttentionReconstructionLoss(c.model.transformer.layers)312def main():Create experiment
317    experiment.create(name="compressive_transformer", comment='')Create configs
319    conf = Configs()Load configurations
321    experiment.configs(conf,A dictionary of configurations to override
323                       {'tokenizer': 'character',
324                        'text': 'tiny_shakespeare',
325                        'optimizer.learning_rate': 2.5e-4,
326                        'optimizer.optimizer': 'AdamW',
327                        'prompt': 'It is',
328                        'prompt_separator': '',
329
330                        'train_loader': 'sequential_train_loader',
331                        'valid_loader': 'sequential_valid_loader',
332
333                        'seq_len': 8,
334                        'mem_len': 8,
335                        'epochs': 128,
336                        'batch_size': 32,
337                        'inner_iterations': 25,
338                        'compression_rate': 2,
339                        })Set models for saving and loading
342    experiment.add_pytorch_models({'model': conf.model})Start the experiment
345    with experiment.start():TrainValidConfigs.run
 
347        conf.run()351if __name__ == '__main__':
352    main()