Transformer Encoder and Decoder Models

Open In Colab Open In Comet

14import math
15
16import torch
17import torch.nn as nn
18from labml_helpers.module import Module
19
20from labml_nn.utils import clone_module_list
21from .feed_forward import FeedForward
22from .mha import MultiHeadAttention
23from .positional_encoding import get_positional_encoding

Embed tokens and add fixed positional encoding

26class EmbeddingsWithPositionalEncoding(Module):
33    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
34        super().__init__()
35        self.linear = nn.Embedding(n_vocab, d_model)
36        self.d_model = d_model
37        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
39    def forward(self, x: torch.Tensor):
40        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
41        return self.linear(x) * math.sqrt(self.d_model) + pe

Embed tokens and add parameterized positional encodings

44class EmbeddingsWithLearnedPositionalEncoding(Module):
51    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
52        super().__init__()
53        self.linear = nn.Embedding(n_vocab, d_model)
54        self.d_model = d_model
55        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
57    def forward(self, x: torch.Tensor):
58        pe = self.positional_encodings[:x.shape[0]]
59        return self.linear(x) * math.sqrt(self.d_model) + pe

Transformer Layer

This can act as an encoder layer or a decoder layer.

🗒 Some implementations, including the paper seem to have differences in where the layer-normalization is done. Here we do a layer normalization before attention and feed-forward networks, and add the original residual vectors. Alternative is to do a layer normalization after adding the residuals. But we found this to be less stable when training. We found a detailed discussion about this in the paper On Layer Normalization in the Transformer Architecture.

62class TransformerLayer(Module):
  • d_model is the token embedding size
  • self_attn is the self attention module
  • src_attn is the source attention module (when this is used in a decoder)
  • feed_forward is the feed forward module
  • dropout_prob is the probability of dropping out after self attention and FFN
80    def __init__(self, *,
81                 d_model: int,
82                 self_attn: MultiHeadAttention,
83                 src_attn: MultiHeadAttention = None,
84                 feed_forward: FeedForward,
85                 dropout_prob: float):
93        super().__init__()
94        self.size = d_model
95        self.self_attn = self_attn
96        self.src_attn = src_attn
97        self.feed_forward = feed_forward
98        self.dropout = nn.Dropout(dropout_prob)
99        self.norm_self_attn = nn.LayerNorm([d_model])
100        if self.src_attn is not None:
101            self.norm_src_attn = nn.LayerNorm([d_model])
102        self.norm_ff = nn.LayerNorm([d_model])

Whether to save input to the feed forward layer

104        self.is_save_ff_input = False
106    def forward(self, *,
107                x: torch.Tensor,
108                mask: torch.Tensor,
109                src: torch.Tensor = None,
110                src_mask: torch.Tensor = None):

Normalize the vectors before doing self attention

112        z = self.norm_self_attn(x)

Run through self attention, i.e. keys and values are from self

114        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)

Add the self attention results

116        x = x + self.dropout(self_attn)

If a source is provided, get results from attention to source. This is when you have a decoder layer that pays attention to encoder outputs

121        if src is not None:

Normalize vectors

123            z = self.norm_src_attn(x)

Attention to source. i.e. keys and values are from source

125            attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)

Add the source attention results

127            x = x + self.dropout(attn_src)

Normalize for feed-forward

130        z = self.norm_ff(x)

Save the input to the feed forward layer if specified

132        if self.is_save_ff_input:
133            self.ff_input = z.clone()

Pass through the feed-forward network

135        ff = self.feed_forward(z)

Add the feed-forward results back

137        x = x + self.dropout(ff)
138
139        return x

Transformer Encoder

142class Encoder(Module):
149    def __init__(self, layer: TransformerLayer, n_layers: int):
150        super().__init__()

Make copies of the transformer layer

152        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

154        self.norm = nn.LayerNorm([layer.size])
156    def forward(self, x: torch.Tensor, mask: torch.Tensor):

Run through each transformer layer

158        for layer in self.layers:
159            x = layer(x=x, mask=mask)

Finally, normalize the vectors

161        return self.norm(x)

Transformer Decoder

164class Decoder(Module):
171    def __init__(self, layer: TransformerLayer, n_layers: int):
172        super().__init__()

Make copies of the transformer layer

174        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

176        self.norm = nn.LayerNorm([layer.size])
178    def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):

Run through each transformer layer

180        for layer in self.layers:
181            x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)

Finally, normalize the vectors

183        return self.norm(x)

Generator

This predicts the tokens and gives the lof softmax of those. You don't need this if you are using nn.CrossEntropyLoss .

186class Generator(Module):
196    def __init__(self, n_vocab: int, d_model: int):
197        super().__init__()
198        self.projection = nn.Linear(d_model, n_vocab)
200    def forward(self, x):
201        return self.projection(x)

Combined Encoder-Decoder

204class EncoderDecoder(Module):
211    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
212        super().__init__()
213        self.encoder = encoder
214        self.decoder = decoder
215        self.src_embed = src_embed
216        self.tgt_embed = tgt_embed
217        self.generator = generator

This was important from their code. Initialize parameters with Glorot / fan_avg.

221        for p in self.parameters():
222            if p.dim() > 1:
223                nn.init.xavier_uniform_(p)
225    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):

Run the source through encoder

227        enc = self.encode(src, src_mask)

Run encodings and targets through decoder

229        return self.decode(enc, src_mask, tgt, tgt_mask)
231    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
232        return self.encoder(self.src_embed(src), src_mask)
234    def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
235        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)