mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 02:07:56 +08:00
326 lines
10 KiB
Python
326 lines
10 KiB
Python
"""
|
|
---
|
|
title: Configurable Transformer Components
|
|
summary: These are configurable components that can be re-used quite easily.
|
|
---
|
|
|
|
# Configurable Transformer Components
|
|
"""
|
|
import copy
|
|
|
|
import torch.nn as nn
|
|
|
|
from labml.configs import BaseConfigs, option, calculate, aggregate
|
|
from .feed_forward import FeedForward
|
|
from .mha import MultiHeadAttention
|
|
from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \
|
|
Encoder, Decoder, Generator, EncoderDecoder
|
|
|
|
|
|
class FeedForwardConfigs(BaseConfigs):
|
|
"""
|
|
<a id="FFN"></a>
|
|
|
|
## FFN Configurations
|
|
|
|
Creates a Position-wise FeedForward Network defined in
|
|
[`feed_forward.py`](feed_forward.html).
|
|
"""
|
|
# Position-wise feedforward layer
|
|
ffn: FeedForward
|
|
# Number of features in the embedding
|
|
d_model: int
|
|
# Number of features in in the hidden layer
|
|
d_ff: int = 2048
|
|
# Dropout probability
|
|
dropout: float = 0.1
|
|
# Activation in position-wise feedforward layer
|
|
activation: nn.Module = 'ReLU'
|
|
# Whether the FFN layer should be gated
|
|
is_gated: bool = False
|
|
# Whether the first fully connected layer should have a learnable bias
|
|
bias1: bool = True
|
|
# Whether the second fully connected layer should have a learnable bias
|
|
bias2: bool = True
|
|
# Whether the fully connected layer for the gate should have a learnable bias
|
|
bias_gate: bool = False
|
|
# Predefined GLU variants
|
|
glu_variant: str = 'none'
|
|
|
|
|
|
@option(FeedForwardConfigs.activation, 'ReLU')
|
|
def _ffn_activation_relu():
|
|
"""
|
|
### ReLU activation
|
|
|
|
$$\max(0, x)$$
|
|
"""
|
|
return nn.ReLU()
|
|
|
|
|
|
@option(FeedForwardConfigs.activation, 'GELU')
|
|
def _ffn_activation_gelu():
|
|
"""
|
|
### GELU activation
|
|
|
|
$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
|
|
|
|
It was introduced in paper [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
|
|
"""
|
|
return nn.GELU()
|
|
|
|
|
|
@option(FeedForwardConfigs.ffn, 'default')
|
|
def _feed_forward(c: FeedForwardConfigs):
|
|
"""
|
|
Initialize a [feed forward network](feed_forward.html)
|
|
"""
|
|
return FeedForward(c.d_model, c.d_ff,
|
|
dropout=c.dropout,
|
|
activation=c.activation,
|
|
is_gated=c.is_gated,
|
|
bias1=c.bias1,
|
|
bias2=c.bias2,
|
|
bias_gate=c.bias_gate)
|
|
|
|
# ## GLU Variants
|
|
# These are variants with gated hidden layers for the FFN
|
|
# as introduced in paper [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202).
|
|
# We have omitted the bias terms as specified in the paper.
|
|
|
|
# ### FFN with Gated Linear Units
|
|
#
|
|
# $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
|
|
aggregate(FeedForwardConfigs.glu_variant, 'GLU',
|
|
(FeedForwardConfigs.is_gated, True),
|
|
(FeedForwardConfigs.bias1, False),
|
|
(FeedForwardConfigs.bias2, False),
|
|
(FeedForwardConfigs.bias_gate, False),
|
|
(FeedForwardConfigs.activation, nn.Sigmoid()))
|
|
|
|
# ### FFN with Bilinear hidden layer
|
|
#
|
|
# $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
|
|
aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
|
|
(FeedForwardConfigs.is_gated, True),
|
|
(FeedForwardConfigs.bias1, False),
|
|
(FeedForwardConfigs.bias2, False),
|
|
(FeedForwardConfigs.bias_gate, False),
|
|
(FeedForwardConfigs.activation, nn.Identity()))
|
|
|
|
# ### FFN with ReLU gate
|
|
#
|
|
# $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
|
|
aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
|
|
(FeedForwardConfigs.is_gated, True),
|
|
(FeedForwardConfigs.bias1, False),
|
|
(FeedForwardConfigs.bias2, False),
|
|
(FeedForwardConfigs.bias_gate, False),
|
|
(FeedForwardConfigs.activation, nn.ReLU()))
|
|
|
|
# ### FFN with GELU gate
|
|
#
|
|
# $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
|
|
aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
|
|
(FeedForwardConfigs.is_gated, True),
|
|
(FeedForwardConfigs.bias1, False),
|
|
(FeedForwardConfigs.bias2, False),
|
|
(FeedForwardConfigs.bias_gate, False),
|
|
(FeedForwardConfigs.activation, nn.GELU()))
|
|
|
|
# ### FFN with Swish gate
|
|
#
|
|
# $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
|
|
# where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
|
|
aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
|
|
(FeedForwardConfigs.is_gated, True),
|
|
(FeedForwardConfigs.bias1, False),
|
|
(FeedForwardConfigs.bias2, False),
|
|
(FeedForwardConfigs.bias_gate, False),
|
|
(FeedForwardConfigs.activation, nn.SiLU()))
|
|
|
|
|
|
class TransformerConfigs(BaseConfigs):
|
|
"""
|
|
<a id="TransformerConfigs"></a>
|
|
|
|
## Transformer Configurations
|
|
|
|
This defines configurations for a transformer.
|
|
The configurations are calculate using option functions.
|
|
These are lazy loaded and therefore only the necessary modules
|
|
are calculated.
|
|
"""
|
|
# Number of attention heads
|
|
n_heads: int = 8
|
|
# Transformer embedding size
|
|
d_model: int = 512
|
|
# Number of layers
|
|
n_layers: int = 6
|
|
# Dropout probability
|
|
dropout: float = 0.1
|
|
# Number of tokens in the source vocabulary (for token embeddings)
|
|
n_src_vocab: int
|
|
# Number of tokens in the target vocabulary (to generate logits for prediction)
|
|
n_tgt_vocab: int
|
|
|
|
# The encoder self attention
|
|
encoder_attn: MultiHeadAttention = 'mha'
|
|
# The decoder self attention
|
|
decoder_attn: MultiHeadAttention = 'mha'
|
|
# The decoder memory attention
|
|
decoder_mem_attn: MultiHeadAttention = 'mha'
|
|
|
|
# Configurable Feedforward Layer
|
|
ffn: FeedForwardConfigs
|
|
|
|
# Encoder layer
|
|
encoder_layer: TransformerLayer = 'default'
|
|
# Decoder layer
|
|
decoder_layer: TransformerLayer = 'default'
|
|
|
|
# Encoder consisting of multiple encoder layers
|
|
encoder: Encoder = 'default'
|
|
# Encoder consisting of multiple decoder layers
|
|
decoder: Decoder = 'default'
|
|
|
|
# Embedding layer for source
|
|
src_embed: nn.Module = 'fixed_pos'
|
|
# Embedding layer for target (for decoder)
|
|
tgt_embed: nn.Module = 'fixed_pos'
|
|
|
|
# Logit generator for prediction
|
|
generator: Generator = 'default'
|
|
|
|
# Encoder-decoder
|
|
encoder_decoder: EncoderDecoder
|
|
|
|
|
|
# ### Multi-head Attention
|
|
def _mha(c: TransformerConfigs):
|
|
return MultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
|
|
|
|
|
|
calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
|
|
calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
|
|
calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
|
|
|
|
|
|
# ### Relative Multi-head Attention
|
|
def _relative_mha(c: TransformerConfigs):
|
|
from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
|
|
return RelativeMultiHeadAttention(c.n_heads, c.d_model)
|
|
|
|
|
|
calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
|
|
calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
|
|
calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)
|
|
|
|
|
|
@option(TransformerConfigs.ffn, 'default')
|
|
def _feed_forward(c: TransformerConfigs):
|
|
"""
|
|
Create feedforward layer configurations
|
|
"""
|
|
conf = FeedForwardConfigs()
|
|
conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
|
|
conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
|
|
return conf
|
|
|
|
|
|
@option(TransformerConfigs.encoder_layer, 'default')
|
|
def _encoder_layer(c: TransformerConfigs):
|
|
"""
|
|
Encoder layer
|
|
"""
|
|
return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
|
|
src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
|
|
dropout_prob=c.dropout)
|
|
|
|
|
|
@option(TransformerConfigs.decoder_layer, 'default')
|
|
def _decoder_layer(c: TransformerConfigs):
|
|
"""
|
|
Decoder layer
|
|
"""
|
|
return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
|
|
src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
|
|
dropout_prob=c.dropout)
|
|
|
|
|
|
@option(TransformerConfigs.encoder, 'default')
|
|
def _encoder(c: TransformerConfigs):
|
|
"""
|
|
Encoder
|
|
"""
|
|
return Encoder(c.encoder_layer, c.n_layers)
|
|
|
|
|
|
@option(TransformerConfigs.decoder, 'default')
|
|
def _decoder(c: TransformerConfigs):
|
|
"""
|
|
Decoder
|
|
"""
|
|
return Decoder(c.decoder_layer, c.n_layers)
|
|
|
|
|
|
@option(TransformerConfigs.generator, 'default')
|
|
def _generator(c: TransformerConfigs):
|
|
"""
|
|
Logit generator
|
|
"""
|
|
return Generator(c.n_tgt_vocab, c.d_model)
|
|
|
|
|
|
# ### Fixed Positional Embeddings
|
|
@option(TransformerConfigs.src_embed, 'fixed_pos')
|
|
def _src_embed_with_positional(c: TransformerConfigs):
|
|
"""
|
|
Source embedding with fixed positional encodings
|
|
"""
|
|
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
|
|
|
|
|
|
@option(TransformerConfigs.tgt_embed, 'fixed_pos')
|
|
def _tgt_embed_with_positional(c: TransformerConfigs):
|
|
"""
|
|
Target embedding with fixed positional encodings
|
|
"""
|
|
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
|
|
|
|
|
|
# ### Learned Positional Embeddings
|
|
@option(TransformerConfigs.src_embed, 'learned_pos')
|
|
def _src_embed_with_learned_positional(c: TransformerConfigs):
|
|
"""
|
|
Source embedding with learned positional encodings
|
|
"""
|
|
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
|
|
|
|
|
|
@option(TransformerConfigs.tgt_embed, 'learned_pos')
|
|
def _tgt_embed_with_learned_positional(c: TransformerConfigs):
|
|
"""
|
|
Target embedding with learned positional encodings
|
|
"""
|
|
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
|
|
|
|
|
|
# ### No Positional Embeddings
|
|
@option(TransformerConfigs.src_embed, 'no_pos')
|
|
def _src_embed_without_positional(c: TransformerConfigs):
|
|
"""
|
|
Source embedding without positional encodings
|
|
"""
|
|
return nn.Embedding(c.n_src_vocab, c.d_model)
|
|
|
|
|
|
@option(TransformerConfigs.tgt_embed, 'no_pos')
|
|
def _tgt_embed_without_positional(c: TransformerConfigs):
|
|
return nn.Embedding(c.n_tgt_vocab, c.d_model)
|
|
|
|
|
|
@option(TransformerConfigs.encoder_decoder, 'default')
|
|
def _encoder_decoder(c: TransformerConfigs):
|
|
return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)
|