diff --git a/docs/optimizers/noam_lr.png b/docs/optimizers/noam_lr.png new file mode 100644 index 00000000..b8945b93 Binary files /dev/null and b/docs/optimizers/noam_lr.png differ diff --git a/docs/optimizers/radam_r_t.png b/docs/optimizers/radam_r_t.png new file mode 100644 index 00000000..7b77edb9 Binary files /dev/null and b/docs/optimizers/radam_r_t.png differ diff --git a/docs/transformers/configs.html b/docs/transformers/configs.html index c9097bfc..b8e20f21 100644 --- a/docs/transformers/configs.html +++ b/docs/transformers/configs.html @@ -86,11 +86,15 @@
Creates a Position-wise FeedForward Network defined in
+feed_forward.py.
21class FeedForwardConfigs(BaseConfigs):Position-wise feedforward layer
23 ffn: FeedForward31 ffn: FeedForwardNumber of features in the embedding
25 d_model: int33 d_model: intNumber of features in in the hidden layer
27 d_ff: int = 204835 d_ff: int = 2048Dropout probability
29 dropout: float = 0.137 dropout: float = 0.1Activation in position-wise feedforward layer
31 activation: nn.Module = 'ReLU'39 activation: nn.Module = 'ReLU'Whether the FFN layer should be gated
33 is_gated: bool = False41 is_gated: bool = FalseWhether the first fully connected layer should have a learnable bias
35 bias1: bool = True43 bias1: bool = TrueWhether the second fully connected layer should have a learnable bias
37 bias2: bool = True45 bias2: bool = TrueWhether the fully connected layer for the gate should have a learnable bias
39 bias_gate: bool = False47 bias_gate: bool = FalsePredefined GLU variants
41 glu_variant: str = 'none'49 glu_variant: str = 'none'44@option(FeedForwardConfigs.activation, 'ReLU')
-45def _ffn_activation_relu():52@option(FeedForwardConfigs.activation, 'ReLU')
+53def _ffn_activation_relu():49 return nn.ReLU()59 return nn.ReLU()GELU activation
++ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
+It was introduced in paper Gaussian Error Linear Units.
52@option(FeedForwardConfigs.activation, 'GELU')
-53def _ffn_activation_gelu():62@option(FeedForwardConfigs.activation, 'GELU')
+63def _ffn_activation_gelu():57 return nn.GELU()71 return nn.GELU()60@option(FeedForwardConfigs.ffn, 'default')
-61def _feed_forward(c: FeedForwardConfigs):74@option(FeedForwardConfigs.ffn, 'default')
+75def _feed_forward(c: FeedForwardConfigs):65 return FeedForward(c.d_model, c.d_ff,
-66 dropout=c.dropout,
-67 activation=c.activation,
-68 is_gated=c.is_gated,
-69 bias1=c.bias1,
-70 bias2=c.bias2,
-71 bias_gate=c.bias_gate)
-72
-73
-74aggregate(FeedForwardConfigs.glu_variant, 'GLU',
-75 (FeedForwardConfigs.is_gated, True),
-76 (FeedForwardConfigs.bias1, False),
-77 (FeedForwardConfigs.bias2, False),
-78 (FeedForwardConfigs.bias_gate, False),
-79 (FeedForwardConfigs.activation, nn.Sigmoid()))
-80
-81aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
-82 (FeedForwardConfigs.is_gated, True),
-83 (FeedForwardConfigs.bias1, False),
-84 (FeedForwardConfigs.bias2, False),
-85 (FeedForwardConfigs.bias_gate, False),
-86 (FeedForwardConfigs.activation, nn.Identity()))
-87aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
-88 (FeedForwardConfigs.is_gated, True),
-89 (FeedForwardConfigs.bias1, False),
-90 (FeedForwardConfigs.bias2, False),
-91 (FeedForwardConfigs.bias_gate, False),
-92 (FeedForwardConfigs.activation, nn.ReLU()))
-93aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
-94 (FeedForwardConfigs.is_gated, True),
-95 (FeedForwardConfigs.bias1, False),
-96 (FeedForwardConfigs.bias2, False),
-97 (FeedForwardConfigs.bias_gate, False),
-98 (FeedForwardConfigs.activation, nn.GELU()))
-99aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
-100 (FeedForwardConfigs.is_gated, True),
-101 (FeedForwardConfigs.bias1, False),
-102 (FeedForwardConfigs.bias2, False),
-103 (FeedForwardConfigs.bias_gate, False),
-104 (FeedForwardConfigs.activation, nn.SiLU()))79 return FeedForward(c.d_model, c.d_ff,
+80 dropout=c.dropout,
+81 activation=c.activation,
+82 is_gated=c.is_gated,
+83 bias1=c.bias1,
+84 bias2=c.bias2,
+85 bias_gate=c.bias_gate)These are variants with gated hidden layers for the FFN +as introduced in paper GLU Variants Improve Transformer. +We have omitted the bias terms as specified in the paper.
+95aggregate(FeedForwardConfigs.glu_variant, 'GLU',
+96 (FeedForwardConfigs.is_gated, True),
+97 (FeedForwardConfigs.bias1, False),
+98 (FeedForwardConfigs.bias2, False),
+99 (FeedForwardConfigs.bias_gate, False),
+100 (FeedForwardConfigs.activation, nn.Sigmoid()))105aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
+106 (FeedForwardConfigs.is_gated, True),
+107 (FeedForwardConfigs.bias1, False),
+108 (FeedForwardConfigs.bias2, False),
+109 (FeedForwardConfigs.bias_gate, False),
+110 (FeedForwardConfigs.activation, nn.Identity()))115aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
+116 (FeedForwardConfigs.is_gated, True),
+117 (FeedForwardConfigs.bias1, False),
+118 (FeedForwardConfigs.bias2, False),
+119 (FeedForwardConfigs.bias_gate, False),
+120 (FeedForwardConfigs.activation, nn.ReLU()))125aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
+126 (FeedForwardConfigs.is_gated, True),
+127 (FeedForwardConfigs.bias1, False),
+128 (FeedForwardConfigs.bias2, False),
+129 (FeedForwardConfigs.bias_gate, False),
+130 (FeedForwardConfigs.activation, nn.GELU()))136aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
+137 (FeedForwardConfigs.is_gated, True),
+138 (FeedForwardConfigs.bias1, False),
+139 (FeedForwardConfigs.bias2, False),
+140 (FeedForwardConfigs.bias_gate, False),
+141 (FeedForwardConfigs.activation, nn.SiLU()))107class TransformerConfigs(BaseConfigs):Number of attention heads
-119 n_heads: int = 8Transformer embedding size
-121 d_model: int = 512Number of layers
-123 n_layers: int = 6Dropout probability
-125 dropout: float = 0.1Number of tokens in the source vocabulary (for token embeddings)
-127 n_src_vocab: intNumber of tokens in the target vocabulary (to generate logits for prediction)
-129 n_tgt_vocab: int144class TransformerConfigs(BaseConfigs):132 encoder_attn: MultiHeadAttention = 'mha'156 n_heads: int = 8134 decoder_attn: MultiHeadAttention = 'mha'158 d_model: int = 512136 decoder_mem_attn: MultiHeadAttention = 'mha'160 n_layers: int = 6139 ffn: FeedForwardConfigs162 dropout: float = 0.1Encoder layer
+Number of tokens in the source vocabulary (for token embeddings)
142 encoder_layer: TransformerLayer = 'default'164 n_src_vocab: intDecoder layer
+Number of tokens in the target vocabulary (to generate logits for prediction)
144 decoder_layer: TransformerLayer = 'default'166 n_tgt_vocab: intEncoder consisting of multiple encoder layers
+The encoder self attention
147 encoder: Encoder = 'default'169 encoder_attn: MultiHeadAttention = 'mha'Encoder consisting of multiple decoder layers
+The decoder self attention
149 decoder: Decoder = 'default'171 decoder_attn: MultiHeadAttention = 'mha'152 src_embed: Module = 'fixed_pos'173 decoder_mem_attn: MultiHeadAttention = 'mha'Embedding layer for target (for decoder)
+Configurable Feedforward Layer
154 tgt_embed: Module = 'fixed_pos'176 ffn: FeedForwardConfigs157 generator: Generator = 'default'179 encoder_layer: TransformerLayer = 'default'160 encoder_decoder: EncoderDecoder181 decoder_layer: TransformerLayer = 'default'Encoder consisting of multiple encoder layers
164def _mha(c: TransformerConfigs):
-165 return MultiHeadAttention(c.n_heads, c.d_model)
-166
-167
-168calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
-169calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
-170calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)184 encoder: Encoder = 'default'Encoder consisting of multiple decoder layers
174def _relative_mha(c: TransformerConfigs):
-175 from .relative_mha import RelativeMultiHeadAttention
-176 return RelativeMultiHeadAttention(c.n_heads, c.d_model)
-177
-178
-179calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
-180calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
-181calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)186 decoder: Decoder = 'default'184@option(TransformerConfigs.ffn, 'default')
-185def _feed_forward(c: TransformerConfigs):189 src_embed: Module = 'fixed_pos'Embedding layer for target (for decoder)
189 conf = FeedForwardConfigs()
-190 conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
-191 conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
-192 return conf191 tgt_embed: Module = 'fixed_pos'195@option(TransformerConfigs.encoder_layer, 'default')
-196def _encoder_layer(c: TransformerConfigs):194 generator: Generator = 'default'Encoder-decoder
200 return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
-201 src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
-202 dropout_prob=c.dropout)197 encoder_decoder: EncoderDecoder205@option(TransformerConfigs.decoder_layer, 'default')
-206def _decoder_layer(c: TransformerConfigs):201def _mha(c: TransformerConfigs):
+202 return MultiHeadAttention(c.n_heads, c.d_model)
+203
+204
+205calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
+206calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
+207calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)210 return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
-211 src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
-212 dropout_prob=c.dropout)211def _relative_mha(c: TransformerConfigs):
+212 from .relative_mha import RelativeMultiHeadAttention
+213 return RelativeMultiHeadAttention(c.n_heads, c.d_model)
+214
+215
+216calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
+217calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
+218calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)215@option(TransformerConfigs.encoder, 'default')
-216def _encoder(c: TransformerConfigs):221@option(TransformerConfigs.ffn, 'default')
+222def _feed_forward(c: TransformerConfigs):220 return Encoder(c.encoder_layer, c.n_layers)226 conf = FeedForwardConfigs()
+227 conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
+228 conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
+229 return conf223@option(TransformerConfigs.decoder, 'default')
-224def _decoder(c: TransformerConfigs):232@option(TransformerConfigs.encoder_layer, 'default')
+233def _encoder_layer(c: TransformerConfigs):228 return Decoder(c.decoder_layer, c.n_layers)237 return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
+238 src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
+239 dropout_prob=c.dropout)231@option(TransformerConfigs.generator, 'default')
-232def _generator(c: TransformerConfigs):242@option(TransformerConfigs.decoder_layer, 'default')
+243def _decoder_layer(c: TransformerConfigs):236 return Generator(c.n_tgt_vocab, c.d_model)247 return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
+248 src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
+249 dropout_prob=c.dropout)Source embedding with fixed positional encodings
+Encoder
240@option(TransformerConfigs.src_embed, 'fixed_pos')
-241def _src_embed_with_positional(c: TransformerConfigs):252@option(TransformerConfigs.encoder, 'default')
+253def _encoder(c: TransformerConfigs):245 return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)257 return Encoder(c.encoder_layer, c.n_layers)Target embedding with fixed positional encodings
+Decoder
248@option(TransformerConfigs.tgt_embed, 'fixed_pos')
-249def _tgt_embed_with_positional(c: TransformerConfigs):260@option(TransformerConfigs.decoder, 'default')
+261def _decoder(c: TransformerConfigs):253 return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)265 return Decoder(c.decoder_layer, c.n_layers)Source embedding with learned positional encodings
+Logit generator
257@option(TransformerConfigs.src_embed, 'learned_pos')
-258def _src_embed_with_learned_positional(c: TransformerConfigs):268@option(TransformerConfigs.generator, 'default')
+269def _generator(c: TransformerConfigs):262 return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)273 return Generator(c.n_tgt_vocab, c.d_model)Target embedding with learned positional encodings
+Source embedding with fixed positional encodings
265@option(TransformerConfigs.tgt_embed, 'learned_pos')
-266def _tgt_embed_with_learned_positional(c: TransformerConfigs):277@option(TransformerConfigs.src_embed, 'fixed_pos')
+278def _src_embed_with_positional(c: TransformerConfigs):270 return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)282 return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)Source embedding without positional encodings
+Target embedding with fixed positional encodings
274@option(TransformerConfigs.src_embed, 'no_pos')
-275def _src_embed_without_positional(c: TransformerConfigs):285@option(TransformerConfigs.tgt_embed, 'fixed_pos')
+286def _tgt_embed_with_positional(c: TransformerConfigs):279 return nn.Embedding(c.n_src_vocab, c.d_model)290 return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)294@option(TransformerConfigs.src_embed, 'learned_pos')
+295def _src_embed_with_learned_positional(c: TransformerConfigs):282@option(TransformerConfigs.tgt_embed, 'no_pos')
-283def _tgt_embed_without_positional(c: TransformerConfigs):
-284 return nn.Embedding(c.n_tgt_vocab, c.d_model)
-285
-286
-287@option(TransformerConfigs.encoder_decoder, 'default')
-288def _encoder_decoder(c: TransformerConfigs):
-289 return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)299 return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)Target embedding with learned positional encodings
+302@option(TransformerConfigs.tgt_embed, 'learned_pos')
+303def _tgt_embed_with_learned_positional(c: TransformerConfigs):307 return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)311@option(TransformerConfigs.src_embed, 'no_pos')
+312def _src_embed_without_positional(c: TransformerConfigs):316 return nn.Embedding(c.n_src_vocab, c.d_model)319@option(TransformerConfigs.tgt_embed, 'no_pos')
+320def _tgt_embed_without_positional(c: TransformerConfigs):
+321 return nn.Embedding(c.n_tgt_vocab, c.d_model)
+322
+323
+324@option(TransformerConfigs.encoder_decoder, 'default')
+325def _encoder_decoder(c: TransformerConfigs):
+326 return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)Sometimes the GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
+This is a generic implementation that supports different variants including +Gated Linear Units (GLU). +We have also implemented experiments on these:
+26import torch
-27from torch import nn as nn
-28
-29from labml_helpers.module import Module35import torch
+36from torch import nn as nn
+37
+38from labml_helpers.module import Module32class FeedForward(Module):41class FeedForward(Module):37 def __init__(self, d_model: int, d_ff: int,
-38 dropout: float = 0.1,
-39 activation=nn.ReLU(),
-40 is_gated: bool = False,
-41 bias1: bool = True,
-42 bias2: bool = True,
-43 bias_gate: bool = True):46 def __init__(self, d_model: int, d_ff: int,
+47 dropout: float = 0.1,
+48 activation=nn.ReLU(),
+49 is_gated: bool = False,
+50 bias1: bool = True,
+51 bias2: bool = True,
+52 bias_gate: bool = True):53 super().__init__()
-54 self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
-55 self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
-56 self.dropout = nn.Dropout(dropout)
-57 self.activation = activation
-58 self.is_gated = is_gated
-59 if is_gated:
-60 self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)62 super().__init__()Layer one parameterized by weight $W_1$ and bias $b_1$
+64 self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)Layer one parameterized by weight $W_1$ and bias $b_1$
+66 self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)Hidden layer dropout
+68 self.dropout = nn.Dropout(dropout)Activation function $f$
+70 self.activation = activationWhether there is a gate
+72 self.is_gated = is_gated
+73 if is_gated:If there is a gate the linear layer to transform inputs to +be multiplied by the gate, parameterized by weight $V$ and bias $c$
+76 self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)62 def __call__(self, x: torch.Tensor):
-63 g = self.activation(self.layer1(x))
-64 if self.is_gated:
-65 x = g * self.linear_v(x)
-66 else:
-67 x = g
-68 x = self.dropout(x)
-69 return self.layer2(x)78 def __call__(self, x: torch.Tensor):$f(x W_1 + b_1)$
+80 g = self.activation(self.layer1(x))If gated, $f(x W_1 + b_1) \otimes (x V + b) $
+82 if self.is_gated:
+83 x = g * self.linear_v(x)Otherwise
+85 else:
+86 x = gApply dropout
+88 x = self.dropout(x)$(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$ +depending on whether it is gated
+91 return self.layer2(x)This trains a simple transformer model for auto-regression.
+This trains a simple transformer model for auto-regression.
+We try different variants for the position-wise feedforward network.
+The reusable & configurable are defined in configs.py.
14import torch
-15from labml import experiment
-16from labml.configs import option
-17from labml.utils.pytorch import get_modules
-18from labml_helpers.module import Module
-19
-20from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
-21from labml_nn.transformers import Encoder, Generator, TransformerConfigs
-22from labml_nn.transformers.utils import subsequent_mask16import torch
+17from labml import experiment
+18from labml.configs import option
+19from labml.utils.pytorch import get_modules
+20from labml_helpers.module import Module
+21
+22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+23from labml_nn.transformers import Encoder, Generator, TransformerConfigs
+24from labml_nn.transformers.utils import subsequent_mask25class AutoregressiveModel(Module):27class AutoregressiveModel(Module):30 def __init__(self, src_embed: Module, encoder: Encoder, generator: Generator):
-31 super().__init__()32 def __init__(self, src_embed: Module, encoder: Encoder, generator: Generator):
+33 super().__init__()Token embedding module
33 self.src_embed = src_embed35 self.src_embed = src_embedTransformer based encoder
35 self.encoder = encoder37 self.encoder = encoder38 self.generator = generator40 self.generator = generatorThis will be initialized on the first call
40 self.src_mask = None42 self.src_mask = None42 def __call__(self, src: torch.Tensor):44 def __call__(self, src: torch.Tensor):Create subsequent mask, so that the transformer can only pay attention to past tokens.
44 if self.src_mask is None or self.src_mask.size(0) != len(src):
-45 self.src_mask = subsequent_mask(len(src)).to(src.device)46 if self.src_mask is None or self.src_mask.size(0) != len(src):
+47 self.src_mask = subsequent_mask(len(src)).to(src.device)Embed the tokens (src) and run it through the the transformer
47 res = self.encoder(self.src_embed(src), self.src_mask)49 res = self.encoder(self.src_embed(src), self.src_mask)Generate logits of the next token
49 return self.generator(res), None51 return self.generator(res), NoneThe default configs can and will be over-ridden when we start the experiment
52class Configs(NLPAutoRegressionConfigs):54class Configs(NLPAutoRegressionConfigs):59 transformer: TransformerConfigs
-60 model: AutoregressiveModel61 transformer: TransformerConfigs
+62 model: AutoregressiveModelInitialize the auto-regressive model
63@option(Configs.model)
-64def autoregressive_model(c: Configs):65@option(Configs.model)
+66def autoregressive_model(c: Configs):68 m = AutoregressiveModel(c.transformer.src_embed, c.transformer.encoder, c.transformer.generator)
-69 return m.to(c.device)70 m = AutoregressiveModel(c.transformer.src_embed, c.transformer.encoder, c.transformer.generator)
+71 return m.to(c.device)Initialize the configurable transformer encoder for our autoregressive model
+Initialize the configurable transformer encoder for our autoregressive model.
72@option(Configs.transformer)
-73def transformer_c(c: Configs):74@option(Configs.transformer)
+75def transformer_c(c: Configs):77 tc = TransformerConfigs()
-78 tc.n_src_vocab = c.n_tokens
-79 tc.n_tgt_vocab = c.n_tokens
-80
-81 return tc79 tc = TransformerConfigs()
+80 tc.n_src_vocab = c.n_tokens
+81 tc.n_tgt_vocab = c.n_tokens
+82
+83 return tc84def main():86def main():Create experiment
86 experiment.create(name="glu_variants")88 experiment.create(name="glu_variants")Create configs
88 conf = Configs()90 conf = Configs()Load configurations
90 experiment.configs(conf,92 experiment.configs(conf,A dictionary of configurations to override
92 {'tokenizer': 'character',
-93 'prompt_separator': '',
-94 'prompt': 'It is ',
-95 'text': 'tiny_shakespeare',
-96
-97 'optimizer.optimizer': 'Noam',
-98 'optimizer.learning_rate': 1.,
-99 'optimizer.d_model': 256,
-100
-101 'seq_len': 1024,
-102 'epochs': 128,
-103 'batch_size': 6,
-104 'inner_iterations': 10,94 {'tokenizer': 'character',
+95 'prompt_separator': '',
+96 'prompt': 'It is ',
+97 'text': 'tiny_shakespeare',
+98
+99 'optimizer.optimizer': 'Noam',
+100 'optimizer.learning_rate': 1.,
+101 'optimizer.d_model': 256,
+102
+103 'seq_len': 1024,
+104 'epochs': 128,
+105 'batch_size': 6,
+106 'inner_iterations': 10,GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU
+These are defined in the configurable FFN +implementation
107 'transformer.ffn.glu_variant': 'Bilinear',112 'transformer.ffn.glu_variant': 'Bilinear',Transformer configurations
110 'transformer.d_model': 256,
-111 'transformer.ffn.d_ff': 1024,
-112 'transformer.n_heads': 8,
-113 'transformer.n_layers': 6})115 'transformer.d_model': 256,
+116 'transformer.ffn.d_ff': 1024,
+117 'transformer.n_heads': 8,
+118 'transformer.n_layers': 6})This is needed to initialize models
116 conf.n_tokens = conf.text.n_tokens121 conf.n_tokens = conf.text.n_tokensSet models for saving and loading
119 experiment.add_pytorch_models(get_modules(conf))124 experiment.add_pytorch_models(get_modules(conf))Start the experiment
122 with experiment.start():127 with experiment.start():TrainValidConfigs.run
124 conf.run()
-125
-126
-127if __name__ == '__main__':
-128 main()129 conf.run()
+130
+131
+132if __name__ == '__main__':
+133 main()This trains a simple transformer model for auto-regression.
+This trains a simple transformer model for auto-regression. +We try different variants for the position-wise feedforward network.
+This is a simpler implementation that doesn’t use labml.configs module.
+We decided to write a simpler implementation to make it easier readers who are not familiar.
13import dataclasses
-14
-15import torch
-16from torch import nn
-17from torch.utils.data import Dataset, DataLoader
+ 17import dataclasses
18
-19from labml import experiment, lab, tracker, monit, logger
-20from labml.logger import Text
-21from labml.utils.download import download_file
-22from labml_nn.experiments.nlp_autoregression import transpose_batch
-23from labml_nn.optimizers.noam import Noam
-24from labml_nn.transformers import Encoder, MultiHeadAttention
-25from labml_nn.transformers.feed_forward import FeedForward
-26from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
-27from labml_nn.transformers.utils import subsequent_mask
+19import torch
+20from torch import nn
+21from torch.utils.data import Dataset, DataLoader
+22
+23from labml import experiment, lab, tracker, monit, logger
+24from labml.logger import Text
+25from labml.utils.download import download_file
+26from labml_nn.experiments.nlp_autoregression import transpose_batch
+27from labml_nn.optimizers.noam import Noam
+28from labml_nn.transformers import Encoder, MultiHeadAttention
+29from labml_nn.transformers.feed_forward import FeedForward
+30from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
+31from labml_nn.transformers.utils import subsequent_mask30class AutoregressiveModel(nn.Module):34class AutoregressiveModel(nn.Module):35 def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
-36 super().__init__()39 def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
+40 super().__init__()Token embedding module
38 self.src_embed = src_embed42 self.src_embed = src_embedTransformer based encoder
40 self.encoder = encoder44 self.encoder = encoder43 self.generator = generator47 self.generator = generatorThis will be initialized on the first call
45 self.src_mask = None49 self.src_mask = None47 def __call__(self, src: torch.Tensor):51 def __call__(self, src: torch.Tensor):Create subsequent mask, so that the transformer can only pay attention to past tokens.
49 if self.src_mask is None or self.src_mask.size(0) != len(src):
-50 self.src_mask = subsequent_mask(len(src)).to(src.device)53 if self.src_mask is None or self.src_mask.size(0) != len(src):
+54 self.src_mask = subsequent_mask(len(src)).to(src.device)Embed the tokens (src) and run it through the the transformer
52 res = self.encoder(self.src_embed(src), self.src_mask)56 res = self.encoder(self.src_embed(src), self.src_mask)Generate logits of the next token
54 return self.generator(res)58 return self.generator(res)57@dataclasses.dataclass
-58class Configs:
-59 d_model: int = 512
-60 seq_len: int = 128
-61 batch_size: int = 32
-62 n_layers: int = 6
-63 n_heads: int = 8
-64 dropout: float = 0.1
-65 d_ff: int = 2048
-66 glu_variant: str = 'GLU'
-67 epochs: int = 5
-68 grad_norm_clip: float = 0.5
-69
-70
-71class TinyShakespeareDataset(Dataset):
-72 def __init__(self, seq_len: int):
-73 path = lab.get_data_path() / 'tiny_shakespeare.txt'
-74 download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
-75 with open(str(path), 'r') as f:
-76 text = f.read()
-77
-78 chars = list(set(text))
-79 self.stoi = {c: i for i, c in enumerate(chars)}
-80 self.itos = {i: c for i, c in enumerate(chars)}
-81 self.seq_len = seq_len
-82 self.data = self.text_to_i(text)
-83
-84 def text_to_i(self, text: str):
-85 return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
-86
-87 def __len__(self):
-88 return len(self.data) - self.seq_len - 1
-89
-90 def __getitem__(self, idx):
-91 return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
-92
-93
-94class Trainer:
-95 def __init__(self, configs: Configs):
-96 self.device = torch.device('cpu')
-97 if torch.cuda.is_available():
-98 self.device = torch.device('cuda:0')
-99 self.dataset = TinyShakespeareDataset(configs.seq_len)
-100 self.dataloader = DataLoader(self.dataset, batch_size=configs.batch_size, collate_fn=transpose_batch,
-101 shuffle=True)
-102
-103 if configs.glu_variant == 'GLU':
-104 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
-105 elif configs.glu_variant == 'Bilinear':
-106 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
-107 elif configs.glu_variant == 'ReGLU':
-108 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
-109 elif configs.glu_variant == 'GEGLU':
-110 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
-111 elif configs.glu_variant == 'SwiGLU':
-112 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
-113 elif configs.glu_variant == 'ReLU':
-114 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
-115 elif configs.glu_variant == 'GELU':
-116 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
-117 else:
-118 raise ValueError(f'Unknown variant {configs.glu_variant}')
-119
-120 n_chars = len(self.dataset.stoi)
-121 self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
-122 Encoder(TransformerLayer(
-123 d_model=configs.d_model,
-124 self_attn=MultiHeadAttention(configs.n_heads, configs.d_model,
-125 configs.dropout),
-126 src_attn=None,
-127 feed_forward=ffn,
-128 dropout_prob=configs.dropout
-129 ), configs.n_layers),
-130 nn.Linear(configs.d_model, n_chars))
-131 self.model.to(self.device)
-132
-133 self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
-134
-135 self.loss_func = nn.CrossEntropyLoss()
-136 self.epochs = configs.epochs
-137 self.grad_norm_clip = configs.grad_norm_clip61@dataclasses.dataclass
+62class Configs:Set tracker configurations
+140 tracker.set_scalar("loss.*", True)66 d_model: int = 512
+67 seq_len: int = 128
+68 batch_size: int = 32
+69 n_layers: int = 6
+70 n_heads: int = 8
+71 dropout: float = 0.1
+72 d_ff: int = 2048
+73 glu_variant: str = 'GLU'
+74 epochs: int = 5
+75 grad_norm_clip: float = 0.5142 def sample(self):78class TinyShakespeareDataset(Dataset):Starting prompt
+148 prompt = 'It is'83 def __init__(self, seq_len: int):Collect output for printing
+Location of the text file
150 log = [(prompt, Text.subtle)]85 path = lab.get_data_path() / 'tiny_shakespeare.txt'Sample 25 tokens
+Download the file
152 for i in monit.iterate('Sample', 25):87 download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)Tokenize the prompt
+Read the downloaded file
154 data = self.dataset.text_to_i(prompt).unsqueeze(-1)
-155 data = data.to(self.device)89 with open(str(path), 'r') as f:
+90 text = f.read()Get the model output
+Extract the characters
157 output = self.model(data)93 chars = list(set(text))Get the model prediction (greedy)
+Character to id (integer) map
159 output = output.argmax(dim=-1).squeeze()95 self.stoi = {c: i for i, c in enumerate(chars)}Add the prediction to prompt
+Id to character map
161 prompt += self.dataset.itos[output[-1].item()]97 self.itos = {i: c for i, c in enumerate(chars)}Add the prediction for logging
+Length of a training sample
163 log += [(self.dataset.itos[output[-1].item()], Text.value)]99 self.seq_len = seq_lenPrint the sampled output
+Data in the form of a tensor of ids
166 logger.log(log)101 self.data = self.text_to_i(text)Transform the text into a tensor of ids
168 def train(self):
-169 for _ in monit.loop(self.epochs):
-170 for i, batch in monit.enum('Train', self.dataloader):103 def text_to_i(self, text: str):Move data to the device
+172 data, target = batch[0].to(self.device), batch[1].to(self.device)
-173
-174 tracker.add_global_step(data.shape[0] * data.shape[1])
-175
-176 self.model.train()
-177 output = self.model(data)107 return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)Calculate and log loss
+Number of samples in the dataset.
+This will read the dataset seq_len times in a single epoch.
180 loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
-181 tracker.add("loss.train", loss)109 def __len__(self):Calculate gradients
+184 loss.backward()115 return len(self.data) - self.seq_len - 1186 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)117 def __getitem__(self, idx):Take optimizer step
+188 self.optimizer.step()121 return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]190 if (i + 1) % 100 == 0:
-191 tracker.add('model', self.model)124class Trainer:Clear the gradients
+193 self.optimizer.zero_grad()
-194
-195 if (i + 1) % 100 == 0:
-196 self.model.eval()
-197 with torch.no_grad():
-198 self.sample()129 def __init__(self, configs: Configs):Save the tracked metrics
+Get the device
201 if (i + 1) % 10 == 0:
-202 tracker.save()
-203
-204 experiment.save_checkpoint()131 self.device = torch.device('cpu')
+132 if torch.cuda.is_available():
+133 self.device = torch.device('cuda:0')Initialize the dataset
207def main():135 self.dataset = TinyShakespeareDataset(configs.seq_len)Create experiment
+Initialize the dataloader
209 experiment.create(name="glu_variants")137 self.dataloader = DataLoader(self.dataset,
+138 batch_size=configs.batch_size,
+139 collate_fn=transpose_batch,
+140 shuffle=True)Create configs
+FFN with Gated Linear Unit + +
211 configs = Configs()144 if configs.glu_variant == 'GLU':
+145 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)Load configurations
+FFN with Bilinear hidden layer + +
213 experiment.configs(dataclasses.asdict(configs))
-214
-215 trainer = Trainer(configs)
-216 experiment.add_pytorch_models({'model': trainer.model})148 elif configs.glu_variant == 'Bilinear':
+149 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)Start the experiment
+FFN with ReLU gate + +
219 with experiment.start():152 elif configs.glu_variant == 'ReGLU':
+153 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)TrainValidConfigs.run
FFN with GELU gate + +
221 trainer.train()
-222
-223
-224if __name__ == '__main__':
-225 main()156 elif configs.glu_variant == 'GEGLU':
+157 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)FFN with Swish gate + +where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
+161 elif configs.glu_variant == 'SwiGLU':
+162 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)FFN with ReLU activation + +
+165 elif configs.glu_variant == 'ReLU':
+166 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())FFN with ReLU activation + +
+169 elif configs.glu_variant == 'GELU':
+170 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
+171 else:
+172 raise ValueError(f'Unknown variant {configs.glu_variant}')Number of different characters
+175 n_chars = len(self.dataset.stoi)Initialize Multi-Head Attention module
+178 mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)Initialize the Transformer Block
+180 transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
+181 feed_forward=ffn, dropout_prob=configs.dropout)Initialize the model with an +embedding layer +(with fixed positional encoding) +transformer encoder and +a linear layer to generate logits.
+187 self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
+188 Encoder(transformer_layer, configs.n_layers),
+189 nn.Linear(configs.d_model, n_chars))Move the model to the current device
+192 self.model.to(self.device)Initialize Noam optimizer
+195 self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)Cross-entropy loss
+198 self.loss_func = nn.CrossEntropyLoss()Number of training epochs;
+*note that our dataset definition repeats the data seq_len times in a single epoch
201 self.epochs = configs.epochsGradient clipping norm
+203 self.grad_norm_clip = configs.grad_norm_clipSet tracker configurations
+206 tracker.set_scalar("loss.*", True)208 def sample(self):Starting prompt
+214 prompt = 'It is'Collect output for printing
+216 log = [(prompt, Text.subtle)]Sample 25 tokens
+218 for i in monit.iterate('Sample', 25):Tokenize the prompt
+220 data = self.dataset.text_to_i(prompt).unsqueeze(-1)
+221 data = data.to(self.device)Get the model output
+223 output = self.model(data)Get the model prediction (greedy)
+225 output = output.argmax(dim=-1).squeeze()Add the prediction to prompt
+227 prompt += self.dataset.itos[output[-1].item()]Add the prediction for logging
+229 log += [(self.dataset.itos[output[-1].item()], Text.value)]Print the sampled output
+232 logger.log(log)234 def train(self):Loop for the given number of epochs
+240 for _ in monit.loop(self.epochs):Iterate over the minibatches
+242 for i, batch in monit.enum('Train', self.dataloader):Move data to the device
+244 data, target = batch[0].to(self.device), batch[1].to(self.device)Set tracker step, as the number of characters trained on
+247 tracker.add_global_step(data.shape[0] * data.shape[1])Set model state to training
+250 self.model.train()Evaluate the model
+252 output = self.model(data)Calculate loss
+255 loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))Log the loss
+257 tracker.add("loss.train", loss)Calculate gradients
+260 loss.backward()Clip gradients
+262 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
+264 self.optimizer.step()Log the model parameters and gradients
+266 if (i + 1) % 100 == 0:
+267 tracker.add('model', self.model)Clear the gradients
+269 self.optimizer.zero_grad()Generate a sample
+272 if (i + 1) % 100 == 0:
+273 self.model.eval()
+274 with torch.no_grad():
+275 self.sample()Save the tracked metrics
+278 if (i + 1) % 10 == 0:
+279 tracker.save()Save the model
+282 experiment.save_checkpoint()285def main():Create experiment
+287 experiment.create(name="glu_variants")Create configs
+289 configs = Configs()Load configurations
+291 experiment.configs(dataclasses.asdict(configs))Create trainer
+294 trainer = Trainer(configs)Set models for training and loading
+296 experiment.add_pytorch_models({'model': trainer.model})Start the experiment
+299 with experiment.start():Train the model
+301 trainer.train()
+302
+303
+304if __name__ == '__main__':
+305 main()