Configurable Transformer Components

9import copy
10
11import torch.nn as nn
12
13from labml.configs import BaseConfigs, option, calculate, aggregate
14from .feed_forward import FeedForward
15from .mha import MultiHeadAttention
16from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \
17    Encoder, Decoder, Generator, EncoderDecoder

FFN Configurations

Creates a Position-wise FeedForward Network defined in feed_forward.py .

20class FeedForwardConfigs(BaseConfigs):

Position-wise feedforward layer

30    ffn: FeedForward

Number of features in the embedding

32    d_model: int

Number of features in in the hidden layer

34    d_ff: int = 2048

Dropout probability

36    dropout: float = 0.1

Activation in position-wise feedforward layer

38    activation: nn.Module = 'ReLU'

Whether the FFN layer should be gated

40    is_gated: bool = False

Whether the first fully connected layer should have a learnable bias

42    bias1: bool = True

Whether the second fully connected layer should have a learnable bias

44    bias2: bool = True

Whether the fully connected layer for the gate should have a learnable bias

46    bias_gate: bool = False

Predefined GLU variants

48    glu_variant: str = 'none'

ReLU activation

51@option(FeedForwardConfigs.activation, 'ReLU')
52def _ffn_activation_relu():
58    return nn.ReLU()

GELU activation

where

It was introduced in paper Gaussian Error Linear Units.

61@option(FeedForwardConfigs.activation, 'GELU')
62def _ffn_activation_gelu():
70    return nn.GELU()

Initialize a feed forward network

73@option(FeedForwardConfigs.ffn, 'default')
74def _feed_forward(c: FeedForwardConfigs):
78    return FeedForward(c.d_model, c.d_ff,
79                       dropout=c.dropout,
80                       activation=c.activation,
81                       is_gated=c.is_gated,
82                       bias1=c.bias1,
83                       bias2=c.bias2,
84                       bias_gate=c.bias_gate)

GLU Variants

These are variants with gated hidden layers for the FFN as introduced in paper GLU Variants Improve Transformer. We have omitted the bias terms as specified in the paper.

FFN with Gated Linear Units

94aggregate(FeedForwardConfigs.glu_variant, 'GLU',
95          (FeedForwardConfigs.is_gated, True),
96          (FeedForwardConfigs.bias1, False),
97          (FeedForwardConfigs.bias2, False),
98          (FeedForwardConfigs.bias_gate, False),
99          (FeedForwardConfigs.activation, nn.Sigmoid()))

FFN with Bilinear hidden layer

104aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
105          (FeedForwardConfigs.is_gated, True),
106          (FeedForwardConfigs.bias1, False),
107          (FeedForwardConfigs.bias2, False),
108          (FeedForwardConfigs.bias_gate, False),
109          (FeedForwardConfigs.activation, nn.Identity()))

FFN with ReLU gate

114aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
115          (FeedForwardConfigs.is_gated, True),
116          (FeedForwardConfigs.bias1, False),
117          (FeedForwardConfigs.bias2, False),
118          (FeedForwardConfigs.bias_gate, False),
119          (FeedForwardConfigs.activation, nn.ReLU()))

FFN with GELU gate

124aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
125          (FeedForwardConfigs.is_gated, True),
126          (FeedForwardConfigs.bias1, False),
127          (FeedForwardConfigs.bias2, False),
128          (FeedForwardConfigs.bias_gate, False),
129          (FeedForwardConfigs.activation, nn.GELU()))

FFN with Swish gate

where

135aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
136          (FeedForwardConfigs.is_gated, True),
137          (FeedForwardConfigs.bias1, False),
138          (FeedForwardConfigs.bias2, False),
139          (FeedForwardConfigs.bias_gate, False),
140          (FeedForwardConfigs.activation, nn.SiLU()))

Transformer Configurations

This defines configurations for a transformer. The configurations are calculate using option functions. These are lazy loaded and therefore only the necessary modules are calculated.

143class TransformerConfigs(BaseConfigs):

Number of attention heads

155    n_heads: int = 8

Transformer embedding size

157    d_model: int = 512

Number of layers

159    n_layers: int = 6

Dropout probability

161    dropout: float = 0.1

Number of tokens in the source vocabulary (for token embeddings)

163    n_src_vocab: int

Number of tokens in the target vocabulary (to generate logits for prediction)

165    n_tgt_vocab: int

The encoder self attention

168    encoder_attn: MultiHeadAttention = 'mha'

The decoder self attention

170    decoder_attn: MultiHeadAttention = 'mha'

The decoder memory attention

172    decoder_mem_attn: MultiHeadAttention = 'mha'

Configurable Feedforward Layer

175    ffn: FeedForwardConfigs

Encoder layer

178    encoder_layer: TransformerLayer = 'default'

Decoder layer

180    decoder_layer: TransformerLayer = 'default'

Encoder consisting of multiple encoder layers

183    encoder: Encoder = 'default'

Encoder consisting of multiple decoder layers

185    decoder: Decoder = 'default'

Embedding layer for source

188    src_embed: nn.Module = 'fixed_pos'

Embedding layer for target (for decoder)

190    tgt_embed: nn.Module = 'fixed_pos'

Logit generator for prediction

193    generator: Generator = 'default'

Encoder-decoder

196    encoder_decoder: EncoderDecoder

Multi-head Attention

200def _mha(c: TransformerConfigs):
201    return MultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
202
203
204calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
205calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
206calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)

Relative Multi-head Attention

210def _relative_mha(c: TransformerConfigs):
211    from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
212    return RelativeMultiHeadAttention(c.n_heads, c.d_model)
213
214
215calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
216calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
217calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)

Create feedforward layer configurations

220@option(TransformerConfigs.ffn, 'default')
221def _feed_forward(c: TransformerConfigs):
225    conf = FeedForwardConfigs()
226    conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
227    conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
228    return conf

Encoder layer

231@option(TransformerConfigs.encoder_layer, 'default')
232def _encoder_layer(c: TransformerConfigs):
236    return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
237                            src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
238                            dropout_prob=c.dropout)

Decoder layer

241@option(TransformerConfigs.decoder_layer, 'default')
242def _decoder_layer(c: TransformerConfigs):
246    return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
247                            src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
248                            dropout_prob=c.dropout)

Encoder

251@option(TransformerConfigs.encoder, 'default')
252def _encoder(c: TransformerConfigs):
256    return Encoder(c.encoder_layer, c.n_layers)

Decoder

259@option(TransformerConfigs.decoder, 'default')
260def _decoder(c: TransformerConfigs):
264    return Decoder(c.decoder_layer, c.n_layers)

Logit generator

267@option(TransformerConfigs.generator, 'default')
268def _generator(c: TransformerConfigs):
272    return Generator(c.n_tgt_vocab, c.d_model)

Fixed Positional Embeddings

Source embedding with fixed positional encodings

276@option(TransformerConfigs.src_embed, 'fixed_pos')
277def _src_embed_with_positional(c: TransformerConfigs):
281    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)

Target embedding with fixed positional encodings

284@option(TransformerConfigs.tgt_embed, 'fixed_pos')
285def _tgt_embed_with_positional(c: TransformerConfigs):
289    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)

Learned Positional Embeddings

Source embedding with learned positional encodings

293@option(TransformerConfigs.src_embed, 'learned_pos')
294def _src_embed_with_learned_positional(c: TransformerConfigs):
298    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)

Target embedding with learned positional encodings

301@option(TransformerConfigs.tgt_embed, 'learned_pos')
302def _tgt_embed_with_learned_positional(c: TransformerConfigs):
306    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)

No Positional Embeddings

Source embedding without positional encodings

310@option(TransformerConfigs.src_embed, 'no_pos')
311def _src_embed_without_positional(c: TransformerConfigs):
315    return nn.Embedding(c.n_src_vocab, c.d_model)
318@option(TransformerConfigs.tgt_embed, 'no_pos')
319def _tgt_embed_without_positional(c: TransformerConfigs):
320    return nn.Embedding(c.n_tgt_vocab, c.d_model)
321
322
323@option(TransformerConfigs.encoder_decoder, 'default')
324def _encoder_decoder(c: TransformerConfigs):
325    return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)