diff --git a/docs/optimizers/noam_lr.png b/docs/optimizers/noam_lr.png new file mode 100644 index 00000000..b8945b93 Binary files /dev/null and b/docs/optimizers/noam_lr.png differ diff --git a/docs/optimizers/radam_r_t.png b/docs/optimizers/radam_r_t.png new file mode 100644 index 00000000..7b77edb9 Binary files /dev/null and b/docs/optimizers/radam_r_t.png differ diff --git a/docs/transformers/configs.html b/docs/transformers/configs.html index c9097bfc..b8e20f21 100644 --- a/docs/transformers/configs.html +++ b/docs/transformers/configs.html @@ -86,11 +86,15 @@
-
+
- +

+

FFN Configurations

+

+

Creates a Position-wise FeedForward Network defined in +feed_forward.py.

21class FeedForwardConfigs(BaseConfigs):
@@ -104,7 +108,7 @@

Position-wise feedforward layer

-
23    ffn: FeedForward
+
31    ffn: FeedForward
@@ -115,7 +119,7 @@

Number of features in the embedding

-
25    d_model: int
+
33    d_model: int
@@ -126,7 +130,7 @@

Number of features in in the hidden layer

-
27    d_ff: int = 2048
+
35    d_ff: int = 2048
@@ -137,7 +141,7 @@

Dropout probability

-
29    dropout: float = 0.1
+
37    dropout: float = 0.1
@@ -148,7 +152,7 @@

Activation in position-wise feedforward layer

-
31    activation: nn.Module = 'ReLU'
+
39    activation: nn.Module = 'ReLU'
@@ -159,7 +163,7 @@

Whether the FFN layer should be gated

-
33    is_gated: bool = False
+
41    is_gated: bool = False
@@ -170,7 +174,7 @@

Whether the first fully connected layer should have a learnable bias

-
35    bias1: bool = True
+
43    bias1: bool = True
@@ -181,7 +185,7 @@

Whether the second fully connected layer should have a learnable bias

-
37    bias2: bool = True
+
45    bias2: bool = True
@@ -192,7 +196,7 @@

Whether the fully connected layer for the gate should have a learnable bias

-
39    bias_gate: bool = False
+
47    bias_gate: bool = False
@@ -203,7 +207,7 @@

Predefined GLU variants

-
41    glu_variant: str = 'none'
+
49    glu_variant: str = 'none'
@@ -211,11 +215,14 @@ -

ReLU activation

+

ReLU activation

+

+ +

-
44@option(FeedForwardConfigs.activation, 'ReLU')
-45def _ffn_activation_relu():
+
52@option(FeedForwardConfigs.activation, 'ReLU')
+53def _ffn_activation_relu():
@@ -226,7 +233,7 @@
-
49    return nn.ReLU()
+
59    return nn.ReLU()
@@ -234,11 +241,14 @@ -

GELU activation

+

GELU activation

+

+ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$

+

It was introduced in paper Gaussian Error Linear Units.

-
52@option(FeedForwardConfigs.activation, 'GELU')
-53def _ffn_activation_gelu():
+
62@option(FeedForwardConfigs.activation, 'GELU')
+63def _ffn_activation_gelu():
@@ -249,7 +259,7 @@
-
57    return nn.GELU()
+
71    return nn.GELU()
@@ -257,11 +267,11 @@ -

Create feedforward layer

+

Initialize a feed forward network

-
60@option(FeedForwardConfigs.ffn, 'default')
-61def _feed_forward(c: FeedForwardConfigs):
+
74@option(FeedForwardConfigs.ffn, 'default')
+75def _feed_forward(c: FeedForwardConfigs):
@@ -272,53 +282,129 @@
-
65    return FeedForward(c.d_model, c.d_ff,
-66                       dropout=c.dropout,
-67                       activation=c.activation,
-68                       is_gated=c.is_gated,
-69                       bias1=c.bias1,
-70                       bias2=c.bias2,
-71                       bias_gate=c.bias_gate)
-72
-73
-74aggregate(FeedForwardConfigs.glu_variant, 'GLU',
-75          (FeedForwardConfigs.is_gated, True),
-76          (FeedForwardConfigs.bias1, False),
-77          (FeedForwardConfigs.bias2, False),
-78          (FeedForwardConfigs.bias_gate, False),
-79          (FeedForwardConfigs.activation, nn.Sigmoid()))
-80
-81aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
-82          (FeedForwardConfigs.is_gated, True),
-83          (FeedForwardConfigs.bias1, False),
-84          (FeedForwardConfigs.bias2, False),
-85          (FeedForwardConfigs.bias_gate, False),
-86          (FeedForwardConfigs.activation, nn.Identity()))
-87aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
-88          (FeedForwardConfigs.is_gated, True),
-89          (FeedForwardConfigs.bias1, False),
-90          (FeedForwardConfigs.bias2, False),
-91          (FeedForwardConfigs.bias_gate, False),
-92          (FeedForwardConfigs.activation, nn.ReLU()))
-93aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
-94          (FeedForwardConfigs.is_gated, True),
-95          (FeedForwardConfigs.bias1, False),
-96          (FeedForwardConfigs.bias2, False),
-97          (FeedForwardConfigs.bias_gate, False),
-98          (FeedForwardConfigs.activation, nn.GELU()))
-99aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
-100          (FeedForwardConfigs.is_gated, True),
-101          (FeedForwardConfigs.bias1, False),
-102          (FeedForwardConfigs.bias2, False),
-103          (FeedForwardConfigs.bias_gate, False),
-104          (FeedForwardConfigs.activation, nn.SiLU()))
+
79    return FeedForward(c.d_model, c.d_ff,
+80                       dropout=c.dropout,
+81                       activation=c.activation,
+82                       is_gated=c.is_gated,
+83                       bias1=c.bias1,
+84                       bias2=c.bias2,
+85                       bias_gate=c.bias_gate)
-
+
+

GLU Variants

+

These are variants with gated hidden layers for the FFN +as introduced in paper GLU Variants Improve Transformer. +We have omitted the bias terms as specified in the paper.

+
+
+
+
+
+
+
+ +

FFN with Gated Linear Units

+

+ +

+
+
+
95aggregate(FeedForwardConfigs.glu_variant, 'GLU',
+96          (FeedForwardConfigs.is_gated, True),
+97          (FeedForwardConfigs.bias1, False),
+98          (FeedForwardConfigs.bias2, False),
+99          (FeedForwardConfigs.bias_gate, False),
+100          (FeedForwardConfigs.activation, nn.Sigmoid()))
+
+
+
+
+ +

FFN with Bilinear hidden layer

+

+ +

+
+
+
105aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
+106          (FeedForwardConfigs.is_gated, True),
+107          (FeedForwardConfigs.bias1, False),
+108          (FeedForwardConfigs.bias2, False),
+109          (FeedForwardConfigs.bias_gate, False),
+110          (FeedForwardConfigs.activation, nn.Identity()))
+
+
+
+
+ +

FFN with ReLU gate

+

+ +

+
+
+
115aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
+116          (FeedForwardConfigs.is_gated, True),
+117          (FeedForwardConfigs.bias1, False),
+118          (FeedForwardConfigs.bias2, False),
+119          (FeedForwardConfigs.bias_gate, False),
+120          (FeedForwardConfigs.activation, nn.ReLU()))
+
+
+
+
+ +

FFN with GELU gate

+

+ +

+
+
+
125aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
+126          (FeedForwardConfigs.is_gated, True),
+127          (FeedForwardConfigs.bias1, False),
+128          (FeedForwardConfigs.bias2, False),
+129          (FeedForwardConfigs.bias_gate, False),
+130          (FeedForwardConfigs.activation, nn.GELU()))
+
+
+
+
+ +

FFN with Swish gate

+

+ +where $\text{Swish}_\beta(x) = x \sigma(\beta x)$

+
+
+
136aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
+137          (FeedForwardConfigs.is_gated, True),
+138          (FeedForwardConfigs.bias1, False),
+139          (FeedForwardConfigs.bias2, False),
+140          (FeedForwardConfigs.bias_gate, False),
+141          (FeedForwardConfigs.activation, nn.SiLU()))
+
+
+
+
+

Transformer Configurations

@@ -328,73 +414,7 @@ These are lazy loaded and therefore only the necessary modules are calculated.

-
107class TransformerConfigs(BaseConfigs):
-
-
-
-
- -

Number of attention heads

-
-
-
119    n_heads: int = 8
-
-
-
-
- -

Transformer embedding size

-
-
-
121    d_model: int = 512
-
-
-
-
- -

Number of layers

-
-
-
123    n_layers: int = 6
-
-
-
-
- -

Dropout probability

-
-
-
125    dropout: float = 0.1
-
-
-
-
- -

Number of tokens in the source vocabulary (for token embeddings)

-
-
-
127    n_src_vocab: int
-
-
-
-
- -

Number of tokens in the target vocabulary (to generate logits for prediction)

-
-
-
129    n_tgt_vocab: int
+
144class TransformerConfigs(BaseConfigs):
@@ -402,10 +422,10 @@ are calculated.

-

The encoder self attention

+

Number of attention heads

-
132    encoder_attn: MultiHeadAttention = 'mha'
+
156    n_heads: int = 8
@@ -413,10 +433,10 @@ are calculated.

-

The decoder self attention

+

Transformer embedding size

-
134    decoder_attn: MultiHeadAttention = 'mha'
+
158    d_model: int = 512
@@ -424,10 +444,10 @@ are calculated.

-

The decoder memory attention

+

Number of layers

-
136    decoder_mem_attn: MultiHeadAttention = 'mha'
+
160    n_layers: int = 6
@@ -435,10 +455,10 @@ are calculated.

-

Configurable Feedforward Layer

+

Dropout probability

-
139    ffn: FeedForwardConfigs
+
162    dropout: float = 0.1
@@ -446,10 +466,10 @@ are calculated.

-

Encoder layer

+

Number of tokens in the source vocabulary (for token embeddings)

-
142    encoder_layer: TransformerLayer = 'default'
+
164    n_src_vocab: int
@@ -457,10 +477,10 @@ are calculated.

-

Decoder layer

+

Number of tokens in the target vocabulary (to generate logits for prediction)

-
144    decoder_layer: TransformerLayer = 'default'
+
166    n_tgt_vocab: int
@@ -468,10 +488,10 @@ are calculated.

-

Encoder consisting of multiple encoder layers

+

The encoder self attention

-
147    encoder: Encoder = 'default'
+
169    encoder_attn: MultiHeadAttention = 'mha'
@@ -479,10 +499,10 @@ are calculated.

-

Encoder consisting of multiple decoder layers

+

The decoder self attention

-
149    decoder: Decoder = 'default'
+
171    decoder_attn: MultiHeadAttention = 'mha'
@@ -490,10 +510,10 @@ are calculated.

-

Embedding layer for source

+

The decoder memory attention

-
152    src_embed: Module = 'fixed_pos'
+
173    decoder_mem_attn: MultiHeadAttention = 'mha'
@@ -501,10 +521,10 @@ are calculated.

-

Embedding layer for target (for decoder)

+

Configurable Feedforward Layer

-
154    tgt_embed: Module = 'fixed_pos'
+
176    ffn: FeedForwardConfigs
@@ -512,10 +532,10 @@ are calculated.

-

Logit generator for prediction

+

Encoder layer

-
157    generator: Generator = 'default'
+
179    encoder_layer: TransformerLayer = 'default'
@@ -523,10 +543,10 @@ are calculated.

-

Encoder-decoder

+

Decoder layer

-
160    encoder_decoder: EncoderDecoder
+
181    decoder_layer: TransformerLayer = 'default'
@@ -534,16 +554,10 @@ are calculated.

-

Multi-head Attention

+

Encoder consisting of multiple encoder layers

-
164def _mha(c: TransformerConfigs):
-165    return MultiHeadAttention(c.n_heads, c.d_model)
-166
-167
-168calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
-169calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
-170calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
+
184    encoder: Encoder = 'default'
@@ -551,29 +565,21 @@ are calculated.

-

Relative Multi-head Attention

+

Encoder consisting of multiple decoder layers

-
174def _relative_mha(c: TransformerConfigs):
-175    from .relative_mha import RelativeMultiHeadAttention
-176    return RelativeMultiHeadAttention(c.n_heads, c.d_model)
-177
-178
-179calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
-180calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
-181calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)
+
186    decoder: Decoder = 'default'
-
+
-

Create feedforward layer configurations

+

Embedding layer for source

-
184@option(TransformerConfigs.ffn, 'default')
-185def _feed_forward(c: TransformerConfigs):
+
189    src_embed: Module = 'fixed_pos'
@@ -581,25 +587,21 @@ are calculated.

- +

Embedding layer for target (for decoder)

-
189    conf = FeedForwardConfigs()
-190    conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
-191    conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
-192    return conf
+
191    tgt_embed: Module = 'fixed_pos'
-
+
-

Encoder layer

+

Logit generator for prediction

-
195@option(TransformerConfigs.encoder_layer, 'default')
-196def _encoder_layer(c: TransformerConfigs):
+
194    generator: Generator = 'default'
@@ -607,24 +609,27 @@ are calculated.

- +

Encoder-decoder

-
200    return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
-201                            src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
-202                            dropout_prob=c.dropout)
+
197    encoder_decoder: EncoderDecoder
-
+
-

Decoder layer

+

Multi-head Attention

-
205@option(TransformerConfigs.decoder_layer, 'default')
-206def _decoder_layer(c: TransformerConfigs):
+
201def _mha(c: TransformerConfigs):
+202    return MultiHeadAttention(c.n_heads, c.d_model)
+203
+204
+205calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
+206calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
+207calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
@@ -632,12 +637,17 @@ are calculated.

- +

Relative Multi-head Attention

-
210    return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
-211                            src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
-212                            dropout_prob=c.dropout)
+
211def _relative_mha(c: TransformerConfigs):
+212    from .relative_mha import RelativeMultiHeadAttention
+213    return RelativeMultiHeadAttention(c.n_heads, c.d_model)
+214
+215
+216calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
+217calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
+218calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)
@@ -645,11 +655,11 @@ are calculated.

-

Encoder

+

Create feedforward layer configurations

-
215@option(TransformerConfigs.encoder, 'default')
-216def _encoder(c: TransformerConfigs):
+
221@option(TransformerConfigs.ffn, 'default')
+222def _feed_forward(c: TransformerConfigs):
@@ -660,7 +670,10 @@ are calculated.

-
220    return Encoder(c.encoder_layer, c.n_layers)
+
226    conf = FeedForwardConfigs()
+227    conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
+228    conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
+229    return conf
@@ -668,11 +681,11 @@ are calculated.

-

Decoder

+

Encoder layer

-
223@option(TransformerConfigs.decoder, 'default')
-224def _decoder(c: TransformerConfigs):
+
232@option(TransformerConfigs.encoder_layer, 'default')
+233def _encoder_layer(c: TransformerConfigs):
@@ -683,7 +696,9 @@ are calculated.

-
228    return Decoder(c.decoder_layer, c.n_layers)
+
237    return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
+238                            src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
+239                            dropout_prob=c.dropout)
@@ -691,11 +706,11 @@ are calculated.

-

Logit generator

+

Decoder layer

-
231@option(TransformerConfigs.generator, 'default')
-232def _generator(c: TransformerConfigs):
+
242@option(TransformerConfigs.decoder_layer, 'default')
+243def _decoder_layer(c: TransformerConfigs):
@@ -706,7 +721,9 @@ are calculated.

-
236    return Generator(c.n_tgt_vocab, c.d_model)
+
247    return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
+248                            src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
+249                            dropout_prob=c.dropout)
@@ -714,12 +731,11 @@ are calculated.

-

Positional Embeddings

-

Source embedding with fixed positional encodings

+

Encoder

-
240@option(TransformerConfigs.src_embed, 'fixed_pos')
-241def _src_embed_with_positional(c: TransformerConfigs):
+
252@option(TransformerConfigs.encoder, 'default')
+253def _encoder(c: TransformerConfigs):
@@ -730,7 +746,7 @@ are calculated.

-
245    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
+
257    return Encoder(c.encoder_layer, c.n_layers)
@@ -738,11 +754,11 @@ are calculated.

-

Target embedding with fixed positional encodings

+

Decoder

-
248@option(TransformerConfigs.tgt_embed, 'fixed_pos')
-249def _tgt_embed_with_positional(c: TransformerConfigs):
+
260@option(TransformerConfigs.decoder, 'default')
+261def _decoder(c: TransformerConfigs):
@@ -753,7 +769,7 @@ are calculated.

-
253    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
+
265    return Decoder(c.decoder_layer, c.n_layers)
@@ -761,12 +777,11 @@ are calculated.

-

Learned Positional Embeddings

-

Source embedding with learned positional encodings

+

Logit generator

-
257@option(TransformerConfigs.src_embed, 'learned_pos')
-258def _src_embed_with_learned_positional(c: TransformerConfigs):
+
268@option(TransformerConfigs.generator, 'default')
+269def _generator(c: TransformerConfigs):
@@ -777,7 +792,7 @@ are calculated.

-
262    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
+
273    return Generator(c.n_tgt_vocab, c.d_model)
@@ -785,11 +800,12 @@ are calculated.

-

Target embedding with learned positional encodings

+

Fixed Positional Embeddings

+

Source embedding with fixed positional encodings

-
265@option(TransformerConfigs.tgt_embed, 'learned_pos')
-266def _tgt_embed_with_learned_positional(c: TransformerConfigs):
+
277@option(TransformerConfigs.src_embed, 'fixed_pos')
+278def _src_embed_with_positional(c: TransformerConfigs):
@@ -800,7 +816,7 @@ are calculated.

-
270    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
+
282    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
@@ -808,12 +824,11 @@ are calculated.

-

No Positional Embeddings

-

Source embedding without positional encodings

+

Target embedding with fixed positional encodings

-
274@option(TransformerConfigs.src_embed, 'no_pos')
-275def _src_embed_without_positional(c: TransformerConfigs):
+
285@option(TransformerConfigs.tgt_embed, 'fixed_pos')
+286def _tgt_embed_with_positional(c: TransformerConfigs):
@@ -824,25 +839,96 @@ are calculated.

-
279    return nn.Embedding(c.n_src_vocab, c.d_model)
+
290    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
-
+
+

Learned Positional Embeddings

+

Source embedding with learned positional encodings

+
+
+
294@option(TransformerConfigs.src_embed, 'learned_pos')
+295def _src_embed_with_learned_positional(c: TransformerConfigs):
+
+
+
+
+
-
282@option(TransformerConfigs.tgt_embed, 'no_pos')
-283def _tgt_embed_without_positional(c: TransformerConfigs):
-284    return nn.Embedding(c.n_tgt_vocab, c.d_model)
-285
-286
-287@option(TransformerConfigs.encoder_decoder, 'default')
-288def _encoder_decoder(c: TransformerConfigs):
-289    return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)
+
299    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
+
+
+
+
+ +

Target embedding with learned positional encodings

+
+
+
302@option(TransformerConfigs.tgt_embed, 'learned_pos')
+303def _tgt_embed_with_learned_positional(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
307    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
+
+
+
+
+ +

No Positional Embeddings

+

Source embedding without positional encodings

+
+
+
311@option(TransformerConfigs.src_embed, 'no_pos')
+312def _src_embed_without_positional(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
316    return nn.Embedding(c.n_src_vocab, c.d_model)
+
+
+
+
+ + +
+
+
319@option(TransformerConfigs.tgt_embed, 'no_pos')
+320def _tgt_embed_without_positional(c: TransformerConfigs):
+321    return nn.Embedding(c.n_tgt_vocab, c.d_model)
+322
+323
+324@option(TransformerConfigs.encoder_decoder, 'default')
+325def _encoder_decoder(c: TransformerConfigs):
+326    return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)
diff --git a/docs/transformers/feed_forward.html b/docs/transformers/feed_forward.html index 053e69b2..e263236f 100644 --- a/docs/transformers/feed_forward.html +++ b/docs/transformers/feed_forward.html @@ -84,12 +84,20 @@ where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters.

Sometimes the GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$

+

Gated Linear Units

+

This is a generic implementation that supports different variants including +Gated Linear Units (GLU). +We have also implemented experiments on these:

+
-
26import torch
-27from torch import nn as nn
-28
-29from labml_helpers.module import Module
+
35import torch
+36from torch import nn as nn
+37
+38from labml_helpers.module import Module
@@ -97,10 +105,10 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. -

Position-wise feed-forward network (FFN) module

+

FFN module

-
32class FeedForward(Module):
+
41class FeedForward(Module):
@@ -119,13 +127,13 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
-
37    def __init__(self, d_model: int, d_ff: int,
-38                 dropout: float = 0.1,
-39                 activation=nn.ReLU(),
-40                 is_gated: bool = False,
-41                 bias1: bool = True,
-42                 bias2: bool = True,
-43                 bias_gate: bool = True):
+
46    def __init__(self, d_model: int, d_ff: int,
+47                 dropout: float = 0.1,
+48                 activation=nn.ReLU(),
+49                 is_gated: bool = False,
+50                 bias1: bool = True,
+51                 bias2: bool = True,
+52                 bias_gate: bool = True):
@@ -136,14 +144,7 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
-
53        super().__init__()
-54        self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
-55        self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
-56        self.dropout = nn.Dropout(dropout)
-57        self.activation = activation
-58        self.is_gated = is_gated
-59        if is_gated:
-60            self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)
+
62        super().__init__()
@@ -151,17 +152,136 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. +

Layer one parameterized by weight $W_1$ and bias $b_1$

+
+
+
64        self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
+
+ +
+
+ +

Layer one parameterized by weight $W_1$ and bias $b_1$

+
+
+
66        self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
+
+
+
+
+ +

Hidden layer dropout

+
+
+
68        self.dropout = nn.Dropout(dropout)
+
+
+
+
+ +

Activation function $f$

+
+
+
70        self.activation = activation
+
+
+
+
+ +

Whether there is a gate

+
+
+
72        self.is_gated = is_gated
+73        if is_gated:
+
+
+
+
+ +

If there is a gate the linear layer to transform inputs to +be multiplied by the gate, parameterized by weight $V$ and bias $c$

+
+
+
76            self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)
+
+
+
+
+
-
62    def __call__(self, x: torch.Tensor):
-63        g = self.activation(self.layer1(x))
-64        if self.is_gated:
-65            x = g * self.linear_v(x)
-66        else:
-67            x = g
-68        x = self.dropout(x)
-69        return self.layer2(x)
+
78    def __call__(self, x: torch.Tensor):
+
+
+
+
+ +

$f(x W_1 + b_1)$

+
+
+
80        g = self.activation(self.layer1(x))
+
+
+
+
+ +

If gated, $f(x W_1 + b_1) \otimes (x V + b) $

+
+
+
82        if self.is_gated:
+83            x = g * self.linear_v(x)
+
+
+
+
+ +

Otherwise

+
+
+
85        else:
+86            x = g
+
+
+
+
+ +

Apply dropout

+
+
+
88        x = self.dropout(x)
+
+
+
+
+ +

$(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$ +depending on whether it is gated

+
+
+
91        return self.layer2(x)
diff --git a/docs/transformers/glu_variants/experiment.html b/docs/transformers/glu_variants/experiment.html index 8ab55ca6..e73e4151 100644 --- a/docs/transformers/glu_variants/experiment.html +++ b/docs/transformers/glu_variants/experiment.html @@ -71,19 +71,21 @@ -

Train Autoregressive Transformer

-

This trains a simple transformer model for auto-regression.

+

Gated Linear Units and Variants

+

This trains a simple transformer model for auto-regression. +We try different variants for the position-wise feedforward network. +The reusable & configurable are defined in configs.py.

-
14import torch
-15from labml import experiment
-16from labml.configs import option
-17from labml.utils.pytorch import get_modules
-18from labml_helpers.module import Module
-19
-20from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
-21from labml_nn.transformers import Encoder, Generator, TransformerConfigs
-22from labml_nn.transformers.utils import subsequent_mask
+
16import torch
+17from labml import experiment
+18from labml.configs import option
+19from labml.utils.pytorch import get_modules
+20from labml_helpers.module import Module
+21
+22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+23from labml_nn.transformers import Encoder, Generator, TransformerConfigs
+24from labml_nn.transformers.utils import subsequent_mask
@@ -94,7 +96,7 @@

Auto regressive model

-
25class AutoregressiveModel(Module):
+
27class AutoregressiveModel(Module):
@@ -105,8 +107,8 @@
-
30    def __init__(self, src_embed: Module, encoder: Encoder, generator: Generator):
-31        super().__init__()
+
32    def __init__(self, src_embed: Module, encoder: Encoder, generator: Generator):
+33        super().__init__()
@@ -117,7 +119,7 @@

Token embedding module

-
33        self.src_embed = src_embed
+
35        self.src_embed = src_embed
@@ -128,7 +130,7 @@

Transformer based encoder

-
35        self.encoder = encoder
+
37        self.encoder = encoder
@@ -140,7 +142,7 @@ this give logits of the the next token

-
38        self.generator = generator
+
40        self.generator = generator
@@ -151,7 +153,7 @@ this give logits of the the next token

This will be initialized on the first call

-
40        self.src_mask = None
+
42        self.src_mask = None
@@ -162,7 +164,7 @@ this give logits of the the next token

-
42    def __call__(self, src: torch.Tensor):
+
44    def __call__(self, src: torch.Tensor):
@@ -173,8 +175,8 @@ this give logits of the the next token

Create subsequent mask, so that the transformer can only pay attention to past tokens.

-
44        if self.src_mask is None or self.src_mask.size(0) != len(src):
-45            self.src_mask = subsequent_mask(len(src)).to(src.device)
+
46        if self.src_mask is None or self.src_mask.size(0) != len(src):
+47            self.src_mask = subsequent_mask(len(src)).to(src.device)
@@ -185,7 +187,7 @@ this give logits of the the next token

Embed the tokens (src) and run it through the the transformer

-
47        res = self.encoder(self.src_embed(src), self.src_mask)
+
49        res = self.encoder(self.src_embed(src), self.src_mask)
@@ -196,7 +198,7 @@ this give logits of the the next token

Generate logits of the next token

-
49        return self.generator(res), None
+
51        return self.generator(res), None
@@ -208,7 +210,7 @@ this give logits of the the next token

The default configs can and will be over-ridden when we start the experiment

-
52class Configs(NLPAutoRegressionConfigs):
+
54class Configs(NLPAutoRegressionConfigs):
@@ -219,8 +221,8 @@ this give logits of the the next token

-
59    transformer: TransformerConfigs
-60    model: AutoregressiveModel
+
61    transformer: TransformerConfigs
+62    model: AutoregressiveModel
@@ -231,8 +233,8 @@ this give logits of the the next token

Initialize the auto-regressive model

-
63@option(Configs.model)
-64def autoregressive_model(c: Configs):
+
65@option(Configs.model)
+66def autoregressive_model(c: Configs):
@@ -243,8 +245,8 @@ this give logits of the the next token

-
68    m = AutoregressiveModel(c.transformer.src_embed, c.transformer.encoder, c.transformer.generator)
-69    return m.to(c.device)
+
70    m = AutoregressiveModel(c.transformer.src_embed, c.transformer.encoder, c.transformer.generator)
+71    return m.to(c.device)
@@ -252,11 +254,11 @@ this give logits of the the next token

-

Initialize the configurable transformer encoder for our autoregressive model

+

Initialize the configurable transformer encoder for our autoregressive model.

-
72@option(Configs.transformer)
-73def transformer_c(c: Configs):
+
74@option(Configs.transformer)
+75def transformer_c(c: Configs):
@@ -267,11 +269,11 @@ this give logits of the the next token

-
77    tc = TransformerConfigs()
-78    tc.n_src_vocab = c.n_tokens
-79    tc.n_tgt_vocab = c.n_tokens
-80
-81    return tc
+
79    tc = TransformerConfigs()
+80    tc.n_src_vocab = c.n_tokens
+81    tc.n_tgt_vocab = c.n_tokens
+82
+83    return tc
@@ -282,7 +284,7 @@ this give logits of the the next token

-
84def main():
+
86def main():
@@ -293,7 +295,7 @@ this give logits of the the next token

Create experiment

-
86    experiment.create(name="glu_variants")
+
88    experiment.create(name="glu_variants")
@@ -304,7 +306,7 @@ this give logits of the the next token

Create configs

-
88    conf = Configs()
+
90    conf = Configs()
@@ -315,7 +317,7 @@ this give logits of the the next token

Load configurations

-
90    experiment.configs(conf,
+
92    experiment.configs(conf,
@@ -326,19 +328,19 @@ this give logits of the the next token

A dictionary of configurations to override

-
92                       {'tokenizer': 'character',
-93                        'prompt_separator': '',
-94                        'prompt': 'It is ',
-95                        'text': 'tiny_shakespeare',
-96
-97                        'optimizer.optimizer': 'Noam',
-98                        'optimizer.learning_rate': 1.,
-99                        'optimizer.d_model': 256,
-100
-101                        'seq_len': 1024,
-102                        'epochs': 128,
-103                        'batch_size': 6,
-104                        'inner_iterations': 10,
+
94                       {'tokenizer': 'character',
+95                        'prompt_separator': '',
+96                        'prompt': 'It is ',
+97                        'text': 'tiny_shakespeare',
+98
+99                        'optimizer.optimizer': 'Noam',
+100                        'optimizer.learning_rate': 1.,
+101                        'optimizer.d_model': 256,
+102
+103                        'seq_len': 1024,
+104                        'epochs': 128,
+105                        'batch_size': 6,
+106                        'inner_iterations': 10,
@@ -347,9 +349,11 @@ this give logits of the the next token

#

GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU

+

These are defined in the configurable FFN +implementation

-
107                        'transformer.ffn.glu_variant': 'Bilinear',
+
112                        'transformer.ffn.glu_variant': 'Bilinear',
@@ -360,10 +364,10 @@ this give logits of the the next token

Transformer configurations

-
110                        'transformer.d_model': 256,
-111                        'transformer.ffn.d_ff': 1024,
-112                        'transformer.n_heads': 8,
-113                        'transformer.n_layers': 6})
+
115                        'transformer.d_model': 256,
+116                        'transformer.ffn.d_ff': 1024,
+117                        'transformer.n_heads': 8,
+118                        'transformer.n_layers': 6})
@@ -374,7 +378,7 @@ this give logits of the the next token

This is needed to initialize models

-
116    conf.n_tokens = conf.text.n_tokens
+
121    conf.n_tokens = conf.text.n_tokens
@@ -385,7 +389,7 @@ this give logits of the the next token

Set models for saving and loading

-
119    experiment.add_pytorch_models(get_modules(conf))
+
124    experiment.add_pytorch_models(get_modules(conf))
@@ -396,7 +400,7 @@ this give logits of the the next token

Start the experiment

-
122    with experiment.start():
+
127    with experiment.start():
@@ -407,11 +411,11 @@ this give logits of the the next token

TrainValidConfigs.run

-
124        conf.run()
-125
-126
-127if __name__ == '__main__':
-128    main()
+
129        conf.run()
+130
+131
+132if __name__ == '__main__':
+133    main()
diff --git a/docs/transformers/glu_variants/simple.html b/docs/transformers/glu_variants/simple.html index 46e54fae..207f5c42 100644 --- a/docs/transformers/glu_variants/simple.html +++ b/docs/transformers/glu_variants/simple.html @@ -71,25 +71,28 @@ -

Train Autoregressive Transformer

-

This trains a simple transformer model for auto-regression.

+

Gated Linear Units and Variants

+

This trains a simple transformer model for auto-regression. +We try different variants for the position-wise feedforward network.

+

This is a simpler implementation that doesn’t use labml.configs module. +We decided to write a simpler implementation to make it easier readers who are not familiar.

-
13import dataclasses
-14
-15import torch
-16from torch import nn
-17from torch.utils.data import Dataset, DataLoader
+                
17import dataclasses
 18
-19from labml import experiment, lab, tracker, monit, logger
-20from labml.logger import Text
-21from labml.utils.download import download_file
-22from labml_nn.experiments.nlp_autoregression import transpose_batch
-23from labml_nn.optimizers.noam import Noam
-24from labml_nn.transformers import Encoder, MultiHeadAttention
-25from labml_nn.transformers.feed_forward import FeedForward
-26from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
-27from labml_nn.transformers.utils import subsequent_mask
+19import torch +20from torch import nn +21from torch.utils.data import Dataset, DataLoader +22 +23from labml import experiment, lab, tracker, monit, logger +24from labml.logger import Text +25from labml.utils.download import download_file +26from labml_nn.experiments.nlp_autoregression import transpose_batch +27from labml_nn.optimizers.noam import Noam +28from labml_nn.transformers import Encoder, MultiHeadAttention +29from labml_nn.transformers.feed_forward import FeedForward +30from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer +31from labml_nn.transformers.utils import subsequent_mask
@@ -100,7 +103,7 @@

Auto regressive model

-
30class AutoregressiveModel(nn.Module):
+
34class AutoregressiveModel(nn.Module):
@@ -111,8 +114,8 @@
-
35    def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
-36        super().__init__()
+
39    def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
+40        super().__init__()
@@ -123,7 +126,7 @@

Token embedding module

-
38        self.src_embed = src_embed
+
42        self.src_embed = src_embed
@@ -134,7 +137,7 @@

Transformer based encoder

-
40        self.encoder = encoder
+
44        self.encoder = encoder
@@ -146,7 +149,7 @@ this give logits of the the next token

-
43        self.generator = generator
+
47        self.generator = generator
@@ -157,7 +160,7 @@ this give logits of the the next token

This will be initialized on the first call

-
45        self.src_mask = None
+
49        self.src_mask = None
@@ -168,7 +171,7 @@ this give logits of the the next token

-
47    def __call__(self, src: torch.Tensor):
+
51    def __call__(self, src: torch.Tensor):
@@ -179,8 +182,8 @@ this give logits of the the next token

Create subsequent mask, so that the transformer can only pay attention to past tokens.

-
49        if self.src_mask is None or self.src_mask.size(0) != len(src):
-50            self.src_mask = subsequent_mask(len(src)).to(src.device)
+
53        if self.src_mask is None or self.src_mask.size(0) != len(src):
+54            self.src_mask = subsequent_mask(len(src)).to(src.device)
@@ -191,7 +194,7 @@ this give logits of the the next token

Embed the tokens (src) and run it through the the transformer

-
52        res = self.encoder(self.src_embed(src), self.src_mask)
+
56        res = self.encoder(self.src_embed(src), self.src_mask)
@@ -202,98 +205,19 @@ this give logits of the the next token

Generate logits of the next token

-
54        return self.generator(res)
+
58        return self.generator(res)
-
+
- +

Configurations

-
57@dataclasses.dataclass
-58class Configs:
-59    d_model: int = 512
-60    seq_len: int = 128
-61    batch_size: int = 32
-62    n_layers: int = 6
-63    n_heads: int = 8
-64    dropout: float = 0.1
-65    d_ff: int = 2048
-66    glu_variant: str = 'GLU'
-67    epochs: int = 5
-68    grad_norm_clip: float = 0.5
-69
-70
-71class TinyShakespeareDataset(Dataset):
-72    def __init__(self, seq_len: int):
-73        path = lab.get_data_path() / 'tiny_shakespeare.txt'
-74        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
-75        with open(str(path), 'r') as f:
-76            text = f.read()
-77
-78        chars = list(set(text))
-79        self.stoi = {c: i for i, c in enumerate(chars)}
-80        self.itos = {i: c for i, c in enumerate(chars)}
-81        self.seq_len = seq_len
-82        self.data = self.text_to_i(text)
-83
-84    def text_to_i(self, text: str):
-85        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
-86
-87    def __len__(self):
-88        return len(self.data) - self.seq_len - 1
-89
-90    def __getitem__(self, idx):
-91        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
-92
-93
-94class Trainer:
-95    def __init__(self, configs: Configs):
-96        self.device = torch.device('cpu')
-97        if torch.cuda.is_available():
-98            self.device = torch.device('cuda:0')
-99        self.dataset = TinyShakespeareDataset(configs.seq_len)
-100        self.dataloader = DataLoader(self.dataset, batch_size=configs.batch_size, collate_fn=transpose_batch,
-101                                     shuffle=True)
-102
-103        if configs.glu_variant == 'GLU':
-104            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
-105        elif configs.glu_variant == 'Bilinear':
-106            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
-107        elif configs.glu_variant == 'ReGLU':
-108            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
-109        elif configs.glu_variant == 'GEGLU':
-110            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
-111        elif configs.glu_variant == 'SwiGLU':
-112            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
-113        elif configs.glu_variant == 'ReLU':
-114            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
-115        elif configs.glu_variant == 'GELU':
-116            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
-117        else:
-118            raise ValueError(f'Unknown variant {configs.glu_variant}')
-119
-120        n_chars = len(self.dataset.stoi)
-121        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
-122                                         Encoder(TransformerLayer(
-123                                             d_model=configs.d_model,
-124                                             self_attn=MultiHeadAttention(configs.n_heads, configs.d_model,
-125                                                                          configs.dropout),
-126                                             src_attn=None,
-127                                             feed_forward=ffn,
-128                                             dropout_prob=configs.dropout
-129                                         ), configs.n_layers),
-130                                         nn.Linear(configs.d_model, n_chars))
-131        self.model.to(self.device)
-132
-133        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
-134
-135        self.loss_func = nn.CrossEntropyLoss()
-136        self.epochs = configs.epochs
-137        self.grad_norm_clip = configs.grad_norm_clip
+
61@dataclasses.dataclass
+62class Configs:
@@ -301,10 +225,19 @@ this give logits of the the next token

-

Set tracker configurations

+
-
140        tracker.set_scalar("loss.*", True)
+
66    d_model: int = 512
+67    seq_len: int = 128
+68    batch_size: int = 32
+69    n_layers: int = 6
+70    n_heads: int = 8
+71    dropout: float = 0.1
+72    d_ff: int = 2048
+73    glu_variant: str = 'GLU'
+74    epochs: int = 5
+75    grad_norm_clip: float = 0.5
@@ -312,10 +245,10 @@ this give logits of the the next token

-

Sampling function to generate samples periodically while training

+

Tiny Shakespeare Dataset

-
142    def sample(self):
+
78class TinyShakespeareDataset(Dataset):
@@ -323,10 +256,10 @@ this give logits of the the next token

-

Starting prompt

+
-
148        prompt = 'It is'
+
83    def __init__(self, seq_len: int):
@@ -334,10 +267,10 @@ this give logits of the the next token

-

Collect output for printing

+

Location of the text file

-
150        log = [(prompt, Text.subtle)]
+
85        path = lab.get_data_path() / 'tiny_shakespeare.txt'
@@ -345,10 +278,10 @@ this give logits of the the next token

-

Sample 25 tokens

+

Download the file

-
152        for i in monit.iterate('Sample', 25):
+
87        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
@@ -356,11 +289,11 @@ this give logits of the the next token

-

Tokenize the prompt

+

Read the downloaded file

-
154            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
-155            data = data.to(self.device)
+
89        with open(str(path), 'r') as f:
+90            text = f.read()
@@ -368,10 +301,10 @@ this give logits of the the next token

-

Get the model output

+

Extract the characters

-
157            output = self.model(data)
+
93        chars = list(set(text))
@@ -379,10 +312,10 @@ this give logits of the the next token

-

Get the model prediction (greedy)

+

Character to id (integer) map

-
159            output = output.argmax(dim=-1).squeeze()
+
95        self.stoi = {c: i for i, c in enumerate(chars)}
@@ -390,10 +323,10 @@ this give logits of the the next token

-

Add the prediction to prompt

+

Id to character map

-
161            prompt += self.dataset.itos[output[-1].item()]
+
97        self.itos = {i: c for i, c in enumerate(chars)}
@@ -401,10 +334,10 @@ this give logits of the the next token

-

Add the prediction for logging

+

Length of a training sample

-
163            log += [(self.dataset.itos[output[-1].item()], Text.value)]
+
99        self.seq_len = seq_len
@@ -412,23 +345,21 @@ this give logits of the the next token

-

Print the sampled output

+

Data in the form of a tensor of ids

-
166        logger.log(log)
+
101        self.data = self.text_to_i(text)
-
+
- +

Transform the text into a tensor of ids

-
168    def train(self):
-169        for _ in monit.loop(self.epochs):
-170            for i, batch in monit.enum('Train', self.dataloader):
+
103    def text_to_i(self, text: str):
@@ -436,27 +367,22 @@ this give logits of the the next token

-

Move data to the device

+
-
172                data, target = batch[0].to(self.device), batch[1].to(self.device)
-173
-174                tracker.add_global_step(data.shape[0] * data.shape[1])
-175
-176                self.model.train()
-177                output = self.model(data)
+
107        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
-
+
-

Calculate and log loss

+

Number of samples in the dataset.

+

This will read the dataset seq_len times in a single epoch.

-
180                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
-181                tracker.add("loss.train", loss)
+
109    def __len__(self):
@@ -464,21 +390,21 @@ this give logits of the the next token

-

Calculate gradients

+
-
184                loss.backward()
+
115        return len(self.data) - self.seq_len - 1
-
+
-

Clip gradients

+

Return a sample

-
186                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
+
117    def __getitem__(self, idx):
@@ -486,22 +412,21 @@ this give logits of the the next token

-

Take optimizer step

+
-
188                self.optimizer.step()
+
121        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
-
+
-

Log the model parameters and gradients on last batch of every epoch

+

Trainer

-
190                if (i + 1) % 100 == 0:
-191                    tracker.add('model', self.model)
+
124class Trainer:
@@ -509,15 +434,10 @@ this give logits of the the next token

-

Clear the gradients

+
-
193                self.optimizer.zero_grad()
-194
-195                if (i + 1) % 100 == 0:
-196                    self.model.eval()
-197                    with torch.no_grad():
-198                        self.sample()
+
129    def __init__(self, configs: Configs):
@@ -525,13 +445,12 @@ this give logits of the the next token

-

Save the tracked metrics

+

Get the device

-
201                if (i + 1) % 10 == 0:
-202                    tracker.save()
-203
-204            experiment.save_checkpoint()
+
131        self.device = torch.device('cpu')
+132        if torch.cuda.is_available():
+133            self.device = torch.device('cuda:0')
@@ -539,10 +458,10 @@ this give logits of the the next token

- +

Initialize the dataset

-
207def main():
+
135        self.dataset = TinyShakespeareDataset(configs.seq_len)
@@ -550,10 +469,13 @@ this give logits of the the next token

-

Create experiment

+

Initialize the dataloader

-
209    experiment.create(name="glu_variants")
+
137        self.dataloader = DataLoader(self.dataset,
+138                                     batch_size=configs.batch_size,
+139                                     collate_fn=transpose_batch,
+140                                     shuffle=True)
@@ -561,10 +483,13 @@ this give logits of the the next token

-

Create configs

+

FFN with Gated Linear Unit + +

-
211    configs = Configs()
+
144        if configs.glu_variant == 'GLU':
+145            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
@@ -572,13 +497,13 @@ this give logits of the the next token

-

Load configurations

+

FFN with Bilinear hidden layer + +

-
213    experiment.configs(dataclasses.asdict(configs))
-214
-215    trainer = Trainer(configs)
-216    experiment.add_pytorch_models({'model': trainer.model})
+
148        elif configs.glu_variant == 'Bilinear':
+149            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
@@ -586,10 +511,13 @@ this give logits of the the next token

-

Start the experiment

+

FFN with ReLU gate + +

-
219    with experiment.start():
+
152        elif configs.glu_variant == 'ReGLU':
+153            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
@@ -597,14 +525,570 @@ this give logits of the the next token

-

TrainValidConfigs.run

+

FFN with GELU gate + +

-
221        trainer.train()
-222
-223
-224if __name__ == '__main__':
-225    main()
+
156        elif configs.glu_variant == 'GEGLU':
+157            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
+
+ +
+
+ +

FFN with Swish gate + +where $\text{Swish}_\beta(x) = x \sigma(\beta x)$

+
+
+
161        elif configs.glu_variant == 'SwiGLU':
+162            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
+
+
+
+
+ +

FFN with ReLU activation + +

+
+
+
165        elif configs.glu_variant == 'ReLU':
+166            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
+
+
+
+
+ +

FFN with ReLU activation + +

+
+
+
169        elif configs.glu_variant == 'GELU':
+170            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
+171        else:
+172            raise ValueError(f'Unknown variant {configs.glu_variant}')
+
+
+
+
+ +

Number of different characters

+
+
+
175        n_chars = len(self.dataset.stoi)
+
+
+
+
+ +

Initialize Multi-Head Attention module

+
+
+
178        mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)
+
+
+
+
+ +

Initialize the Transformer Block

+
+
+
180        transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
+181                                             feed_forward=ffn, dropout_prob=configs.dropout)
+
+
+
+
+ +

Initialize the model with an +embedding layer +(with fixed positional encoding) +transformer encoder and +a linear layer to generate logits.

+
+
+
187        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
+188                                         Encoder(transformer_layer, configs.n_layers),
+189                                         nn.Linear(configs.d_model, n_chars))
+
+
+
+
+ +

Move the model to the current device

+
+
+
192        self.model.to(self.device)
+
+
+
+
+ +

Initialize Noam optimizer

+
+
+
195        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
+
+
+
+
+ +

Cross-entropy loss

+
+
+
198        self.loss_func = nn.CrossEntropyLoss()
+
+
+
+
+ +

Number of training epochs; +*note that our dataset definition repeats the data seq_len times in a single epoch

+
+
+
201        self.epochs = configs.epochs
+
+
+
+
+ +

Gradient clipping norm

+
+
+
203        self.grad_norm_clip = configs.grad_norm_clip
+
+
+
+
+ +

Set tracker configurations

+
+
+
206        tracker.set_scalar("loss.*", True)
+
+
+
+
+ +

Sampling function to generate samples periodically while training

+
+
+
208    def sample(self):
+
+
+
+
+ +

Starting prompt

+
+
+
214        prompt = 'It is'
+
+
+
+
+ +

Collect output for printing

+
+
+
216        log = [(prompt, Text.subtle)]
+
+
+
+
+ +

Sample 25 tokens

+
+
+
218        for i in monit.iterate('Sample', 25):
+
+
+
+
+ +

Tokenize the prompt

+
+
+
220            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
+221            data = data.to(self.device)
+
+
+
+
+ +

Get the model output

+
+
+
223            output = self.model(data)
+
+
+
+
+ +

Get the model prediction (greedy)

+
+
+
225            output = output.argmax(dim=-1).squeeze()
+
+
+
+
+ +

Add the prediction to prompt

+
+
+
227            prompt += self.dataset.itos[output[-1].item()]
+
+
+
+
+ +

Add the prediction for logging

+
+
+
229            log += [(self.dataset.itos[output[-1].item()], Text.value)]
+
+
+
+
+ +

Print the sampled output

+
+
+
232        logger.log(log)
+
+
+
+
+ +

Train the model

+
+
+
234    def train(self):
+
+
+
+
+ +

Loop for the given number of epochs

+
+
+
240        for _ in monit.loop(self.epochs):
+
+
+
+
+ +

Iterate over the minibatches

+
+
+
242            for i, batch in monit.enum('Train', self.dataloader):
+
+
+
+
+ +

Move data to the device

+
+
+
244                data, target = batch[0].to(self.device), batch[1].to(self.device)
+
+
+
+
+ +

Set tracker step, as the number of characters trained on

+
+
+
247                tracker.add_global_step(data.shape[0] * data.shape[1])
+
+
+
+
+ +

Set model state to training

+
+
+
250                self.model.train()
+
+
+
+
+ +

Evaluate the model

+
+
+
252                output = self.model(data)
+
+
+
+
+ +

Calculate loss

+
+
+
255                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
+
+
+
+
+ +

Log the loss

+
+
+
257                tracker.add("loss.train", loss)
+
+
+
+
+ +

Calculate gradients

+
+
+
260                loss.backward()
+
+
+
+
+ +

Clip gradients

+
+
+
262                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
+
+
+
+
+ +

Take optimizer step

+
+
+
264                self.optimizer.step()
+
+
+
+
+ +

Log the model parameters and gradients

+
+
+
266                if (i + 1) % 100 == 0:
+267                    tracker.add('model', self.model)
+
+
+
+
+ +

Clear the gradients

+
+
+
269                self.optimizer.zero_grad()
+
+
+
+
+ +

Generate a sample

+
+
+
272                if (i + 1) % 100 == 0:
+273                    self.model.eval()
+274                    with torch.no_grad():
+275                        self.sample()
+
+
+
+
+ +

Save the tracked metrics

+
+
+
278                if (i + 1) % 10 == 0:
+279                    tracker.save()
+
+
+
+
+ +

Save the model

+
+
+
282            experiment.save_checkpoint()
+
+
+
+
+ + +
+
+
285def main():
+
+
+
+
+ +

Create experiment

+
+
+
287    experiment.create(name="glu_variants")
+
+
+
+
+ +

Create configs

+
+
+
289    configs = Configs()
+
+
+
+
+ +

Load configurations

+
+
+
291    experiment.configs(dataclasses.asdict(configs))
+
+
+
+
+ +

Create trainer

+
+
+
294    trainer = Trainer(configs)
+
+
+
+
+ +

Set models for training and loading

+
+
+
296    experiment.add_pytorch_models({'model': trainer.model})
+
+
+
+
+ +

Start the experiment

+
+
+
299    with experiment.start():
+
+
+
+
+ +

Train the model

+
+
+
301        trainer.train()
+302
+303
+304if __name__ == '__main__':
+305    main()
diff --git a/labml_nn/transformers/configs.py b/labml_nn/transformers/configs.py index 2afdb37d..e11c984d 100644 --- a/labml_nn/transformers/configs.py +++ b/labml_nn/transformers/configs.py @@ -19,6 +19,14 @@ from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPosit class FeedForwardConfigs(BaseConfigs): + """ + + ## FFN Configurations + + + Creates a Position-wise FeedForward Network defined in + [`feed_forward.py`](feed_forward.html). + """ # Position-wise feedforward layer ffn: FeedForward # Number of features in the embedding @@ -44,7 +52,9 @@ class FeedForwardConfigs(BaseConfigs): @option(FeedForwardConfigs.activation, 'ReLU') def _ffn_activation_relu(): """ - ReLU activation + ### ReLU activation + + $$\max(0, x)$$ """ return nn.ReLU() @@ -52,7 +62,11 @@ def _ffn_activation_relu(): @option(FeedForwardConfigs.activation, 'GELU') def _ffn_activation_gelu(): """ - GELU activation + ### GELU activation + + $$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$ + + It was introduced in paper [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415). """ return nn.GELU() @@ -60,7 +74,7 @@ def _ffn_activation_gelu(): @option(FeedForwardConfigs.ffn, 'default') def _feed_forward(c: FeedForwardConfigs): """ - Create feedforward layer + Initialize a [feed forward network](feed_forward.html) """ return FeedForward(c.d_model, c.d_ff, dropout=c.dropout, @@ -70,7 +84,14 @@ def _feed_forward(c: FeedForwardConfigs): bias2=c.bias2, bias_gate=c.bias_gate) +# ## GLU Variants +# These are variants with gated hidden layers for the FFN +# as introduced in paper [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202). +# We have omitted the bias terms as specified in the paper. +# ### FFN with Gated Linear Units +# +# $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$ aggregate(FeedForwardConfigs.glu_variant, 'GLU', (FeedForwardConfigs.is_gated, True), (FeedForwardConfigs.bias1, False), @@ -78,24 +99,40 @@ aggregate(FeedForwardConfigs.glu_variant, 'GLU', (FeedForwardConfigs.bias_gate, False), (FeedForwardConfigs.activation, nn.Sigmoid())) +# ### FFN with Bilinear hidden layer +# +# $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$ aggregate(FeedForwardConfigs.glu_variant, 'Bilinear', (FeedForwardConfigs.is_gated, True), (FeedForwardConfigs.bias1, False), (FeedForwardConfigs.bias2, False), (FeedForwardConfigs.bias_gate, False), (FeedForwardConfigs.activation, nn.Identity())) + +# ### FFN with ReLU gate +# +# $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$ aggregate(FeedForwardConfigs.glu_variant, 'ReGLU', (FeedForwardConfigs.is_gated, True), (FeedForwardConfigs.bias1, False), (FeedForwardConfigs.bias2, False), (FeedForwardConfigs.bias_gate, False), (FeedForwardConfigs.activation, nn.ReLU())) + +# ### FFN with GELU gate +# +# $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$ aggregate(FeedForwardConfigs.glu_variant, 'GEGLU', (FeedForwardConfigs.is_gated, True), (FeedForwardConfigs.bias1, False), (FeedForwardConfigs.bias2, False), (FeedForwardConfigs.bias_gate, False), (FeedForwardConfigs.activation, nn.GELU())) + +# ### FFN with Swish gate +# +# $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$ +# where $\text{Swish}_\beta(x) = x \sigma(\beta x)$ aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU', (FeedForwardConfigs.is_gated, True), (FeedForwardConfigs.bias1, False), @@ -236,7 +273,7 @@ def _generator(c: TransformerConfigs): return Generator(c.n_tgt_vocab, c.d_model) -# ## Positional Embeddings +# ### Fixed Positional Embeddings @option(TransformerConfigs.src_embed, 'fixed_pos') def _src_embed_with_positional(c: TransformerConfigs): """ @@ -253,7 +290,7 @@ def _tgt_embed_with_positional(c: TransformerConfigs): return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab) -# ## Learned Positional Embeddings +# ### Learned Positional Embeddings @option(TransformerConfigs.src_embed, 'learned_pos') def _src_embed_with_learned_positional(c: TransformerConfigs): """ @@ -270,7 +307,7 @@ def _tgt_embed_with_learned_positional(c: TransformerConfigs): return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab) -# ## No Positional Embeddings +# ### No Positional Embeddings @option(TransformerConfigs.src_embed, 'no_pos') def _src_embed_without_positional(c: TransformerConfigs): """ diff --git a/labml_nn/transformers/feed_forward.py b/labml_nn/transformers/feed_forward.py index a7c92afb..57cedc9b 100644 --- a/labml_nn/transformers/feed_forward.py +++ b/labml_nn/transformers/feed_forward.py @@ -21,6 +21,15 @@ where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters. Sometimes the GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. $$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$ + +### Gated Linear Units + +This is a generic implementation that supports different variants including +[Gated Linear Units](https://arxiv.org/abs/2002.05202) (GLU). +We have also implemented experiments on these: + +* [experiment that uses `labml.configs`](glu_variants/experiment.html) +* [simpler version from scratch](glu_variants/simple.html) """ import torch @@ -31,7 +40,7 @@ from labml_helpers.module import Module class FeedForward(Module): """ - ## Position-wise feed-forward network (FFN) module + ## FFN module """ def __init__(self, d_model: int, d_ff: int, @@ -51,19 +60,32 @@ class FeedForward(Module): * `bias_gate` specified whether the fully connected layer for the gate should have a learnable bias """ super().__init__() + # Layer one parameterized by weight $W_1$ and bias $b_1$ self.layer1 = nn.Linear(d_model, d_ff, bias=bias1) + # Layer one parameterized by weight $W_1$ and bias $b_1$ self.layer2 = nn.Linear(d_ff, d_model, bias=bias2) + # Hidden layer dropout self.dropout = nn.Dropout(dropout) + # Activation function $f$ self.activation = activation + # Whether there is a gate self.is_gated = is_gated if is_gated: + # If there is a gate the linear layer to transform inputs to + # be multiplied by the gate, parameterized by weight $V$ and bias $c$ self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate) def __call__(self, x: torch.Tensor): + # $f(x W_1 + b_1)$ g = self.activation(self.layer1(x)) + # If gated, $f(x W_1 + b_1) \otimes (x V + b) $ if self.is_gated: x = g * self.linear_v(x) + # Otherwise else: x = g + # Apply dropout x = self.dropout(x) + # $(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$ + # depending on whether it is gated return self.layer2(x) diff --git a/labml_nn/transformers/glu_variants/experiment.py b/labml_nn/transformers/glu_variants/experiment.py index b4fb4e47..89c21b70 100644 --- a/labml_nn/transformers/glu_variants/experiment.py +++ b/labml_nn/transformers/glu_variants/experiment.py @@ -6,9 +6,11 @@ summary: > for the position-wise feedforward network (FFN). --- -# Train Autoregressive Transformer +# Gated Linear Units and Variants This trains a simple [transformer](../../) model for auto-regression. +We try different variants for the [position-wise feedforward network](../feed_forward). +The reusable & configurable are defined in [`configs.py`](configs.html). """ import torch @@ -72,7 +74,7 @@ def autoregressive_model(c: Configs): @option(Configs.transformer) def transformer_c(c: Configs): """ - Initialize the configurable transformer encoder for our autoregressive model + Initialize the [configurable transformer](../configs.html) encoder for our autoregressive model. """ tc = TransformerConfigs() tc.n_src_vocab = c.n_tokens @@ -104,6 +106,9 @@ def main(): 'inner_iterations': 10, # GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU + # + # These are defined in the [configurable FFN](../configs.html#FFN) + # implementation 'transformer.ffn.glu_variant': 'Bilinear', # Transformer configurations diff --git a/labml_nn/transformers/glu_variants/simple.py b/labml_nn/transformers/glu_variants/simple.py index 0173cb29..0cf7f494 100644 --- a/labml_nn/transformers/glu_variants/simple.py +++ b/labml_nn/transformers/glu_variants/simple.py @@ -6,9 +6,13 @@ summary: > for the position-wise feedforward network (FFN). --- -# Train Autoregressive Transformer +# Gated Linear Units and Variants This trains a simple [transformer](../../) model for auto-regression. +We try different variants for the [position-wise feedforward network](../feed_forward). + +*This is a simpler implementation that doesn't use [`labml.configs`](experiment.html) module. +We decided to write a simpler implementation to make it easier readers who are not familiar.* """ import dataclasses @@ -56,6 +60,9 @@ class AutoregressiveModel(nn.Module): @dataclasses.dataclass class Configs: + """ + ### Configurations + """ d_model: int = 512 seq_len: int = 128 batch_size: int = 32 @@ -69,71 +76,130 @@ class Configs: class TinyShakespeareDataset(Dataset): + """ + ### Tiny Shakespeare Dataset + """ + def __init__(self, seq_len: int): + # Location of the text file path = lab.get_data_path() / 'tiny_shakespeare.txt' + # Download the file download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path) + # Read the downloaded file with open(str(path), 'r') as f: text = f.read() + # Extract the characters chars = list(set(text)) + # Character to id (integer) map self.stoi = {c: i for i, c in enumerate(chars)} + # Id to character map self.itos = {i: c for i, c in enumerate(chars)} + # Length of a training sample self.seq_len = seq_len + # Data in the form of a tensor of ids self.data = self.text_to_i(text) def text_to_i(self, text: str): + """ + Transform the text into a tensor of ids + """ return torch.tensor([self.stoi[c] for c in text], dtype=torch.long) def __len__(self): + """ + Number of samples in the dataset. + + *This will read the dataset `seq_len` times in a single epoch.* + """ return len(self.data) - self.seq_len - 1 def __getitem__(self, idx): + """ + Return a sample + """ return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1] class Trainer: + """ + ## Trainer + """ + def __init__(self, configs: Configs): + # Get the device self.device = torch.device('cpu') if torch.cuda.is_available(): self.device = torch.device('cuda:0') + # Initialize the dataset self.dataset = TinyShakespeareDataset(configs.seq_len) - self.dataloader = DataLoader(self.dataset, batch_size=configs.batch_size, collate_fn=transpose_batch, + # Initialize the dataloader + self.dataloader = DataLoader(self.dataset, + batch_size=configs.batch_size, + collate_fn=transpose_batch, shuffle=True) + # FFN with Gated Linear Unit + # $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$ if configs.glu_variant == 'GLU': ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False) + # FFN with Bilinear hidden layer + # $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$ elif configs.glu_variant == 'Bilinear': ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False) + # FFN with ReLU gate + # $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$ elif configs.glu_variant == 'ReGLU': ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False) + # FFN with GELU gate + # $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$ elif configs.glu_variant == 'GEGLU': ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False) + # FFN with Swish gate + # $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$ + # where $\text{Swish}_\beta(x) = x \sigma(\beta x)$ elif configs.glu_variant == 'SwiGLU': ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False) + # FFN with ReLU activation + # $$FFN_{ReLU}(x)(x, W_1, W_2, b_1, b_2) = \text{ReLU}_1(x W_1 + b_1) W_2 + b_2$$ elif configs.glu_variant == 'ReLU': ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU()) + # FFN with ReLU activation + # $$FFN_{GELU}(x)(x, W_1, W_2, b_1, b_2) = \text{GELU}_1(x W_1 + b_1) W_2 + b_2$$ elif configs.glu_variant == 'GELU': ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU()) else: raise ValueError(f'Unknown variant {configs.glu_variant}') + # Number of different characters n_chars = len(self.dataset.stoi) + + # Initialize [Multi-Head Attention module](../mha.html) + mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout) + # Initialize the [Transformer Block](../models.html#TransformerLayer) + transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None, + feed_forward=ffn, dropout_prob=configs.dropout) + # Initialize the model with an + # [embedding layer](../models.html#EmbeddingsWithPositionalEncoding) + # (with fixed positional encoding) + # [transformer encoder](../models.html#Encoder) and + # a linear layer to generate logits. self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars), - Encoder(TransformerLayer( - d_model=configs.d_model, - self_attn=MultiHeadAttention(configs.n_heads, configs.d_model, - configs.dropout), - src_attn=None, - feed_forward=ffn, - dropout_prob=configs.dropout - ), configs.n_layers), + Encoder(transformer_layer, configs.n_layers), nn.Linear(configs.d_model, n_chars)) + + # Move the model to the current device self.model.to(self.device) + # Initialize [Noam optimizer](../../optimizers/noam.html) self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model) + # Cross-entropy loss self.loss_func = nn.CrossEntropyLoss() + # Number of training epochs; + # *note that our dataset definition repeats the data `seq_len` times in a single epoch self.epochs = configs.epochs + # Gradient clipping norm self.grad_norm_clip = configs.grad_norm_clip # Set tracker configurations @@ -166,18 +232,28 @@ class Trainer: logger.log(log) def train(self): + """ + ### Train the model + """ + + # Loop for the given number of epochs for _ in monit.loop(self.epochs): + # Iterate over the minibatches for i, batch in monit.enum('Train', self.dataloader): # Move data to the device data, target = batch[0].to(self.device), batch[1].to(self.device) + # Set tracker step, as the number of characters trained on tracker.add_global_step(data.shape[0] * data.shape[1]) + # Set model state to training self.model.train() + # Evaluate the model output = self.model(data) - # Calculate and log loss + # Calculate loss loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1)) + # Log the loss tracker.add("loss.train", loss) # Calculate gradients @@ -186,12 +262,13 @@ class Trainer: torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip) # Take optimizer step self.optimizer.step() - # Log the model parameters and gradients on last batch of every epoch + # Log the model parameters and gradients if (i + 1) % 100 == 0: tracker.add('model', self.model) # Clear the gradients self.optimizer.zero_grad() + # Generate a sample if (i + 1) % 100 == 0: self.model.eval() with torch.no_grad(): @@ -201,6 +278,7 @@ class Trainer: if (i + 1) % 10 == 0: tracker.save() + # Save the model experiment.save_checkpoint() @@ -212,12 +290,14 @@ def main(): # Load configurations experiment.configs(dataclasses.asdict(configs)) + # Create trainer trainer = Trainer(configs) + # Set models for training and loading experiment.add_pytorch_models({'model': trainer.model}) # Start the experiment with experiment.start(): - # `TrainValidConfigs.run` + # Train the model trainer.train()