mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 04:37:46 +08:00
feedback transformer notes
This commit is contained in:
@ -22,23 +22,31 @@ class FeedbackAttention(Module):
|
||||
"""
|
||||
## Feedback Attention
|
||||
|
||||
This is very similar to [Relative Multi-Head Attention](../relative_mha.html)
|
||||
but with some modifications.
|
||||
|
||||
📝 Decided not to extend from [Relative Multi-Head Attention](../relative_mha.html)
|
||||
or [Multi-Head Attention](../mha.html) to improve readability.
|
||||
This module computes recurrent attention similar to attention from original transformers
|
||||
paper.
|
||||
|
||||
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^T K}{\sqrt{d_k}}\Bigg)V$$
|
||||
"""
|
||||
|
||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||
"""
|
||||
* 'heads' is the number of attention heads
|
||||
* `d_model` is the number of features in the transformer
|
||||
* `dropout_prob` is the attention dropout probability
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Number of features per head
|
||||
self.d_k = d_model // heads
|
||||
#
|
||||
self.heads = heads
|
||||
|
||||
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
||||
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
|
||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
|
||||
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
|
||||
|
||||
# Output layer
|
||||
self.output = nn.Linear(d_model, d_model)
|
||||
@ -55,8 +63,6 @@ class FeedbackAttention(Module):
|
||||
|
||||
# Relative positional embeddings for key relative to the query.
|
||||
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
|
||||
# Relative positional embedding bias for key relative to the query.
|
||||
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
|
||||
# Positional embeddings for the query is independent of the position of the query
|
||||
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
|
||||
|
||||
@ -65,47 +71,27 @@ class FeedbackAttention(Module):
|
||||
|
||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||
"""
|
||||
### Get relative attention scores
|
||||
### Get attention scores
|
||||
|
||||
\begin{align}
|
||||
A_{j} &= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
|
||||
&= \color{cyan}{Q^T} \color{lightgreen}{K_j} +
|
||||
\color{cyan}{Q^T} \color{lightgreen}{U_j} +
|
||||
\color{cyan}{V^T} \color{lightgreen}{K_j} +
|
||||
\color{cyan}{V^T} \color{lightgreen}{U_j}
|
||||
A_{j} &== Q^T K_j \\
|
||||
&= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
|
||||
&= (\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}
|
||||
\end{align}
|
||||
|
||||
where $\color{cyan}{Q}, \color{lightgreen}{K_j}$, are linear transformations of
|
||||
original embeddings $\color{cyan}{X^q}, \color{lightgreen}{X^k_j}$
|
||||
and $\color{cyan}{V}, \color{lightgreen}{U_j}$ are linear transformations of
|
||||
absolute positional encodings $\color{cyan}{P}, \color{lightgreen}{P_j}$.
|
||||
|
||||
We replace $\color{cyan}{V^T} \color{lightgreen}{U_j}$ with
|
||||
$S_j$.
|
||||
|
||||
\begin{align}
|
||||
A^{rel}_{j} &= \underset{\mathbf{A}}{\color{cyan}{Q^T} \color{lightgreen}{K_j}} +
|
||||
\underset{\mathbf{B}}{\color{cyan}{Q^T} \color{lightgreen}{U_j}} +
|
||||
\underset{\mathbf{C}}{\color{cyan}{V^T} \color{lightgreen}{K_j}} +
|
||||
\underset{\mathbf{D}}{\color{orange}{S_j}}
|
||||
\end{align}
|
||||
"""
|
||||
|
||||
# $\color{lightgreen}{U_j}$
|
||||
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
|
||||
# $\color{orange}{S_j}$
|
||||
key_pos_bias = self.key_pos_bias[-key.shape[0]:]
|
||||
# $\color{cyan}{V^T}$
|
||||
query_pos_bias = self.query_pos_bias[None, :, :]
|
||||
|
||||
# $\underset{\mathbf{A}}{\color{cyan}{Q^T} \color{lightgreen}{K_j}} +
|
||||
# \underset{\mathbf{C}}{\color{cyan}{V^T} \color{lightgreen}{K_j}}$
|
||||
ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)
|
||||
# $\underset{\mathbf{B}}{\color{cyan}{Q^T} \color{lightgreen}{U_j}} +
|
||||
# \underset{\mathbf{D}}{\color{orange}{S_j}}$
|
||||
bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]
|
||||
|
||||
return ac + bd
|
||||
# $(\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}$
|
||||
return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
|
||||
|
||||
def __call__(self, *,
|
||||
query: torch.Tensor,
|
||||
@ -147,22 +133,39 @@ class FeedbackAttention(Module):
|
||||
|
||||
|
||||
class FeedbackTransformerLayer(Module):
|
||||
"""
|
||||
## Feedback Transformer Layer
|
||||
|
||||
This implements a single transformer layer in the feedback transformer.
|
||||
"""
|
||||
|
||||
def __init__(self, *,
|
||||
d_model: int,
|
||||
attn: FeedbackAttention,
|
||||
feed_forward: FeedForward,
|
||||
dropout_prob: float):
|
||||
"""
|
||||
* `d_model` is the number of features in the transformer
|
||||
* `attn` is the feedback attention module
|
||||
* `feed_forward` is the position-wise feed forward layer
|
||||
* `dropout_prob` is the dropout probability for dropout layers after attention and feed-forward
|
||||
"""
|
||||
super().__init__()
|
||||
# Transformer size $d_{model}$
|
||||
self.size = d_model
|
||||
#
|
||||
self.attn = attn
|
||||
self.feed_forward = feed_forward
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
# Normalization layers
|
||||
self.norm_self_attn = nn.LayerNorm([d_model])
|
||||
self.norm_ff = nn.LayerNorm([d_model])
|
||||
|
||||
def __call__(self, *,
|
||||
x: torch.Tensor,
|
||||
mem: Optional[torch.Tensor]):
|
||||
# If there is memory
|
||||
if mem is not None:
|
||||
# Normalize the vectors before doing self attention
|
||||
z = self.norm_self_attn(x)
|
||||
@ -178,49 +181,66 @@ class FeedbackTransformerLayer(Module):
|
||||
# Add the feed-forward results back
|
||||
x = x + self.dropout(ff)
|
||||
|
||||
#
|
||||
return x
|
||||
|
||||
|
||||
class FeedbackTransformer(Module):
|
||||
"""
|
||||
## Transformer Encoder
|
||||
## Feedback Transformer Module
|
||||
"""
|
||||
|
||||
def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
|
||||
"""
|
||||
* `layer` is the feedback transformer layer, which we clone for each layer
|
||||
* `n_layers` is the number of layers in the transformer
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
# Make copies of the transformer layer
|
||||
self.layers = clone_module_list(layer, n_layers)
|
||||
# Final normalization layer
|
||||
self.norm = nn.LayerNorm([layer.size])
|
||||
#
|
||||
# Memory vectors are computed as a weighted sum of representations of each layer.
|
||||
# This is the weights parameter for that.
|
||||
self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
|
||||
#
|
||||
# Softmax for weights before taking the weighted sum
|
||||
self.softmax = nn.Softmax(0)
|
||||
|
||||
def __call__(self, x_seq: torch.Tensor):
|
||||
"""
|
||||
* `x_seq` is the input with shape `[seq_len, batch_size, d_model]`
|
||||
"""
|
||||
|
||||
# Split the input to a list along the sequence axis
|
||||
x_seq = torch.unbind(x_seq, dim=0)
|
||||
# List to store the outputs
|
||||
res = []
|
||||
# List to store the memory vectors
|
||||
mem = []
|
||||
# For each input step
|
||||
for x in x_seq:
|
||||
# List of embeddings from each layer
|
||||
emb = [x]
|
||||
# List to store layer outputs
|
||||
layer_outputs = [x]
|
||||
|
||||
# If there is memory, stack them into a vector
|
||||
mem_tensor = torch.stack(mem) if mem else None
|
||||
|
||||
# Run through each layer
|
||||
for layer in self.layers:
|
||||
# Get layer output
|
||||
x = layer(x=x, mem=mem_tensor)
|
||||
emb.append(x)
|
||||
# Append them to the list of layer outputs
|
||||
layer_outputs.append(x)
|
||||
|
||||
# Stack embeddings
|
||||
emb = torch.stack(emb)
|
||||
# Weighted sum of embeddings
|
||||
mem.append(torch.einsum('lbd,l->bd', emb, self.softmax(self.weights)))
|
||||
# Stack the layer outputs to a tensor
|
||||
layer_outputs = torch.stack(layer_outputs)
|
||||
# Calculate the memory vector as a weighted sum of layer outputs
|
||||
mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))
|
||||
# Append the output to results
|
||||
res.append(x)
|
||||
|
||||
# Stack the output tensors
|
||||
res = torch.stack(res)
|
||||
# Normalize
|
||||
# Normalize the output
|
||||
return self.norm(res)
|
||||
|
||||
@ -91,9 +91,9 @@ class MultiHeadAttention(Module):
|
||||
self.heads = heads
|
||||
|
||||
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
||||
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
||||
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
||||
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
||||
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
|
||||
|
||||
# Softmax for attention along the time dimension of `key`
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
|
||||
Reference in New Issue
Block a user