This is an implementation of the paper Accessing Higher-level Representations in Sequential Transformers with Feedback Memory.
Normal transformers process tokens in parallel and each transformer layer pays attention to the outputs of the previous layer. Feedback transformer pays attention to the output of all layers in previous steps. So this adds recurrence and we need to process token-by-token. This slows down the training significantly (about 5X - 10X depending on the sequence length). However when predicting Feedback Transformer is faster because you can predict the next token if you cache the memory vectors.
In order to speed up the training the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.
The feedback transformer doesn’t keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. This reduces the memory used for caching during prediction.
Here’s a notebook for training a feedback transformer on Tiny Shakespeare dataset.
35import math
36from typing import Optional
37
38import torch
39from torch import nn
40
41from labml_helpers.module import Module
42from labml_nn.transformers.mha import PrepareForMultiHeadAttention
43from labml_nn.transformers.feed_forward import FeedForward
44from labml_nn.utils import clone_module_list
This module computes recurrent attention similar to attention from original transformers paper.
47class FeedbackAttention(Module):
d_model
is the number of features in the transformerdropout_prob
is the attention dropout probability58 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
65 super().__init__()
Number of features per head
68 self.d_k = d_model // heads
70 self.heads = heads
These transform the query
, key
and value
vectors for multi-headed attention.
73 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
74 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
75 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
Output layer
78 self.output = nn.Linear(d_model, d_model)
Dropout
80 self.dropout = nn.Dropout(dropout_prob)
Scaling factor before the softmax
82 self.scale = 1 / math.sqrt(self.d_k)
Softmax for attention along the time dimension of key
85 self.softmax = nn.Softmax(dim=0)
Number of relative positions
88 self.P = 2 ** 12
Relative positional embeddings for key relative to the query.
91 self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
Positional embeddings for the query is independent of the position of the query
93 self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
We store attentions so that it can used for logging, or other computations if needed
96 self.attn = None
where $Q, K_j$, are linear transformations of original embeddings $X^q, X^k_j$ and $U^Q, U^K_j$ are linear transformations of absolute positional encodings $P_q, P_j$.
98 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
$U^K_j$
115 key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
$U^Q$
117 query_pos_bias = self.query_pos_bias[None, :, :]
$(Q + U^Q)^\top(K_j + U^K_j)$
120 return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
query
has shape [batch_size, d_model]
key
and value
has shape [seq_len, batch_size, d_model]
122 def __call__(self, *,
123 query: torch.Tensor,
124 key: torch.Tensor,
125 value: torch.Tensor):
Prepare query
, key
and value
for attention computation
key
and value
will then have shape [seq_len, batch_size, heads, d_k]
and query
will have shape [batch_size, heads, d_k]
134 query = self.query(query)
135 key = self.key(key)
136 value = self.value(value)
Compute attention scores
Results in a tensor of shape [seq_len, batch_size, heads]
140 scores = self.get_scores(query, key)
Scale scores $\frac{1}{\sqrt{d_k}}$
143 scores *= self.scale
Softmax
146 attn = self.softmax(scores)
Apply dropout
149 attn = self.dropout(attn)
Multiply by the values
152 x = torch.einsum("jbh,jbhd->bhd", attn, value)
Concatenate multiple heads
155 x = x.reshape(x.shape[0], -1)
Output layer
158 return self.output(x)
This implements a single transformer layer in the feedback transformer.
161class FeedbackTransformerLayer(Module):
d_model
is the number of features in the transformerattn
is the feedback attention modulefeed_forward
is the position-wise feed forward layerdropout_prob
is the dropout probability for dropout layers after attention and feed-forward168 def __init__(self, *,
169 d_model: int,
170 attn: FeedbackAttention,
171 feed_forward: FeedForward,
172 dropout_prob: float):
179 super().__init__()
Transformer size $d_{model}$
181 self.size = d_model
183 self.attn = attn
184 self.feed_forward = feed_forward
185 self.dropout = nn.Dropout(dropout_prob)
Normalization layers
188 self.norm_self_attn = nn.LayerNorm([d_model])
189 self.norm_ff = nn.LayerNorm([d_model])
191 def __call__(self, *,
192 x: torch.Tensor,
193 mem: Optional[torch.Tensor]):
If there is memory
195 if mem is not None:
Normalize the vectors before doing self attention
197 z = self.norm_self_attn(x)
Run through self attention, i.e. keys and values are from self
199 self_attn = self.attn(query=z, key=mem, value=mem)
Add the self attention results
201 x = x + self.dropout(self_attn)
Normalize for feed-forward
204 z = self.norm_ff(x)
Pass through the feed-forward network
206 ff = self.feed_forward(z)
Add the feed-forward results back
208 x = x + self.dropout(ff)
211 return x
214class FeedbackTransformer(Module):
layer
is the feedback transformer layer, which we clone for each layern_layers
is the number of layers in the transformer219 def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
225 super().__init__()
Make copies of the transformer layer
227 self.layers = clone_module_list(layer, n_layers)
Final normalization layer
229 self.norm = nn.LayerNorm([layer.size])
Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.
232 self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
Softmax for weights before taking the weighted sum
234 self.softmax = nn.Softmax(0)
x_seq
is the input with shape [seq_len, batch_size, d_model]
236 def __call__(self, x_seq: torch.Tensor):
Split the input to a list along the sequence axis
242 x_seq = torch.unbind(x_seq, dim=0)
List to store the outputs
244 res = []
List to store the memory vectors
246 mem = []
For each input step
248 for x in x_seq:
List to store layer outputs
250 layer_outputs = [x]
If there is memory, stack them into a vector
253 mem_tensor = torch.stack(mem) if mem else None
Run through each layer
256 for layer in self.layers:
Get layer output
258 x = layer(x=x, mem=mem_tensor)
Append them to the list of layer outputs
260 layer_outputs.append(x)
Stack the layer outputs to a tensor
263 layer_outputs = torch.stack(layer_outputs)
Calculate the memory vector as a weighted sum of layer outputs
265 mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))
Append the output to results
267 res.append(x)
Stack the output tensors
270 res = torch.stack(res)
Normalize the output
272 return self.norm(res)