This is a PyTorch implementation of the paper An Attention Free Transformer.
This paper replaces the self-attention layer with a new efficient operation, that has memory complexity of $\mathcal{O}(Td)$, where $T$ is the sequence length and $d$ is the dimensionality of embeddings.
The paper introduces AFT along with AFT-local and AFT-conv. Here we have implemented AFT-local which pays attention to closeby tokens in an autoregressive model.
AFT (similar to MHA) first transforms the embeddings $X$ into query $Q = XW^Q$, key $K = XW^K$ and value $V = XW^V$ tensors with learned weights. The output for each position $t \in [1, T]$ is calculated with the following operation.
, where $\odot$ is element-wise product, $\sigma$ is a non-linearity (sigmoid) and $w \in \mathbb{R}^{T \times T}$ is a learned matrix of pair-wise position biases.
This means that we take the weighted average of values and multiply them by the query. This eliminates the need to calculate the $T \times T$ attention matrix that MHA requires, and therefore reduce the memory requirement.
AFT Local only apply learned pair-wise position biases locally:
, where $s \le T$ is the local window size.
Although $w’_{t,t’}$ is $0$ outside the local window the AFT operation still uses key-value pairs from other areas. This is different from local transformers where embeddings outside the local window are completely not visible.
Here is the training code for a AFT Local model.
61from typing import Optional
62
63import torch
64from torch import nn
65
66from labml_helpers.module import Module
69class AFTLocal(Module):
d_model
is the number of features in the query
, key
and value
vectors.seq_len
is $T$local_window_size
is the local window size $s$bias
is whether to have a bias parameter for transformations for $Q$, $K$ and $V$.88 def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):
96 super().__init__()
Local window size $s$
99 self.local_window_size = local_window_size
These transform the query
, key
and value
vectors.
101 self.query = nn.Linear(d_model, d_model, bias=bias)
102 self.key = nn.Linear(d_model, d_model, bias=bias)
103 self.value = nn.Linear(d_model, d_model, bias=bias)
Pair-wise positional biases $w \in \mathbb{R}^{T \times T}$
105 self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)
Mask for $w_{t,t’}$
107 self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)
Activation $\sigma$
109 self.activation = nn.Sigmoid()
Output layer
111 self.output = nn.Linear(d_model, d_model)
113 @staticmethod
114 def create_local_mask(seq_len, local_window_size):
Initialize to ones
129 local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)
Make $t’ - t \ge s$ zero
131 local_mask = torch.tril(local_mask, local_window_size - 1)
Make $t - t’ \ge s$ zero
133 local_mask = torch.triu(local_mask, -(local_window_size - 1))
136 return local_mask
query
, key
and value
are the tensors that store
collection of token embeddings for query, key and value.
They have shape [seq_len, batch_size, d_model]
.
mask
has shape [seq_len, seq_len, batch_size]
and
mask[i, j, b]
indicates whether for batch b
,
query at position i
has access to key-value at position j
.
138 def forward(self, *,
139 query: torch.Tensor,
140 key: torch.Tensor,
141 value: torch.Tensor,
142 mask: Optional[torch.Tensor] = None):
query
, key
and value
have shape [seq_len, batch_size, d_model]
154 seq_len, _, _ = query.shape
155
156 if mask is not None:
mask
has shape [seq_len_q, seq_len_k, batch_size]
,
where first dimension is the query dimension.
If the query dimension is equal to $1$ it will be broadcasted.
160 assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
161 assert mask.shape[1] == key.shape[0]
162 assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]
Transform query, key and value embeddings
165 query = self.query(query)
166 key = self.key(key)
167 value = self.value(value)
Get using the mask
178 pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
179 pos_bias = pos_bias.unsqueeze(-1)
180 pos_bias.masked_fill_(~mask, float('-inf'))
We compute $\exp(w_{t,t’})$, $\exp(K_{t’}) \odot V_{t’}$ and $\exp(K_{t’})$ separately and do a matrix multiplication. We use einsum for clarity.
We subtract $\max_{t’}(K_{t’})$ and $\max_{t’}(w_{t,t’})$ before calculating the exponents to stabilize the softmax calculation.
If $x_i$ is large $\exp(x_i)$ becomes huge and the computation of $\frac{\sum\exp(x_i)y_i}{\sum\exp(x_i)}$becomes unstable. Subtracting a constant before calculating the exponent from numerator and denominator will cancel out. and can help stabilize the computation. So we subtract $\max(x_i)$ to stabilize the computation.
202 max_key = key.max(dim=0, keepdims=True)[0]
203 max_pos_bias = pos_bias.max(dim=1, keepdims=True)[0]
$\exp \big(K_{t’}- \max_{t’}(K_{t’})\big)$
206 exp_key = torch.exp(key - max_key)
$\exp \big(w_{t,t’} - \max_{t’}(w_{t,t’})\big)$
208 exp_pos_bias = torch.exp(pos_bias - max_pos_bias)
The numerator part $\sum_{t’=1}^T \exp(w_{t,t’}) \odot \exp(K_{t’}) \odot V_{t’}$
211 num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)
The denominator part $\sum_{t’=1}^T \exp(w_{t,t’}) \odot \exp(K_{t’})$
213 den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)
Output
218 y = self.activation(query) * num / den
Output layer
221 return self.output(y)
Test local mask
224def _test_local_mask():
228 from labml.logger import inspect
229 inspect(AFTLocal.create_local_mask(10, 4))
233if __name__ == '__main__':
234 _test_local_mask()