📚 feedback transformer notes

This commit is contained in:
Varuna Jayasiri
2021-01-29 14:34:04 +05:30
parent 2f1918f6db
commit 2a26ecaaa1
5 changed files with 1592 additions and 256 deletions

View File

@ -22,11 +22,19 @@ In order to speed up the training the paper discusses starting with a short sequ
gradually increasing it.
They also discuss using a pretrained parallel transformer as the starting point.
The feedback transformer doesn't keep the outputs of all layers.
The original feedback transformer doesn't keep the outputs of all layers.
Instead it keeps weighted sum of the output of all layers.
This reduces the memory used for caching during prediction.
The first half of this file implements this.
Here's a notebook for training a feedback transformer on Tiny Shakespeare dataset.
The updated feedback transformer shares weights $W^l_k$ and $W^l_v$ used
to calculate keys and values for among the layers.
We then calculate the keys and values for each step only once and keep
them cached.
The [second half](#shared_kv) of this file implements this.
We implemented a custom PyTorch function to improve performance.
Here's [the training code](experiment.html) and a notebook for training a feedback transformer on Tiny Shakespeare dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002)
@ -39,8 +47,8 @@ import torch
from torch import nn
from labml_helpers.module import Module
from labml_nn.transformers.mha import PrepareForMultiHeadAttention
from labml_nn.transformers.feed_forward import FeedForward
from labml_nn.transformers.mha import PrepareForMultiHeadAttention
from labml_nn.utils import clone_module_list
@ -61,6 +69,7 @@ class FeedbackAttention(Module):
* 'heads' is the number of attention heads
* `d_model` is the number of features in the transformer
* `dropout_prob` is the attention dropout probability
* `is_kv_precomputed` is whether key, value tensors are already calculated
"""
super().__init__()
@ -70,11 +79,13 @@ class FeedbackAttention(Module):
#
self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
# These transform the `query` multi-headed attention.
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
# These transform the `key` and `value` fir multi-headed attention.
if not is_kv_precomputed:
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
# Keys and values are already calculated
else:
self.key = None
self.value = None
@ -94,6 +105,8 @@ class FeedbackAttention(Module):
# Relative positional embeddings for key relative to the query.
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
# Relative positional embedding bias for key relative to the query.
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
# Positional embeddings for the query is independent of the position of the query
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
@ -104,27 +117,42 @@ class FeedbackAttention(Module):
"""
### Get attention scores
We use relative positional encodings for attention, similar
to [relative multi-head attention form Transformer-XL paper](../relative_mha.html).
Attention from current step's query to key in step $j$ (relative to current step) is,
\begin{align}
A_{j} &= Q^\top K_j \\
&= lin_q(X^q + P_q)^\top lin_k(X^k_j + P_j) \\
&= (Q + U^Q)^\top(K_j + U^K_j)
&= (Q + U^Q)^\top(K_j + U^K_j) \\
&= \underset{\color{lightgreen}{A}}{Q^\top K_j} +
\underset{\color{lightgreen}{B}}{Q^\top U^K_j} +
\underset{\color{lightgreen}{C}}{{U^Q}^\top K_j} +
\underset{\color{lightgreen}{D}}{{U^Q}^\top U^K_j}
\end{align}
where $Q, K_j$, are linear transformations of
original embeddings $X^q, X^k_j$
and $U^Q, U^K_j$ are linear transformations of
absolute positional encodings $P_q, P_j$.
positional encodings $P_q, P_j$.
We replace term $\color{lightgreen}{D}$ with $S_j$.
"""
# $U^K_j$
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
# $U^Q$
query_pos_bias = self.query_pos_bias[None, :, :]
# $S_j$
key_pos_bias = self.key_pos_bias[-key.shape[0]:]
# $(Q + U^Q)^\top(K_j + U^K_j)$
# $\underset{\color{lightgreen}{A}}{Q^\top K_j} + \underset{\color{lightgreen}{C}}{{U^Q}^\top K_j}$
ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)
bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb)
# $\underset{\color{lightgreen}{B}}{Q^\top U^K_j} + \underset{\color{lightgreen}{D}}{S_j}$
bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]
# $A_j$
return ac + bd
def __call__(self, *,
@ -283,23 +311,69 @@ class FeedbackTransformer(Module):
return self.norm(res)
# <a id="shared_kv">
# # Shared keys and values for among layers
# </a>
class StackFunction(torch.autograd.Function):
"""
### Stack Function implementation
We implement a custom function instead of appending to a python list
and then doing `torch.stack`.
This greatly improves the performance over calling `torch.stack` at
each step along the sequence.
Everytime `torch.stack` is called it creates a new tensor, while
this method and the accompanying class `Stack` share memory for each step.
"""
@staticmethod
def forward(ctx, memory, memory_grad, last, n):
"""
* `ctx` is the context of the function (which lets as cache stuff)
* `memory` is the shared memory tensor where we stack and store the values of each step (keys & values)
* `memory_grad` is the shared memory tensor to store and accumulate gradients of each step
* `last` is the last value stacked
* `n` is the number of steps (i.e. size of the stack)
This returns the stacked tensor for steps upto `n`.
"""
# Cache accumulated gradients
ctx._mem_grad = memory_grad
# Cache the size of the stack
ctx._n = n
# Return the stack
return memory[:n + 1]
@staticmethod
def backward(ctx, grad_output):
"""
* `grad_output` is the gradient with respect to the output of about `forward` function
This accumulates the gradients in the shared memory tensor and return the
gradients with respect to the `last` result in the stack.
"""
# Get the current size of the stack
n = ctx._n
# Get the accumulated gradients
memory_grad = ctx._mem_grad
# Add the gradients
memory_grad[:n + 1] += grad_output
# Return the gradients w.r.t to last value in the stack
return None, None, memory_grad[n], None
class Stack:
"""
### Stack Module
This uses the stack function defined above, and does the necessary initializations.
"""
def __init__(self, max_len: int):
"""
* `max_len` is the maximum size of the stack
"""
self.max_len = max_len
self.memory = None
self.memory_grad = None
@ -307,37 +381,70 @@ class Stack:
self.n = -1
self.last_get_n = -1
def append(self, n: int, vector: torch.Tensor):
def append(self, n: int, value: torch.Tensor):
"""
* `n` is the size of the stack
* `value` is the tensor that needs to be added to the stack
"""
# You need to get (use) the stack after adding a value.
# Otherwise this implementation fails
assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"
# Do this without gradients
with torch.no_grad():
if self.memory is None or self.memory.shape[1:] != vector.shape:
# Initialize the shared memory tensor to keep the stack
if self.memory is None or self.memory.shape[1:] != value.shape:
# This should only happen when the stack is empty
assert n == 0
self.memory = vector.new_zeros(self.max_len, *vector.shape, requires_grad=False)
self.memory_grad = vector.new_zeros(self.memory.shape, requires_grad=False)
# Create a tensor for the stack
self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)
# Create a tensor to accumulate the gradients
self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)
# The memory is already initialized but we are resetting the stack.
#
# This could have been another function like `reset`, but
# we found this easier to use.
elif n == 0:
# Reset accumulated gradients
self.memory_grad.fill_(0.)
# memory[n] = vector.detach()
self.memory.data[n] = vector.detach()
# Set the value in the correct position of the stack
self.memory.data[n] = value.detach()
# Keep track of the stack (for debugging)
self.n = n
self.last = vector
# Keep track of the last value added to the stack.
# We need this to be passed on to `StackFunction` in order
# to get the gradients propagated backwards.
self.last = value
def get(self):
"""
Returns the stack
"""
# Keep track of the size of the stack when it was used.
# This is used for a sanity check in `append`.
self.last_get_n = self.n
# Take it all through `StackFunction` so that `StackFunction.backwards`
# is called by PyTorch during backpropagation.
return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)
class FeedbackTransformerKV(Module):
"""
## Feedback Transformer Module
## Updated Feedback Transformer Module
This is the updated feedback transformer module that caches the keys and values.
"""
def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):
"""
* `layer` is the feedback transformer layer, which we clone for each layer
* `n_layers` is the number of layers in the transformer
* `d_model` is the number of features in the transformer
* 'heads' is the number of attention heads
"""
super().__init__()
@ -351,11 +458,16 @@ class FeedbackTransformerKV(Module):
# Softmax for weights before taking the weighted sum
self.softmax = nn.Softmax(0)
# Number of features in a head
d_k = d_model // heads
# Module to transform embeddings (memory) to get keys
self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
# Module to transform embeddings (memory) to get keys
self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
# Memory for stacked keys
self.mem_key = Stack(512)
# Memory for stacked values
self.mem_value = Stack(512)
def __call__(self, x_seq: torch.Tensor):
@ -372,9 +484,10 @@ class FeedbackTransformerKV(Module):
# List to store layer outputs
layer_outputs = [x]
# If there is memory, stack them into a vector
# Stack of keys and values
key_tensor = None
value_tensor = None
# Get the keys and values tensors if we are beyond the initial step
if step > 0:
key_tensor = self.mem_key.get()
value_tensor = self.mem_value.get()
@ -390,7 +503,9 @@ class FeedbackTransformerKV(Module):
layer_outputs = torch.stack(layer_outputs)
# Calculate the memory vector as a weighted sum of layer outputs
mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))
# Calculate the keys from memory and add it to the stack
self.mem_key.append(step, self.key(mem))
# Calculate the values from memory and add it to the stack
self.mem_value.append(step, self.value(mem))
# Append the output to results
res.append(x)

View File

@ -7,6 +7,13 @@ summary: This is training code with notes for a feedback transformer.
# Train Feedback Transformer
This trains a [feedback transformer](index.html) model for auto-regression.
You can pick the original feedback transformer or the new version
where the keys and values are precalculated.
Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002)
"""
import torch
@ -35,8 +42,9 @@ class AutoregressiveModel(Module):
self.generator = nn.Linear(d_model, n_vocab)
def __call__(self, x: torch.Tensor):
# Embed the tokens
x = self.src_embed(x)
# Embed the tokens (`src`) and run it through the the transformer
# Run it through the the transformer
res = self.transformer(x)
# Generate logits of the next token
return self.generator(res), None
@ -60,6 +68,9 @@ class Configs(NLPAutoRegressionConfigs):
@option(Configs.model)
def feedback_transformer(c: Configs):
"""
Create [original feedback transformer](index.html).
"""
from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \
FeedbackAttention, FeedForward
@ -75,6 +86,9 @@ def feedback_transformer(c: Configs):
@option(Configs.model)
def feedback_transformer_kv(c: Configs):
"""
Create [updated feedback transformer](index.html#kv_shared), with precalculated keys and values.
"""
from labml_nn.transformers.feedback import FeedbackTransformerKV, FeedbackTransformerLayer, \
FeedbackAttention, FeedForward
@ -104,6 +118,7 @@ def main():
'prompt': 'It is',
'prompt_separator': '',
# Use `feedback_transformer` for original feedback transformer
'model': 'feedback_transformer_kv',
'train_loader': 'shuffled_train_loader',
@ -119,7 +134,7 @@ def main():
# Start the experiment
with experiment.start():
# `TrainValidConfigs.run`
# Run the training loop
conf.run()