This is a PyTorch implementation of the paper An Attention Free Transformer.
This paper replaces the self-attention layer with a new efficient operation, that has memory complexity of $\mathcal{O}(Td)$, where $T$ is the sequence length and $d$ is the dimensionality of embeddings.
The paper introduces AFT along with AFT-local and AFT-conv. Here we have implemented AFT-local which pays attention to closeby tokens in an autoregressive model.
AFT (similar to MHA) first transforms the embeddings $X$ into query $Q = XW^Q$, key $K = XW^K$ and value $V = XW^V$ tensors with learned weights. The output for each position $t \in [1, T]$ is calculated with the following operation.
, where $\odot$ is element-wise product, $\sigma$ is a non-linearity (sigmoid) and $w \in \mathbb{R}^{T \times T}$ is a learned matrix of pair-wise position biases.
This means that we take the weighted average of values and multiply them by the query. This eliminates the need to calculate the $T \times T$ attention matrix that MHA requires, and therefore reduce the memory requirement.
AFT Local only apply learned pair-wise position biases locally:
, where $s \le T$ is the local window size.
Although $w’_{t,t’}$ is $0$ outside the local window the AFT operation still uses key-value pairs from other areas. This is different from local transformers where embeddings outside the local window are completely not visible.
Here is the training code for a AFT Local model.
61from typing import Optional
62
63import torch
64from torch import nn
65
66from labml_helpers.module import ModuleThis is an implementation of AFT Local for auto-regression, where $Y_t$ only has visibility to tokens before $t$:
69class AFTLocalAutoregressive(Module):d_model is the number of features in the query, key and value vectors.seq_len is $T$s is the local window size $s$bias is whether to have a bias parameter for transformations for $Q$, $K$ and $V$.81    def __init__(self, d_model: int, seq_len: int, s: int, bias: bool = True):89        super().__init__()Local window size $s$
92        self.s = sThese transform the query, key and value vectors.
94        self.query = nn.Linear(d_model, d_model, bias=bias)
95        self.key = nn.Linear(d_model, d_model, bias=bias)
96        self.value = nn.Linear(d_model, d_model, bias=bias)Pair-wise positional biases $w \in \mathbb{R}^{T \times T}$
98        self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)Activation $\sigma$
100        self.activation = nn.Sigmoid()Output layer
102        self.output = nn.Linear(d_model, d_model)query, key and value are the tensors that store
collection of token embeddings for  query, key and value.
They have shape [seq_len, batch_size, d_model].
mask should be None. We keep this parameter so that we can use this as an
 drop in replacement for MHA.
104    def forward(self, *,
105                query: torch.Tensor,
106                key: torch.Tensor,
107                value: torch.Tensor,
108                mask: Optional[torch.Tensor] = None):query, key and value  have shape [seq_len, batch_size, d_model]
119        seq_len, _, _ = query.shape
120
121        query = self.query(query)
122        key = self.key(key)
123        value = self.value(value)We subtract $\max(K_{t’} + w_{t,t’})$ before calculating the exponents to stabilize the softmax calculation.
If $x_i$ is large $\exp(x_i)$ becomes huge and the computation of $\frac{\sum\exp(x_i)y_i}{\sum\exp(x_i)}$becomes unstable. Subtracting a constant before calculating the exponent from numerator and denominator will cancel out. and can help stabilize the computation. So we subtract $\max(x_i)$ to stabilize the computation.
Here the maximum is the higher of $\max(K_{t’} + w_{t,t’})$ and $\max(K_{t’})$
135        max_logit = torch.max($\max(K_{t’})$
137            key.max(dim=0)[0],$\max(K_{t’} + w_{t,t’})$
139            (key + self.pos_bias[:seq_len, :seq_len].max(dim=0)[0].view(-1, 1, 1)).max(dim=0)[0]
140        )[0]The numerator part $\sum_{t’=1}^{t-s} \exp(K_{t’}) \odot V_{t’}$
162        num = key.new_zeros(key.shape[1:])The denominator part $\sum_{t’=1}^{t-s} \exp(K_{t’})$
164        den = key.new_zeros(key.shape[1:])Output $Y$
166        y = key.new_zeros(key.shape)Iterate $t \in [0, T]$
168        for t in range(seq_len):$t - s + 1$
170            f = t - self.s + 1This actually mean $t - s \ge 1$ since we are indexing from $1$ in the math equations and indexing from $0$ in code
173            if f >= 1:$\exp(K_{t-s}$
175                exp_l = torch.exp(key[f - 1] - max_logit)Update numerator and denominator parts
177                num = num + exp_l * value[f - 1]
178                den = den + exp_lStart from the beginning if the local window size falls beyond
180            f = max(0, f)$\exp(K_{t’} + w_{t,t’})$
182            exp_l = torch.exp(key[f: t + 1] + self.pos_bias[t, f: t + 1].view(-1, 1, 1) - max_logit.squeeze(0))Numerator
185            n = num + (exp_l * value[f: t + 1]).sum(dim=0)Denominator
188            d = den + exp_l.sum(dim=0)192            y[t] = self.activation(query[t]) * n / dOutput layer
195        return self.output(y)