Feedback Transformer

This is an implementation of the paper Accessing Higher-level Representations in Sequential Transformers with Feedback Memory.

Normal transformers process tokens in parallel and each transformer layer pays attention to the outputs of the previous layer. Feedback transformer pays attention to the output of all layers in previous steps. So this adds recurrence and we need to process token-by-token. This slows down the training significantly (about 5X - 10X depending on the sequence length). However when predicting Feedback Transformer is faster because you can predict the next token if you cache the memory vectors.

In order to speed up the training the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.

The feedback transformer doesn’t keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. This reduces the memory used for caching during prediction.

Here’s a notebook for training a feedback transformer on Tiny Shakespeare dataset.

Open In Colab View Run

35import math
36from typing import Optional
37
38import torch
39from torch import nn
40
41from labml_helpers.module import Module
42from labml_nn.transformers.mha import PrepareForMultiHeadAttention
43from labml_nn.transformers.feed_forward import FeedForward
44from labml_nn.utils import clone_module_list

Feedback Attention

This module computes recurrent attention similar to attention from original transformers paper.

47class FeedbackAttention(Module):
  • ‘heads’ is the number of attention heads
  • d_model is the number of features in the transformer
  • dropout_prob is the attention dropout probability
58    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
65        super().__init__()

Number of features per head

68        self.d_k = d_model // heads
70        self.heads = heads

These transform the query, key and value vectors for multi-headed attention.

73        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=False)
74        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=False)
75        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=True)

Output layer

78        self.output = nn.Linear(d_model, d_model)

Dropout

80        self.dropout = nn.Dropout(dropout_prob)

Scaling factor before the softmax

82        self.scale = 1 / math.sqrt(self.d_k)

Softmax for attention along the time dimension of key

85        self.softmax = nn.Softmax(dim=0)

Number of relative positions

88        self.P = 2 ** 12

Relative positional embeddings for key relative to the query.

91        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)

Positional embeddings for the query is independent of the position of the query

93        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)

We store attentions so that it can used for logging, or other computations if needed

96        self.attn = None

Get attention scores

where $Q, K_j$, are linear transformations of original embeddings $X^q, X^k_j$ and $U^Q, U^K_j$ are linear transformations of absolute positional encodings $P_q, P_j$.

98    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

$U^K_j$

115        key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]

$U^Q$

117        query_pos_bias = self.query_pos_bias[None, :, :]

$(Q + U^Q)^\top(K_j + U^K_j)$

120        return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
  • query has shape [batch_size, d_model]
  • key and value has shape [seq_len, batch_size, d_model]
122    def __call__(self, *,
123                 query: torch.Tensor,
124                 key: torch.Tensor,
125                 value: torch.Tensor):

Prepare query, key and value for attention computation key and value will then have shape [seq_len, batch_size, heads, d_k] and query will have shape [batch_size, heads, d_k]

134        query = self.query(query)
135        key = self.key(key)
136        value = self.value(value)

Compute attention scores Results in a tensor of shape [seq_len, batch_size, heads]

140        scores = self.get_scores(query, key)

Scale scores $\frac{1}{\sqrt{d_k}}$

143        scores *= self.scale

Softmax

146        attn = self.softmax(scores)

Apply dropout

149        attn = self.dropout(attn)

Multiply by the values

152        x = torch.einsum("jbh,jbhd->bhd", attn, value)

Concatenate multiple heads

155        x = x.reshape(x.shape[0], -1)

Output layer

158        return self.output(x)

Feedback Transformer Layer

This implements a single transformer layer in the feedback transformer.

161class FeedbackTransformerLayer(Module):
  • d_model is the number of features in the transformer
  • attn is the feedback attention module
  • feed_forward is the position-wise feed forward layer
  • dropout_prob is the dropout probability for dropout layers after attention and feed-forward
168    def __init__(self, *,
169                 d_model: int,
170                 attn: FeedbackAttention,
171                 feed_forward: FeedForward,
172                 dropout_prob: float):
179        super().__init__()

Transformer size $d_{model}$

181        self.size = d_model
183        self.attn = attn
184        self.feed_forward = feed_forward
185        self.dropout = nn.Dropout(dropout_prob)

Normalization layers

188        self.norm_self_attn = nn.LayerNorm([d_model])
189        self.norm_ff = nn.LayerNorm([d_model])
191    def __call__(self, *,
192                 x: torch.Tensor,
193                 mem: Optional[torch.Tensor]):

If there is memory

195        if mem is not None:

Normalize the vectors before doing self attention

197            z = self.norm_self_attn(x)

Run through self attention, i.e. keys and values are from self

199            self_attn = self.attn(query=z, key=mem, value=mem)

Add the self attention results

201            x = x + self.dropout(self_attn)

Normalize for feed-forward

204        z = self.norm_ff(x)

Pass through the feed-forward network

206        ff = self.feed_forward(z)

Add the feed-forward results back

208        x = x + self.dropout(ff)
211        return x

Feedback Transformer Module

214class FeedbackTransformer(Module):
  • layer is the feedback transformer layer, which we clone for each layer
  • n_layers is the number of layers in the transformer
219    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
225        super().__init__()

Make copies of the transformer layer

227        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

229        self.norm = nn.LayerNorm([layer.size])

Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.

232        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)

Softmax for weights before taking the weighted sum

234        self.softmax = nn.Softmax(0)
  • x_seq is the input with shape [seq_len, batch_size, d_model]
236    def __call__(self, x_seq: torch.Tensor):

Split the input to a list along the sequence axis

242        x_seq = torch.unbind(x_seq, dim=0)

List to store the outputs

244        res = []

List to store the memory vectors

246        mem = []

For each input step

248        for x in x_seq:

List to store layer outputs

250            layer_outputs = [x]

If there is memory, stack them into a vector

253            mem_tensor = torch.stack(mem) if mem else None

Run through each layer

256            for layer in self.layers:

Get layer output

258                x = layer(x=x, mem=mem_tensor)

Append them to the list of layer outputs

260                layer_outputs.append(x)

Stack the layer outputs to a tensor

263            layer_outputs = torch.stack(layer_outputs)

Calculate the memory vector as a weighted sum of layer outputs

265            mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))

Append the output to results

267            res.append(x)

Stack the output tensors

270        res = torch.stack(res)

Normalize the output

272        return self.norm(res)