""" --- title: Configurable Transformer Components summary: These are configurable components that can be re-used quite easily. --- # Configurable Transformer Components """ import copy import torch.nn as nn from labml.configs import BaseConfigs, option, calculate, aggregate from .feed_forward import FeedForward from .mha import MultiHeadAttention from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \ Encoder, Decoder, Generator, EncoderDecoder class FeedForwardConfigs(BaseConfigs): """ ## FFN Configurations Creates a Position-wise FeedForward Network defined in [`feed_forward.py`](feed_forward.html). """ # Position-wise feedforward layer ffn: FeedForward # Number of features in the embedding d_model: int # Number of features in in the hidden layer d_ff: int = 2048 # Dropout probability dropout: float = 0.1 # Activation in position-wise feedforward layer activation: nn.Module = 'ReLU' # Whether the FFN layer should be gated is_gated: bool = False # Whether the first fully connected layer should have a learnable bias bias1: bool = True # Whether the second fully connected layer should have a learnable bias bias2: bool = True # Whether the fully connected layer for the gate should have a learnable bias bias_gate: bool = False # Predefined GLU variants glu_variant: str = 'none' @option(FeedForwardConfigs.activation, 'ReLU') def _ffn_activation_relu(): """ ### ReLU activation $$\max(0, x)$$ """ return nn.ReLU() @option(FeedForwardConfigs.activation, 'GELU') def _ffn_activation_gelu(): """ ### GELU activation $$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$ It was introduced in paper [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415). """ return nn.GELU() @option(FeedForwardConfigs.ffn, 'default') def _feed_forward(c: FeedForwardConfigs): """ Initialize a [feed forward network](feed_forward.html) """ return FeedForward(c.d_model, c.d_ff, dropout=c.dropout, activation=c.activation, is_gated=c.is_gated, bias1=c.bias1, bias2=c.bias2, bias_gate=c.bias_gate) # ## GLU Variants # These are variants with gated hidden layers for the FFN # as introduced in paper [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202). # We have omitted the bias terms as specified in the paper. # ### FFN with Gated Linear Units # # $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$ aggregate(FeedForwardConfigs.glu_variant, 'GLU', (FeedForwardConfigs.is_gated, True), (FeedForwardConfigs.bias1, False), (FeedForwardConfigs.bias2, False), (FeedForwardConfigs.bias_gate, False), (FeedForwardConfigs.activation, nn.Sigmoid())) # ### FFN with Bilinear hidden layer # # $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$ aggregate(FeedForwardConfigs.glu_variant, 'Bilinear', (FeedForwardConfigs.is_gated, True), (FeedForwardConfigs.bias1, False), (FeedForwardConfigs.bias2, False), (FeedForwardConfigs.bias_gate, False), (FeedForwardConfigs.activation, nn.Identity())) # ### FFN with ReLU gate # # $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$ aggregate(FeedForwardConfigs.glu_variant, 'ReGLU', (FeedForwardConfigs.is_gated, True), (FeedForwardConfigs.bias1, False), (FeedForwardConfigs.bias2, False), (FeedForwardConfigs.bias_gate, False), (FeedForwardConfigs.activation, nn.ReLU())) # ### FFN with GELU gate # # $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$ aggregate(FeedForwardConfigs.glu_variant, 'GEGLU', (FeedForwardConfigs.is_gated, True), (FeedForwardConfigs.bias1, False), (FeedForwardConfigs.bias2, False), (FeedForwardConfigs.bias_gate, False), (FeedForwardConfigs.activation, nn.GELU())) # ### FFN with Swish gate # # $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$ # where $\text{Swish}_\beta(x) = x \sigma(\beta x)$ aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU', (FeedForwardConfigs.is_gated, True), (FeedForwardConfigs.bias1, False), (FeedForwardConfigs.bias2, False), (FeedForwardConfigs.bias_gate, False), (FeedForwardConfigs.activation, nn.SiLU())) class TransformerConfigs(BaseConfigs): """ ## Transformer Configurations This defines configurations for a transformer. The configurations are calculate using option functions. These are lazy loaded and therefore only the necessary modules are calculated. """ # Number of attention heads n_heads: int = 8 # Transformer embedding size d_model: int = 512 # Number of layers n_layers: int = 6 # Dropout probability dropout: float = 0.1 # Number of tokens in the source vocabulary (for token embeddings) n_src_vocab: int # Number of tokens in the target vocabulary (to generate logits for prediction) n_tgt_vocab: int # The encoder self attention encoder_attn: MultiHeadAttention = 'mha' # The decoder self attention decoder_attn: MultiHeadAttention = 'mha' # The decoder memory attention decoder_mem_attn: MultiHeadAttention = 'mha' # Configurable Feedforward Layer ffn: FeedForwardConfigs # Encoder layer encoder_layer: TransformerLayer = 'default' # Decoder layer decoder_layer: TransformerLayer = 'default' # Encoder consisting of multiple encoder layers encoder: Encoder = 'default' # Encoder consisting of multiple decoder layers decoder: Decoder = 'default' # Embedding layer for source src_embed: nn.Module = 'fixed_pos' # Embedding layer for target (for decoder) tgt_embed: nn.Module = 'fixed_pos' # Logit generator for prediction generator: Generator = 'default' # Encoder-decoder encoder_decoder: EncoderDecoder # ### Multi-head Attention def _mha(c: TransformerConfigs): return MultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout) calculate(TransformerConfigs.encoder_attn, 'mha', _mha) calculate(TransformerConfigs.decoder_attn, 'mha', _mha) calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha) # ### Relative Multi-head Attention def _relative_mha(c: TransformerConfigs): from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention return RelativeMultiHeadAttention(c.n_heads, c.d_model) calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha) calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha) calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha) @option(TransformerConfigs.ffn, 'default') def _feed_forward(c: TransformerConfigs): """ Create feedforward layer configurations """ conf = FeedForwardConfigs() conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model) conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout) return conf @option(TransformerConfigs.encoder_layer, 'default') def _encoder_layer(c: TransformerConfigs): """ Encoder layer """ return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn, src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn), dropout_prob=c.dropout) @option(TransformerConfigs.decoder_layer, 'default') def _decoder_layer(c: TransformerConfigs): """ Decoder layer """ return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn, src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn), dropout_prob=c.dropout) @option(TransformerConfigs.encoder, 'default') def _encoder(c: TransformerConfigs): """ Encoder """ return Encoder(c.encoder_layer, c.n_layers) @option(TransformerConfigs.decoder, 'default') def _decoder(c: TransformerConfigs): """ Decoder """ return Decoder(c.decoder_layer, c.n_layers) @option(TransformerConfigs.generator, 'default') def _generator(c: TransformerConfigs): """ Logit generator """ return Generator(c.n_tgt_vocab, c.d_model) # ### Fixed Positional Embeddings @option(TransformerConfigs.src_embed, 'fixed_pos') def _src_embed_with_positional(c: TransformerConfigs): """ Source embedding with fixed positional encodings """ return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab) @option(TransformerConfigs.tgt_embed, 'fixed_pos') def _tgt_embed_with_positional(c: TransformerConfigs): """ Target embedding with fixed positional encodings """ return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab) # ### Learned Positional Embeddings @option(TransformerConfigs.src_embed, 'learned_pos') def _src_embed_with_learned_positional(c: TransformerConfigs): """ Source embedding with learned positional encodings """ return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab) @option(TransformerConfigs.tgt_embed, 'learned_pos') def _tgt_embed_with_learned_positional(c: TransformerConfigs): """ Target embedding with learned positional encodings """ return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab) # ### No Positional Embeddings @option(TransformerConfigs.src_embed, 'no_pos') def _src_embed_without_positional(c: TransformerConfigs): """ Source embedding without positional encodings """ return nn.Embedding(c.n_src_vocab, c.d_model) @option(TransformerConfigs.tgt_embed, 'no_pos') def _tgt_embed_without_positional(c: TransformerConfigs): return nn.Embedding(c.n_tgt_vocab, c.d_model) @option(TransformerConfigs.encoder_decoder, 'default') def _encoder_decoder(c: TransformerConfigs): return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)