This is a PyTorch implementation of the paper Accessing Higher-level Representations in Sequential Transformers with Feedback Memory.
Normal transformers process tokens in parallel. Each transformer layer pays attention to the outputs of the previous layer. Feedback transformer pays attention to the output of all layers in previous steps. So this adds recurrence, and we need to process token-by-token. This slows down the training significantly (about 5X - 10X depending on the sequence length). However, when predicting Feedback Transformer is faster because you can predict the next token if you cache the memory vectors.
In order to speed up the training, the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.
The original feedback transformer doesn't keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. This reduces the memory used for caching during prediction. The first half of this file implements this.
The updated feedback transformer shares weights and used to calculate keys and values among the layers. We then calculate the keys and values for each step only once and keep them cached. The second half of this file implements this. We implemented a custom PyTorch function to improve performance.
Here's the training code and a notebook for training a feedback transformer on Tiny Shakespeare dataset.
42import math
43from typing import Optional
44
45import torch
46from torch import nn
47
48from labml_nn.transformers.feed_forward import FeedForward
49from labml_nn.transformers.mha import PrepareForMultiHeadAttention
50from labml_nn.utils import clone_module_listThis module computes recurrent attention similar to attention from original transformers paper.
53class FeedbackAttention(nn.Module):d_model
 is the number of features in the transformer dropout_prob
 is the attention dropout probability is_kv_precomputed
 is whether key, value tensors are already calculated64    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
65                 is_kv_precomputed: bool = False):73        super().__init__()Number of features per head
76        self.d_k = d_model // heads78        self.heads = headsThese transform the query
 multi-headed attention. 
81        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)These transform the key
 and value
 for multi-headed attention. 
83        if not is_kv_precomputed:
84            self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
85            self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)Keys and values are already calculated
87        else:
88            self.key = None
89            self.value = NoneOutput layer
92        self.output = nn.Linear(d_model, d_model)Dropout
94        self.dropout = nn.Dropout(dropout_prob)Scaling factor before the softmax
96        self.scale = 1 / math.sqrt(self.d_k)Softmax for attention along the time dimension of key
 
99        self.softmax = nn.Softmax(dim=0)Number of relative positions
102        self.P = 2 ** 12Relative positional embeddings for key relative to the query.
105        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)Relative positional embedding bias for key relative to the query.
107        self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)Positional embeddings for the query is independent of the position of the query
109        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)We store attentions so that it can be used for logging, or other computations if needed
112        self.attn = NoneWe use relative positional encodings for attention, similar to relative multi-head attention form Transformer-XL paper.
Attention from current step's query to key in step (relative to current step) is,
where , are linear transformations of original embeddings and are linear transformations of positional encodings .
We replace term with .
114    def get_scores(self, query: torch.Tensor, key: torch.Tensor):142        key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]144        query_pos_bias = self.query_pos_bias[None, :, :]146        key_pos_bias = self.key_pos_bias[-key.shape[0]:]149        ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)151        bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]154        return ac + bdquery
 has shape [batch_size, d_model]
 key
 and value
 has shape [seq_len, batch_size, d_model]
156    def forward(self, *,
157                query: torch.Tensor,
158                key: torch.Tensor,
159                value: torch.Tensor):Prepare query
, key
 and value
 for attention computation key
 and value
 will then have shape [seq_len, batch_size, heads, d_k]
 and query
 will have shape [batch_size, heads, d_k]
 
168        query = self.query(query)
169        if self.key:
170            key = self.key(key)
171        if self.value:
172            value = self.value(value)Compute attention scores. Results in a tensor of shape [seq_len, batch_size, heads]
 
176        scores = self.get_scores(query, key)Scale scores
179        scores *= self.scaleSoftmax
182        attn = self.softmax(scores)Apply dropout
185        attn = self.dropout(attn)Multiply by the values
188        x = torch.einsum("jbh,jbhd->bhd", attn, value)Concatenate multiple heads
191        x = x.reshape(x.shape[0], -1)Output layer
194        return self.output(x)This implements a single transformer layer in the feedback transformer.
197class FeedbackTransformerLayer(nn.Module):d_model
 is the number of features in the transformer attn
 is the feedback attention module feed_forward
 is the position-wise feed forward layer dropout_prob
 is the dropout probability for dropout layers after attention and feed-forward204    def __init__(self, *,
205                 d_model: int,
206                 attn: FeedbackAttention,
207                 feed_forward: FeedForward,
208                 dropout_prob: float):215        super().__init__()Transformer size
217        self.size = d_model219        self.attn = attn
220        self.feed_forward = feed_forward
221        self.dropout = nn.Dropout(dropout_prob)Normalization layers
224        self.norm_self_attn = nn.LayerNorm([d_model])
225        self.norm_ff = nn.LayerNorm([d_model])227    def forward(self, *,
228                x: torch.Tensor,
229                key: Optional[torch.Tensor],
230                value: Optional[torch.Tensor]):If there is memory
232        if key is not None:Normalize the vectors before doing self attention
234            z = self.norm_self_attn(x)Run through self attention, i.e. keys and values are from self
236            self_attn = self.attn(query=z, key=key, value=value)Add the self attention results
238            x = x + self.dropout(self_attn)Normalize for feed-forward
241        z = self.norm_ff(x)Pass through the feed-forward network
243        ff = self.feed_forward(z)Add the feed-forward results back
245        x = x + self.dropout(ff)248        return x251class FeedbackTransformer(nn.Module):layer
 is the feedback transformer layer, which we clone for each layer n_layers
 is the number of layers in the transformer256    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):262        super().__init__()Make copies of the transformer layer
264        self.layers = clone_module_list(layer, n_layers)Final normalization layer
266        self.norm = nn.LayerNorm([layer.size])Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.
269        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)Softmax for weights before taking the weighted sum
271        self.softmax = nn.Softmax(0)x_seq
 is the input with shape [seq_len, batch_size, d_model]
273    def forward(self, x_seq: torch.Tensor):Split the input to a list along the sequence axis
279        x_seq = torch.unbind(x_seq, dim=0)List to store the outputs
281        res = []List to store the memory vectors
283        mem = []For each input step
285        for x in x_seq:List to store layer outputs
287            layer_outputs = [x]If there is memory, stack them into a vector
290            mem_tensor = torch.stack(mem) if mem else NoneRun through each layer
293            for layer in self.layers:Get layer output
295                x = layer(x=x, key=mem_tensor, value=mem_tensor)Append them to the list of layer outputs
297                layer_outputs.append(x)Stack the layer outputs to a tensor
300            layer_outputs = torch.stack(layer_outputs)Calculate the memory vector as a weighted sum of layer outputs
302            mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))Append the output to results
304            res.append(x)Stack the output tensors
307        res = torch.stack(res)Normalize the output
309        return self.norm(res)We implement a custom function instead of appending to a python list and then doing torch.stack
. This greatly improves the performance over calling torch.stack
 at each step along the sequence. Everytime torch.stack
 is called, it creates a new tensor, while this method and the accompanying class Stack
 share memory for each step.
316class StackFunction(torch.autograd.Function):ctx
 is the context of the function (which lets us cache stuff) memory
 is the shared memory tensor where we stack and store the values of each step (keys & values) memory_grad
 is the shared memory tensor to store and accumulate gradients of each step last
 is the last value stacked n
 is the number of steps (i.e. size of the stack)This returns the stacked tensor for steps upto n
.
328    @staticmethod
329    def forward(ctx, memory, memory_grad, last, n):Cache accumulated gradients
341        ctx._mem_grad = memory_gradCache the size of the stack
343        ctx._n = nReturn the stack
345        return memory[:n + 1]grad_output
 is the gradient with respect to the output of about forward
 functionThis accumulates the gradients in the shared memory tensor and return the gradients with respect to the last
 result in the stack.
347    @staticmethod
348    def backward(ctx, grad_output):Get the current size of the stack
356        n = ctx._nGet the accumulated gradients
358        memory_grad = ctx._mem_gradAdd the gradients
360        memory_grad[:n + 1] += grad_outputReturn the gradients w.r.t to last value in the stack
362        return None, None, memory_grad[n], None365class Stack:max_len
 is the maximum size of the stack372    def __init__(self, max_len: int):376        self.max_len = max_len
377        self.memory = None
378        self.memory_grad = None
379        self.last = None
380        self.n = -1
381        self.last_get_n = -1n
 is the size of the stack value
 is the tensor that needs to be added to the stack383    def append(self, n: int, value: torch.Tensor):You need to get (use) the stack after adding a value. Otherwise this implementation fails
391        assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"Do this without gradients
394        with torch.no_grad():Initialize the shared memory tensor to keep the stack
396            if self.memory is None or self.memory.shape[1:] != value.shape:This should only happen when the stack is empty
398                assert n == 0Create a tensor for the stack
400                self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)Create a tensor to accumulate the gradients
402                self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)The memory is already initialized but we are resetting the stack.
This could have been another function like reset
, but we found this easier to use. 
407            elif n == 0:Reset accumulated gradients
409                self.memory_grad.fill_(0.)Set the value in the correct position of the stack
412            self.memory.data[n] = value.detach()Keep track of the stack (for debugging)
414            self.n = nKeep track of the last value added to the stack. We need this to be passed on to StackFunction
 in order to get the gradients propagated backwards. 
419        self.last = valueReturns the stack
421    def get(self):Keep track of the size of the stack when it was used. This is used for a sanity check in append
. 
428        self.last_get_n = self.nTake it all through StackFunction
 so that StackFunction.backwards
 is called by PyTorch during backpropagation. 
431        return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)To release memory
433    def free(self):438        self.memory = None
439        self.memory_grad = None
440        self.last = NoneThis is the updated feedback transformer module that caches the keys and values.
443class FeedbackTransformerKV(nn.Module):layer
 is the feedback transformer layer, which we clone for each layer n_layers
 is the number of layers in the transformer d_model
 is the number of features in the transformer 450    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):458        super().__init__()Make copies of the transformer layer
460        self.layers = clone_module_list(layer, n_layers)Final normalization layer
462        self.norm = nn.LayerNorm([layer.size])Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.
465        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)Softmax for weights before taking the weighted sum
467        self.softmax = nn.Softmax(0)Number of features in a head
470        d_k = d_model // headsModule to transform embeddings (memory) to get keys
472        self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)Module to transform embeddings (memory) to get keys
474        self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)Memory for stacked keys
477        self.mem_key = Stack(512)Memory for stacked values
479        self.mem_value = Stack(512)x_seq
 is the input with shape [seq_len, batch_size, d_model]
481    def forward(self, x_seq: torch.Tensor):Split the input to a list along the sequence axis
487        x_seq = torch.unbind(x_seq, dim=0)List to store the outputs
489        res = []For each input step
491        for step, x in enumerate(x_seq):List to store layer outputs
493            layer_outputs = [x]Stack of keys and values
496            key_tensor = None
497            value_tensor = NoneGet the keys and values tensors if we are beyond the initial step
499            if step > 0:
500                key_tensor = self.mem_key.get()
501                value_tensor = self.mem_value.get()Run through each layer
504            for layer in self.layers:Get layer output
506                x = layer(x=x, key=key_tensor, value=value_tensor)Append them to the list of layer outputs
508                layer_outputs.append(x)Stack the layer outputs to a tensor
511            layer_outputs = torch.stack(layer_outputs)Calculate the memory vector as a weighted sum of layer outputs
513            mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))Calculate the keys from memory and add it to the stack
515            self.mem_key.append(step, self.key(mem))Calculate the values from memory and add it to the stack
517            self.mem_value.append(step, self.value(mem))Append the output to results
519            res.append(x)Stack the output tensors
522        res = torch.stack(res)Normalize the output
524        return self.norm(res)526    def free(self):
527        self.mem_key.free()
528        self.mem_value.free()