9import copy
10
11import torch.nn as nn
12
13from labml.configs import BaseConfigs, option, calculate, aggregate
14from .feed_forward import FeedForward
15from .mha import MultiHeadAttention
16from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \
17 Encoder, Decoder, Generator, EncoderDecoder20class FeedForwardConfigs(BaseConfigs):Position-wise feedforward layer
30 ffn: FeedForwardNumber of features in the embedding
32 d_model: intNumber of features in in the hidden layer
34 d_ff: int = 2048Dropout probability
36 dropout: float = 0.1Activation in position-wise feedforward layer
38 activation: nn.Module = 'ReLU'Whether the FFN layer should be gated
40 is_gated: bool = FalseWhether the first fully connected layer should have a learnable bias
42 bias1: bool = TrueWhether the second fully connected layer should have a learnable bias
44 bias2: bool = TrueWhether the fully connected layer for the gate should have a learnable bias
46 bias_gate: bool = FalsePredefined GLU variants
48 glu_variant: str = 'none'51@option(FeedForwardConfigs.activation, 'ReLU')
52def _ffn_activation_relu():58 return nn.ReLU()61@option(FeedForwardConfigs.activation, 'GELU')
62def _ffn_activation_gelu():70 return nn.GELU()Initialize a feed forward network
73@option(FeedForwardConfigs.ffn, 'default')
74def _feed_forward(c: FeedForwardConfigs):78 return FeedForward(c.d_model, c.d_ff,
79 dropout=c.dropout,
80 activation=c.activation,
81 is_gated=c.is_gated,
82 bias1=c.bias1,
83 bias2=c.bias2,
84 bias_gate=c.bias_gate)These are variants with gated hidden layers for the FFN as introduced in paper GLU Variants Improve Transformer. We have omitted the bias terms as specified in the paper.
94aggregate(FeedForwardConfigs.glu_variant, 'GLU',
95 (FeedForwardConfigs.is_gated, True),
96 (FeedForwardConfigs.bias1, False),
97 (FeedForwardConfigs.bias2, False),
98 (FeedForwardConfigs.bias_gate, False),
99 (FeedForwardConfigs.activation, nn.Sigmoid()))104aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
105 (FeedForwardConfigs.is_gated, True),
106 (FeedForwardConfigs.bias1, False),
107 (FeedForwardConfigs.bias2, False),
108 (FeedForwardConfigs.bias_gate, False),
109 (FeedForwardConfigs.activation, nn.Identity()))114aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
115 (FeedForwardConfigs.is_gated, True),
116 (FeedForwardConfigs.bias1, False),
117 (FeedForwardConfigs.bias2, False),
118 (FeedForwardConfigs.bias_gate, False),
119 (FeedForwardConfigs.activation, nn.ReLU()))124aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
125 (FeedForwardConfigs.is_gated, True),
126 (FeedForwardConfigs.bias1, False),
127 (FeedForwardConfigs.bias2, False),
128 (FeedForwardConfigs.bias_gate, False),
129 (FeedForwardConfigs.activation, nn.GELU()))135aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
136 (FeedForwardConfigs.is_gated, True),
137 (FeedForwardConfigs.bias1, False),
138 (FeedForwardConfigs.bias2, False),
139 (FeedForwardConfigs.bias_gate, False),
140 (FeedForwardConfigs.activation, nn.SiLU()))This defines configurations for a transformer. The configurations are calculate using option functions. These are lazy loaded and therefore only the necessary modules are calculated.
143class TransformerConfigs(BaseConfigs):Number of attention heads
155 n_heads: int = 8Transformer embedding size
157 d_model: int = 512Number of layers
159 n_layers: int = 6Dropout probability
161 dropout: float = 0.1Number of tokens in the source vocabulary (for token embeddings)
163 n_src_vocab: intNumber of tokens in the target vocabulary (to generate logits for prediction)
165 n_tgt_vocab: intThe encoder self attention
168 encoder_attn: MultiHeadAttention = 'mha'The decoder self attention
170 decoder_attn: MultiHeadAttention = 'mha'The decoder memory attention
172 decoder_mem_attn: MultiHeadAttention = 'mha'Configurable Feedforward Layer
175 ffn: FeedForwardConfigsEncoder layer
178 encoder_layer: TransformerLayer = 'default'Decoder layer
180 decoder_layer: TransformerLayer = 'default'Encoder consisting of multiple encoder layers
183 encoder: Encoder = 'default'Encoder consisting of multiple decoder layers
185 decoder: Decoder = 'default'Embedding layer for source
188 src_embed: nn.Module = 'fixed_pos'Embedding layer for target (for decoder)
190 tgt_embed: nn.Module = 'fixed_pos'Logit generator for prediction
193 generator: Generator = 'default'Encoder-decoder
196 encoder_decoder: EncoderDecoder200def _mha(c: TransformerConfigs):
201 return MultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
202
203
204calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
205calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
206calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)210def _relative_mha(c: TransformerConfigs):
211 from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
212 return RelativeMultiHeadAttention(c.n_heads, c.d_model)
213
214
215calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
216calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
217calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)Create feedforward layer configurations
220@option(TransformerConfigs.ffn, 'default')
221def _feed_forward(c: TransformerConfigs):225 conf = FeedForwardConfigs()
226 conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
227 conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
228 return confEncoder layer
231@option(TransformerConfigs.encoder_layer, 'default')
232def _encoder_layer(c: TransformerConfigs):236 return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
237 src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
238 dropout_prob=c.dropout)Decoder layer
241@option(TransformerConfigs.decoder_layer, 'default')
242def _decoder_layer(c: TransformerConfigs):246 return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
247 src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
248 dropout_prob=c.dropout)Encoder
251@option(TransformerConfigs.encoder, 'default')
252def _encoder(c: TransformerConfigs):256 return Encoder(c.encoder_layer, c.n_layers)Decoder
259@option(TransformerConfigs.decoder, 'default')
260def _decoder(c: TransformerConfigs):264 return Decoder(c.decoder_layer, c.n_layers)Logit generator
267@option(TransformerConfigs.generator, 'default')
268def _generator(c: TransformerConfigs):272 return Generator(c.n_tgt_vocab, c.d_model)276@option(TransformerConfigs.src_embed, 'fixed_pos')
277def _src_embed_with_positional(c: TransformerConfigs):281 return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)Target embedding with fixed positional encodings
284@option(TransformerConfigs.tgt_embed, 'fixed_pos')
285def _tgt_embed_with_positional(c: TransformerConfigs):289 return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)293@option(TransformerConfigs.src_embed, 'learned_pos')
294def _src_embed_with_learned_positional(c: TransformerConfigs):298 return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)Target embedding with learned positional encodings
301@option(TransformerConfigs.tgt_embed, 'learned_pos')
302def _tgt_embed_with_learned_positional(c: TransformerConfigs):306 return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)310@option(TransformerConfigs.src_embed, 'no_pos')
311def _src_embed_without_positional(c: TransformerConfigs):315 return nn.Embedding(c.n_src_vocab, c.d_model)318@option(TransformerConfigs.tgt_embed, 'no_pos')
319def _tgt_embed_without_positional(c: TransformerConfigs):
320 return nn.Embedding(c.n_tgt_vocab, c.d_model)
321
322
323@option(TransformerConfigs.encoder_decoder, 'default')
324def _encoder_decoder(c: TransformerConfigs):
325 return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)