mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	relative attention notes
This commit is contained in:
		| @ -18,38 +18,17 @@ from labml_nn.transformers.models import FeedForward | |||||||
| from labml_nn.utils import clone_module_list | from labml_nn.utils import clone_module_list | ||||||
|  |  | ||||||
|  |  | ||||||
| class PrepareQueryForMultiHeadAttention(Module): |  | ||||||
|     """ |  | ||||||
|     ## Prepare query for multi-head attention |  | ||||||
|  |  | ||||||
|     This module does a linear transformation and splits the vector into given |  | ||||||
|     number of heads for multi-head attention. |  | ||||||
|     This is used to transform **key**, **query**, and **value** vectors. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     def __init__(self, d_model: int, heads: int, d_k: int, bias: bool): |  | ||||||
|         super().__init__() |  | ||||||
|         # Linear layer for linear transform |  | ||||||
|         self.linear = nn.Linear(d_model, heads * d_k, bias=bias) |  | ||||||
|         # Number of heads |  | ||||||
|         self.heads = heads |  | ||||||
|         # Number of dimensions in vectors in each head |  | ||||||
|         self.d_k = d_k |  | ||||||
|  |  | ||||||
|     def __call__(self, x: torch.Tensor): |  | ||||||
|         # Input has shape `[seq_len, batch_size, d_model]` |  | ||||||
|         batch_size, _ = x.shape |  | ||||||
|  |  | ||||||
|         # Linear transform |  | ||||||
|         x = self.linear(x) |  | ||||||
|         # Split into heads |  | ||||||
|         x = x.view(batch_size, self.heads, self.d_k) |  | ||||||
|  |  | ||||||
|         # Output has shape `[seq_len, batch_size, heads, d_k]` |  | ||||||
|         return x |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class FeedbackAttention(Module): | class FeedbackAttention(Module): | ||||||
|  |     """ | ||||||
|  |     ## Feedback Attention | ||||||
|  |  | ||||||
|  |     This is very similar to [Relative Multi-Head Attention](../relative_mha.html) | ||||||
|  |     but with some modifications. | ||||||
|  |  | ||||||
|  |     📝 Decided not to extend from [Relative Multi-Head Attention](../relative_mha.html) | ||||||
|  |      or [Multi-Head Attention](../mha.html) to improve readability. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): |     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |  | ||||||
| @ -57,7 +36,7 @@ class FeedbackAttention(Module): | |||||||
|         self.heads = heads |         self.heads = heads | ||||||
|  |  | ||||||
|         # These transform the `query`, `key` and `value` vectors for multi-headed attention. |         # These transform the `query`, `key` and `value` vectors for multi-headed attention. | ||||||
|         self.query = PrepareQueryForMultiHeadAttention(d_model, heads, self.d_k, False) |         self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) | ||||||
|         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) |         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) | ||||||
|         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) |         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) | ||||||
|  |  | ||||||
| @ -68,17 +47,60 @@ class FeedbackAttention(Module): | |||||||
|         # Scaling factor before the softmax |         # Scaling factor before the softmax | ||||||
|         self.scale = 1 / math.sqrt(self.d_k) |         self.scale = 1 / math.sqrt(self.d_k) | ||||||
|  |  | ||||||
|  |         # Softmax for attention along the time dimension of `key` | ||||||
|  |         self.softmax = nn.Softmax(dim=0) | ||||||
|  |  | ||||||
|  |         # Number of relative positions | ||||||
|  |         self.P = 2 ** 12 | ||||||
|  |  | ||||||
|  |         # Relative positional embeddings for key relative to the query. | ||||||
|  |         self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True) | ||||||
|  |         # Relative positional embedding bias for key relative to the query. | ||||||
|  |         self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True) | ||||||
|  |         # Positional embeddings for the query is independent of the position of the query | ||||||
|  |         self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) | ||||||
|  |  | ||||||
|         # We store attentions so that it can used for logging, or other computations if needed |         # We store attentions so that it can used for logging, or other computations if needed | ||||||
|         self.attn = None |         self.attn = None | ||||||
|  |  | ||||||
|         self.P = 2 ** 12 |  | ||||||
|  |  | ||||||
|         self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True) |  | ||||||
|         self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True) |  | ||||||
|         self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) |  | ||||||
|         self.softmax = nn.Softmax(dim=0) |  | ||||||
|  |  | ||||||
|     def get_scores(self, query: torch.Tensor, key: torch.Tensor): |     def get_scores(self, query: torch.Tensor, key: torch.Tensor): | ||||||
|  |         """ | ||||||
|  |         ### Get relative attention scores | ||||||
|  |  | ||||||
|  |         With absolute attention | ||||||
|  |  | ||||||
|  |         \begin{align} | ||||||
|  |         A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\ | ||||||
|  |                       &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + | ||||||
|  |                          \color{cyan}{Q_i^T} \color{lightgreen}{U_j} + | ||||||
|  |                          \color{cyan}{V_i^T} \color{lightgreen}{K_j} + | ||||||
|  |                          \color{cyan}{V_i^T} \color{lightgreen}{U_j} | ||||||
|  |         \end{align} | ||||||
|  |  | ||||||
|  |         where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of | ||||||
|  |          original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$ | ||||||
|  |          and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of | ||||||
|  |          absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$. | ||||||
|  |  | ||||||
|  |         They reason out that the attention to a given key should be the same regardless of | ||||||
|  |         the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$ | ||||||
|  |         with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$. | ||||||
|  |         🤔 May be worthwhile testing without this assumption. | ||||||
|  |  | ||||||
|  |         For the second and third terms relative positional encodings are introduced. | ||||||
|  |         So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is | ||||||
|  |         replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$ | ||||||
|  |         and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$. | ||||||
|  |  | ||||||
|  |         \begin{align} | ||||||
|  |         A^{rel}_{i,j} &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + | ||||||
|  |                          \color{cyan}{Q_i^T} \color{orange}{R_{i - j}} + | ||||||
|  |                          \color{orange}{v^T} \color{lightgreen}{K_j} + | ||||||
|  |                          \color{orange}{S_{i-j}} | ||||||
|  |         \end{align} | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # $\color{orange}{R_{i - j}}$ | ||||||
|         key_pos_emb = self.key_pos_embeddings[-key.shape[0]:] |         key_pos_emb = self.key_pos_embeddings[-key.shape[0]:] | ||||||
|         key_pos_bias = self.key_pos_bias[-key.shape[0]:] |         key_pos_bias = self.key_pos_bias[-key.shape[0]:] | ||||||
|         query_pos_bias = self.query_pos_bias[None, :, :] |         query_pos_bias = self.query_pos_bias[None, :, :] | ||||||
|  | |||||||
| @ -18,10 +18,10 @@ import math | |||||||
| from typing import Optional | from typing import Optional | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
|  | from torch import nn as nn | ||||||
|  |  | ||||||
| from labml import tracker | from labml import tracker | ||||||
| from labml_helpers.module import Module | from labml_helpers.module import Module | ||||||
| from torch import nn as nn |  | ||||||
| from torch.nn import functional as F |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class PrepareForMultiHeadAttention(Module): | class PrepareForMultiHeadAttention(Module): | ||||||
| @ -43,43 +43,51 @@ class PrepareForMultiHeadAttention(Module): | |||||||
|         self.d_k = d_k |         self.d_k = d_k | ||||||
|  |  | ||||||
|     def __call__(self, x: torch.Tensor): |     def __call__(self, x: torch.Tensor): | ||||||
|         # Input has shape `[seq_len, batch_size, d_model]` |         # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`. | ||||||
|         seq_len, batch_size, _ = x.shape |         # We apply the linear transformation of the last dimension and splits that into | ||||||
|  |         # the heads | ||||||
|  |         head_shape = x.shape[:-1] | ||||||
|  |  | ||||||
|         # Linear transform |         # Linear transform | ||||||
|         x = self.linear(x) |         x = self.linear(x) | ||||||
|         # Split into heads |  | ||||||
|         x = x.view(seq_len, batch_size, self.heads, self.d_k) |  | ||||||
|  |  | ||||||
|         # Output has shape `[seq_len, batch_size, heads, d_k]` |         # Split last dimension into heads | ||||||
|  |         x = x.view(*head_shape, self.heads, self.d_k) | ||||||
|  |  | ||||||
|  |         # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]` | ||||||
|         return x |         return x | ||||||
|  |  | ||||||
|  |  | ||||||
| class MultiHeadAttention(Module): | class MultiHeadAttention(Module): | ||||||
|  |     r""" | ||||||
|  |     ## Multi-Head Attention Module | ||||||
|  |  | ||||||
|  |     This computes scaled multi-headed attention for given `query`, `key` and `value` vectors. | ||||||
|  |  | ||||||
|  |     $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ | ||||||
|  |  | ||||||
|  |     In simple terms, it finds keys that matches the query, and get the values of | ||||||
|  |      those keys. | ||||||
|  |  | ||||||
|  |     It uses dot-product of query and key as the indicator of how matching they are. | ||||||
|  |     Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. | ||||||
|  |     This is done to avoid large dot-product values causing softmax to | ||||||
|  |     give very small gradients when $d_k$ is large. | ||||||
|  |  | ||||||
|  |     Softmax is calculate along the axis of of the sequence (or time). | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True): |     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True): | ||||||
|         """ |         """ | ||||||
|         ## Multi-Head Attention Module |  | ||||||
|  |  | ||||||
|         * `heads` is the number of heads. |         * `heads` is the number of heads. | ||||||
|         * `d_model` is the number of features in the `query`, `key` and `value` vectors. |         * `d_model` is the number of features in the `query`, `key` and `value` vectors. | ||||||
|  |  | ||||||
|         This computes scaled multi-headed attention for given `query`, `key` and `value` vectors. |  | ||||||
|  |  | ||||||
|         $$\mathop{Attention}(Q, K, V) = \mathop{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ |  | ||||||
|  |  | ||||||
|         In simple terms, it finds keys that matches the query, and get the values of |  | ||||||
|          those keys. |  | ||||||
|  |  | ||||||
|         It uses dot-product of query and key as the indicator of how matching they are. |  | ||||||
|         Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. |  | ||||||
|         This is done to avoid large dot-product values causing softmax to |  | ||||||
|         give very small gradients when $d_k$ is large. |  | ||||||
|  |  | ||||||
|         Softmax is calculate along the axis of of the sequence (or time). |  | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |  | ||||||
|  |         # Number of features per head | ||||||
|         self.d_k = d_model // heads |         self.d_k = d_model // heads | ||||||
|  |         # Number of heads | ||||||
|         self.heads = heads |         self.heads = heads | ||||||
|  |  | ||||||
|         # These transform the `query`, `key` and `value` vectors for multi-headed attention. |         # These transform the `query`, `key` and `value` vectors for multi-headed attention. | ||||||
| @ -87,6 +95,9 @@ class MultiHeadAttention(Module): | |||||||
|         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) |         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||||
|         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) |         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||||
|  |  | ||||||
|  |         # Softmax for attention along the time dimension of `key` | ||||||
|  |         self.softmax = nn.Softmax(dim=1) | ||||||
|  |  | ||||||
|         # Output layer |         # Output layer | ||||||
|         self.output = nn.Linear(d_model, d_model) |         self.output = nn.Linear(d_model, d_model) | ||||||
|         # Dropout |         # Dropout | ||||||
| @ -153,7 +164,7 @@ class MultiHeadAttention(Module): | |||||||
|  |  | ||||||
|         # $softmax$ attention along the key sequence dimension |         # $softmax$ attention along the key sequence dimension | ||||||
|         # $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$ |         # $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$ | ||||||
|         attn = F.softmax(scores, dim=1) |         attn = self.softmax(scores) | ||||||
|  |  | ||||||
|         # Save attentions if debugging |         # Save attentions if debugging | ||||||
|         tracker.debug('attn', attn) |         tracker.debug('attn', attn) | ||||||
|  | |||||||
| @ -51,56 +51,77 @@ class RelativeMultiHeadAttention(MultiHeadAttention): | |||||||
|         # The linear transformations doesn't need a bias since we take care of it when |         # The linear transformations doesn't need a bias since we take care of it when | ||||||
|         # calculating scores. |         # calculating scores. | ||||||
|         # However having a bias for `value` might make sense. |         # However having a bias for `value` might make sense. | ||||||
|         super().__init__(heads, d_model, dropout_prob, False) |         super().__init__(heads, d_model, dropout_prob, bias=False) | ||||||
|  |  | ||||||
|  |         # Number of relative positions | ||||||
|         self.P = 2 ** 12 |         self.P = 2 ** 12 | ||||||
|  |  | ||||||
|  |         # Relative positional embeddings for key relative to the query. | ||||||
|  |         # We need $2P$ embeddings because the keys can be before or after the query. | ||||||
|         self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True) |         self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True) | ||||||
|         self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) |         # Relative positional embedding bias for key relative to the query. | ||||||
|         self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True) |         self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True) | ||||||
|  |         # Positional embeddings for the query is independent of the position of the query | ||||||
|  |         self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) | ||||||
|  |  | ||||||
|     def get_scores(self, query: torch.Tensor, key: torch.Tensor): |     def get_scores(self, query: torch.Tensor, key: torch.Tensor): | ||||||
|         """ |         r""" | ||||||
|  |         ### Get relative attention scores | ||||||
|  |  | ||||||
|         With absolute attention |         With absolute attention | ||||||
|  |  | ||||||
|         \begin{align} |         \begin{align} | ||||||
|         A^{abs}_{i,j} &= lin_q(X^q_i + P_i)^T lin_k(X^k_j + P_j) \\ |         A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\ | ||||||
|                       &= Q_i^T K_j + Q_i^T U_j + V_i^T K_j + V_i^T U_j |                       &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + | ||||||
|  |                          \color{cyan}{Q_i^T} \color{lightgreen}{U_j} + | ||||||
|  |                          \color{cyan}{V_i^T} \color{lightgreen}{K_j} + | ||||||
|  |                          \color{cyan}{V_i^T} \color{lightgreen}{U_j} | ||||||
|         \end{align} |         \end{align} | ||||||
|  |  | ||||||
|         where $Q_i$, $K_j$, $V_i$, and $U_j$ are linear transformations of |         where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of | ||||||
|          orginal embeddings and positional encodings. |          original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$ | ||||||
|  |          and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of | ||||||
|  |          absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$. | ||||||
|  |  | ||||||
|         They reason out that the attention to a given key should be the same regardless of  |         They reason out that the attention to a given key should be the same regardless of | ||||||
|         the position of query. Hence replace $V_i^T K_j$ with a constant $v^T K_j$. |         the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$ | ||||||
|  |         with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$. | ||||||
|         🤔 May be worthwhile testing without this assumption. |         🤔 May be worthwhile testing without this assumption. | ||||||
|  |  | ||||||
|         For the second and third terms relative positional encodings are introduced. |         For the second and third terms relative positional encodings are introduced. | ||||||
|         So $Q_i^T U_j$ is replaced with $Q_i^T R_{i - j}$ and $V_i^T U_j$ with $S_{i-j}$. |         So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is | ||||||
|  |         replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$ | ||||||
|  |         and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$. | ||||||
|  |  | ||||||
|         \begin{align} |         \begin{align} | ||||||
|         A^{rel}_{i,j} &= Q_i^T K_j + Q_i^T R_{i - j} + v^T K_j + S_{i-j} |         A^{rel}_{i,j} &= \underset{\mathbf{A}}{\color{cyan}{Q_i^T} \color{lightgreen}{K_j}} + | ||||||
|  |                          \underset{\mathbf{B}}{\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}} + | ||||||
|  |                          \underset{\mathbf{C}}{\color{orange}{v^T} \color{lightgreen}{K_j}} + | ||||||
|  |                          \underset{\mathbf{D}}{\color{orange}{S_{i-j}}} | ||||||
|         \end{align} |         \end{align} | ||||||
|  |  | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         # $R_{i-j}$ pre-shift |         # $\color{orange}{R_k}$ | ||||||
|         key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]] |         key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]] | ||||||
|         # $S_{i-j}$ pre-shift |         # $\color{orange}{S_k}$ | ||||||
|         key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]] |         key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]] | ||||||
|         # $v^T$ |         # $\color{orange}{v^T}$ | ||||||
|         query_pos_bias = self.query_pos_bias[None, None, :, :] |         query_pos_bias = self.query_pos_bias[None, None, :, :] | ||||||
|  |  | ||||||
|         # $Q_i^T K_j + v^T K_j$ |         # ${(\mathbf{A} + \mathbf{C})}_{i,j} = \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + | ||||||
|  |         # \color{orange}{v^T} \color{lightgreen}{K_j}$ | ||||||
|         ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key) |         ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key) | ||||||
|         # $Q_i^T R_{i - j}$ pre-shift |         # $\mathbf{B'}_{i,k} = \color{cyan}{Q_i^T} \color{orange}{R_k}$ | ||||||
|         b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb) |         b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb) | ||||||
|         # $S_{i-j}$ pre-shift |         # $\mathbf{D'}_{i,k} = \color{orange}{S_k}$ | ||||||
|         d = key_pos_bias[None, :, None, :] |         d = key_pos_bias[None, :, None, :] | ||||||
|         # $Q_i^T R_{i - j} + S_{i-j}$ |         # Shift the rows of $\mathbf{(B' + D')}_{i,k}$ | ||||||
|  |         # to get $$\mathbf{(B + D)}_{i,j} = \mathbf{(B' + D')}_{i,i - j}$$ | ||||||
|         bd = shift_right(b + d) |         bd = shift_right(b + d) | ||||||
|  |         # Remove extra positions | ||||||
|         bd = bd[:, -key.shape[0]:] |         bd = bd[:, -key.shape[0]:] | ||||||
|  |  | ||||||
|  |         # Return the sum $\mathbf{(A + B + C + D)}_{i,j}$ | ||||||
|         return ac + bd |         return ac + bd | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri