diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 72c09fb8..af3bc135 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -253,14 +253,21 @@
This trains a feedback transformer model for auto-regression. +You can pick the original feedback transformer or the new version +where the keys and values are precalculated.
+Here’s a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.
+ +19import torch
+20from torch import nn
+21
+22from labml import experiment
+23from labml.configs import option
+24from labml.utils.pytorch import get_modules
+25from labml_helpers.module import Module
+26
+27from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+28from labml_nn.transformers import Encoder, Generator, TransformerConfigs
+29from labml_nn.transformers.utils import subsequent_mask32class AutoregressiveModel(Module):37 def __init__(self, n_vocab: int, d_model: int, transformer: Module):
+38 super().__init__()Token embedding module
+40 self.src_embed = nn.Embedding(n_vocab, d_model)
+41 self.transformer = transformer
+42 self.generator = nn.Linear(d_model, n_vocab)44 def __call__(self, x: torch.Tensor):Embed the tokens
+46 x = self.src_embed(x)Run it through the the transformer
+48 res = self.transformer(x)Generate logits of the next token
+50 return self.generator(res), NoneThe default configs can and will be over-ridden when we start the experiment
+53class Configs(NLPAutoRegressionConfigs):60 model: AutoregressiveModel
+61
+62 d_model: int = 512
+63 heads: int = 8
+64 dropout: float = 0.0
+65 d_ff: int = 2048
+66 n_layers: int = 6Create original feedback transformer.
+69@option(Configs.model)
+70def feedback_transformer(c: Configs):74 from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \
+75 FeedbackAttention, FeedForward
+76
+77 return AutoregressiveModel(
+78 c.n_tokens, c.d_model,
+79 FeedbackTransformer(
+80 FeedbackTransformerLayer(d_model=c.d_model,
+81 attn=FeedbackAttention(c.heads, c.d_model, c.dropout),
+82 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
+83 dropout_prob=c.dropout),
+84 c.n_layers)).to(c.device)Create updated feedback transformer, with precalculated keys and values.
+87@option(Configs.model)
+88def feedback_transformer_kv(c: Configs):92 from labml_nn.transformers.feedback import FeedbackTransformerKV, FeedbackTransformerLayer, \
+93 FeedbackAttention, FeedForward
+94
+95 return AutoregressiveModel(
+96 c.n_tokens, c.d_model,
+97 FeedbackTransformerKV(
+98 FeedbackTransformerLayer(d_model=c.d_model,
+99 attn=FeedbackAttention(c.heads, c.d_model, c.dropout,
+100 is_kv_precomputed=True),
+101 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
+102 dropout_prob=c.dropout),
+103 c.n_layers, c.d_model, c.heads)).to(c.device)106def main():Create experiment
+108 experiment.create(name="feedback_transformer")Create configs
+110 conf = Configs()Load configurations
+112 experiment.configs(conf,A dictionary of configurations to override
+114 {'tokenizer': 'character',
+115 'text': 'tiny_shakespeare',
+116 'optimizer.learning_rate': 1.0,
+117 'optimizer.optimizer': 'Noam',
+118 'prompt': 'It is',
+119 'prompt_separator': '',Use feedback_transformer for original feedback transformer
122 'model': 'feedback_transformer_kv',
+123
+124 'train_loader': 'shuffled_train_loader',
+125 'valid_loader': 'shuffled_valid_loader',
+126
+127 'seq_len': 128,
+128 'epochs': 128,
+129 'batch_size': 64,
+130 'inner_iterations': 25})Set models for saving and loading
+133 experiment.add_pytorch_models(get_modules(conf))Start the experiment
+136 with experiment.start():Run the training loop
+138 conf.run()
+139
+140
+141if __name__ == '__main__':
+142 main()In order to speed up the training the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.
-The feedback transformer doesn’t keep the outputs of all layers. +
The original feedback transformer doesn’t keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. -This reduces the memory used for caching during prediction.
-Here’s a notebook for training a feedback transformer on Tiny Shakespeare dataset.
+This reduces the memory used for caching during prediction. +The first half of this file implements this. +The updated feedback transformer shares weights $W^l_k$ and $W^l_v$ used +to calculate keys and values for among the layers. +We then calculate the keys and values for each step only once and keep +them cached. +The second half of this file implements this. +We implemented a custom PyTorch function to improve performance.
+Here’s the training code and a notebook for training a feedback transformer on Tiny Shakespeare dataset.
35import math
-36from typing import Optional
-37
-38import torch
-39from torch import nn
-40
-41from labml_helpers.module import Module
-42from labml_nn.transformers.mha import PrepareForMultiHeadAttention
-43from labml_nn.transformers.feed_forward import FeedForward
-44from labml_nn.utils import clone_module_list43import math
+44from typing import Optional
+45
+46import torch
+47from torch import nn
+48
+49from labml_helpers.module import Module
+50from labml_nn.transformers.feed_forward import FeedForward
+51from labml_nn.transformers.mha import PrepareForMultiHeadAttention
+52from labml_nn.utils import clone_module_list47class FeedbackAttention(Module):55class FeedbackAttention(Module):d_model is the number of features in the transformerdropout_prob is the attention dropout probabilityis_kv_precomputed is whether key, value tensors are already calculated58 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):66 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
+67 is_kv_precomputed: bool = False):65 super().__init__()75 super().__init__()Number of features per head
68 self.d_k = d_model // heads78 self.d_k = d_model // heads70 self.heads = heads80 self.heads = headsThese transform the query, key and value vectors for multi-headed attention.
These transform the query multi-headed attention.
73 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
-74 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
-75 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)83 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)Output layer
+These transform the key and value fir multi-headed attention.
78 self.output = nn.Linear(d_model, d_model)85 if not is_kv_precomputed:
+86 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
+87 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)80 self.dropout = nn.Dropout(dropout_prob)89 else:
+90 self.key = None
+91 self.value = None82 self.scale = 1 / math.sqrt(self.d_k)94 self.output = nn.Linear(d_model, d_model)85 self.softmax = nn.Softmax(dim=0)96 self.dropout = nn.Dropout(dropout_prob)88 self.P = 2 ** 1298 self.scale = 1 / math.sqrt(self.d_k)Relative positional embeddings for key relative to the query.
+Softmax for attention along the time dimension of key
91 self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)101 self.softmax = nn.Softmax(dim=0)Positional embeddings for the query is independent of the position of the query
+Number of relative positions
93 self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)104 self.P = 2 ** 12We store attentions so that it can used for logging, or other computations if needed
+Relative positional embeddings for key relative to the query.
96 self.attn = None107 self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)- -
-where $Q, K_j$, are linear transformations of - original embeddings $X^q, X^k_j$ - and $U^Q, U^K_j$ are linear transformations of - absolute positional encodings $P_q, P_j$.
+Relative positional embedding bias for key relative to the query.
98 def get_scores(self, query: torch.Tensor, key: torch.Tensor):109 self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)$U^K_j$
+Positional embeddings for the query is independent of the position of the query
115 key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]111 self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)$U^Q$
+We store attentions so that it can used for logging, or other computations if needed
117 query_pos_bias = self.query_pos_bias[None, :, :]114 self.attn = None$(Q + U^Q)^\top(K_j + U^K_j)$
+We use relative positional encodings for attention, similar +to relative multi-head attention form Transformer-XL paper.
+Attention from current step’s query to key in step $j$ (relative to current step) is,
++ +
+where $Q, K_j$, are linear transformations of + original embeddings $X^q, X^k_j$ + and $U^Q, U^K_j$ are linear transformations of + positional encodings $P_q, P_j$.
+We replace term $\color{lightgreen}{D}$ with $S_j$.
120 return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])116 def get_scores(self, query: torch.Tensor, key: torch.Tensor):query has shape [batch_size, d_model]key and value has shape [seq_len, batch_size, d_model]$U^K_j$
122 def __call__(self, *,
-123 query: torch.Tensor,
-124 key: torch.Tensor,
-125 value: torch.Tensor):144 key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]Prepare query, key and value for attention computation
-key and value will then have shape [seq_len, batch_size, heads, d_k]
-and query will have shape [batch_size, heads, d_k]
$U^Q$
134 query = self.query(query)
-135 key = self.key(key)
-136 value = self.value(value)146 query_pos_bias = self.query_pos_bias[None, :, :]query will have shape [batch_size, heads, d_k]
- Compute attention scores
-Results in a tensor of shape [seq_len, batch_size, heads]
$S_j$
140 scores = self.get_scores(query, key)148 key_pos_bias = self.key_pos_bias[-key.shape[0]:][seq_len, batch_size, heads]
- Scale scores $\frac{1}{\sqrt{d_k}}$
+$\underset{\color{lightgreen}{A}}{Q^\top K_j} + \underset{\color{lightgreen}{C}}{{U^Q}^\top K_j}$
143 scores *= self.scale151 ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)[seq_len, batch_size, heads]
- Softmax
+$\underset{\color{lightgreen}{B}}{Q^\top U^K_j} + \underset{\color{lightgreen}{D}}{S_j}$
146 attn = self.softmax(scores)153 bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :][seq_len, batch_size, heads]
- Apply dropout
+$A_j$
149 attn = self.dropout(attn)156 return ac + bdMultiply by the values
+query has shape [batch_size, d_model]key and value has shape [seq_len, batch_size, d_model]152 x = torch.einsum("jbh,jbhd->bhd", attn, value)158 def __call__(self, *,
+159 query: torch.Tensor,
+160 key: torch.Tensor,
+161 value: torch.Tensor):[seq_len, batch_size, heads]
- Concatenate multiple heads
+Prepare query, key and value for attention computation
+key and value will then have shape [seq_len, batch_size, heads, d_k]
+and query will have shape [batch_size, heads, d_k]
155 x = x.reshape(x.shape[0], -1)170 query = self.query(query)
+171 if self.key:
+172 key = self.key(key)
+173 if self.value:
+174 value = self.value(value)[seq_len, batch_size, heads]
- Output layer
+Compute attention scores
+Results in a tensor of shape [seq_len, batch_size, heads]
158 return self.output(x)178 scores = self.get_scores(query, key)Scale scores $\frac{1}{\sqrt{d_k}}$
+181 scores *= self.scaleSoftmax
+184 attn = self.softmax(scores)Apply dropout
+187 attn = self.dropout(attn)Multiply by the values
+190 x = torch.einsum("jbh,jbhd->bhd", attn, value)Concatenate multiple heads
+193 x = x.reshape(x.shape[0], -1)Output layer
+196 return self.output(x)This implements a single transformer layer in the feedback transformer.
161class FeedbackTransformerLayer(Module):199class FeedbackTransformerLayer(Module):d_model is the number of features in the transformer[seq_len, batch_size, heads]
168 def __init__(self, *,
-169 d_model: int,
-170 attn: FeedbackAttention,
-171 feed_forward: FeedForward,
-172 dropout_prob: float):179 super().__init__()Transformer size $d_{model}$
-181 self.size = d_model183 self.attn = attn
-184 self.feed_forward = feed_forward
-185 self.dropout = nn.Dropout(dropout_prob)Normalization layers
-188 self.norm_self_attn = nn.LayerNorm([d_model])
-189 self.norm_ff = nn.LayerNorm([d_model])191 def __call__(self, *,
-192 x: torch.Tensor,
-193 mem: Optional[torch.Tensor]):If there is memory
-195 if mem is not None:206 def __init__(self, *,
+207 d_model: int,
+208 attn: FeedbackAttention,
+209 feed_forward: FeedForward,
+210 dropout_prob: float):[seq_len, batch_size, heads]
- Normalize the vectors before doing self attention
+197 z = self.norm_self_attn(x)217 super().__init__()[seq_len, batch_size, heads]
- Run through self attention, i.e. keys and values are from self
+Transformer size $d_{model}$
199 self_attn = self.attn(query=z, key=mem, value=mem)219 self.size = d_model[seq_len, batch_size, heads]
- Add the self attention results
+201 x = x + self.dropout(self_attn)221 self.attn = attn
+222 self.feed_forward = feed_forward
+223 self.dropout = nn.Dropout(dropout_prob)[seq_len, batch_size, heads]
- Normalize for feed-forward
+Normalization layers
204 z = self.norm_ff(x)226 self.norm_self_attn = nn.LayerNorm([d_model])
+227 self.norm_ff = nn.LayerNorm([d_model])[seq_len, batch_size, heads]
- Pass through the feed-forward network
+206 ff = self.feed_forward(z)229 def __call__(self, *,
+230 x: torch.Tensor,
+231 key: Optional[torch.Tensor],
+232 value: Optional[torch.Tensor]):[seq_len, batch_size, heads]
- Add the feed-forward results back
+If there is memory
208 x = x + self.dropout(ff)234 if key is not None:[seq_len, batch_size, heads]
-
+ Normalize the vectors before doing self attention
211 return x236 z = self.norm_self_attn(x)214class FeedbackTransformer(Module):238 self_attn = self.attn(query=z, key=key, value=value)layer is the feedback transformer layer, which we clone for each layern_layers is the number of layers in the transformerAdd the self attention results
219 def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):240 x = x + self.dropout(self_attn)[seq_len, batch_size, heads]
-
+ Normalize for feed-forward
225 super().__init__()243 z = self.norm_ff(x)[seq_len, batch_size, heads]
- Make copies of the transformer layer
+Pass through the feed-forward network
227 self.layers = clone_module_list(layer, n_layers)245 ff = self.feed_forward(z)[seq_len, batch_size, heads]
- Final normalization layer
+Add the feed-forward results back
229 self.norm = nn.LayerNorm([layer.size])247 x = x + self.dropout(ff)[seq_len, batch_size, heads]
- Memory vectors are computed as a weighted sum of representations of each layer. -This is the weights parameter for that.
+232 self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)250 return x234 self.softmax = nn.Softmax(0)253class FeedbackTransformer(Module):x_seq is the input with shape [seq_len, batch_size, d_model]layer is the feedback transformer layer, which we clone for each layern_layers is the number of layers in the transformer236 def __call__(self, x_seq: torch.Tensor):258 def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):Split the input to a list along the sequence axis
+242 x_seq = torch.unbind(x_seq, dim=0)264 super().__init__()List to store the outputs
+Make copies of the transformer layer
244 res = []266 self.layers = clone_module_list(layer, n_layers)List to store the memory vectors
+Final normalization layer
246 mem = []268 self.norm = nn.LayerNorm([layer.size])For each input step
+Memory vectors are computed as a weighted sum of representations of each layer. +This is the weights parameter for that.
248 for x in x_seq:271 self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)List to store layer outputs
+Softmax for weights before taking the weighted sum
250 layer_outputs = [x]273 self.softmax = nn.Softmax(0)If there is memory, stack them into a vector
+x_seq is the input with shape [seq_len, batch_size, d_model]253 mem_tensor = torch.stack(mem) if mem else None275 def __call__(self, x_seq: torch.Tensor):Run through each layer
+Split the input to a list along the sequence axis
256 for layer in self.layers:281 x_seq = torch.unbind(x_seq, dim=0)Get layer output
+List to store the outputs
258 x = layer(x=x, mem=mem_tensor)283 res = []Append them to the list of layer outputs
+List to store the memory vectors
260 layer_outputs.append(x)285 mem = []Stack the layer outputs to a tensor
+For each input step
263 layer_outputs = torch.stack(layer_outputs)287 for x in x_seq:Calculate the memory vector as a weighted sum of layer outputs
+List to store layer outputs
265 mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))289 layer_outputs = [x]Append the output to results
+If there is memory, stack them into a vector
267 res.append(x)292 mem_tensor = torch.stack(mem) if mem else NoneStack the output tensors
+Run through each layer
270 res = torch.stack(res)295 for layer in self.layers:Get layer output
+297 x = layer(x=x, key=mem_tensor, value=mem_tensor)Append them to the list of layer outputs
+299 layer_outputs.append(x)Stack the layer outputs to a tensor
+302 layer_outputs = torch.stack(layer_outputs)Calculate the memory vector as a weighted sum of layer outputs
+304 mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))Append the output to results
+306 res.append(x)Stack the output tensors
+309 res = torch.stack(res)Normalize the output
272 return self.norm(res)311 return self.norm(res)We implement a custom function instead of appending to a python list
+and then doing torch.stack.
+This greatly improves the performance over calling torch.stack at
+each step along the sequence.
+Everytime torch.stack is called it creates a new tensor, while
+this method and the accompanying class Stack share memory for each step.
318class StackFunction(torch.autograd.Function):ctx is the context of the function (which lets as cache stuff)memory is the shared memory tensor where we stack and store the values of each step (keys & values)memory_grad is the shared memory tensor to store and accumulate gradients of each steplast is the last value stackedn is the number of steps (i.e. size of the stack)This returns the stacked tensor for steps upto n.
330 @staticmethod
+331 def forward(ctx, memory, memory_grad, last, n):Cache accumulated gradients
+343 ctx._mem_grad = memory_gradCache the size of the stack
+345 ctx._n = nReturn the stack
+347 return memory[:n + 1]grad_output is the gradient with respect to the output of about forward functionThis accumulates the gradients in the shared memory tensor and return the
+gradients with respect to the last result in the stack.
349 @staticmethod
+350 def backward(ctx, grad_output):Get the current size of the stack
+358 n = ctx._nGet the accumulated gradients
+360 memory_grad = ctx._mem_gradAdd the gradients
+362 memory_grad[:n + 1] += grad_outputReturn the gradients w.r.t to last value in the stack
+364 return None, None, memory_grad[n], NoneThis uses the stack function defined above, and does the necessary initializations.
+367class Stack:max_len is the maximum size of the stack373 def __init__(self, max_len: int):377 self.max_len = max_len
+378 self.memory = None
+379 self.memory_grad = None
+380 self.last = None
+381 self.n = -1
+382 self.last_get_n = -1n is the size of the stackvalue is the tensor that needs to be added to the stack384 def append(self, n: int, value: torch.Tensor):You need to get (use) the stack after adding a value. +Otherwise this implementation fails
+392 assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"Do this without gradients
+395 with torch.no_grad():Initialize the shared memory tensor to keep the stack
+397 if self.memory is None or self.memory.shape[1:] != value.shape:This should only happen when the stack is empty
+399 assert n == 0Create a tensor for the stack
+401 self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)Create a tensor to accumulate the gradients
+403 self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)The memory is already initialized but we are resetting the stack.
+This could have been another function like reset, but
+we found this easier to use.
408 elif n == 0:Reset accumulated gradients
+410 self.memory_grad.fill_(0.)Set the value in the correct position of the stack
+413 self.memory.data[n] = value.detach()Keep track of the stack (for debugging)
+415 self.n = nKeep track of the last value added to the stack.
+We need this to be passed on to StackFunction in order
+to get the gradients propagated backwards.
420 self.last = valueReturns the stack
+422 def get(self):Keep track of the size of the stack when it was used.
+This is used for a sanity check in append.
429 self.last_get_n = self.nTake it all through StackFunction so that StackFunction.backwards
+is called by PyTorch during backpropagation.
432 return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)This is the updated feedback transformer module that caches the keys and values.
+435class FeedbackTransformerKV(Module):layer is the feedback transformer layer, which we clone for each layern_layers is the number of layers in the transformerd_model is the number of features in the transformer442 def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):450 super().__init__()Make copies of the transformer layer
+452 self.layers = clone_module_list(layer, n_layers)Final normalization layer
+454 self.norm = nn.LayerNorm([layer.size])Memory vectors are computed as a weighted sum of representations of each layer. +This is the weights parameter for that.
+457 self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)Softmax for weights before taking the weighted sum
+459 self.softmax = nn.Softmax(0)Number of features in a head
+462 d_k = d_model // headsModule to transform embeddings (memory) to get keys
+464 self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)Module to transform embeddings (memory) to get keys
+466 self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)Memory for stacked keys
+469 self.mem_key = Stack(512)Memory for stacked values
+471 self.mem_value = Stack(512)x_seq is the input with shape [seq_len, batch_size, d_model]473 def __call__(self, x_seq: torch.Tensor):Split the input to a list along the sequence axis
+479 x_seq = torch.unbind(x_seq, dim=0)List to store the outputs
+481 res = []For each input step
+483 for step, x in enumerate(x_seq):List to store layer outputs
+485 layer_outputs = [x]Stack of keys and values
+488 key_tensor = None
+489 value_tensor = NoneGet the keys and values tensors if we are beyond the initial step
+491 if step > 0:
+492 key_tensor = self.mem_key.get()
+493 value_tensor = self.mem_value.get()Run through each layer
+496 for layer in self.layers:Get layer output
+498 x = layer(x=x, key=key_tensor, value=value_tensor)Append them to the list of layer outputs
+500 layer_outputs.append(x)Stack the layer outputs to a tensor
+503 layer_outputs = torch.stack(layer_outputs)Calculate the memory vector as a weighted sum of layer outputs
+505 mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))Calculate the keys from memory and add it to the stack
+507 self.mem_key.append(step, self.key(mem))Calculate the values from memory and add it to the stack
+509 self.mem_value.append(step, self.value(mem))Append the output to results
+511 res.append(x)Stack the output tensors
+514 res = torch.stack(res)Normalize the output
+516 return self.norm(res)