This is a PyTorch implementation of the paper Accessing Higher-level Representations in Sequential Transformers with Feedback Memory.
Normal transformers process tokens in parallel. Each transformer layer pays attention to the outputs of the previous layer. Feedback transformer pays attention to the output of all layers in previous steps. So this adds recurrence, and we need to process token-by-token. This slows down the training significantly (about 5X - 10X depending on the sequence length). However, when predicting Feedback Transformer is faster because you can predict the next token if you cache the memory vectors.
In order to speed up the training, the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.
The original feedback transformer doesn’t keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. This reduces the memory used for caching during prediction. The first half of this file implements this.
The updated feedback transformer shares weights $W^l_k$ and $W^l_v$ used to calculate keys and values among the layers. We then calculate the keys and values for each step only once and keep them cached. The second half of this file implements this. We implemented a custom PyTorch function to improve performance.
Here’s the training code and a notebook for training a feedback transformer on Tiny Shakespeare dataset.
43import math
44from typing import Optional
45
46import torch
47from torch import nn
48
49from labml_helpers.module import Module
50from labml_nn.transformers.feed_forward import FeedForward
51from labml_nn.transformers.mha import PrepareForMultiHeadAttention
52from labml_nn.utils import clone_module_listThis module computes recurrent attention similar to attention from original transformers paper.
55class FeedbackAttention(Module):d_model is the number of features in the transformerdropout_prob is the attention dropout probabilityis_kv_precomputed is whether key, value tensors are already calculated66    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
67                 is_kv_precomputed: bool = False):75        super().__init__()Number of features per head
78        self.d_k = d_model // heads80        self.heads = headsThese transform the query multi-headed attention.
83        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)These transform the key and value for multi-headed attention.
85        if not is_kv_precomputed:
86            self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
87            self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)Keys and values are already calculated
89        else:
90            self.key = None
91            self.value = NoneOutput layer
94        self.output = nn.Linear(d_model, d_model)Dropout
96        self.dropout = nn.Dropout(dropout_prob)Scaling factor before the softmax
98        self.scale = 1 / math.sqrt(self.d_k)Softmax for attention along the time dimension of key
101        self.softmax = nn.Softmax(dim=0)Number of relative positions
104        self.P = 2 ** 12Relative positional embeddings for key relative to the query.
107        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)Relative positional embedding bias for key relative to the query.
109        self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)Positional embeddings for the query is independent of the position of the query
111        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)We store attentions so that it can be used for logging, or other computations if needed
114        self.attn = NoneWe use relative positional encodings for attention, similar to relative multi-head attention form Transformer-XL paper.
Attention from current step’s query to key in step $j$ (relative to current step) is,
where $Q, K_j$, are linear transformations of original embeddings $X^q, X^k_j$ and $U^Q, U^K_j$ are linear transformations of positional encodings $P_q, P_j$.
We replace term $\color{lightgreen}{D}$ with $S_j$.
116    def get_scores(self, query: torch.Tensor, key: torch.Tensor):$U^K_j$
144        key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]$U^Q$
146        query_pos_bias = self.query_pos_bias[None, :, :]$S_j$
148        key_pos_bias = self.key_pos_bias[-key.shape[0]:]$\underset{\color{lightgreen}{A}}{Q^\top K_j} + \underset{\color{lightgreen}{C}}{{U^Q}^\top K_j}$
151        ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)$\underset{\color{lightgreen}{B}}{Q^\top U^K_j} + \underset{\color{lightgreen}{D}}{S_j}$
153        bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]$A_j$
156        return ac + bdquery has shape [batch_size, d_model]key and value has shape [seq_len, batch_size, d_model]158    def forward(self, *,
159                 query: torch.Tensor,
160                 key: torch.Tensor,
161                 value: torch.Tensor):Prepare query, key and value for attention computation
key and value  will then have shape [seq_len, batch_size, heads, d_k]
and query will have shape [batch_size, heads, d_k]
170        query = self.query(query)
171        if self.key:
172            key = self.key(key)
173        if self.value:
174            value = self.value(value)Compute attention scores.
Results in a tensor of shape [seq_len, batch_size, heads]
178        scores = self.get_scores(query, key)Scale scores $\frac{1}{\sqrt{d_k}}$
181        scores *= self.scaleSoftmax
184        attn = self.softmax(scores)Apply dropout
187        attn = self.dropout(attn)Multiply by the values
190        x = torch.einsum("jbh,jbhd->bhd", attn, value)Concatenate multiple heads
193        x = x.reshape(x.shape[0], -1)Output layer
196        return self.output(x)This implements a single transformer layer in the feedback transformer.
199class FeedbackTransformerLayer(Module):d_model is the number of features in the transformerattn is the feedback attention modulefeed_forward is the position-wise feed forward layerdropout_prob is the dropout probability for dropout layers after attention and feed-forward206    def __init__(self, *,
207                 d_model: int,
208                 attn: FeedbackAttention,
209                 feed_forward: FeedForward,
210                 dropout_prob: float):217        super().__init__()Transformer size $d_{model}$
219        self.size = d_model221        self.attn = attn
222        self.feed_forward = feed_forward
223        self.dropout = nn.Dropout(dropout_prob)Normalization layers
226        self.norm_self_attn = nn.LayerNorm([d_model])
227        self.norm_ff = nn.LayerNorm([d_model])229    def forward(self, *,
230                 x: torch.Tensor,
231                 key: Optional[torch.Tensor],
232                 value: Optional[torch.Tensor]):If there is memory
234        if key is not None:Normalize the vectors before doing self attention
236            z = self.norm_self_attn(x)Run through self attention, i.e. keys and values are from self
238            self_attn = self.attn(query=z, key=key, value=value)Add the self attention results
240            x = x + self.dropout(self_attn)Normalize for feed-forward
243        z = self.norm_ff(x)Pass through the feed-forward network
245        ff = self.feed_forward(z)Add the feed-forward results back
247        x = x + self.dropout(ff)250        return x253class FeedbackTransformer(Module):layer is the feedback transformer layer, which we clone for each layern_layers is the number of layers in the transformer258    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):264        super().__init__()Make copies of the transformer layer
266        self.layers = clone_module_list(layer, n_layers)Final normalization layer
268        self.norm = nn.LayerNorm([layer.size])Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.
271        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)Softmax for weights before taking the weighted sum
273        self.softmax = nn.Softmax(0)x_seq is the input with shape [seq_len, batch_size, d_model]275    def forward(self, x_seq: torch.Tensor):Split the input to a list along the sequence axis
281        x_seq = torch.unbind(x_seq, dim=0)List to store the outputs
283        res = []List to store the memory vectors
285        mem = []For each input step
287        for x in x_seq:List to store layer outputs
289            layer_outputs = [x]If there is memory, stack them into a vector
292            mem_tensor = torch.stack(mem) if mem else NoneRun through each layer
295            for layer in self.layers:Get layer output
297                x = layer(x=x, key=mem_tensor, value=mem_tensor)Append them to the list of layer outputs
299                layer_outputs.append(x)Stack the layer outputs to a tensor
302            layer_outputs = torch.stack(layer_outputs)Calculate the memory vector as a weighted sum of layer outputs
304            mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))Append the output to results
306            res.append(x)Stack the output tensors
309        res = torch.stack(res)Normalize the output
311        return self.norm(res)We implement a custom function instead of appending to a python list
and then doing torch.stack.
This greatly improves the performance over calling torch.stack at
each step along the sequence.
Everytime torch.stack is called, it creates a new tensor, while
this method and the accompanying class Stack share memory for each step.
318class StackFunction(torch.autograd.Function):ctx is the context of the function (which lets us cache stuff)memory is the shared memory tensor where we stack and store the values of each step (keys & values)memory_grad is the shared memory tensor to store and accumulate gradients of each steplast is the last value stackedn is the number of steps (i.e. size of the stack)This returns the stacked tensor for steps upto n.
330    @staticmethod
331    def forward(ctx, memory, memory_grad, last, n):Cache accumulated gradients
343        ctx._mem_grad = memory_gradCache the size of the stack
345        ctx._n = nReturn the stack
347        return memory[:n + 1]grad_output is the gradient with respect to the output of about forward functionThis accumulates the gradients in the shared memory tensor and return the
gradients with respect to the last result in the stack.
349    @staticmethod
350    def backward(ctx, grad_output):Get the current size of the stack
358        n = ctx._nGet the accumulated gradients
360        memory_grad = ctx._mem_gradAdd the gradients
362        memory_grad[:n + 1] += grad_outputReturn the gradients w.r.t to last value in the stack
364        return None, None, memory_grad[n], None367class Stack:max_len is the maximum size of the stack373    def __init__(self, max_len: int):377        self.max_len = max_len
378        self.memory = None
379        self.memory_grad = None
380        self.last = None
381        self.n = -1
382        self.last_get_n = -1n is the size of the stackvalue is the tensor that needs to be added to the stack384    def append(self, n: int, value: torch.Tensor):You need to get (use) the stack after adding a value. Otherwise this implementation fails
392        assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"Do this without gradients
395        with torch.no_grad():Initialize the shared memory tensor to keep the stack
397            if self.memory is None or self.memory.shape[1:] != value.shape:This should only happen when the stack is empty
399                assert n == 0Create a tensor for the stack
401                self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)Create a tensor to accumulate the gradients
403                self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)The memory is already initialized but we are resetting the stack.
This could have been another function like reset, but
we found this easier to use.
408            elif n == 0:Reset accumulated gradients
410                self.memory_grad.fill_(0.)Set the value in the correct position of the stack
413            self.memory.data[n] = value.detach()Keep track of the stack (for debugging)
415            self.n = nKeep track of the last value added to the stack.
We need this to be passed on to StackFunction in order
to get the gradients propagated backwards.
420        self.last = valueReturns the stack
422    def get(self):Keep track of the size of the stack when it was used.
This is used for a sanity check in append.
429        self.last_get_n = self.nTake it all through StackFunction so that StackFunction.backwards
is called by PyTorch during backpropagation.
432        return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)This is the updated feedback transformer module that caches the keys and values.
435class FeedbackTransformerKV(Module):layer is the feedback transformer layer, which we clone for each layern_layers is the number of layers in the transformerd_model is the number of features in the transformer442    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):450        super().__init__()Make copies of the transformer layer
452        self.layers = clone_module_list(layer, n_layers)Final normalization layer
454        self.norm = nn.LayerNorm([layer.size])Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.
457        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)Softmax for weights before taking the weighted sum
459        self.softmax = nn.Softmax(0)Number of features in a head
462        d_k = d_model // headsModule to transform embeddings (memory) to get keys
464        self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)Module to transform embeddings (memory) to get keys
466        self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)Memory for stacked keys
469        self.mem_key = Stack(512)Memory for stacked values
471        self.mem_value = Stack(512)x_seq is the input with shape [seq_len, batch_size, d_model]473    def forward(self, x_seq: torch.Tensor):Split the input to a list along the sequence axis
479        x_seq = torch.unbind(x_seq, dim=0)List to store the outputs
481        res = []For each input step
483        for step, x in enumerate(x_seq):List to store layer outputs
485            layer_outputs = [x]Stack of keys and values
488            key_tensor = None
489            value_tensor = NoneGet the keys and values tensors if we are beyond the initial step
491            if step > 0:
492                key_tensor = self.mem_key.get()
493                value_tensor = self.mem_value.get()Run through each layer
496            for layer in self.layers:Get layer output
498                x = layer(x=x, key=key_tensor, value=value_tensor)Append them to the list of layer outputs
500                layer_outputs.append(x)Stack the layer outputs to a tensor
503            layer_outputs = torch.stack(layer_outputs)Calculate the memory vector as a weighted sum of layer outputs
505            mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))Calculate the keys from memory and add it to the stack
507            self.mem_key.append(step, self.key(mem))Calculate the values from memory and add it to the stack
509            self.mem_value.append(step, self.value(mem))Append the output to results
511            res.append(x)Stack the output tensors
514        res = torch.stack(res)Normalize the output
516        return self.norm(res)