14import math
15
16import torch
17import torch.nn as nn
18
19from labml_nn.utils import clone_module_list
20from .feed_forward import FeedForward
21from .mha import MultiHeadAttention
22from .positional_encoding import get_positional_encoding25class EmbeddingsWithPositionalEncoding(nn.Module):32    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
33        super().__init__()
34        self.linear = nn.Embedding(n_vocab, d_model)
35        self.d_model = d_model
36        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))38    def forward(self, x: torch.Tensor):
39        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
40        return self.linear(x) * math.sqrt(self.d_model) + pe43class EmbeddingsWithLearnedPositionalEncoding(nn.Module):50    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
51        super().__init__()
52        self.linear = nn.Embedding(n_vocab, d_model)
53        self.d_model = d_model
54        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)56    def forward(self, x: torch.Tensor):
57        pe = self.positional_encodings[:x.shape[0]]
58        return self.linear(x) * math.sqrt(self.d_model) + peThis can act as an encoder layer or a decoder layer.
🗒 Some implementations, including the paper seem to have differences in where the layer-normalization is done. Here we do a layer normalization before attention and feed-forward networks, and add the original residual vectors. Alternative is to do a layer normalization after adding the residuals. But we found this to be less stable when training. We found a detailed discussion about this in the paper On Layer Normalization in the Transformer Architecture.
61class TransformerLayer(nn.Module):d_model
 is the token embedding size self_attn
 is the self attention module src_attn
 is the source attention module (when this is used in a decoder) feed_forward
 is the feed forward module dropout_prob
 is the probability of dropping out after self attention and FFN79    def __init__(self, *,
80                 d_model: int,
81                 self_attn: MultiHeadAttention,
82                 src_attn: MultiHeadAttention = None,
83                 feed_forward: FeedForward,
84                 dropout_prob: float):92        super().__init__()
93        self.size = d_model
94        self.self_attn = self_attn
95        self.src_attn = src_attn
96        self.feed_forward = feed_forward
97        self.dropout = nn.Dropout(dropout_prob)
98        self.norm_self_attn = nn.LayerNorm([d_model])
99        if self.src_attn is not None:
100            self.norm_src_attn = nn.LayerNorm([d_model])
101        self.norm_ff = nn.LayerNorm([d_model])Whether to save input to the feed forward layer
103        self.is_save_ff_input = False105    def forward(self, *,
106                x: torch.Tensor,
107                mask: torch.Tensor,
108                src: torch.Tensor = None,
109                src_mask: torch.Tensor = None):Normalize the vectors before doing self attention
111        z = self.norm_self_attn(x)Run through self attention, i.e. keys and values are from self
113        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)Add the self attention results
115        x = x + self.dropout(self_attn)If a source is provided, get results from attention to source. This is when you have a decoder layer that pays attention to encoder outputs
120        if src is not None:Normalize vectors
122            z = self.norm_src_attn(x)Attention to source. i.e. keys and values are from source
124            attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)Add the source attention results
126            x = x + self.dropout(attn_src)Normalize for feed-forward
129        z = self.norm_ff(x)Save the input to the feed forward layer if specified
131        if self.is_save_ff_input:
132            self.ff_input = z.clone()Pass through the feed-forward network
134        ff = self.feed_forward(z)Add the feed-forward results back
136        x = x + self.dropout(ff)
137
138        return x141class Encoder(nn.Module):148    def __init__(self, layer: TransformerLayer, n_layers: int):
149        super().__init__()Make copies of the transformer layer
151        self.layers = clone_module_list(layer, n_layers)Final normalization layer
153        self.norm = nn.LayerNorm([layer.size])155    def forward(self, x: torch.Tensor, mask: torch.Tensor):Run through each transformer layer
157        for layer in self.layers:
158            x = layer(x=x, mask=mask)Finally, normalize the vectors
160        return self.norm(x)163class Decoder(nn.Module):170    def __init__(self, layer: TransformerLayer, n_layers: int):
171        super().__init__()Make copies of the transformer layer
173        self.layers = clone_module_list(layer, n_layers)Final normalization layer
175        self.norm = nn.LayerNorm([layer.size])177    def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):Run through each transformer layer
179        for layer in self.layers:
180            x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)Finally, normalize the vectors
182        return self.norm(x)This predicts the tokens and gives the lof softmax of those. You don't need this if you are using nn.CrossEntropyLoss
.
185class Generator(nn.Module):195    def __init__(self, n_vocab: int, d_model: int):
196        super().__init__()
197        self.projection = nn.Linear(d_model, n_vocab)199    def forward(self, x):
200        return self.projection(x)203class EncoderDecoder(nn.Module):210    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
211        super().__init__()
212        self.encoder = encoder
213        self.decoder = decoder
214        self.src_embed = src_embed
215        self.tgt_embed = tgt_embed
216        self.generator = generatorThis was important from their code. Initialize parameters with Glorot / fan_avg.
220        for p in self.parameters():
221            if p.dim() > 1:
222                nn.init.xavier_uniform_(p)224    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):Run the source through encoder
226        enc = self.encode(src, src_mask)Run encodings and targets through decoder
228        return self.decode(enc, src_mask, tgt, tgt_mask)230    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
231        return self.encoder(self.src_embed(src), src_mask)233    def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
234        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)