diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 1d86c497..cfe8a8ae 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -344,7 +344,7 @@
+ home + transformers + retro +
+ +16import math
+17from typing import Set
+18
+19import torch
+20from torch import nn
+21
+22from labml.logger import inspect
We use rotary position embeddings in self-attention layers. We assume the positional information gets embedded in embeddings and therefore not use them in causal attention. Non-causal self-attention needs explicit positional information because it cannot infer it.
+ +25class RotaryPositionalEmbeddings(nn.Module):
d
+ is the number of features base
+ is the constant used for calculating 36 def __init__(self, d: int, base: int = 10_000):
41 super().__init__()
+ +
43 self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
x
+ is the Tensor at the head of a key or a query with shape [ batch_size, seq_len, n_heads, d]
+45 def forward(self, x: torch.Tensor):
Extract the shape
+ +50 batch_size, seq_len, n_heads, d = x.shape
+ +
53 d_2 = d // 2
Create position indexes [0, 1, ..., seq_len - 1]
+
56 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
Calculate the product of position index and
+ +59 idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
Concatenate so that for row we have
+ +63 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
Calculate
+ +67 neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
79 rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])
+ +
82 return rx
85class SelfAttention(nn.Module):
d_model
+ is the number of features in transformer embeddings n_heads
+ is the number of attention heads d_k
+ is the number of features per head is_causal
+ indicates whether this is causal attention (masked)92 def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
99 super().__init__()
+100
+101 self.is_causal = is_causal
+102 self.n_heads = n_heads
+103 self.d_k = d_k
To scale attentions before softmax by
+ +106 self.scale = 1 / math.sqrt(self.d_k)
Linear layers for query, key and value heads.
+ +109 self.query = nn.Linear(d_model, n_heads * d_k)
+110 self.key = nn.Linear(d_model, n_heads * d_k)
+111 self.value = nn.Linear(d_model, n_heads * d_k)
Pre-norm layer. The paper uses RMSNorm instead.
+ +114 self.norm = nn.LayerNorm(d_model)
Softmax for attention probabilities
+ +117 self.softmax = nn.Softmax(dim=-1)
Rotary positional embeddings
+ +120 self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)
Final linear layer
+ +123 self.output = nn.Linear(n_heads * d_k, d_model)
attn
+ is the attention matrix of shape [batch_size, n_heads, seq_len, seq_len]
+125 def mask_attention(self, attn: torch.Tensor):
No masking for non-causal attention
+ +133 if not self.is_causal:
+134 return attn
Create a triangular mask
+ +137 mask = torch.tril(attn.new_ones(attn.shape[-2:]))
Filter by the mask
+ +139 return attn.masked_fill(mask == 0, float('-inf'))
h
+ is the transformer embeddings of shape [batch_size, seq_len, d_model]
+141 def forward(self, h: torch.Tensor):
Residual connection
+ +147 h_res = h
Pre-normalization
+ +150 h = self.norm(h)
Get query, key, and values and split them in to heads. These will have shapes [batch_size, seq_len, n_heads, d_k]
+
154 mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
+155 q = self.query(h).view(mh_shape)
+156 k = self.key(h).view(mh_shape)
+157 v = self.value(h).view(mh_shape)
Apply rotary positional embeddings
+ +160 q = self.rotary_pe(q)
+161 k = self.rotary_pe(k)
Calculate attentions
+ +164 attn = torch.einsum('bihd,bjhd->bhij', q, k)
Scale it by
+ +166 attn = attn * self.scale
Apply masks if it's causal attention
+ +169 attn = self.mask_attention(attn)
Calculate attention probabilities
+ +172 attn = self.softmax(attn)
Get values
+ +175 h = torch.einsum("bhij,bjhd->bihd", attn, v)
Change from shape [batch_size, seq_len, n_heads, d_k]
+ to [batch_size, seq_len, n_heads * d_k]
+
179 h = h.reshape(*h.shape[:-2], -1)
Apply final linear layer. The result will have shape [batch_size, seq_len, d_model]
+
183 h = self.output(h)
Add the residual connection
+ +186 return h + h_res
This is similar to the self-attention layer defined above, except that it gets keys and values from a different set of embeddings than the queries.
+This is used in the encoder to encode the retrieved chunks based on the input chunks.
+We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.
+ +189class CrossAttention(nn.Module):
d_model
+ is the number of features in transformer embeddings n_heads
+ is the number of attention heads d_k
+ is the number of features per head203 def __init__(self, d_model: int, n_heads: int, d_k: int):
209 super().__init__()
+210
+211 self.n_heads = n_heads
+212 self.d_k = d_k
To scale attentions before softmax by
+ +215 self.scale = 1 / math.sqrt(self.d_k)
Linear layers for query, key and value heads.
+ +218 self.query = nn.Linear(d_model, n_heads * d_k)
+219 self.key = nn.Linear(d_model, n_heads * d_k)
+220 self.value = nn.Linear(d_model, n_heads * d_k)
Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
+ +223 self.norm = nn.LayerNorm(d_model)
Softmax for attention probabilities
+ +226 self.softmax = nn.Softmax(dim=-1)
Final linear layer
+ +229 self.output = nn.Linear(n_heads * d_k, d_model)
e
+ are the retrieved nearest neighbor chunk embeddings with shape [batch_size, chunks, neighbors, neighbor_len, d_model]
+ h
+ are the input chunks from which the nearest neighbors were retrieved with shape [batch_size, chunks, chunk_len, d_model]
+. This is already normalized.231 def forward(self, e: torch.Tensor, h: torch.Tensor):
Residual connection
+ +240 e_res = e
Normalize retrieved chunks
+ +243 e = self.norm(e)
Get query from the retrieved chunks
+ +246 q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)
Get keys and values from the input chunks
+ +248 k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
+249 v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)
Calculate attention scores for all chunks. Each retrieved neighbor will pay attention to the original chunk that retrieved it. This will have shape [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]
+
254 attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)
Scale attention scores
+ +256 attn = attn * self.scale
Calculate softmax across the last dimension
+ +259 attn = self.softmax(attn)
Gather values
+ +262 e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)
Change from shape [batch_size, chunks, neighbors, neighbor_len, n_heads, d_k]
+ to [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]
+
266 e = e.reshape(*e.shape[:-2], -1)
Apply final linear layer. The result will have shape [batch_size, chunks, neighbors, neighbor_len, d_model]
+
270 e = self.output(e)
Add residual connection
+ +273 return e + e_res
This is similar to the cross-attention layer defined above.
+This is used in the decoder to pay attention to the retrieved neighbor chunks.
+We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.
+ +276class ChunkedCrossAttention(nn.Module):
d_model
+ is the number of features in transformer embeddings n_heads
+ is the number of attention heads d_k
+ is the number of features per head chunk_len
+ is the length of a chunk288 def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
296 super().__init__()
+297
+298 self.chunk_len = chunk_len
+299 self.n_heads = n_heads
+300 self.d_k = d_k
To scale attentions before softmax by
+ +303 self.scale = 1 / math.sqrt(self.d_k)
Linear layers for query, key and value heads.
+ +306 self.query = nn.Linear(d_model, n_heads * d_k)
+307 self.key = nn.Linear(d_model, n_heads * d_k)
+308 self.value = nn.Linear(d_model, n_heads * d_k)
Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
+ +311 self.norm = nn.LayerNorm(d_model)
Softmax for attention probabilities
+ +314 self.softmax = nn.Softmax(dim=-1)
Final linear layer
+ +317 self.output = nn.Linear(n_heads * d_k, d_model)
h
+ are the input embeddings of shape [batch_size, seq_len, d_model]
+ e
+ are the retrieved nearest neighbors of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
+
319 def forward(self, h: torch.Tensor, e: torch.Tensor):
Get shape
+ +326 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
No attention if there are no chunks (for short inputs when sampling)
+ +329 if chunks == 0:
+330 return h
Residual connection
+ +333 h_res = h
Remove the first chunk_len - 1
+ embeddings. The input pays attention to neighbors retrieved and encoded using the past tokens only; so that there is no information leakage. That is the retrieved neighbors from the first chunks will have information from the first chunk. So by shifting the sequence to the left by chunk_len - 1
+ we make sure that information only flows to the right.
341 h = h[:, self.chunk_len - 1:]
Pre-norm
+ +343 h = self.norm(h)
Append empty embeddings to the end to be able to split the input into chunks
+ +345 if h.shape[1] < chunks * self.chunk_len:
+346 h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)
Reshape the input into chunks.
+ +348 h = h.reshape(batch_size, chunks, self.chunk_len, d_model)
Get query from the input
+ +351 q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)
Get keys and values from the retrieved neighbors
+ +353 k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
+354 v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)
Calculate attention scores for input chunks. Each chunk will pay attention to neighbors retrieved by the previous chunk. This will have shape [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]
+
359 attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)
Scale attention scores
+ +361 attn = attn * self.scale
Apply softmax over the last two dimensions neighbors, neighbor_len
+
364 attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)
Gather values
+ +367 h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)
Change from shape [batch_size, chunks, chunk_len, n_heads, d_k]
+ to [batch_size, chunks * chunk_len, n_heads * d_k]
+
371 h = h.reshape(batch_size, chunks * self.chunk_len, -1)
Apply final linear layer. The result will have shape [batch_size, chunks * chunk_len, d_model]
+
375 h = self.output(h)
Append chunk_len - 1
+ zero embedding to the left; i.e. right shift it back
378 h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)
Truncate and add the residual connection
+ +381 return h[:, :h_res.shape[1]] + h_res
This consists of two linear layers and an activation in the middle.
+ +384class FeedForward(nn.Module):
d_model
+ is the number of features in transformer embeddings d_ff
+ is the number features in the hidden layer391 def __init__(self, d_model: int, d_ff: int):
397 super().__init__()
The two linear layers
+ +400 self.lin1 = nn.Linear(d_model, d_ff)
+401 self.lin2 = nn.Linear(d_ff, d_model)
ReLU Activation
+ +404 self.act = nn.ReLU()
Pre-norm layer
+ +407 self.norm = nn.LayerNorm(d_model)
h
+ are the embeddings of shape [batch_size, seq_len, d_model]
+
409 def forward(self, h: torch.Tensor):
Residual
+ +415 h_res = h
Pre-norm
+ +417 h = self.norm(h)
First linear layer
+ +419 h = self.lin1(h)
Activation
+ +421 h = self.act(h)
Second linear layer
+ +423 h = self.lin2(h)
Add the residual connection
+ +426 return h + h_res
429class NearestNeighborEncoder(nn.Module):
chunk_len
+ is the length of a chunk n_layer
+ is the number of layers in the encoder ca_layers
+ are the layers with cross attention d_model
+ is the number of features in embeddings n_heads
+ is the number of heads in attention layers d_k
+ is the size of attention heads d_ff
+ is the size of the feed-forward networks hidden layers436 def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
+437 d_model: int, n_heads: int, d_k: int, d_ff: int):
448 super().__init__()
+449 self.ca_layers = ca_layers
+450 self.chunk_len = chunk_len
Cross-attention layers
+ +452 self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])
Bi-directional self attention layers
+ +454 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])
Feed forward layers
+ +456 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
Pre-normalization layer for
+ +459 self.norm_h = nn.LayerNorm(d_model)
e
+ are token embeddings of the retrieved nearest neighbors, of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
+h
+ is are the input token embeddings, of shape [batch_size, seq_len, d_model]
+The chunks and neighbors are processed in parallel.
+ +461 def forward(self, e: torch.Tensor, h: torch.Tensor):
Get shape
+ +474 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
+ +
477 h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)
Pre-norm
+ +480 h_split = self.norm_h(h_split)
Keep the index of the cross attention layer
+ +483 p_ca = 0
For all layers
+ +485 for p in range(len(self.attn)):
Bi-directional self attention
+ +488 e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)
Cross attention if
+ +491 if p in self.ca_layers:
+ +
493 e = self.ca[p_ca](e, h_split)
Incremnt the cross attention index
+ +495 p_ca += 1
Feed forward layer
+ +498 e = self.ffw[p](e)
return
+ +501 return e
504class RetroModel(nn.Module):
v_vocab
+ is the number of tokens in the vocabulary d_model
+ is the number of features in embeddings n_layers
+ is the number of layers in the decoder ca_layers
+ are the layers with cross attention chunk_len
+ is the length of a chunk n_heads
+ is the number of heads in attention layers d_k
+ is the size of attention heads d_ff
+ is the size of the feed-forward networks hidden layers encoder
+ is the nearest neighbor encoder511 def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
+512 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
524 super().__init__()
+525
+526 self.ca_layers = ca_layers
+527 self.encoder = encoder
Token embedding layer
+ +530 self.emb = nn.Embedding(n_vocab, d_model)
Chunked cross attention layers
+ +532 self.cca = nn.ModuleList(
+533 [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])
Attention layers
+ +535 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])
Feed forward layers
+ +537 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
Readout layer
+ +539 self.read = nn.Linear(d_model, n_vocab)
Pre-normalization layer for nearest neighbor embeddings from
+ +543 self.norm_e = nn.LayerNorm(d_model)
x
+ is the input sequence, of shape [batch_size, seq_len]
+ ret
+ are the retrieved neighbors of shape [batch_size, chunks, neighbors, neighbor_len]
+545 def forward(self, x: torch.Tensor, ret: torch.Tensor):
Get input embeddings
+ +554 h = self.emb(x)
Embeddings of the retrieved neighbors .
+We use same embeddings for both input and neighbors
+ +560 ret_emb = self.emb(ret)
Keep index of the chunked cross attention layer
+ +563 p_ca = 0
For all layers
+ +565 for p in range(len(self.attn)):
Causal self attention
+ +567 h = self.attn[p](h)
Get encoder embeddings before the first layer, when
+ +571 if self.ca_layers and p == min(self.ca_layers):
575 e = self.encoder(ret_emb, h)
Normalize encoder embeddings
+ +577 e = self.norm_e(e)
Chunked-cross attention if
+ +580 if p in self.ca_layers:
+ +
582 h = self.cca[p_ca](h, e)
Increment chunked cross-attention index
+ +584 p_ca += 1
+ +
587 h = self.ffw[p](h)
+ +
590 return self.read(h)
593def _test():
597 chunk_len = 4
+598 d_model = 8
+599 d_ff = 32
+600 n_heads = 2
+601 d_k = 4
+602
+603 device = torch.device('cuda:0')
+604
+605 m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
+606 encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
+607
+608 m.to(device)
+609 x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
+610 ret = [
+611 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
+612 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
+613 ]
+614 res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
+615
+616 inspect(res)
+ +
620if __name__ == '__main__':
+621 _test()