diff --git a/docs/sitemap.xml b/docs/sitemap.xml index 1d86c497..cfe8a8ae 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -344,7 +344,7 @@ https://nn.labml.ai/diffusion/ddpm/index.html - 2021-10-21T16:30:00+00:00 + 2022-03-21T16:30:00+00:00 1.00 @@ -729,7 +729,7 @@ https://nn.labml.ai/transformers/retro/model.html - 2022-03-12T16:30:00+00:00 + 2022-03-21T16:30:00+00:00 1.00 diff --git a/docs/transformers/retro/model.html b/docs/transformers/retro/model.html index e69de29b..9ceec9fc 100644 --- a/docs/transformers/retro/model.html +++ b/docs/transformers/retro/model.html @@ -0,0 +1,2055 @@ + + + + + + + + + + + + + + + + + + + + + + + RETRO model + + + + + + + + + + +
+
+
+
+

+ home + transformers + retro +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

RETRO model

+

This is the model definition for RETRO.

+

View Run

+ +
+
+
16import math
+17from typing import Set
+18
+19import torch
+20from torch import nn
+21
+22from labml.logger import inspect
+
+
+
+
+ +

RoPE embeddings

+

We use rotary position embeddings in self-attention layers. We assume the positional information gets embedded in embeddings and therefore not use them in causal attention. Non-causal self-attention needs explicit positional information because it cannot infer it.

+ +
+
+
25class RotaryPositionalEmbeddings(nn.Module):
+
+
+
+
+ +
  • d + is the number of features
  • +
  • base + is the constant used for calculating
+ +
+
+
36    def __init__(self, d: int, base: int = 10_000):
+
+
+
+
+ + +
+
+
41        super().__init__()
+
+
+
+
+ +

+ +
+
+
43        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
+
+
+
+
+ +
  • x + is the Tensor at the head of a key or a query with shape [ batch_size, seq_len, n_heads, d] +
+ +
+
+
45    def forward(self, x: torch.Tensor):
+
+
+
+
+ +

Extract the shape

+ +
+
+
50        batch_size, seq_len, n_heads, d = x.shape
+
+
+
+
+ +

+ +
+
+
53        d_2 = d // 2
+
+
+
+
+ +

Create position indexes [0, 1, ..., seq_len - 1] +

+ +
+
+
56        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
+
+
+
+
+ +

Calculate the product of position index and

+ +
+
+
59        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
+
+
+
+
+ +

Concatenate so that for row we have

+ +
+
+
63        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
+
+
+
+
+ +

Calculate

+ +
+
+
67        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
+
+
+
+
+ +

Calculate

+

for

+ +
+
+
79        rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])
+
+
+
+
+ +

+ +
+
+
82        return rx
+
+
+
+
+ +

Self-Attention Layer

+

This applies causal and non-causal multi-headed self-attention.

+ +
+
+
85class SelfAttention(nn.Module):
+
+
+
+
+ +
  • d_model + is the number of features in transformer embeddings
  • +
  • n_heads + is the number of attention heads
  • +
  • d_k + is the number of features per head
  • +
  • is_causal + indicates whether this is causal attention (masked)
+ +
+
+
92    def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
+
+
+
+
+ + +
+
+
99        super().__init__()
+100
+101        self.is_causal = is_causal
+102        self.n_heads = n_heads
+103        self.d_k = d_k
+
+
+
+
+ +

To scale attentions before softmax by

+ +
+
+
106        self.scale = 1 / math.sqrt(self.d_k)
+
+
+
+
+ +

Linear layers for query, key and value heads.

+ +
+
+
109        self.query = nn.Linear(d_model, n_heads * d_k)
+110        self.key = nn.Linear(d_model, n_heads * d_k)
+111        self.value = nn.Linear(d_model, n_heads * d_k)
+
+
+
+
+ +

Pre-norm layer. The paper uses RMSNorm instead.

+ +
+
+
114        self.norm = nn.LayerNorm(d_model)
+
+
+
+
+ +

Softmax for attention probabilities

+ +
+
+
117        self.softmax = nn.Softmax(dim=-1)
+
+
+
+
+ +

Rotary positional embeddings

+ +
+
+
120        self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)
+
+
+
+
+ +

Final linear layer

+ +
+
+
123        self.output = nn.Linear(n_heads * d_k, d_model)
+
+
+
+
+ +

Mask the attention layer for causal attention

+
  • attn + is the attention matrix of shape [batch_size, n_heads, seq_len, seq_len] +
+ +
+
+
125    def mask_attention(self, attn: torch.Tensor):
+
+
+
+
+ +

No masking for non-causal attention

+ +
+
+
133        if not self.is_causal:
+134            return attn
+
+
+
+
+ +

Create a triangular mask

+ +
+
+
137        mask = torch.tril(attn.new_ones(attn.shape[-2:]))
+
+
+
+
+ +

Filter by the mask

+ +
+
+
139        return attn.masked_fill(mask == 0, float('-inf'))
+
+
+
+
+ +
  • h + is the transformer embeddings of shape [batch_size, seq_len, d_model] +
+ +
+
+
141    def forward(self, h: torch.Tensor):
+
+
+
+
+ +

Residual connection

+ +
+
+
147        h_res = h
+
+
+
+
+ +

Pre-normalization

+ +
+
+
150        h = self.norm(h)
+
+
+
+
+ +

Get query, key, and values and split them in to heads. These will have shapes [batch_size, seq_len, n_heads, d_k] +

+ +
+
+
154        mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
+155        q = self.query(h).view(mh_shape)
+156        k = self.key(h).view(mh_shape)
+157        v = self.value(h).view(mh_shape)
+
+
+
+
+ +

Apply rotary positional embeddings

+ +
+
+
160        q = self.rotary_pe(q)
+161        k = self.rotary_pe(k)
+
+
+
+
+ +

Calculate attentions

+ +
+
+
164        attn = torch.einsum('bihd,bjhd->bhij', q, k)
+
+
+
+
+ +

Scale it by

+ +
+
+
166        attn = attn * self.scale
+
+
+
+
+ +

Apply masks if it's causal attention

+ +
+
+
169        attn = self.mask_attention(attn)
+
+
+
+
+ +

Calculate attention probabilities

+ +
+
+
172        attn = self.softmax(attn)
+
+
+
+
+ +

Get values

+ +
+
+
175        h = torch.einsum("bhij,bjhd->bihd", attn, v)
+
+
+
+
+ +

Change from shape [batch_size, seq_len, n_heads, d_k] + to [batch_size, seq_len, n_heads * d_k] +

+ +
+
+
179        h = h.reshape(*h.shape[:-2], -1)
+
+
+
+
+ +

Apply final linear layer. The result will have shape [batch_size, seq_len, d_model] +

+ +
+
+
183        h = self.output(h)
+
+
+
+
+ +

Add the residual connection

+ +
+
+
186        return h + h_res
+
+
+
+
+ +

Cross-Attention Layer

+

This is similar to the self-attention layer defined above, except that it gets keys and values from a different set of embeddings than the queries.

+

This is used in the encoder to encode the retrieved chunks based on the input chunks.

+

We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.

+ +
+
+
189class CrossAttention(nn.Module):
+
+
+
+
+ +
  • d_model + is the number of features in transformer embeddings
  • +
  • n_heads + is the number of attention heads
  • +
  • d_k + is the number of features per head
+ +
+
+
203    def __init__(self, d_model: int, n_heads: int, d_k: int):
+
+
+
+
+ + +
+
+
209        super().__init__()
+210
+211        self.n_heads = n_heads
+212        self.d_k = d_k
+
+
+
+
+ +

To scale attentions before softmax by

+ +
+
+
215        self.scale = 1 / math.sqrt(self.d_k)
+
+
+
+
+ +

Linear layers for query, key and value heads.

+ +
+
+
218        self.query = nn.Linear(d_model, n_heads * d_k)
+219        self.key = nn.Linear(d_model, n_heads * d_k)
+220        self.value = nn.Linear(d_model, n_heads * d_k)
+
+
+
+
+ +

Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.

+ +
+
+
223        self.norm = nn.LayerNorm(d_model)
+
+
+
+
+ +

Softmax for attention probabilities

+ +
+
+
226        self.softmax = nn.Softmax(dim=-1)
+
+
+
+
+ +

Final linear layer

+ +
+
+
229        self.output = nn.Linear(n_heads * d_k, d_model)
+
+
+
+
+ +
  • e + are the retrieved nearest neighbor chunk embeddings with shape [batch_size, chunks, neighbors, neighbor_len, d_model] +
  • +
  • h + are the input chunks from which the nearest neighbors were retrieved with shape [batch_size, chunks, chunk_len, d_model] +. This is already normalized.
+ +
+
+
231    def forward(self, e: torch.Tensor, h: torch.Tensor):
+
+
+
+
+ +

Residual connection

+ +
+
+
240        e_res = e
+
+
+
+
+ +

Normalize retrieved chunks

+ +
+
+
243        e = self.norm(e)
+
+
+
+
+ +

Get query from the retrieved chunks

+ +
+
+
246        q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)
+
+
+
+
+ +

Get keys and values from the input chunks

+ +
+
+
248        k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
+249        v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)
+
+
+
+
+ +

Calculate attention scores for all chunks. Each retrieved neighbor will pay attention to the original chunk that retrieved it. This will have shape [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len] +

+ +
+
+
254        attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)
+
+
+
+
+ +

Scale attention scores

+ +
+
+
256        attn = attn * self.scale
+
+
+
+
+ +

Calculate softmax across the last dimension

+ +
+
+
259        attn = self.softmax(attn)
+
+
+
+
+ +

Gather values

+ +
+
+
262        e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)
+
+
+
+
+ +

Change from shape [batch_size, chunks, neighbors, neighbor_len, n_heads, d_k] + to [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k] +

+ +
+
+
266        e = e.reshape(*e.shape[:-2], -1)
+
+
+
+
+ +

Apply final linear layer. The result will have shape [batch_size, chunks, neighbors, neighbor_len, d_model] +

+ +
+
+
270        e = self.output(e)
+
+
+
+
+ +

Add residual connection

+ +
+
+
273        return e + e_res
+
+
+
+
+ +

Chunked Cross-Attention Layer

+

This is similar to the cross-attention layer defined above.

+

This is used in the decoder to pay attention to the retrieved neighbor chunks.

+

We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.

+ +
+
+
276class ChunkedCrossAttention(nn.Module):
+
+
+
+
+ +
  • d_model + is the number of features in transformer embeddings
  • +
  • n_heads + is the number of attention heads
  • +
  • d_k + is the number of features per head
  • +
  • chunk_len + is the length of a chunk
+ +
+
+
288    def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
+
+
+
+
+ + +
+
+
296        super().__init__()
+297
+298        self.chunk_len = chunk_len
+299        self.n_heads = n_heads
+300        self.d_k = d_k
+
+
+
+
+ +

To scale attentions before softmax by

+ +
+
+
303        self.scale = 1 / math.sqrt(self.d_k)
+
+
+
+
+ +

Linear layers for query, key and value heads.

+ +
+
+
306        self.query = nn.Linear(d_model, n_heads * d_k)
+307        self.key = nn.Linear(d_model, n_heads * d_k)
+308        self.value = nn.Linear(d_model, n_heads * d_k)
+
+
+
+
+ +

Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.

+ +
+
+
311        self.norm = nn.LayerNorm(d_model)
+
+
+
+
+ +

Softmax for attention probabilities

+ +
+
+
314        self.softmax = nn.Softmax(dim=-1)
+
+
+
+
+ +

Final linear layer

+ +
+
+
317        self.output = nn.Linear(n_heads * d_k, d_model)
+
+
+
+
+ +

h + are the input embeddings of shape [batch_size, seq_len, d_model] + e + are the retrieved nearest neighbors of shape [batch_size, chunks, neighbors, neighbor_len, d_model] +

+ +
+
+
319    def forward(self, h: torch.Tensor, e: torch.Tensor):
+
+
+
+
+ +

Get shape

+ +
+
+
326        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
+
+
+
+
+ +

No attention if there are no chunks (for short inputs when sampling)

+ +
+
+
329        if chunks == 0:
+330            return h
+
+
+
+
+ +

Residual connection

+ +
+
+
333        h_res = h
+
+
+
+
+ +

Remove the first chunk_len - 1 + embeddings. The input pays attention to neighbors retrieved and encoded using the past tokens only; so that there is no information leakage. That is the retrieved neighbors from the first chunks will have information from the first chunk. So by shifting the sequence to the left by chunk_len - 1 + we make sure that information only flows to the right.

+ +
+
+
341        h = h[:, self.chunk_len - 1:]
+
+
+
+
+ +

Pre-norm

+ +
+
+
343        h = self.norm(h)
+
+
+
+
+ +

Append empty embeddings to the end to be able to split the input into chunks

+ +
+
+
345        if h.shape[1] < chunks * self.chunk_len:
+346            h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)
+
+
+
+
+ +

Reshape the input into chunks.

+ +
+
+
348        h = h.reshape(batch_size, chunks, self.chunk_len, d_model)
+
+
+
+
+ +

Get query from the input

+ +
+
+
351        q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)
+
+
+
+
+ +

Get keys and values from the retrieved neighbors

+ +
+
+
353        k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
+354        v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)
+
+
+
+
+ +

Calculate attention scores for input chunks. Each chunk will pay attention to neighbors retrieved by the previous chunk. This will have shape [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len] +

+ +
+
+
359        attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)
+
+
+
+
+ +

Scale attention scores

+ +
+
+
361        attn = attn * self.scale
+
+
+
+
+ +

Apply softmax over the last two dimensions neighbors, neighbor_len +

+ +
+
+
364        attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)
+
+
+
+
+ +

Gather values

+ +
+
+
367        h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)
+
+
+
+
+ +

Change from shape [batch_size, chunks, chunk_len, n_heads, d_k] + to [batch_size, chunks * chunk_len, n_heads * d_k] +

+ +
+
+
371        h = h.reshape(batch_size, chunks * self.chunk_len, -1)
+
+
+
+
+ +

Apply final linear layer. The result will have shape [batch_size, chunks * chunk_len, d_model] +

+ +
+
+
375        h = self.output(h)
+
+
+
+
+ +

Append chunk_len - 1 + zero embedding to the left; i.e. right shift it back

+ +
+
+
378        h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)
+
+
+
+
+ +

Truncate and add the residual connection

+ +
+
+
381        return h[:, :h_res.shape[1]] + h_res
+
+
+
+
+ +

Position-wise Feed Forward Layer

+

This consists of two linear layers and an activation in the middle.

+ +
+
+
384class FeedForward(nn.Module):
+
+
+
+
+ +
  • d_model + is the number of features in transformer embeddings
  • +
  • d_ff + is the number features in the hidden layer
+ +
+
+
391    def __init__(self, d_model: int, d_ff: int):
+
+
+
+
+ + +
+
+
397        super().__init__()
+
+
+
+
+ +

The two linear layers

+ +
+
+
400        self.lin1 = nn.Linear(d_model, d_ff)
+401        self.lin2 = nn.Linear(d_ff, d_model)
+
+
+
+
+ +

ReLU Activation

+ +
+
+
404        self.act = nn.ReLU()
+
+
+
+
+ +

Pre-norm layer

+ +
+
+
407        self.norm = nn.LayerNorm(d_model)
+
+
+
+
+ +

h + are the embeddings of shape [batch_size, seq_len, d_model] +

+ +
+
+
409    def forward(self, h: torch.Tensor):
+
+
+
+
+ +

Residual

+ +
+
+
415        h_res = h
+
+
+
+
+ +

Pre-norm

+ +
+
+
417        h = self.norm(h)
+
+
+
+
+ +

First linear layer

+ +
+
+
419        h = self.lin1(h)
+
+
+
+
+ +

Activation

+ +
+
+
421        h = self.act(h)
+
+
+
+
+ +

Second linear layer

+ +
+
+
423        h = self.lin2(h)
+
+
+
+
+ +

Add the residual connection

+ +
+
+
426        return h + h_res
+
+
+
+
+ +

Nearest Neighbor Encoder

+

This module encodes the retrieved nearest neighbors

+ +
+
+
429class NearestNeighborEncoder(nn.Module):
+
+
+
+
+ +
  • chunk_len + is the length of a chunk
  • +
  • n_layer + is the number of layers in the encoder
  • +
  • ca_layers + are the layers with cross attention
  • +
  • d_model + is the number of features in embeddings
  • +
  • n_heads + is the number of heads in attention layers
  • +
  • d_k + is the size of attention heads
  • +
  • d_ff + is the size of the feed-forward networks hidden layers
+ +
+
+
436    def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
+437                 d_model: int, n_heads: int, d_k: int, d_ff: int):
+
+
+
+
+ + +
+
+
448        super().__init__()
+449        self.ca_layers = ca_layers
+450        self.chunk_len = chunk_len
+
+
+
+
+ +

Cross-attention layers

+ +
+
+
452        self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])
+
+
+
+
+ +

Bi-directional self attention layers

+ +
+
+
454        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])
+
+
+
+
+ +

Feed forward layers

+ +
+
+
456        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
+
+
+
+
+ +

Pre-normalization layer for

+ +
+
+
459        self.norm_h = nn.LayerNorm(d_model)
+
+
+
+
+ +
  • e + are token embeddings of the retrieved nearest neighbors, of shape [batch_size, chunks, neighbors, neighbor_len, d_model] +
+
  • h + is are the input token embeddings, of shape [batch_size, seq_len, d_model] +
+

The chunks and neighbors are processed in parallel.

+ +
+
+
461    def forward(self, e: torch.Tensor, h: torch.Tensor):
+
+
+
+
+ +

Get shape

+ +
+
+
474        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
+
+
+
+
+ +

+ +
+
+
477        h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)
+
+
+
+
+ +

Pre-norm

+ +
+
+
480        h_split = self.norm_h(h_split)
+
+
+
+
+ +

Keep the index of the cross attention layer

+ +
+
+
483        p_ca = 0
+
+
+
+
+ +

For all layers

+ +
+
+
485        for p in range(len(self.attn)):
+
+
+
+
+ +

Bi-directional self attention

+ +
+
+
488            e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)
+
+
+
+
+ +

Cross attention if

+ +
+
+
491            if p in self.ca_layers:
+
+
+
+
+ +

+ +
+
+
493                e = self.ca[p_ca](e, h_split)
+
+
+
+
+ +

Incremnt the cross attention index

+ +
+
+
495                p_ca += 1
+
+
+
+
+ +

Feed forward layer

+ +
+
+
498            e = self.ffw[p](e)
+
+
+
+
+ +

return

+ +
+
+
501        return e
+
+
+
+
+ +

Retro Model

+

This is the Retro decoder

+ +
+
+
504class RetroModel(nn.Module):
+
+
+
+
+ +
  • v_vocab + is the number of tokens in the vocabulary
  • +
  • d_model + is the number of features in embeddings
  • +
  • n_layers + is the number of layers in the decoder
  • +
  • ca_layers + are the layers with cross attention
  • +
  • chunk_len + is the length of a chunk
  • +
  • n_heads + is the number of heads in attention layers
  • +
  • d_k + is the size of attention heads
  • +
  • d_ff + is the size of the feed-forward networks hidden layers
  • +
  • encoder + is the nearest neighbor encoder
+ +
+
+
511    def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
+512                 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
+
+
+
+
+ + +
+
+
524        super().__init__()
+525
+526        self.ca_layers = ca_layers
+527        self.encoder = encoder
+
+
+
+
+ +

Token embedding layer

+ +
+
+
530        self.emb = nn.Embedding(n_vocab, d_model)
+
+
+
+
+ +

Chunked cross attention layers

+ +
+
+
532        self.cca = nn.ModuleList(
+533            [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])
+
+
+
+
+ +

Attention layers

+ +
+
+
535        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])
+
+
+
+
+ +

Feed forward layers

+ +
+
+
537        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
+
+
+
+
+ +

Readout layer

+ +
+
+
539        self.read = nn.Linear(d_model, n_vocab)
+
+
+
+
+ +

Pre-normalization layer for nearest neighbor embeddings from

+ +
+
+
543        self.norm_e = nn.LayerNorm(d_model)
+
+
+
+
+ +
  • x + is the input sequence, of shape [batch_size, seq_len] +
  • +
  • ret + are the retrieved neighbors of shape [batch_size, chunks, neighbors, neighbor_len] +
+ +
+
+
545    def forward(self, x: torch.Tensor, ret: torch.Tensor):
+
+
+
+
+ +

Get input embeddings

+ +
+
+
554        h = self.emb(x)
+
+
+
+
+ +

Embeddings of the retrieved neighbors .

+

We use same embeddings for both input and neighbors

+ +
+
+
560        ret_emb = self.emb(ret)
+
+
+
+
+ +

Keep index of the chunked cross attention layer

+ +
+
+
563        p_ca = 0
+
+
+
+
+ +

For all layers

+ +
+
+
565        for p in range(len(self.attn)):
+
+
+
+
+ +

Causal self attention

+ +
+
+
567            h = self.attn[p](h)
+
+
+
+
+ +

Get encoder embeddings before the first layer, when

+ +
+
+
571            if self.ca_layers and p == min(self.ca_layers):
+
+
+
+
+ +

+

We passed the embeddings of to encoder.

+ +
+
+
575                e = self.encoder(ret_emb, h)
+
+
+
+
+ +

Normalize encoder embeddings

+ +
+
+
577                e = self.norm_e(e)
+
+
+
+
+ +

Chunked-cross attention if

+ +
+
+
580            if p in self.ca_layers:
+
+
+
+
+ +

+ +
+
+
582                h = self.cca[p_ca](h, e)
+
+
+
+
+ +

Increment chunked cross-attention index

+ +
+
+
584                p_ca += 1
+
+
+
+
+ +

+ +
+
+
587            h = self.ffw[p](h)
+
+
+
+
+ +

+ +
+
+
590        return self.read(h)
+
+
+
+
+ +

Test the model with fake data

+ +
+
+
593def _test():
+
+
+
+
+ + +
+
+
597    chunk_len = 4
+598    d_model = 8
+599    d_ff = 32
+600    n_heads = 2
+601    d_k = 4
+602
+603    device = torch.device('cuda:0')
+604
+605    m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
+606                   encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
+607
+608    m.to(device)
+609    x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
+610    ret = [
+611        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
+612        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
+613    ]
+614    res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
+615
+616    inspect(res)
+
+
+
+
+ +

+ +
+
+
620if __name__ == '__main__':
+621    _test()
+
+
+ +
+ + + + \ No newline at end of file