mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			327 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			327 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: Configurable Transformer Components
 | |
| summary: These are configurable components that can be re-used quite easily.
 | |
| ---
 | |
| 
 | |
| # Configurable Transformer Components
 | |
| """
 | |
| import copy
 | |
| 
 | |
| import torch.nn as nn
 | |
| 
 | |
| from labml.configs import BaseConfigs, option, calculate, aggregate
 | |
| from labml_helpers.module import Module
 | |
| from .feed_forward import FeedForward
 | |
| from .mha import MultiHeadAttention
 | |
| from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \
 | |
|     Encoder, Decoder, Generator, EncoderDecoder
 | |
| 
 | |
| 
 | |
| class FeedForwardConfigs(BaseConfigs):
 | |
|     """
 | |
|     <a id="FFN">
 | |
|     ## FFN Configurations
 | |
|     </a>
 | |
| 
 | |
|     Creates a Position-wise FeedForward Network defined in
 | |
|     [`feed_forward.py`](feed_forward.html).
 | |
|     """
 | |
|     # Position-wise feedforward layer
 | |
|     ffn: FeedForward
 | |
|     # Number of features in the embedding
 | |
|     d_model: int
 | |
|     # Number of features in in the hidden layer
 | |
|     d_ff: int = 2048
 | |
|     # Dropout probability
 | |
|     dropout: float = 0.1
 | |
|     # Activation in position-wise feedforward layer
 | |
|     activation: nn.Module = 'ReLU'
 | |
|     # Whether the FFN layer should be gated
 | |
|     is_gated: bool = False
 | |
|     # Whether the first fully connected layer should have a learnable bias
 | |
|     bias1: bool = True
 | |
|     # Whether the second fully connected layer should have a learnable bias
 | |
|     bias2: bool = True
 | |
|     # Whether the fully connected layer for the gate should have a learnable bias
 | |
|     bias_gate: bool = False
 | |
|     # Predefined GLU variants
 | |
|     glu_variant: str = 'none'
 | |
| 
 | |
| 
 | |
| @option(FeedForwardConfigs.activation, 'ReLU')
 | |
| def _ffn_activation_relu():
 | |
|     """
 | |
|     ### ReLU activation
 | |
| 
 | |
|     $$\max(0, x)$$
 | |
|     """
 | |
|     return nn.ReLU()
 | |
| 
 | |
| 
 | |
| @option(FeedForwardConfigs.activation, 'GELU')
 | |
| def _ffn_activation_gelu():
 | |
|     """
 | |
|     ### GELU activation
 | |
| 
 | |
|     $$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
 | |
| 
 | |
|     It was introduced in paper [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
 | |
|     """
 | |
|     return nn.GELU()
 | |
| 
 | |
| 
 | |
| @option(FeedForwardConfigs.ffn, 'default')
 | |
| def _feed_forward(c: FeedForwardConfigs):
 | |
|     """
 | |
|     Initialize a [feed forward network](feed_forward.html)
 | |
|     """
 | |
|     return FeedForward(c.d_model, c.d_ff,
 | |
|                        dropout=c.dropout,
 | |
|                        activation=c.activation,
 | |
|                        is_gated=c.is_gated,
 | |
|                        bias1=c.bias1,
 | |
|                        bias2=c.bias2,
 | |
|                        bias_gate=c.bias_gate)
 | |
| 
 | |
| # ## GLU Variants
 | |
| # These are variants with gated hidden layers for the FFN
 | |
| # as introduced in paper [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202).
 | |
| # We have omitted the bias terms as specified in the paper.
 | |
| 
 | |
| # ### FFN with Gated Linear Units
 | |
| #
 | |
| # $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
 | |
| aggregate(FeedForwardConfigs.glu_variant, 'GLU',
 | |
|           (FeedForwardConfigs.is_gated, True),
 | |
|           (FeedForwardConfigs.bias1, False),
 | |
|           (FeedForwardConfigs.bias2, False),
 | |
|           (FeedForwardConfigs.bias_gate, False),
 | |
|           (FeedForwardConfigs.activation, nn.Sigmoid()))
 | |
| 
 | |
| # ### FFN with Bilinear hidden layer
 | |
| #
 | |
| # $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
 | |
| aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
 | |
|           (FeedForwardConfigs.is_gated, True),
 | |
|           (FeedForwardConfigs.bias1, False),
 | |
|           (FeedForwardConfigs.bias2, False),
 | |
|           (FeedForwardConfigs.bias_gate, False),
 | |
|           (FeedForwardConfigs.activation, nn.Identity()))
 | |
| 
 | |
| # ### FFN with ReLU gate
 | |
| #
 | |
| # $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
 | |
| aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
 | |
|           (FeedForwardConfigs.is_gated, True),
 | |
|           (FeedForwardConfigs.bias1, False),
 | |
|           (FeedForwardConfigs.bias2, False),
 | |
|           (FeedForwardConfigs.bias_gate, False),
 | |
|           (FeedForwardConfigs.activation, nn.ReLU()))
 | |
| 
 | |
| # ### FFN with GELU gate
 | |
| #
 | |
| # $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
 | |
| aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
 | |
|           (FeedForwardConfigs.is_gated, True),
 | |
|           (FeedForwardConfigs.bias1, False),
 | |
|           (FeedForwardConfigs.bias2, False),
 | |
|           (FeedForwardConfigs.bias_gate, False),
 | |
|           (FeedForwardConfigs.activation, nn.GELU()))
 | |
| 
 | |
| # ### FFN with Swish gate
 | |
| #
 | |
| # $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
 | |
| # where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
 | |
| aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
 | |
|           (FeedForwardConfigs.is_gated, True),
 | |
|           (FeedForwardConfigs.bias1, False),
 | |
|           (FeedForwardConfigs.bias2, False),
 | |
|           (FeedForwardConfigs.bias_gate, False),
 | |
|           (FeedForwardConfigs.activation, nn.SiLU()))
 | |
| 
 | |
| 
 | |
| class TransformerConfigs(BaseConfigs):
 | |
|     """
 | |
|     <a id="TransformerConfigs">
 | |
|     ## Transformer Configurations
 | |
|     </a>
 | |
|     
 | |
|     This defines configurations for a transformer.
 | |
|     The configurations are calculate using option functions.
 | |
|     These are lazy loaded and therefore only the necessary modules
 | |
|     are calculated.
 | |
|     """
 | |
|     # Number of attention heads
 | |
|     n_heads: int = 8
 | |
|     # Transformer embedding size
 | |
|     d_model: int = 512
 | |
|     # Number of layers
 | |
|     n_layers: int = 6
 | |
|     # Dropout probability
 | |
|     dropout: float = 0.1
 | |
|     # Number of tokens in the source vocabulary (for token embeddings)
 | |
|     n_src_vocab: int
 | |
|     # Number of tokens in the target vocabulary (to generate logits for prediction)
 | |
|     n_tgt_vocab: int
 | |
| 
 | |
|     # The encoder self attention
 | |
|     encoder_attn: MultiHeadAttention = 'mha'
 | |
|     # The decoder self attention
 | |
|     decoder_attn: MultiHeadAttention = 'mha'
 | |
|     # The decoder memory attention
 | |
|     decoder_mem_attn: MultiHeadAttention = 'mha'
 | |
| 
 | |
|     # Configurable Feedforward Layer
 | |
|     ffn: FeedForwardConfigs
 | |
| 
 | |
|     # Encoder layer
 | |
|     encoder_layer: TransformerLayer = 'default'
 | |
|     # Decoder layer
 | |
|     decoder_layer: TransformerLayer = 'default'
 | |
| 
 | |
|     # Encoder consisting of multiple encoder layers
 | |
|     encoder: Encoder = 'default'
 | |
|     # Encoder consisting of multiple decoder layers
 | |
|     decoder: Decoder = 'default'
 | |
| 
 | |
|     # Embedding layer for source
 | |
|     src_embed: Module = 'fixed_pos'
 | |
|     # Embedding layer for target (for decoder)
 | |
|     tgt_embed: Module = 'fixed_pos'
 | |
| 
 | |
|     # Logit generator for prediction
 | |
|     generator: Generator = 'default'
 | |
| 
 | |
|     # Encoder-decoder
 | |
|     encoder_decoder: EncoderDecoder
 | |
| 
 | |
| 
 | |
| # ### Multi-head Attention
 | |
| def _mha(c: TransformerConfigs):
 | |
|     return MultiHeadAttention(c.n_heads, c.d_model)
 | |
| 
 | |
| 
 | |
| calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
 | |
| calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
 | |
| calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
 | |
| 
 | |
| 
 | |
| # ### Relative Multi-head Attention
 | |
| def _relative_mha(c: TransformerConfigs):
 | |
|     from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
 | |
|     return RelativeMultiHeadAttention(c.n_heads, c.d_model)
 | |
| 
 | |
| 
 | |
| calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
 | |
| calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
 | |
| calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.ffn, 'default')
 | |
| def _feed_forward(c: TransformerConfigs):
 | |
|     """
 | |
|     Create feedforward layer configurations
 | |
|     """
 | |
|     conf = FeedForwardConfigs()
 | |
|     conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
 | |
|     conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
 | |
|     return conf
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.encoder_layer, 'default')
 | |
| def _encoder_layer(c: TransformerConfigs):
 | |
|     """
 | |
|     Encoder layer
 | |
|     """
 | |
|     return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
 | |
|                             src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
 | |
|                             dropout_prob=c.dropout)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.decoder_layer, 'default')
 | |
| def _decoder_layer(c: TransformerConfigs):
 | |
|     """
 | |
|     Decoder layer
 | |
|     """
 | |
|     return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
 | |
|                             src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
 | |
|                             dropout_prob=c.dropout)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.encoder, 'default')
 | |
| def _encoder(c: TransformerConfigs):
 | |
|     """
 | |
|     Encoder
 | |
|     """
 | |
|     return Encoder(c.encoder_layer, c.n_layers)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.decoder, 'default')
 | |
| def _decoder(c: TransformerConfigs):
 | |
|     """
 | |
|     Decoder
 | |
|     """
 | |
|     return Decoder(c.decoder_layer, c.n_layers)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.generator, 'default')
 | |
| def _generator(c: TransformerConfigs):
 | |
|     """
 | |
|     Logit generator
 | |
|     """
 | |
|     return Generator(c.n_tgt_vocab, c.d_model)
 | |
| 
 | |
| 
 | |
| # ### Fixed Positional Embeddings
 | |
| @option(TransformerConfigs.src_embed, 'fixed_pos')
 | |
| def _src_embed_with_positional(c: TransformerConfigs):
 | |
|     """
 | |
|     Source embedding with fixed positional encodings
 | |
|     """
 | |
|     return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.tgt_embed, 'fixed_pos')
 | |
| def _tgt_embed_with_positional(c: TransformerConfigs):
 | |
|     """
 | |
|     Target embedding with fixed positional encodings
 | |
|     """
 | |
|     return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
 | |
| 
 | |
| 
 | |
| # ### Learned Positional Embeddings
 | |
| @option(TransformerConfigs.src_embed, 'learned_pos')
 | |
| def _src_embed_with_learned_positional(c: TransformerConfigs):
 | |
|     """
 | |
|     Source embedding with learned positional encodings
 | |
|     """
 | |
|     return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.tgt_embed, 'learned_pos')
 | |
| def _tgt_embed_with_learned_positional(c: TransformerConfigs):
 | |
|     """
 | |
|     Target embedding with learned positional encodings
 | |
|     """
 | |
|     return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
 | |
| 
 | |
| 
 | |
| # ### No Positional Embeddings
 | |
| @option(TransformerConfigs.src_embed, 'no_pos')
 | |
| def _src_embed_without_positional(c: TransformerConfigs):
 | |
|     """
 | |
|     Source embedding without positional encodings
 | |
|     """
 | |
|     return nn.Embedding(c.n_src_vocab, c.d_model)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.tgt_embed, 'no_pos')
 | |
| def _tgt_embed_without_positional(c: TransformerConfigs):
 | |
|     return nn.Embedding(c.n_tgt_vocab, c.d_model)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.encoder_decoder, 'default')
 | |
| def _encoder_decoder(c: TransformerConfigs):
 | |
|     return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)
 | 
