import math import torch import torch.nn as nn import torch.nn.functional as F from labml_helpers.module import Module from labml_nn.utils import clone_module_list from .mha import MultiHeadAttention from .positional_encoding import get_positional_encoding class EmbeddingsWithPositionalEncoding(Module): """ ## Embed tokens and add [fixed positional encoding](positional_encoding.html) """ def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000): super().__init__() self.linear = nn.Embedding(n_vocab, d_model) self.d_model = d_model self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len)) def __call__(self, x: torch.Tensor): pe = self.positional_encodings[:x.shape[0]].requires_grad_(False) return self.linear(x) * math.sqrt(self.d_model) + pe class EmbeddingsWithLearnedPositionalEncoding(Module): """ ## Embed tokens and add parameterized positional encodings """ def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000): super().__init__() self.linear = nn.Embedding(n_vocab, d_model) self.d_model = d_model self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True) def __call__(self, x: torch.Tensor): pe = self.positional_encodings[:x.shape[0]] return self.linear(x) * math.sqrt(self.d_model) + pe class FeedForward(Module): """ ## Position-wise feed-forward network with hidden layer """ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.layer1 = nn.Linear(d_model, d_ff) self.layer2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def __call__(self, x: torch.Tensor): x = self.layer1(x) x = F.relu(x) x = self.dropout(x) return self.layer2(x) class TransformerLayer(Module): """ ## Transformer Layer This can act as a encoder layer or a decoder layer. 🗒 Some implementations, including the paper seem to have differences in where the layer-normalization is done. Here we do a layer normalization before attention and feed-forward networks, and add the original residual vectors. Alternative is to do a layer normalization after adding the residuals. But we found this to be less stable when training. We found a detailed discussion about this in paper [On Layer Normalization in the Transformer Architecture](https://arxiv.org/abs/2002.04745). """ def __init__(self, *, d_model: int, self_attn: MultiHeadAttention, src_attn: MultiHeadAttention = None, feed_forward: FeedForward, dropout_prob: float): super().__init__() self.size = d_model self.self_attn = self_attn self.src_attn = src_attn self.feed_forward = feed_forward self.dropout = nn.Dropout(dropout_prob) self.norm_self_attn = nn.LayerNorm([d_model]) if self.src_attn is not None: self.norm_src_attn = nn.LayerNorm([d_model]) self.norm_ff = nn.LayerNorm([d_model]) def __call__(self, *, x: torch.Tensor, mask: torch.Tensor, src: torch.Tensor = None, src_mask: torch.Tensor = None): # Normalize the vectors before doing self attention z = self.norm_self_attn(x) # Run through self attention, i.e. keys and values are from self self_attn = self.self_attn(query=z, key=z, value=z, mask=mask) # Add the self attention results x = x + self.dropout(self_attn) # If a source is provided, get results from attention to source. # This is when you have a decoder layer that pays attention to # encoder outputs if src is not None: # Normalize vectors z = self.norm_src_attn(x) # Attention to source. i.e. keys and values are from source attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask) # Add the source attention results x = x + self.dropout(attn_src) # Normalize for feed-forward z = self.norm_ff(x) # Pass through the feed-forward network ff = self.feed_forward(z) # Add the feed-forward results back x = x + self.dropout(ff) return x class Encoder(Module): """ ## Transformer Encoder """ def __init__(self, layer: TransformerLayer, n_layers: int): super().__init__() # Make copies of the transformer layer self.layers = clone_module_list(layer, n_layers) self.norm = nn.LayerNorm([layer.size]) def __call__(self, x: torch.Tensor, mask: torch.Tensor): # Run through each transformer layer for layer in self.layers: x = layer(x=x, mask=mask) # Finally, normalize the vectors return self.norm(x) class Decoder(Module): """ ## Transformer Decoder """ def __init__(self, layer: TransformerLayer, n_layers: int): super().__init__() # Make copies of the transformer layer self.layers = clone_module_list(layer, n_layers) self.norm = nn.LayerNorm([layer.size]) def __call__(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor): # Run through each transformer layer for layer in self.layers: x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask) # Finally, normalize the vectors return self.norm(x) class Generator(Module): """ ## Generator This predicts the tokens and gives the lof softmax of those. You don't need this if you are using `nn.CrossEntropyLoss`. """ def __init__(self, n_vocab: int, d_model: int): super().__init__() self.projection = nn.Linear(d_model, n_vocab) def __call__(self, x): return F.log_softmax(self.projection(x), dim=-1) class EncoderDecoder(Module): """ ## Combined Encoder-Decoder """ def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module): super().__init__() self.encoder = encoder self.decoder = decoder self.src_embed = src_embed self.tgt_embed = tgt_embed self.generator = generator # This was important from their code. # Initialize parameters with Glorot / fan_avg. for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor): # Runs the source through encoder enc = self.encode(src, src_mask) # Run encodings and targets through decoder return self.decode(enc, src_mask, tgt, tgt_mask) def encode(self, src: torch.Tensor, src_mask: torch.Tensor): return self.encoder(self.src_embed(src), src_mask) def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor): return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)