mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	relative attention notes
This commit is contained in:
		| @ -18,38 +18,17 @@ from labml_nn.transformers.models import FeedForward | ||||
| from labml_nn.utils import clone_module_list | ||||
|  | ||||
|  | ||||
| class PrepareQueryForMultiHeadAttention(Module): | ||||
|     """ | ||||
|     ## Prepare query for multi-head attention | ||||
|  | ||||
|     This module does a linear transformation and splits the vector into given | ||||
|     number of heads for multi-head attention. | ||||
|     This is used to transform **key**, **query**, and **value** vectors. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, d_model: int, heads: int, d_k: int, bias: bool): | ||||
|         super().__init__() | ||||
|         # Linear layer for linear transform | ||||
|         self.linear = nn.Linear(d_model, heads * d_k, bias=bias) | ||||
|         # Number of heads | ||||
|         self.heads = heads | ||||
|         # Number of dimensions in vectors in each head | ||||
|         self.d_k = d_k | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|         # Input has shape `[seq_len, batch_size, d_model]` | ||||
|         batch_size, _ = x.shape | ||||
|  | ||||
|         # Linear transform | ||||
|         x = self.linear(x) | ||||
|         # Split into heads | ||||
|         x = x.view(batch_size, self.heads, self.d_k) | ||||
|  | ||||
|         # Output has shape `[seq_len, batch_size, heads, d_k]` | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class FeedbackAttention(Module): | ||||
|     """ | ||||
|     ## Feedback Attention | ||||
|  | ||||
|     This is very similar to [Relative Multi-Head Attention](../relative_mha.html) | ||||
|     but with some modifications. | ||||
|  | ||||
|     📝 Decided not to extend from [Relative Multi-Head Attention](../relative_mha.html) | ||||
|      or [Multi-Head Attention](../mha.html) to improve readability. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): | ||||
|         super().__init__() | ||||
|  | ||||
| @ -57,7 +36,7 @@ class FeedbackAttention(Module): | ||||
|         self.heads = heads | ||||
|  | ||||
|         # These transform the `query`, `key` and `value` vectors for multi-headed attention. | ||||
|         self.query = PrepareQueryForMultiHeadAttention(d_model, heads, self.d_k, False) | ||||
|         self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) | ||||
|         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) | ||||
|         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) | ||||
|  | ||||
| @ -68,17 +47,60 @@ class FeedbackAttention(Module): | ||||
|         # Scaling factor before the softmax | ||||
|         self.scale = 1 / math.sqrt(self.d_k) | ||||
|  | ||||
|         # Softmax for attention along the time dimension of `key` | ||||
|         self.softmax = nn.Softmax(dim=0) | ||||
|  | ||||
|         # Number of relative positions | ||||
|         self.P = 2 ** 12 | ||||
|  | ||||
|         # Relative positional embeddings for key relative to the query. | ||||
|         self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True) | ||||
|         # Relative positional embedding bias for key relative to the query. | ||||
|         self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True) | ||||
|         # Positional embeddings for the query is independent of the position of the query | ||||
|         self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) | ||||
|  | ||||
|         # We store attentions so that it can used for logging, or other computations if needed | ||||
|         self.attn = None | ||||
|  | ||||
|         self.P = 2 ** 12 | ||||
|  | ||||
|         self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True) | ||||
|         self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True) | ||||
|         self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) | ||||
|         self.softmax = nn.Softmax(dim=0) | ||||
|  | ||||
|     def get_scores(self, query: torch.Tensor, key: torch.Tensor): | ||||
|         """ | ||||
|         ### Get relative attention scores | ||||
|  | ||||
|         With absolute attention | ||||
|  | ||||
|         \begin{align} | ||||
|         A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\ | ||||
|                       &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + | ||||
|                          \color{cyan}{Q_i^T} \color{lightgreen}{U_j} + | ||||
|                          \color{cyan}{V_i^T} \color{lightgreen}{K_j} + | ||||
|                          \color{cyan}{V_i^T} \color{lightgreen}{U_j} | ||||
|         \end{align} | ||||
|  | ||||
|         where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of | ||||
|          original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$ | ||||
|          and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of | ||||
|          absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$. | ||||
|  | ||||
|         They reason out that the attention to a given key should be the same regardless of | ||||
|         the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$ | ||||
|         with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$. | ||||
|         🤔 May be worthwhile testing without this assumption. | ||||
|  | ||||
|         For the second and third terms relative positional encodings are introduced. | ||||
|         So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is | ||||
|         replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$ | ||||
|         and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$. | ||||
|  | ||||
|         \begin{align} | ||||
|         A^{rel}_{i,j} &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + | ||||
|                          \color{cyan}{Q_i^T} \color{orange}{R_{i - j}} + | ||||
|                          \color{orange}{v^T} \color{lightgreen}{K_j} + | ||||
|                          \color{orange}{S_{i-j}} | ||||
|         \end{align} | ||||
|         """ | ||||
|  | ||||
|         # $\color{orange}{R_{i - j}}$ | ||||
|         key_pos_emb = self.key_pos_embeddings[-key.shape[0]:] | ||||
|         key_pos_bias = self.key_pos_bias[-key.shape[0]:] | ||||
|         query_pos_bias = self.query_pos_bias[None, :, :] | ||||
|  | ||||
| @ -18,10 +18,10 @@ import math | ||||
| from typing import Optional | ||||
|  | ||||
| import torch | ||||
| from torch import nn as nn | ||||
|  | ||||
| from labml import tracker | ||||
| from labml_helpers.module import Module | ||||
| from torch import nn as nn | ||||
| from torch.nn import functional as F | ||||
|  | ||||
|  | ||||
| class PrepareForMultiHeadAttention(Module): | ||||
| @ -43,29 +43,28 @@ class PrepareForMultiHeadAttention(Module): | ||||
|         self.d_k = d_k | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|         # Input has shape `[seq_len, batch_size, d_model]` | ||||
|         seq_len, batch_size, _ = x.shape | ||||
|         # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`. | ||||
|         # We apply the linear transformation of the last dimension and splits that into | ||||
|         # the heads | ||||
|         head_shape = x.shape[:-1] | ||||
|  | ||||
|         # Linear transform | ||||
|         x = self.linear(x) | ||||
|         # Split into heads | ||||
|         x = x.view(seq_len, batch_size, self.heads, self.d_k) | ||||
|  | ||||
|         # Output has shape `[seq_len, batch_size, heads, d_k]` | ||||
|         # Split last dimension into heads | ||||
|         x = x.view(*head_shape, self.heads, self.d_k) | ||||
|  | ||||
|         # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]` | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class MultiHeadAttention(Module): | ||||
|     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True): | ||||
|         """ | ||||
|     r""" | ||||
|     ## Multi-Head Attention Module | ||||
|  | ||||
|         * `heads` is the number of heads. | ||||
|         * `d_model` is the number of features in the `query`, `key` and `value` vectors. | ||||
|  | ||||
|     This computes scaled multi-headed attention for given `query`, `key` and `value` vectors. | ||||
|  | ||||
|         $$\mathop{Attention}(Q, K, V) = \mathop{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ | ||||
|     $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ | ||||
|  | ||||
|     In simple terms, it finds keys that matches the query, and get the values of | ||||
|      those keys. | ||||
| @ -78,8 +77,17 @@ class MultiHeadAttention(Module): | ||||
|     Softmax is calculate along the axis of of the sequence (or time). | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True): | ||||
|         """ | ||||
|         * `heads` is the number of heads. | ||||
|         * `d_model` is the number of features in the `query`, `key` and `value` vectors. | ||||
|         """ | ||||
|  | ||||
|         super().__init__() | ||||
|  | ||||
|         # Number of features per head | ||||
|         self.d_k = d_model // heads | ||||
|         # Number of heads | ||||
|         self.heads = heads | ||||
|  | ||||
|         # These transform the `query`, `key` and `value` vectors for multi-headed attention. | ||||
| @ -87,6 +95,9 @@ class MultiHeadAttention(Module): | ||||
|         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||
|         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||
|  | ||||
|         # Softmax for attention along the time dimension of `key` | ||||
|         self.softmax = nn.Softmax(dim=1) | ||||
|  | ||||
|         # Output layer | ||||
|         self.output = nn.Linear(d_model, d_model) | ||||
|         # Dropout | ||||
| @ -153,7 +164,7 @@ class MultiHeadAttention(Module): | ||||
|  | ||||
|         # $softmax$ attention along the key sequence dimension | ||||
|         # $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$ | ||||
|         attn = F.softmax(scores, dim=1) | ||||
|         attn = self.softmax(scores) | ||||
|  | ||||
|         # Save attentions if debugging | ||||
|         tracker.debug('attn', attn) | ||||
|  | ||||
| @ -51,56 +51,77 @@ class RelativeMultiHeadAttention(MultiHeadAttention): | ||||
|         # The linear transformations doesn't need a bias since we take care of it when | ||||
|         # calculating scores. | ||||
|         # However having a bias for `value` might make sense. | ||||
|         super().__init__(heads, d_model, dropout_prob, False) | ||||
|         super().__init__(heads, d_model, dropout_prob, bias=False) | ||||
|  | ||||
|         # Number of relative positions | ||||
|         self.P = 2 ** 12 | ||||
|  | ||||
|         # Relative positional embeddings for key relative to the query. | ||||
|         # We need $2P$ embeddings because the keys can be before or after the query. | ||||
|         self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True) | ||||
|         self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) | ||||
|         # Relative positional embedding bias for key relative to the query. | ||||
|         self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True) | ||||
|         # Positional embeddings for the query is independent of the position of the query | ||||
|         self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) | ||||
|  | ||||
|     def get_scores(self, query: torch.Tensor, key: torch.Tensor): | ||||
|         """ | ||||
|         r""" | ||||
|         ### Get relative attention scores | ||||
|  | ||||
|         With absolute attention | ||||
|  | ||||
|         \begin{align} | ||||
|         A^{abs}_{i,j} &= lin_q(X^q_i + P_i)^T lin_k(X^k_j + P_j) \\ | ||||
|                       &= Q_i^T K_j + Q_i^T U_j + V_i^T K_j + V_i^T U_j | ||||
|         A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\ | ||||
|                       &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + | ||||
|                          \color{cyan}{Q_i^T} \color{lightgreen}{U_j} + | ||||
|                          \color{cyan}{V_i^T} \color{lightgreen}{K_j} + | ||||
|                          \color{cyan}{V_i^T} \color{lightgreen}{U_j} | ||||
|         \end{align} | ||||
|  | ||||
|         where $Q_i$, $K_j$, $V_i$, and $U_j$ are linear transformations of | ||||
|          orginal embeddings and positional encodings. | ||||
|         where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of | ||||
|          original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$ | ||||
|          and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of | ||||
|          absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$. | ||||
|  | ||||
|         They reason out that the attention to a given key should be the same regardless of | ||||
|         the position of query. Hence replace $V_i^T K_j$ with a constant $v^T K_j$. | ||||
|         the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$ | ||||
|         with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$. | ||||
|         🤔 May be worthwhile testing without this assumption. | ||||
|  | ||||
|         For the second and third terms relative positional encodings are introduced. | ||||
|         So $Q_i^T U_j$ is replaced with $Q_i^T R_{i - j}$ and $V_i^T U_j$ with $S_{i-j}$. | ||||
|         So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is | ||||
|         replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$ | ||||
|         and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$. | ||||
|  | ||||
|         \begin{align} | ||||
|         A^{rel}_{i,j} &= Q_i^T K_j + Q_i^T R_{i - j} + v^T K_j + S_{i-j} | ||||
|         A^{rel}_{i,j} &= \underset{\mathbf{A}}{\color{cyan}{Q_i^T} \color{lightgreen}{K_j}} + | ||||
|                          \underset{\mathbf{B}}{\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}} + | ||||
|                          \underset{\mathbf{C}}{\color{orange}{v^T} \color{lightgreen}{K_j}} + | ||||
|                          \underset{\mathbf{D}}{\color{orange}{S_{i-j}}} | ||||
|         \end{align} | ||||
|  | ||||
|         """ | ||||
|  | ||||
|         # $R_{i-j}$ pre-shift | ||||
|         # $\color{orange}{R_k}$ | ||||
|         key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]] | ||||
|         # $S_{i-j}$ pre-shift | ||||
|         # $\color{orange}{S_k}$ | ||||
|         key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]] | ||||
|         # $v^T$ | ||||
|         # $\color{orange}{v^T}$ | ||||
|         query_pos_bias = self.query_pos_bias[None, None, :, :] | ||||
|  | ||||
|         # $Q_i^T K_j + v^T K_j$ | ||||
|         # ${(\mathbf{A} + \mathbf{C})}_{i,j} = \color{cyan}{Q_i^T} \color{lightgreen}{K_j} + | ||||
|         # \color{orange}{v^T} \color{lightgreen}{K_j}$ | ||||
|         ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key) | ||||
|         # $Q_i^T R_{i - j}$ pre-shift | ||||
|         # $\mathbf{B'}_{i,k} = \color{cyan}{Q_i^T} \color{orange}{R_k}$ | ||||
|         b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb) | ||||
|         # $S_{i-j}$ pre-shift | ||||
|         # $\mathbf{D'}_{i,k} = \color{orange}{S_k}$ | ||||
|         d = key_pos_bias[None, :, None, :] | ||||
|         # $Q_i^T R_{i - j} + S_{i-j}$ | ||||
|         # Shift the rows of $\mathbf{(B' + D')}_{i,k}$ | ||||
|         # to get $$\mathbf{(B + D)}_{i,j} = \mathbf{(B' + D')}_{i,i - j}$$ | ||||
|         bd = shift_right(b + d) | ||||
|         # Remove extra positions | ||||
|         bd = bd[:, -key.shape[0]:] | ||||
|  | ||||
|         # Return the sum $\mathbf{(A + B + C + D)}_{i,j}$ | ||||
|         return ac + bd | ||||
|  | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri