This is a PyTorch implementation of the paper An Attention Free Transformer.
This paper replaces the self-attention layer with a new efficient operation, that has memory complexity of , where is the sequence length and is the dimensionality of embeddings.
The paper introduces AFT along with AFT-local and AFT-conv. Here we have implemented AFT-local which pays attention to closeby tokens in an autoregressive model.
AFT (similar to MHA) first transforms the embeddings into query , key and value tensors with learned weights. The output for each position is calculated with the following operation.
, where is element-wise product, is a non-linearity (sigmoid) and is a learned matrix of pair-wise position biases.
This means that we take the weighted average of values and multiply them by the query. This eliminates the need to calculate the attention matrix that MHA requires, and therefore reduce the memory requirement.
AFT Local only apply learned pair-wise position biases locally:
, where is the local window size.
Although is outside the local window the AFT operation still uses key-value pairs from other areas. This is different from local transformers where embeddings outside the local window are completely not visible.
Here is the training code for a AFT Local model.
59from typing import Optional
60
61import torch
62from torch import nn66class AFTLocal(nn.Module):d_model
 is the number of features in the query
, key
 and value
 vectors. seq_len
 is  local_window_size
 is the local window size  bias
 is whether to have a bias parameter for transformations for ,  and .85    def __init__(self, d_model: int, seq_len: int, local_window_size: int, bias: bool = True):93        super().__init__()Local window size
96        self.local_window_size = local_window_sizeThese transform the query
, key
 and value
 vectors. 
98        self.query = nn.Linear(d_model, d_model, bias=bias)
99        self.key = nn.Linear(d_model, d_model, bias=bias)
100        self.value = nn.Linear(d_model, d_model, bias=bias)Pair-wise positional biases
102        self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)Mask for
104        self.local_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)Activation
106        self.activation = nn.Sigmoid()Output layer
108        self.output = nn.Linear(d_model, d_model)110    @staticmethod
111    def create_local_mask(seq_len, local_window_size):Initialize to ones
127        local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)Make zero
129        local_mask = torch.tril(local_mask, local_window_size - 1)Make zero
131        local_mask = torch.triu(local_mask, -(local_window_size - 1))134        return local_mask query
, key
 and value
 are the tensors that store collection of token embeddings for query, key and value. They have shape [seq_len, batch_size, d_model]
.
mask
 has shape [seq_len, seq_len, batch_size]
 and mask[i, j, b]
 indicates whether for batch b
, query at position i
 has access to key-value at position j
.
136    def forward(self, *,
137                query: torch.Tensor,
138                key: torch.Tensor,
139                value: torch.Tensor,
140                mask: Optional[torch.Tensor] = None):query
, key
 and value
 have shape [seq_len, batch_size, d_model]
 
152        seq_len, _, _ = query.shape
153
154        if mask is not None:mask
 has shape [seq_len_q, seq_len_k, batch_size]
, where first dimension is the query dimension. If the query dimension is equal to  it will be broadcasted. 
158            assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
159            assert mask.shape[1] == key.shape[0]
160            assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]Transform query, key and value embeddings
163        query = self.query(query)
164        key = self.key(key)
165        value = self.value(value)178        pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
179        pos_bias = pos_bias.unsqueeze(-1)
180        pos_bias.masked_fill_(~mask, float('-inf'))We subtract and before calculating the exponents to stabilize the softmax calculation.
If is large becomes huge and the computation of becomes unstable. Subtracting a constant before calculating the exponent from numerator and denominator will cancel out. and can help stabilize the computation. So we subtract to stabilize the computation.
202        max_key = key.max(dim=0, keepdims=True)[0]
203        max_pos_bias = pos_bias.max(dim=1,  keepdims=True)[0]206        exp_key = torch.exp(key - max_key)208        exp_pos_bias = torch.exp(pos_bias - max_pos_bias)The numerator part
211        num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)The denominator part
213        den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)Output
218        y = self.activation(query) * num / denOutput layer
221        return self.output(y)Test local mask
224def _test_local_mask():228    from labml.logger import inspect
229    inspect(AFTLocal.create_local_mask(10, 4))233if __name__ == '__main__':
234    _test_local_mask()