11import math
12
13import torch
14import torch.nn as nn
15from labml_helpers.module import Module
16
17from labml_nn.utils import clone_module_list
18from .feed_forward import FeedForward
19from .mha import MultiHeadAttention
20from .positional_encoding import get_positional_encoding23class EmbeddingsWithPositionalEncoding(Module):29 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
30 super().__init__()
31 self.linear = nn.Embedding(n_vocab, d_model)
32 self.d_model = d_model
33 self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))35 def forward(self, x: torch.Tensor):
36 pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
37 return self.linear(x) * math.sqrt(self.d_model) + pe40class EmbeddingsWithLearnedPositionalEncoding(Module):46 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
47 super().__init__()
48 self.linear = nn.Embedding(n_vocab, d_model)
49 self.d_model = d_model
50 self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)52 def forward(self, x: torch.Tensor):
53 pe = self.positional_encodings[:x.shape[0]]
54 return self.linear(x) * math.sqrt(self.d_model) + peThis can act as an encoder layer or a decoder layer.
🗒 Some implementations, including the paper seem to have differences in where the layer-normalization is done. Here we do a layer normalization before attention and feed-forward networks, and add the original residual vectors. Alternative is to do a layer normalization after adding the residuals. But we found this to be less stable when training. We found a detailed discussion about this in the paper On Layer Normalization in the Transformer Architecture.
57class TransformerLayer(Module):d_model is the token embedding sizeself_attn is the self attention modulesrc_attn is the source attention module (when this is used in a decoder)feed_forward is the feed forward moduledropout_prob is the probability of dropping out after self attention and FFN74 def __init__(self, *,
75 d_model: int,
76 self_attn: MultiHeadAttention,
77 src_attn: MultiHeadAttention = None,
78 feed_forward: FeedForward,
79 dropout_prob: float):87 super().__init__()
88 self.size = d_model
89 self.self_attn = self_attn
90 self.src_attn = src_attn
91 self.feed_forward = feed_forward
92 self.dropout = nn.Dropout(dropout_prob)
93 self.norm_self_attn = nn.LayerNorm([d_model])
94 if self.src_attn is not None:
95 self.norm_src_attn = nn.LayerNorm([d_model])
96 self.norm_ff = nn.LayerNorm([d_model])Whether to save input to the feed forward layer
98 self.is_save_ff_input = False100 def forward(self, *,
101 x: torch.Tensor,
102 mask: torch.Tensor,
103 src: torch.Tensor = None,
104 src_mask: torch.Tensor = None):Normalize the vectors before doing self attention
106 z = self.norm_self_attn(x)Run through self attention, i.e. keys and values are from self
108 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)Add the self attention results
110 x = x + self.dropout(self_attn)If a source is provided, get results from attention to source. This is when you have a decoder layer that pays attention to encoder outputs
115 if src is not None:Normalize vectors
117 z = self.norm_src_attn(x)Attention to source. i.e. keys and values are from source
119 attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)Add the source attention results
121 x = x + self.dropout(attn_src)Normalize for feed-forward
124 z = self.norm_ff(x)Save the input to the feed forward layer if specified
126 if self.is_save_ff_input:
127 self.ff_input = z.clone()Pass through the feed-forward network
129 ff = self.feed_forward(z)Add the feed-forward results back
131 x = x + self.dropout(ff)
132
133 return x136class Encoder(Module):142 def __init__(self, layer: TransformerLayer, n_layers: int):
143 super().__init__()Make copies of the transformer layer
145 self.layers = clone_module_list(layer, n_layers)Final normalization layer
147 self.norm = nn.LayerNorm([layer.size])149 def forward(self, x: torch.Tensor, mask: torch.Tensor):Run through each transformer layer
151 for layer in self.layers:
152 x = layer(x=x, mask=mask)Finally, normalize the vectors
154 return self.norm(x)157class Decoder(Module):163 def __init__(self, layer: TransformerLayer, n_layers: int):
164 super().__init__()Make copies of the transformer layer
166 self.layers = clone_module_list(layer, n_layers)Final normalization layer
168 self.norm = nn.LayerNorm([layer.size])170 def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):Run through each transformer layer
172 for layer in self.layers:
173 x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)Finally, normalize the vectors
175 return self.norm(x)This predicts the tokens and gives the lof softmax of those.
You don’t need this if you are using nn.CrossEntropyLoss.
178class Generator(Module):187 def __init__(self, n_vocab: int, d_model: int):
188 super().__init__()
189 self.projection = nn.Linear(d_model, n_vocab)191 def forward(self, x):
192 return self.projection(x)195class EncoderDecoder(Module):201 def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
202 super().__init__()
203 self.encoder = encoder
204 self.decoder = decoder
205 self.src_embed = src_embed
206 self.tgt_embed = tgt_embed
207 self.generator = generatorThis was important from their code. Initialize parameters with Glorot / fan_avg.
211 for p in self.parameters():
212 if p.dim() > 1:
213 nn.init.xavier_uniform_(p)215 def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):Run the source through encoder
217 enc = self.encode(src, src_mask)Run encodings and targets through decoder
219 return self.decode(enc, src_mask, tgt, tgt_mask)221 def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
222 return self.encoder(self.src_embed(src), src_mask)224 def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
225 return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)