16import math
17from typing import Set
18
19import torch
20from torch import nn
21
22from labml.logger import inspectWe use rotary position embeddings in self-attention layers. We assume the positional information gets embedded in embeddings and therefore not use them in causal attention. Non-causal self-attention needs explicit positional information because it cannot infer it.
25class RotaryPositionalEmbeddings(nn.Module):d
 is the number of features  base
 is the constant used for calculating 36    def __init__(self, d: int, base: int = 10_000):41        super().__init__()43        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)x
 is the Tensor at the head of a key or a query with shape [ batch_size, seq_len, n_heads, d]
45    def forward(self, x: torch.Tensor):Extract the shape
50        batch_size, seq_len, n_heads, d = x.shape53        d_2 = d // 2Create position indexes [0, 1, ..., seq_len - 1]
 
56        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)Calculate the product of position index and
59        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)Concatenate so that for row we have
63        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)Calculate
67        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)79        rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])82        return rx85class SelfAttention(nn.Module):d_model
 is the number of features in transformer embeddings n_heads
 is the number of attention heads d_k
 is the number of features per head is_causal
 indicates whether this is causal attention (masked)92    def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):99        super().__init__()
100
101        self.is_causal = is_causal
102        self.n_heads = n_heads
103        self.d_k = d_kTo scale attentions before softmax by
106        self.scale = 1 / math.sqrt(self.d_k)Linear layers for query, key and value heads.
109        self.query = nn.Linear(d_model, n_heads * d_k)
110        self.key = nn.Linear(d_model, n_heads * d_k)
111        self.value = nn.Linear(d_model, n_heads * d_k)Pre-norm layer. The paper uses RMSNorm instead.
114        self.norm = nn.LayerNorm(d_model)Softmax for attention probabilities
117        self.softmax = nn.Softmax(dim=-1)Rotary positional embeddings
120        self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)Final linear layer
123        self.output = nn.Linear(n_heads * d_k, d_model)attn
 is the attention matrix of shape [batch_size, n_heads, seq_len, seq_len]
125    def mask_attention(self, attn: torch.Tensor):No masking for non-causal attention
133        if not self.is_causal:
134            return attnCreate a triangular mask
137        mask = torch.tril(attn.new_ones(attn.shape[-2:]))Filter by the mask
139        return attn.masked_fill(mask == 0, float('-inf'))h
 is the transformer embeddings of shape [batch_size, seq_len, d_model]
141    def forward(self, h: torch.Tensor):Residual connection
147        h_res = hPre-normalization
150        h = self.norm(h)Get query, key, and values and split them in to heads. These will have shapes [batch_size, seq_len, n_heads, d_k]
 
154        mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
155        q = self.query(h).view(mh_shape)
156        k = self.key(h).view(mh_shape)
157        v = self.value(h).view(mh_shape)Apply rotary positional embeddings
160        q = self.rotary_pe(q)
161        k = self.rotary_pe(k)Calculate attentions
164        attn = torch.einsum('bihd,bjhd->bhij', q, k)Scale it by
166        attn = attn * self.scaleApply masks if it's causal attention
169        attn = self.mask_attention(attn)Calculate attention probabilities
172        attn = self.softmax(attn)Get values
175        h = torch.einsum("bhij,bjhd->bihd", attn, v)Change from shape [batch_size, seq_len, n_heads, d_k]
 to [batch_size, seq_len, n_heads * d_k]
 
179        h = h.reshape(*h.shape[:-2], -1)Apply final linear layer. The result will have shape [batch_size, seq_len, d_model]
 
183        h = self.output(h)Add the residual connection
186        return h + h_resThis is similar to the self-attention layer defined above, except that it gets keys and values from a different set of embeddings than the queries.
This is used in the encoder to encode the retrieved chunks based on the input chunks.
We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.
189class CrossAttention(nn.Module):d_model
 is the number of features in transformer embeddings n_heads
 is the number of attention heads d_k
 is the number of features per head203    def __init__(self, d_model: int, n_heads: int, d_k: int):209        super().__init__()
210
211        self.n_heads = n_heads
212        self.d_k = d_kTo scale attentions before softmax by
215        self.scale = 1 / math.sqrt(self.d_k)Linear layers for query, key and value heads.
218        self.query = nn.Linear(d_model, n_heads * d_k)
219        self.key = nn.Linear(d_model, n_heads * d_k)
220        self.value = nn.Linear(d_model, n_heads * d_k)Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
223        self.norm = nn.LayerNorm(d_model)Softmax for attention probabilities
226        self.softmax = nn.Softmax(dim=-1)Final linear layer
229        self.output = nn.Linear(n_heads * d_k, d_model)e
 are the retrieved nearest neighbor chunk embeddings with shape  [batch_size, chunks, neighbors, neighbor_len, d_model]
 h
 are the input chunks from which the nearest neighbors were retrieved with shape  [batch_size, chunks, chunk_len, d_model]
. This is already normalized.231    def forward(self, e: torch.Tensor, h: torch.Tensor):Residual connection
240        e_res = eNormalize retrieved chunks
243        e = self.norm(e)Get query from the retrieved chunks
246        q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)Get keys and values from the input chunks
248        k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
249        v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)Calculate attention scores for all chunks. Each retrieved neighbor will pay attention to the original chunk that retrieved it. This will have shape [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]
 
254        attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)Scale attention scores
256        attn = attn * self.scaleCalculate softmax across the last dimension
259        attn = self.softmax(attn)Gather values
262        e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)Change from shape [batch_size, chunks, neighbors, neighbor_len, n_heads, d_k]
 to [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]
 
266        e = e.reshape(*e.shape[:-2], -1)Apply final linear layer. The result will have shape [batch_size, chunks, neighbors, neighbor_len, d_model]
 
270        e = self.output(e)Add residual connection
273        return e + e_resThis is similar to the cross-attention layer defined above.
This is used in the decoder to pay attention to the retrieved neighbor chunks.
We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.
276class ChunkedCrossAttention(nn.Module):d_model
 is the number of features in transformer embeddings n_heads
 is the number of attention heads d_k
 is the number of features per head chunk_len
 is the length of a chunk288    def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):296        super().__init__()
297
298        self.chunk_len = chunk_len
299        self.n_heads = n_heads
300        self.d_k = d_kTo scale attentions before softmax by
303        self.scale = 1 / math.sqrt(self.d_k)Linear layers for query, key and value heads.
306        self.query = nn.Linear(d_model, n_heads * d_k)
307        self.key = nn.Linear(d_model, n_heads * d_k)
308        self.value = nn.Linear(d_model, n_heads * d_k)Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
311        self.norm = nn.LayerNorm(d_model)Softmax for attention probabilities
314        self.softmax = nn.Softmax(dim=-1)Final linear layer
317        self.output = nn.Linear(n_heads * d_k, d_model) h
 are the input embeddings of shape [batch_size, seq_len, d_model]
 e
 are the retrieved nearest neighbors of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
319    def forward(self, h: torch.Tensor, e: torch.Tensor):Get shape
326        batch_size, chunks, neighbors, neighbor_len, d_model = e.shapeNo attention if there are no chunks (for short inputs when sampling)
329        if chunks == 0:
330            return hResidual connection
333        h_res = hRemove the first chunk_len - 1
 embeddings. The input pays attention to neighbors retrieved and encoded using the past tokens only; so that there is no information leakage. That is the retrieved neighbors from the first chunks will have information from the first chunk. So by shifting the sequence to the left by chunk_len - 1
 we make sure that information only flows to the right. 
341        h = h[:, self.chunk_len - 1:]Pre-norm
343        h = self.norm(h)Append empty embeddings to the end to be able to split the input into chunks
345        if h.shape[1] < chunks * self.chunk_len:
346            h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)Reshape the input into chunks.
348        h = h.reshape(batch_size, chunks, self.chunk_len, d_model)Get query from the input
351        q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)Get keys and values from the retrieved neighbors
353        k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
354        v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)Calculate attention scores for input chunks. Each chunk will pay attention to neighbors retrieved by the previous chunk. This will have shape [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]
 
359        attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)Scale attention scores
361        attn = attn * self.scaleApply softmax over the last two dimensions neighbors, neighbor_len
 
364        attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)Gather values
367        h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)Change from shape [batch_size, chunks, chunk_len, n_heads, d_k]
 to [batch_size, chunks * chunk_len, n_heads * d_k]
 
371        h = h.reshape(batch_size, chunks * self.chunk_len, -1)Apply final linear layer. The result will have shape [batch_size, chunks * chunk_len, d_model]
 
375        h = self.output(h)Append chunk_len - 1
 zero embedding to the left; i.e. right shift it back 
378        h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)Truncate and add the residual connection
381        return h[:, :h_res.shape[1]] + h_resThis consists of two linear layers and an activation in the middle.
384class FeedForward(nn.Module):d_model
 is the number of features in transformer embeddings d_ff
 is the number features in the hidden layer391    def __init__(self, d_model: int, d_ff: int):397        super().__init__()The two linear layers
400        self.lin1 = nn.Linear(d_model, d_ff)
401        self.lin2 = nn.Linear(d_ff, d_model)ReLU Activation
404        self.act = nn.ReLU()Pre-norm layer
407        self.norm = nn.LayerNorm(d_model) h
 are the embeddings of shape [batch_size, seq_len, d_model]
409    def forward(self, h: torch.Tensor):Residual
415        h_res = hPre-norm
417        h = self.norm(h)First linear layer
419        h = self.lin1(h)Activation
421        h = self.act(h)Second linear layer
423        h = self.lin2(h)Add the residual connection
426        return h + h_res429class NearestNeighborEncoder(nn.Module):chunk_len
 is the length of a chunk n_layer
 is the number of layers in the encoder  ca_layers
 are the layers with cross attention  d_model
 is the number of features in embeddings n_heads
 is the number of heads in attention layers d_k
 is the size of attention heads d_ff
 is the size of the feed-forward networks hidden layers436    def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
437                 d_model: int, n_heads: int, d_k: int, d_ff: int):448        super().__init__()
449        self.ca_layers = ca_layers
450        self.chunk_len = chunk_lenCross-attention layers
452        self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])Bi-directional self attention layers
454        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])Feed forward layers
456        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])Pre-normalization layer for
459        self.norm_h = nn.LayerNorm(d_model)e
 are token embeddings of the retrieved nearest neighbors,    of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
h
 is are the input token embeddings,   of shape [batch_size, seq_len, d_model]
The chunks and neighbors are processed in parallel.
461    def forward(self, e: torch.Tensor, h: torch.Tensor):Get shape
474        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape477        h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)Pre-norm
480        h_split = self.norm_h(h_split)Keep the index of the cross attention layer
483        p_ca = 0For all layers
485        for p in range(len(self.attn)):Bi-directional self attention
488            e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)Cross attention if
491            if p in self.ca_layers:493                e = self.ca[p_ca](e, h_split)Incremnt the cross attention index
495                p_ca += 1Feed forward layer
498            e = self.ffw[p](e)return
501        return e504class RetroModel(nn.Module):v_vocab
 is the number of tokens in the vocabulary d_model
 is the number of features in embeddings n_layers
 is the number of layers in the decoder  ca_layers
 are the layers with cross attention  chunk_len
 is the length of a chunk n_heads
 is the number of heads in attention layers d_k
 is the size of attention heads d_ff
 is the size of the feed-forward networks hidden layers encoder
 is the nearest neighbor encoder511    def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
512                 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):524        super().__init__()
525
526        self.ca_layers = ca_layers
527        self.encoder = encoderToken embedding layer
530        self.emb = nn.Embedding(n_vocab, d_model)Chunked cross attention layers
532        self.cca = nn.ModuleList(
533            [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])Attention layers
535        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])Feed forward layers
537        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])Readout layer
539        self.read = nn.Linear(d_model, n_vocab)Pre-normalization layer for nearest neighbor embeddings from
543        self.norm_e = nn.LayerNorm(d_model)x
 is the input sequence,  of shape [batch_size, seq_len]
 ret
 are the retrieved neighbors    of shape [batch_size, chunks, neighbors, neighbor_len]
545    def forward(self, x: torch.Tensor, ret: torch.Tensor):Get input embeddings
554        h = self.emb(x)560        ret_emb = self.emb(ret)Keep index of the chunked cross attention layer
563        p_ca = 0For all layers
565        for p in range(len(self.attn)):Causal self attention
567            h = self.attn[p](h)Get encoder embeddings before the first layer, when
571            if self.ca_layers and p == min(self.ca_layers):575                e = self.encoder(ret_emb, h)Normalize encoder embeddings
577                e = self.norm_e(e)Chunked-cross attention if
580            if p in self.ca_layers:582                h = self.cca[p_ca](h, e)Increment chunked cross-attention index
584                p_ca += 1587            h = self.ffw[p](h)590        return self.read(h)593def _test():597    chunk_len = 4
598    d_model = 8
599    d_ff = 32
600    n_heads = 2
601    d_k = 4
602
603    device = torch.device('cuda:0')
604
605    m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
606                   encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
607
608    m.to(device)
609    x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
610    ret = [
611        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
612        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
613    ]
614    res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
615
616    inspect(res)620if __name__ == '__main__':
621    _test()