feedback transformer notes

This commit is contained in:
Varuna Jayasiri
2021-01-10 10:22:46 +05:30
parent 5874106161
commit 1dc4ff825f
2 changed files with 69 additions and 49 deletions

View File

@ -22,23 +22,31 @@ class FeedbackAttention(Module):
"""
## Feedback Attention
This is very similar to [Relative Multi-Head Attention](../relative_mha.html)
but with some modifications.
📝 Decided not to extend from [Relative Multi-Head Attention](../relative_mha.html)
or [Multi-Head Attention](../mha.html) to improve readability.
This module computes recurrent attention similar to attention from original transformers
paper.
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^T K}{\sqrt{d_k}}\Bigg)V$$
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
"""
* 'heads' is the number of attention heads
* `d_model` is the number of features in the transformer
* `dropout_prob` is the attention dropout probability
"""
super().__init__()
# Number of features per head
self.d_k = d_model // heads
#
self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
# Output layer
self.output = nn.Linear(d_model, d_model)
@ -55,8 +63,6 @@ class FeedbackAttention(Module):
# Relative positional embeddings for key relative to the query.
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
# Relative positional embedding bias for key relative to the query.
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
# Positional embeddings for the query is independent of the position of the query
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
@ -65,47 +71,27 @@ class FeedbackAttention(Module):
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
"""
### Get relative attention scores
### Get attention scores
\begin{align}
A_{j} &= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
&= \color{cyan}{Q^T} \color{lightgreen}{K_j} +
\color{cyan}{Q^T} \color{lightgreen}{U_j} +
\color{cyan}{V^T} \color{lightgreen}{K_j} +
\color{cyan}{V^T} \color{lightgreen}{U_j}
A_{j} &== Q^T K_j \\
&= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
&= (\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}
\end{align}
where $\color{cyan}{Q}, \color{lightgreen}{K_j}$, are linear transformations of
original embeddings $\color{cyan}{X^q}, \color{lightgreen}{X^k_j}$
and $\color{cyan}{V}, \color{lightgreen}{U_j}$ are linear transformations of
absolute positional encodings $\color{cyan}{P}, \color{lightgreen}{P_j}$.
We replace $\color{cyan}{V^T} \color{lightgreen}{U_j}$ with
$S_j$.
\begin{align}
A^{rel}_{j} &= \underset{\mathbf{A}}{\color{cyan}{Q^T} \color{lightgreen}{K_j}} +
\underset{\mathbf{B}}{\color{cyan}{Q^T} \color{lightgreen}{U_j}} +
\underset{\mathbf{C}}{\color{cyan}{V^T} \color{lightgreen}{K_j}} +
\underset{\mathbf{D}}{\color{orange}{S_j}}
\end{align}
"""
# $\color{lightgreen}{U_j}$
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
# $\color{orange}{S_j}$
key_pos_bias = self.key_pos_bias[-key.shape[0]:]
# $\color{cyan}{V^T}$
query_pos_bias = self.query_pos_bias[None, :, :]
# $\underset{\mathbf{A}}{\color{cyan}{Q^T} \color{lightgreen}{K_j}} +
# \underset{\mathbf{C}}{\color{cyan}{V^T} \color{lightgreen}{K_j}}$
ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)
# $\underset{\mathbf{B}}{\color{cyan}{Q^T} \color{lightgreen}{U_j}} +
# \underset{\mathbf{D}}{\color{orange}{S_j}}$
bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]
return ac + bd
# $(\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}$
return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
def __call__(self, *,
query: torch.Tensor,
@ -147,22 +133,39 @@ class FeedbackAttention(Module):
class FeedbackTransformerLayer(Module):
"""
## Feedback Transformer Layer
This implements a single transformer layer in the feedback transformer.
"""
def __init__(self, *,
d_model: int,
attn: FeedbackAttention,
feed_forward: FeedForward,
dropout_prob: float):
"""
* `d_model` is the number of features in the transformer
* `attn` is the feedback attention module
* `feed_forward` is the position-wise feed forward layer
* `dropout_prob` is the dropout probability for dropout layers after attention and feed-forward
"""
super().__init__()
# Transformer size $d_{model}$
self.size = d_model
#
self.attn = attn
self.feed_forward = feed_forward
self.dropout = nn.Dropout(dropout_prob)
# Normalization layers
self.norm_self_attn = nn.LayerNorm([d_model])
self.norm_ff = nn.LayerNorm([d_model])
def __call__(self, *,
x: torch.Tensor,
mem: Optional[torch.Tensor]):
# If there is memory
if mem is not None:
# Normalize the vectors before doing self attention
z = self.norm_self_attn(x)
@ -178,49 +181,66 @@ class FeedbackTransformerLayer(Module):
# Add the feed-forward results back
x = x + self.dropout(ff)
#
return x
class FeedbackTransformer(Module):
"""
## Transformer Encoder
## Feedback Transformer Module
"""
def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
"""
* `layer` is the feedback transformer layer, which we clone for each layer
* `n_layers` is the number of layers in the transformer
"""
super().__init__()
# Make copies of the transformer layer
self.layers = clone_module_list(layer, n_layers)
# Final normalization layer
self.norm = nn.LayerNorm([layer.size])
#
# Memory vectors are computed as a weighted sum of representations of each layer.
# This is the weights parameter for that.
self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
#
# Softmax for weights before taking the weighted sum
self.softmax = nn.Softmax(0)
def __call__(self, x_seq: torch.Tensor):
"""
* `x_seq` is the input with shape `[seq_len, batch_size, d_model]`
"""
# Split the input to a list along the sequence axis
x_seq = torch.unbind(x_seq, dim=0)
# List to store the outputs
res = []
# List to store the memory vectors
mem = []
# For each input step
for x in x_seq:
# List of embeddings from each layer
emb = [x]
# List to store layer outputs
layer_outputs = [x]
# If there is memory, stack them into a vector
mem_tensor = torch.stack(mem) if mem else None
# Run through each layer
for layer in self.layers:
# Get layer output
x = layer(x=x, mem=mem_tensor)
emb.append(x)
# Append them to the list of layer outputs
layer_outputs.append(x)
# Stack embeddings
emb = torch.stack(emb)
# Weighted sum of embeddings
mem.append(torch.einsum('lbd,l->bd', emb, self.softmax(self.weights)))
# Stack the layer outputs to a tensor
layer_outputs = torch.stack(layer_outputs)
# Calculate the memory vector as a weighted sum of layer outputs
mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))
# Append the output to results
res.append(x)
# Stack the output tensors
res = torch.stack(res)
# Normalize
# Normalize the output
return self.norm(res)

View File

@ -91,9 +91,9 @@ class MultiHeadAttention(Module):
self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
# Softmax for attention along the time dimension of `key`
self.softmax = nn.Softmax(dim=1)