diff --git a/labml_nn/transformers/feedback/__init__.py b/labml_nn/transformers/feedback/__init__.py index 56f1c337..91dc9815 100644 --- a/labml_nn/transformers/feedback/__init__.py +++ b/labml_nn/transformers/feedback/__init__.py @@ -18,38 +18,17 @@ from labml_nn.transformers.models import FeedForward from labml_nn.utils import clone_module_list -class PrepareQueryForMultiHeadAttention(Module): - """ - ## Prepare query for multi-head attention - - This module does a linear transformation and splits the vector into given - number of heads for multi-head attention. - This is used to transform **key**, **query**, and **value** vectors. - """ - - def __init__(self, d_model: int, heads: int, d_k: int, bias: bool): - super().__init__() - # Linear layer for linear transform - self.linear = nn.Linear(d_model, heads * d_k, bias=bias) - # Number of heads - self.heads = heads - # Number of dimensions in vectors in each head - self.d_k = d_k - - def __call__(self, x: torch.Tensor): - # Input has shape `[seq_len, batch_size, d_model]` - batch_size, _ = x.shape - - # Linear transform - x = self.linear(x) - # Split into heads - x = x.view(batch_size, self.heads, self.d_k) - - # Output has shape `[seq_len, batch_size, heads, d_k]` - return x - - class FeedbackAttention(Module): + """ + ## Feedback Attention + + This is very similar to [Relative Multi-Head Attention](../relative_mha.html) + but with some modifications. + + 📝 Decided not to extend from [Relative Multi-Head Attention](../relative_mha.html) + or [Multi-Head Attention](../mha.html) to improve readability. + """ + def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): super().__init__() @@ -57,7 +36,7 @@ class FeedbackAttention(Module): self.heads = heads # These transform the `query`, `key` and `value` vectors for multi-headed attention. - self.query = PrepareQueryForMultiHeadAttention(d_model, heads, self.d_k, False) + self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) @@ -68,17 +47,60 @@ class FeedbackAttention(Module): # Scaling factor before the softmax self.scale = 1 / math.sqrt(self.d_k) + # Softmax for attention along the time dimension of `key` + self.softmax = nn.Softmax(dim=0) + + # Number of relative positions + self.P = 2 ** 12 + + # Relative positional embeddings for key relative to the query. + self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True) + # Relative positional embedding bias for key relative to the query. + self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True) + # Positional embeddings for the query is independent of the position of the query + self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) + # We store attentions so that it can used for logging, or other computations if needed self.attn = None - self.P = 2 ** 12 - - self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True) - self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True) - self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) - self.softmax = nn.Softmax(dim=0) - def get_scores(self, query: torch.Tensor, key: torch.Tensor): + """ + ### Get relative attention scores + + With absolute attention + + \begin{align} + A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\ + &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + + \color{cyan}{Q_i^T} \color{lightgreen}{U_j} + + \color{cyan}{V_i^T} \color{lightgreen}{K_j} + + \color{cyan}{V_i^T} \color{lightgreen}{U_j} + \end{align} + + where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of + original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$ + and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of + absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$. + + They reason out that the attention to a given key should be the same regardless of + the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$ + with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$. + 🤔 May be worthwhile testing without this assumption. + + For the second and third terms relative positional encodings are introduced. + So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is + replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$ + and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$. + + \begin{align} + A^{rel}_{i,j} &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + + \color{cyan}{Q_i^T} \color{orange}{R_{i - j}} + + \color{orange}{v^T} \color{lightgreen}{K_j} + + \color{orange}{S_{i-j}} + \end{align} + """ + + # $\color{orange}{R_{i - j}}$ key_pos_emb = self.key_pos_embeddings[-key.shape[0]:] key_pos_bias = self.key_pos_bias[-key.shape[0]:] query_pos_bias = self.query_pos_bias[None, :, :] diff --git a/labml_nn/transformers/mha.py b/labml_nn/transformers/mha.py index af093293..76f60c26 100644 --- a/labml_nn/transformers/mha.py +++ b/labml_nn/transformers/mha.py @@ -18,10 +18,10 @@ import math from typing import Optional import torch +from torch import nn as nn + from labml import tracker from labml_helpers.module import Module -from torch import nn as nn -from torch.nn import functional as F class PrepareForMultiHeadAttention(Module): @@ -43,43 +43,51 @@ class PrepareForMultiHeadAttention(Module): self.d_k = d_k def __call__(self, x: torch.Tensor): - # Input has shape `[seq_len, batch_size, d_model]` - seq_len, batch_size, _ = x.shape + # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`. + # We apply the linear transformation of the last dimension and splits that into + # the heads + head_shape = x.shape[:-1] # Linear transform x = self.linear(x) - # Split into heads - x = x.view(seq_len, batch_size, self.heads, self.d_k) - # Output has shape `[seq_len, batch_size, heads, d_k]` + # Split last dimension into heads + x = x.view(*head_shape, self.heads, self.d_k) + + # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]` return x class MultiHeadAttention(Module): + r""" + ## Multi-Head Attention Module + + This computes scaled multi-headed attention for given `query`, `key` and `value` vectors. + + $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ + + In simple terms, it finds keys that matches the query, and get the values of + those keys. + + It uses dot-product of query and key as the indicator of how matching they are. + Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. + This is done to avoid large dot-product values causing softmax to + give very small gradients when $d_k$ is large. + + Softmax is calculate along the axis of of the sequence (or time). + """ + def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True): """ - ## Multi-Head Attention Module - * `heads` is the number of heads. * `d_model` is the number of features in the `query`, `key` and `value` vectors. - - This computes scaled multi-headed attention for given `query`, `key` and `value` vectors. - - $$\mathop{Attention}(Q, K, V) = \mathop{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ - - In simple terms, it finds keys that matches the query, and get the values of - those keys. - - It uses dot-product of query and key as the indicator of how matching they are. - Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. - This is done to avoid large dot-product values causing softmax to - give very small gradients when $d_k$ is large. - - Softmax is calculate along the axis of of the sequence (or time). """ super().__init__() + + # Number of features per head self.d_k = d_model // heads + # Number of heads self.heads = heads # These transform the `query`, `key` and `value` vectors for multi-headed attention. @@ -87,6 +95,9 @@ class MultiHeadAttention(Module): self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) + # Softmax for attention along the time dimension of `key` + self.softmax = nn.Softmax(dim=1) + # Output layer self.output = nn.Linear(d_model, d_model) # Dropout @@ -153,7 +164,7 @@ class MultiHeadAttention(Module): # $softmax$ attention along the key sequence dimension # $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$ - attn = F.softmax(scores, dim=1) + attn = self.softmax(scores) # Save attentions if debugging tracker.debug('attn', attn) diff --git a/labml_nn/transformers/relative_mha.py b/labml_nn/transformers/relative_mha.py index be8d79cb..2428e461 100644 --- a/labml_nn/transformers/relative_mha.py +++ b/labml_nn/transformers/relative_mha.py @@ -51,56 +51,77 @@ class RelativeMultiHeadAttention(MultiHeadAttention): # The linear transformations doesn't need a bias since we take care of it when # calculating scores. # However having a bias for `value` might make sense. - super().__init__(heads, d_model, dropout_prob, False) + super().__init__(heads, d_model, dropout_prob, bias=False) + # Number of relative positions self.P = 2 ** 12 + # Relative positional embeddings for key relative to the query. + # We need $2P$ embeddings because the keys can be before or after the query. self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True) - self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) + # Relative positional embedding bias for key relative to the query. self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True) + # Positional embeddings for the query is independent of the position of the query + self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) def get_scores(self, query: torch.Tensor, key: torch.Tensor): - """ + r""" + ### Get relative attention scores + With absolute attention \begin{align} - A^{abs}_{i,j} &= lin_q(X^q_i + P_i)^T lin_k(X^k_j + P_j) \\ - &= Q_i^T K_j + Q_i^T U_j + V_i^T K_j + V_i^T U_j + A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\ + &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + + \color{cyan}{Q_i^T} \color{lightgreen}{U_j} + + \color{cyan}{V_i^T} \color{lightgreen}{K_j} + + \color{cyan}{V_i^T} \color{lightgreen}{U_j} \end{align} - where $Q_i$, $K_j$, $V_i$, and $U_j$ are linear transformations of - orginal embeddings and positional encodings. + where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of + original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$ + and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of + absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$. - They reason out that the attention to a given key should be the same regardless of - the position of query. Hence replace $V_i^T K_j$ with a constant $v^T K_j$. + They reason out that the attention to a given key should be the same regardless of + the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$ + with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$. 🤔 May be worthwhile testing without this assumption. For the second and third terms relative positional encodings are introduced. - So $Q_i^T U_j$ is replaced with $Q_i^T R_{i - j}$ and $V_i^T U_j$ with $S_{i-j}$. + So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is + replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$ + and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$. \begin{align} - A^{rel}_{i,j} &= Q_i^T K_j + Q_i^T R_{i - j} + v^T K_j + S_{i-j} + A^{rel}_{i,j} &= \underset{\mathbf{A}}{\color{cyan}{Q_i^T} \color{lightgreen}{K_j}} + + \underset{\mathbf{B}}{\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}} + + \underset{\mathbf{C}}{\color{orange}{v^T} \color{lightgreen}{K_j}} + + \underset{\mathbf{D}}{\color{orange}{S_{i-j}}} \end{align} - """ - # $R_{i-j}$ pre-shift + # $\color{orange}{R_k}$ key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]] - # $S_{i-j}$ pre-shift + # $\color{orange}{S_k}$ key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]] - # $v^T$ + # $\color{orange}{v^T}$ query_pos_bias = self.query_pos_bias[None, None, :, :] - # $Q_i^T K_j + v^T K_j$ + # ${(\mathbf{A} + \mathbf{C})}_{i,j} = \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + + # \color{orange}{v^T} \color{lightgreen}{K_j}$ ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key) - # $Q_i^T R_{i - j}$ pre-shift + # $\mathbf{B'}_{i,k} = \color{cyan}{Q_i^T} \color{orange}{R_k}$ b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb) - # $S_{i-j}$ pre-shift + # $\mathbf{D'}_{i,k} = \color{orange}{S_k}$ d = key_pos_bias[None, :, None, :] - # $Q_i^T R_{i - j} + S_{i-j}$ + # Shift the rows of $\mathbf{(B' + D')}_{i,k}$ + # to get $$\mathbf{(B + D)}_{i,j} = \mathbf{(B' + D')}_{i,i - j}$$ bd = shift_right(b + d) + # Remove extra positions bd = bd[:, -key.shape[0]:] + # Return the sum $\mathbf{(A + B + C + D)}_{i,j}$ return ac + bd