diff --git a/docs/activations/index.html b/docs/activations/index.html new file mode 100644 index 00000000..062b3757 --- /dev/null +++ b/docs/activations/index.html @@ -0,0 +1,112 @@ + + + + + + + + + + + + + + + + + + + + + + + __init__.py + + + + + + + +
+
+
+
+

+ home + activations +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ + +
+
+
1from .swish import Swish
+
+
+
+ + + + + + \ No newline at end of file diff --git a/docs/activations/swish.html b/docs/activations/swish.html new file mode 100644 index 00000000..9e2ff0a3 --- /dev/null +++ b/docs/activations/swish.html @@ -0,0 +1,149 @@ + + + + + + + + + + + + + + + + + + + + + + + swish.py + + + + + + + +
+
+
+
+

+ home + activations +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ + +
+
+
1import torch
+2from torch import nn
+
+
+
+
+ + +
+
+
5class Swish(nn.Module):
+
+
+
+
+ + +
+
+
6    def __init__(self):
+7        super().__init__()
+8        self.sigmoid = nn.Sigmoid()
+
+
+
+
+ + +
+
+
10    def forward(self, x: torch.Tensor) -> torch.Tensor:
+11        return x * self.sigmoid(x)
+
+
+
+ + + + + + \ No newline at end of file diff --git a/docs/capsule_networks/index.html b/docs/capsule_networks/index.html index 497a3e37..9b048843 100644 --- a/docs/capsule_networks/index.html +++ b/docs/capsule_networks/index.html @@ -43,6 +43,7 @@

+ home capsule_networks

diff --git a/docs/capsule_networks/mnist.html b/docs/capsule_networks/mnist.html index c5b897ad..c0502ede 100644 --- a/docs/capsule_networks/mnist.html +++ b/docs/capsule_networks/mnist.html @@ -43,6 +43,7 @@

+ home capsule_networks

diff --git a/docs/experiments/index.html b/docs/experiments/index.html index f91748b6..d7635d38 100644 --- a/docs/experiments/index.html +++ b/docs/experiments/index.html @@ -43,6 +43,7 @@

+ home experiments

diff --git a/docs/experiments/nlp_autoregression.html b/docs/experiments/nlp_autoregression.html index 73b4910f..f6776e01 100644 --- a/docs/experiments/nlp_autoregression.html +++ b/docs/experiments/nlp_autoregression.html @@ -43,6 +43,7 @@

+ home experiments

@@ -76,16 +77,16 @@ 12 13import torch 14import torch.nn as nn -15from labml import lab, monit, logger, tracker -16from labml.configs import option -17from labml.logger import Text -18from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset -19from labml_helpers.device import DeviceConfigs -20from labml_helpers.metrics.accuracy import Accuracy -21from labml_helpers.module import Module -22from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex -23from torch.utils.data import DataLoader -24 +15from torch.utils.data import DataLoader +16 +17from labml import lab, monit, logger, tracker +18from labml.configs import option +19from labml.logger import Text +20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset +21from labml_helpers.device import DeviceConfigs +22from labml_helpers.metrics.accuracy import Accuracy +23from labml_helpers.module import Module +24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex 25from labml_nn.optimizers.configs import OptimizerConfigs

@@ -654,8 +655,8 @@ This is not implemented yet. 😜

Default optimizer configurations

-
176@option(NLPAutoRegressionConfigs.optimizer)
-177def _optimizer(c: NLPAutoRegressionConfigs):
+
177@option(NLPAutoRegressionConfigs.optimizer)
+178def _optimizer(c: NLPAutoRegressionConfigs):
@@ -666,12 +667,12 @@ This is not implemented yet. 😜

-
182    optimizer = OptimizerConfigs()
-183    optimizer.parameters = c.model.parameters()
-184    optimizer.optimizer = 'Adam'
-185    optimizer.d_model = c.d_model
-186
-187    return optimizer
+
183    optimizer = OptimizerConfigs()
+184    optimizer.parameters = c.model.parameters()
+185    optimizer.optimizer = 'Adam'
+186    optimizer.d_model = c.d_model
+187
+188    return optimizer
@@ -682,8 +683,8 @@ This is not implemented yet. 😜

Get number of tokens

-
190@option(NLPAutoRegressionConfigs.n_tokens)
-191def _n_tokens(c: NLPAutoRegressionConfigs):
+
191@option(NLPAutoRegressionConfigs.n_tokens)
+192def _n_tokens(c: NLPAutoRegressionConfigs):
@@ -694,7 +695,7 @@ This is not implemented yet. 😜

-
195    return c.text.n_tokens
+
196    return c.text.n_tokens
@@ -711,8 +712,8 @@ You can switch by setting,

as the configurations dictionary when starting the experiment.

-
198@option(NLPAutoRegressionConfigs.tokenizer)
-199def basic_english():
+
199@option(NLPAutoRegressionConfigs.tokenizer)
+200def basic_english():
@@ -723,8 +724,8 @@ You can switch by setting,

-
213    from torchtext.data import get_tokenizer
-214    return get_tokenizer('basic_english')
+
214    from torchtext.data import get_tokenizer
+215    return get_tokenizer('basic_english')
@@ -735,7 +736,7 @@ You can switch by setting,

Character level tokenizer

-
217def character_tokenizer(x: str):
+
218def character_tokenizer(x: str):
@@ -746,7 +747,7 @@ You can switch by setting,

-
221    return list(x)
+
222    return list(x)
@@ -757,8 +758,8 @@ You can switch by setting,

Character level tokenizer configuration

-
224@option(NLPAutoRegressionConfigs.tokenizer)
-225def character():
+
225@option(NLPAutoRegressionConfigs.tokenizer)
+226def character():
@@ -769,7 +770,7 @@ You can switch by setting,

-
229    return character_tokenizer
+
230    return character_tokenizer
@@ -781,8 +782,8 @@ You can switch by setting,

It will download from the url if not present

-
232@option(NLPAutoRegressionConfigs.text)
-233def tiny_shakespeare(c: NLPAutoRegressionConfigs):
+
233@option(NLPAutoRegressionConfigs.text)
+234def tiny_shakespeare(c: NLPAutoRegressionConfigs):
@@ -793,10 +794,10 @@ You can switch by setting,

-
239    return TextFileDataset(
-240        lab.get_data_path() / 'tiny_shakespeare.txt',
-241        c.tokenizer,
-242        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
+
240    return TextFileDataset(
+241        lab.get_data_path() / 'tiny_shakespeare.txt',
+242        c.tokenizer,
+243        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
@@ -807,8 +808,8 @@ You can switch by setting,

Sequential training data loader

-
245@option(NLPAutoRegressionConfigs.train_loader)
-246def sequential_train_loader(c: NLPAutoRegressionConfigs):
+
246@option(NLPAutoRegressionConfigs.train_loader)
+247def sequential_train_loader(c: NLPAutoRegressionConfigs):
@@ -819,10 +820,10 @@ You can switch by setting,

-
250    return SequentialDataLoader(text=c.text.train,
-251                                dataset=c.text,
-252                                batch_size=c.batch_size,
-253                                seq_len=c.seq_len)
+
251    return SequentialDataLoader(text=c.text.train,
+252                                dataset=c.text,
+253                                batch_size=c.batch_size,
+254                                seq_len=c.seq_len)
@@ -833,8 +834,8 @@ You can switch by setting,

Sequential validation data loader

-
256@option(NLPAutoRegressionConfigs.valid_loader)
-257def sequential_valid_loader(c: NLPAutoRegressionConfigs):
+
257@option(NLPAutoRegressionConfigs.valid_loader)
+258def sequential_valid_loader(c: NLPAutoRegressionConfigs):
@@ -845,10 +846,10 @@ You can switch by setting,

-
261    return SequentialDataLoader(text=c.text.valid,
-262                                dataset=c.text,
-263                                batch_size=c.batch_size,
-264                                seq_len=c.seq_len)
+
262    return SequentialDataLoader(text=c.text.valid,
+263                                dataset=c.text,
+264                                batch_size=c.batch_size,
+265                                seq_len=c.seq_len)
@@ -861,7 +862,7 @@ You can switch by setting,

We need to transpose it to be sequence first.

-
267def transpose_batch(batch):
+
268def transpose_batch(batch):
@@ -872,7 +873,7 @@ We need to transpose it to be sequence first.

-
275    transposed_data = list(zip(*batch))
+
276    transposed_data = list(zip(*batch))
@@ -883,10 +884,10 @@ We need to transpose it to be sequence first.

Stack the batch along the second dimension dim=1

-
277    src = torch.stack(transposed_data[0], dim=1)
-278    tgt = torch.stack(transposed_data[1], dim=1)
-279
-280    return src, tgt
+
278    src = torch.stack(transposed_data[0], dim=1)
+279    tgt = torch.stack(transposed_data[1], dim=1)
+280
+281    return src, tgt
@@ -897,8 +898,8 @@ We need to transpose it to be sequence first.

Shuffled training data loader

-
283@option(NLPAutoRegressionConfigs.train_loader)
-284def shuffled_train_loader(c: NLPAutoRegressionConfigs):
+
284@option(NLPAutoRegressionConfigs.train_loader)
+285def shuffled_train_loader(c: NLPAutoRegressionConfigs):
@@ -909,12 +910,12 @@ We need to transpose it to be sequence first.

-
288    return DataLoader(SequentialUnBatchedDataset(text=c.text.train,
-289                                                 dataset=c.text,
-290                                                 seq_len=c.seq_len),
-291                      batch_size=c.batch_size,
-292                      collate_fn=transpose_batch,
-293                      shuffle=True)
+
289    return DataLoader(SequentialUnBatchedDataset(text=c.text.train,
+290                                                 dataset=c.text,
+291                                                 seq_len=c.seq_len),
+292                      batch_size=c.batch_size,
+293                      collate_fn=transpose_batch,
+294                      shuffle=True)
@@ -925,8 +926,8 @@ We need to transpose it to be sequence first.

Shuffled validation data loader

-
296@option(NLPAutoRegressionConfigs.valid_loader)
-297def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
+
297@option(NLPAutoRegressionConfigs.valid_loader)
+298def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
@@ -937,12 +938,12 @@ We need to transpose it to be sequence first.

-
301    return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
-302                                                 dataset=c.text,
-303                                                 seq_len=c.seq_len),
-304                      batch_size=c.batch_size,
-305                      collate_fn=transpose_batch,
-306                      shuffle=True)
+
302    return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
+303                                                 dataset=c.text,
+304                                                 seq_len=c.seq_len),
+305                      batch_size=c.batch_size,
+306                      collate_fn=transpose_batch,
+307                      shuffle=True)
diff --git a/docs/gan/cycle_gan.html b/docs/gan/cycle_gan.html index 6ffe2cdc..6f7e46ac 100644 --- a/docs/gan/cycle_gan.html +++ b/docs/gan/cycle_gan.html @@ -43,6 +43,7 @@

+ home gan

diff --git a/docs/gan/dcgan.html b/docs/gan/dcgan.html index 19ce9da8..51398fe9 100644 --- a/docs/gan/dcgan.html +++ b/docs/gan/dcgan.html @@ -43,6 +43,7 @@

+ home gan

diff --git a/docs/gan/index.html b/docs/gan/index.html index 7f3a23e3..1f22b87c 100644 --- a/docs/gan/index.html +++ b/docs/gan/index.html @@ -43,6 +43,7 @@

+ home gan

diff --git a/docs/gan/simple_mnist_experiment.html b/docs/gan/simple_mnist_experiment.html index 17d6da2d..4a16a338 100644 --- a/docs/gan/simple_mnist_experiment.html +++ b/docs/gan/simple_mnist_experiment.html @@ -43,6 +43,7 @@

+ home gan

diff --git a/docs/hypernetworks/experiment.html b/docs/hypernetworks/experiment.html index f5ba9527..0ac8b4c1 100644 --- a/docs/hypernetworks/experiment.html +++ b/docs/hypernetworks/experiment.html @@ -43,6 +43,7 @@

+ home hypernetworks

diff --git a/docs/hypernetworks/hyper_lstm.html b/docs/hypernetworks/hyper_lstm.html index 8060f419..daaa7402 100644 --- a/docs/hypernetworks/hyper_lstm.html +++ b/docs/hypernetworks/hyper_lstm.html @@ -43,6 +43,7 @@

+ home hypernetworks

diff --git a/docs/hypernetworks/index.html b/docs/hypernetworks/index.html index 5129531f..ef887806 100644 --- a/docs/hypernetworks/index.html +++ b/docs/hypernetworks/index.html @@ -43,6 +43,7 @@

+ home hypernetworks

diff --git a/docs/index.html b/docs/index.html index e6e20bf5..7325d024 100644 --- a/docs/index.html +++ b/docs/index.html @@ -43,6 +43,7 @@

+ home

diff --git a/docs/lstm/index.html b/docs/lstm/index.html index d23a3133..1c4cddbb 100644 --- a/docs/lstm/index.html +++ b/docs/lstm/index.html @@ -43,6 +43,7 @@

+ home lstm

diff --git a/docs/optimizers/ada_belief.html b/docs/optimizers/ada_belief.html index 36552463..d741231c 100644 --- a/docs/optimizers/ada_belief.html +++ b/docs/optimizers/ada_belief.html @@ -43,6 +43,7 @@

+ home optimizers

diff --git a/docs/optimizers/adam.html b/docs/optimizers/adam.html index 38bc940b..2a8899c5 100644 --- a/docs/optimizers/adam.html +++ b/docs/optimizers/adam.html @@ -43,6 +43,7 @@

+ home optimizers

diff --git a/docs/optimizers/adam_warmup.html b/docs/optimizers/adam_warmup.html index 91e563ab..3c67642d 100644 --- a/docs/optimizers/adam_warmup.html +++ b/docs/optimizers/adam_warmup.html @@ -43,6 +43,7 @@

+ home optimizers

diff --git a/docs/optimizers/adam_warmup_cosine_decay.html b/docs/optimizers/adam_warmup_cosine_decay.html index ceab76a5..f3d444dd 100644 --- a/docs/optimizers/adam_warmup_cosine_decay.html +++ b/docs/optimizers/adam_warmup_cosine_decay.html @@ -43,6 +43,7 @@

+ home optimizers

diff --git a/docs/optimizers/amsgrad.html b/docs/optimizers/amsgrad.html index 265c5fef..a77dff9b 100644 --- a/docs/optimizers/amsgrad.html +++ b/docs/optimizers/amsgrad.html @@ -43,6 +43,7 @@

+ home optimizers

diff --git a/docs/optimizers/configs.html b/docs/optimizers/configs.html index add354f3..c2cf1480 100644 --- a/docs/optimizers/configs.html +++ b/docs/optimizers/configs.html @@ -43,6 +43,7 @@

+ home optimizers

diff --git a/docs/optimizers/index.html b/docs/optimizers/index.html index 2352d559..7ca877fe 100644 --- a/docs/optimizers/index.html +++ b/docs/optimizers/index.html @@ -43,6 +43,7 @@

+ home optimizers

diff --git a/docs/optimizers/mnist_experiment.html b/docs/optimizers/mnist_experiment.html index 5d79d16f..bb14af3e 100644 --- a/docs/optimizers/mnist_experiment.html +++ b/docs/optimizers/mnist_experiment.html @@ -43,6 +43,7 @@

+ home optimizers

diff --git a/docs/optimizers/noam.html b/docs/optimizers/noam.html index bf7c8e0c..948006b1 100644 --- a/docs/optimizers/noam.html +++ b/docs/optimizers/noam.html @@ -43,6 +43,7 @@

+ home optimizers

diff --git a/docs/optimizers/performance_test.html b/docs/optimizers/performance_test.html index 7f9b2d8a..61c24836 100644 --- a/docs/optimizers/performance_test.html +++ b/docs/optimizers/performance_test.html @@ -43,6 +43,7 @@

+ home optimizers

diff --git a/docs/optimizers/radam.html b/docs/optimizers/radam.html index 05dadd02..56fc81eb 100644 --- a/docs/optimizers/radam.html +++ b/docs/optimizers/radam.html @@ -43,6 +43,7 @@

+ home optimizers

diff --git a/docs/recurrent_highway_networks/index.html b/docs/recurrent_highway_networks/index.html index 960f0133..05825489 100644 --- a/docs/recurrent_highway_networks/index.html +++ b/docs/recurrent_highway_networks/index.html @@ -43,6 +43,7 @@

+ home recurrent_highway_networks

diff --git a/docs/rl/dqn/experiment.html b/docs/rl/dqn/experiment.html index 797ff1b6..cce434ac 100644 --- a/docs/rl/dqn/experiment.html +++ b/docs/rl/dqn/experiment.html @@ -43,6 +43,7 @@

+ home rl dqn

diff --git a/docs/rl/dqn/index.html b/docs/rl/dqn/index.html index 07fb43c8..46af7e79 100644 --- a/docs/rl/dqn/index.html +++ b/docs/rl/dqn/index.html @@ -43,6 +43,7 @@

+ home rl dqn

diff --git a/docs/rl/dqn/model.html b/docs/rl/dqn/model.html index 06ed8a05..30ee293d 100644 --- a/docs/rl/dqn/model.html +++ b/docs/rl/dqn/model.html @@ -43,6 +43,7 @@

+ home rl dqn

diff --git a/docs/rl/dqn/replay_buffer.html b/docs/rl/dqn/replay_buffer.html index 255c3003..a25eb71f 100644 --- a/docs/rl/dqn/replay_buffer.html +++ b/docs/rl/dqn/replay_buffer.html @@ -43,6 +43,7 @@

+ home rl dqn

diff --git a/docs/rl/game.html b/docs/rl/game.html index cef45573..5e9157f4 100644 --- a/docs/rl/game.html +++ b/docs/rl/game.html @@ -43,6 +43,7 @@

+ home rl

diff --git a/docs/rl/index.html b/docs/rl/index.html index 9c6c4a06..3e1e8740 100644 --- a/docs/rl/index.html +++ b/docs/rl/index.html @@ -43,6 +43,7 @@

+ home rl

diff --git a/docs/rl/ppo/experiment.html b/docs/rl/ppo/experiment.html index 47601f61..e045ee6d 100644 --- a/docs/rl/ppo/experiment.html +++ b/docs/rl/ppo/experiment.html @@ -43,6 +43,7 @@

+ home rl ppo

diff --git a/docs/rl/ppo/gae.html b/docs/rl/ppo/gae.html index 08a6a2f6..b9ad7c4e 100644 --- a/docs/rl/ppo/gae.html +++ b/docs/rl/ppo/gae.html @@ -43,6 +43,7 @@

+ home rl ppo

diff --git a/docs/rl/ppo/index.html b/docs/rl/ppo/index.html index 0ed0d7f8..6ab16831 100644 --- a/docs/rl/ppo/index.html +++ b/docs/rl/ppo/index.html @@ -43,6 +43,7 @@

+ home rl ppo

diff --git a/docs/sketch_rnn/index.html b/docs/sketch_rnn/index.html index 33bf3753..fdbd3fc5 100644 --- a/docs/sketch_rnn/index.html +++ b/docs/sketch_rnn/index.html @@ -43,6 +43,7 @@

+ home sketch_rnn

diff --git a/docs/transformers/configs.html b/docs/transformers/configs.html index ac771d08..c9097bfc 100644 --- a/docs/transformers/configs.html +++ b/docs/transformers/configs.html @@ -43,6 +43,7 @@

+ home transformers

@@ -75,19 +76,249 @@

9import copy
 10
 11import torch.nn as nn
-12from labml.configs import BaseConfigs, option, calculate
-13from labml_helpers.module import Module
-14
-15from .mha import MultiHeadAttention
-16from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, FeedForward, \
-17    TransformerLayer, Encoder, Decoder, Generator, EncoderDecoder
+12 +13from labml.configs import BaseConfigs, option, calculate, aggregate +14from labml_helpers.module import Module +15from .feed_forward import FeedForward +16from .mha import MultiHeadAttention +17from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \ +18 Encoder, Decoder, Generator, EncoderDecoder
-
+
+ +
+
+
21class FeedForwardConfigs(BaseConfigs):
+
+
+
+
+ +

Position-wise feedforward layer

+
+
+
23    ffn: FeedForward
+
+
+
+
+ +

Number of features in the embedding

+
+
+
25    d_model: int
+
+
+
+
+ +

Number of features in in the hidden layer

+
+
+
27    d_ff: int = 2048
+
+
+
+
+ +

Dropout probability

+
+
+
29    dropout: float = 0.1
+
+
+
+
+ +

Activation in position-wise feedforward layer

+
+
+
31    activation: nn.Module = 'ReLU'
+
+
+
+
+ +

Whether the FFN layer should be gated

+
+
+
33    is_gated: bool = False
+
+
+
+
+ +

Whether the first fully connected layer should have a learnable bias

+
+
+
35    bias1: bool = True
+
+
+
+
+ +

Whether the second fully connected layer should have a learnable bias

+
+
+
37    bias2: bool = True
+
+
+
+
+ +

Whether the fully connected layer for the gate should have a learnable bias

+
+
+
39    bias_gate: bool = False
+
+
+
+
+ +

Predefined GLU variants

+
+
+
41    glu_variant: str = 'none'
+
+
+
+
+ +

ReLU activation

+
+
+
44@option(FeedForwardConfigs.activation, 'ReLU')
+45def _ffn_activation_relu():
+
+
+
+
+ + +
+
+
49    return nn.ReLU()
+
+
+
+
+ +

GELU activation

+
+
+
52@option(FeedForwardConfigs.activation, 'GELU')
+53def _ffn_activation_gelu():
+
+
+
+
+ + +
+
+
57    return nn.GELU()
+
+
+
+
+ +

Create feedforward layer

+
+
+
60@option(FeedForwardConfigs.ffn, 'default')
+61def _feed_forward(c: FeedForwardConfigs):
+
+
+
+
+ + +
+
+
65    return FeedForward(c.d_model, c.d_ff,
+66                       dropout=c.dropout,
+67                       activation=c.activation,
+68                       is_gated=c.is_gated,
+69                       bias1=c.bias1,
+70                       bias2=c.bias2,
+71                       bias_gate=c.bias_gate)
+72
+73
+74aggregate(FeedForwardConfigs.glu_variant, 'GLU',
+75          (FeedForwardConfigs.is_gated, True),
+76          (FeedForwardConfigs.bias1, False),
+77          (FeedForwardConfigs.bias2, False),
+78          (FeedForwardConfigs.bias_gate, False),
+79          (FeedForwardConfigs.activation, nn.Sigmoid()))
+80
+81aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
+82          (FeedForwardConfigs.is_gated, True),
+83          (FeedForwardConfigs.bias1, False),
+84          (FeedForwardConfigs.bias2, False),
+85          (FeedForwardConfigs.bias_gate, False),
+86          (FeedForwardConfigs.activation, nn.Identity()))
+87aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
+88          (FeedForwardConfigs.is_gated, True),
+89          (FeedForwardConfigs.bias1, False),
+90          (FeedForwardConfigs.bias2, False),
+91          (FeedForwardConfigs.bias_gate, False),
+92          (FeedForwardConfigs.activation, nn.ReLU()))
+93aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
+94          (FeedForwardConfigs.is_gated, True),
+95          (FeedForwardConfigs.bias1, False),
+96          (FeedForwardConfigs.bias2, False),
+97          (FeedForwardConfigs.bias_gate, False),
+98          (FeedForwardConfigs.activation, nn.GELU()))
+99aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
+100          (FeedForwardConfigs.is_gated, True),
+101          (FeedForwardConfigs.bias1, False),
+102          (FeedForwardConfigs.bias2, False),
+103          (FeedForwardConfigs.bias_gate, False),
+104          (FeedForwardConfigs.activation, nn.SiLU()))
+
+
+
+
+

Transformer Configurations

@@ -97,194 +328,7 @@ These are lazy loaded and therefore only the necessary modules are calculated.

-
20class TransformerConfigs(BaseConfigs):
-
-
-
-
- -

Number of attention heads

-
-
-
32    n_heads: int = 8
-
-
-
-
- -

Transformer embedding size

-
-
-
34    d_model: int = 512
-
-
-
-
- -

Number of layers

-
-
-
36    n_layers: int = 6
-
-
-
-
- -

Number of features in position-wise feedforward layer

-
-
-
38    d_ff: int = 2048
-
-
-
-
- -

Dropout probability

-
-
-
40    dropout: float = 0.1
-
-
-
-
- -

Number of tokens in the source vocabulary (for token embeddings)

-
-
-
42    n_src_vocab: int
-
-
-
-
- -

Number of tokens in the target vocabulary (to generate logits for prediction)

-
-
-
44    n_tgt_vocab: int
-
-
-
-
- -

The encoder self attention

-
-
-
47    encoder_attn: MultiHeadAttention = 'mha'
-
-
-
-
- -

The decoder self attention

-
-
-
49    decoder_attn: MultiHeadAttention = 'mha'
-
-
-
-
- -

The decoder memory attention

-
-
-
51    decoder_mem_attn: MultiHeadAttention = 'mha'
-
-
-
-
- -

Position-wise feedforward layer

-
-
-
53    feed_forward: FeedForward
-
-
-
-
- -

Activation in position-wise feedforward layer

-
-
-
55    feed_forward_activation: nn.Module = 'ReLU'
-
-
-
-
- -

Encoder layer

-
-
-
58    encoder_layer: TransformerLayer = 'default'
-
-
-
-
- -

Decoder layer

-
-
-
60    decoder_layer: TransformerLayer = 'default'
-
-
-
-
- -

Encoder consisting of multiple encoder layers

-
-
-
63    encoder: Encoder = 'default'
-
-
-
-
- -

Encoder consisting of multiple decoder layers

-
-
-
65    decoder: Decoder = 'default'
-
-
-
-
- -

Embedding layer for source

-
-
-
68    src_embed: Module = 'fixed_pos'
+
107class TransformerConfigs(BaseConfigs):
@@ -292,10 +336,10 @@ are calculated.

-

Embedding layer for target (for decoder)

+

Number of attention heads

-
70    tgt_embed: Module = 'fixed_pos'
+
119    n_heads: int = 8
@@ -303,10 +347,10 @@ are calculated.

-

Logit generator for prediction

+

Transformer embedding size

-
73    generator: Generator = 'default'
+
121    d_model: int = 512
@@ -314,22 +358,21 @@ are calculated.

-

Encoder-decoder

+

Number of layers

-
76    encoder_decoder: EncoderDecoder
+
123    n_layers: int = 6
-
+
-

ReLU activation

+

Dropout probability

-
79@option(TransformerConfigs.feed_forward_activation, 'ReLU')
-80def _feed_forward_activation_relu():
+
125    dropout: float = 0.1
@@ -337,22 +380,21 @@ are calculated.

- +

Number of tokens in the source vocabulary (for token embeddings)

-
84    return nn.ReLU()
+
127    n_src_vocab: int
-
+
-

GELU activation

+

Number of tokens in the target vocabulary (to generate logits for prediction)

-
87@option(TransformerConfigs.feed_forward_activation, 'GELU')
-88def _feed_forward_activation_relu():
+
129    n_tgt_vocab: int
@@ -360,22 +402,21 @@ are calculated.

- +

The encoder self attention

-
92    return nn.GELU()
+
132    encoder_attn: MultiHeadAttention = 'mha'
-
+
-

Create feedforward layer

+

The decoder self attention

-
95@option(TransformerConfigs.feed_forward, 'default')
-96def _feed_forward(c: TransformerConfigs):
+
134    decoder_attn: MultiHeadAttention = 'mha'
@@ -383,10 +424,10 @@ are calculated.

- +

The decoder memory attention

-
100    return FeedForward(c.d_model, c.d_ff, c.dropout, c.feed_forward_activation)
+
136    decoder_mem_attn: MultiHeadAttention = 'mha'
@@ -394,15 +435,10 @@ are calculated.

-

Multi-head Attention

+

Configurable Feedforward Layer

-
104def _mha(c: TransformerConfigs):
-105    return MultiHeadAttention(c.n_heads, c.d_model)
-106
-107calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
-108calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
-109calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
+
139    ffn: FeedForwardConfigs
@@ -410,29 +446,21 @@ are calculated.

-

Relative Multi-head Attention

-
-
-
113def _relative_mha(c: TransformerConfigs):
-114    from .relative_mha import RelativeMultiHeadAttention
-115    return RelativeMultiHeadAttention(c.n_heads, c.d_model)
-116
-117
-118calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
-119calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
-120calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)
-
-
-
-
-

Encoder layer

-
123@option(TransformerConfigs.encoder_layer, 'default')
-124def _encoder_layer(c: TransformerConfigs):
+
142    encoder_layer: TransformerLayer = 'default'
+
+
+
+
+ +

Decoder layer

+
+
+
144    decoder_layer: TransformerLayer = 'default'
@@ -440,24 +468,21 @@ are calculated.

- +

Encoder consisting of multiple encoder layers

-
128    return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
-129                            src_attn=None, feed_forward=copy.deepcopy(c.feed_forward),
-130                            dropout_prob=c.dropout)
+
147    encoder: Encoder = 'default'
-
+
-

Decoder layer

+

Encoder consisting of multiple decoder layers

-
133@option(TransformerConfigs.decoder_layer, 'default')
-134def _decoder_layer(c: TransformerConfigs):
+
149    decoder: Decoder = 'default'
@@ -465,24 +490,21 @@ are calculated.

- +

Embedding layer for source

-
138    return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
-139                            src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.feed_forward),
-140                            dropout_prob=c.dropout)
+
152    src_embed: Module = 'fixed_pos'
-
+
-

Encoder

+

Embedding layer for target (for decoder)

-
143@option(TransformerConfigs.encoder, 'default')
-144def _encoder(c: TransformerConfigs):
+
154    tgt_embed: Module = 'fixed_pos'
@@ -490,22 +512,21 @@ are calculated.

- +

Logit generator for prediction

-
148    return Encoder(c.encoder_layer, c.n_layers)
+
157    generator: Generator = 'default'
-
+
-

Decoder

+

Encoder-decoder

-
151@option(TransformerConfigs.decoder, 'default')
-152def _decoder(c: TransformerConfigs):
+
160    encoder_decoder: EncoderDecoder
@@ -513,151 +534,168 @@ are calculated.

- +

Multi-head Attention

-
156    return Decoder(c.decoder_layer, c.n_layers)
+
164def _mha(c: TransformerConfigs):
+165    return MultiHeadAttention(c.n_heads, c.d_model)
+166
+167
+168calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
+169calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
+170calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
-
+
+

Relative Multi-head Attention

+
+
+
174def _relative_mha(c: TransformerConfigs):
+175    from .relative_mha import RelativeMultiHeadAttention
+176    return RelativeMultiHeadAttention(c.n_heads, c.d_model)
+177
+178
+179calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
+180calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
+181calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)
+
+
+
+
+ +

Create feedforward layer configurations

+
+
+
184@option(TransformerConfigs.ffn, 'default')
+185def _feed_forward(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
189    conf = FeedForwardConfigs()
+190    conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
+191    conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
+192    return conf
+
+
+
+
+ +

Encoder layer

+
+
+
195@option(TransformerConfigs.encoder_layer, 'default')
+196def _encoder_layer(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
200    return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
+201                            src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
+202                            dropout_prob=c.dropout)
+
+
+
+
+ +

Decoder layer

+
+
+
205@option(TransformerConfigs.decoder_layer, 'default')
+206def _decoder_layer(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
210    return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
+211                            src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
+212                            dropout_prob=c.dropout)
+
+
+
+
+ +

Encoder

+
+
+
215@option(TransformerConfigs.encoder, 'default')
+216def _encoder(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
220    return Encoder(c.encoder_layer, c.n_layers)
+
+
+
+
+ +

Decoder

+
+
+
223@option(TransformerConfigs.decoder, 'default')
+224def _decoder(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
228    return Decoder(c.decoder_layer, c.n_layers)
+
+
+
+
+

Logit generator

-
159@option(TransformerConfigs.generator, 'default')
-160def _generator(c: TransformerConfigs):
-
-
-
-
- - -
-
-
164    return Generator(c.n_tgt_vocab, c.d_model)
-
-
-
-
- -

Positional Embeddings

-

Source embedding with fixed positional encodings

-
-
-
168@option(TransformerConfigs.src_embed, 'fixed_pos')
-169def _src_embed_with_positional(c: TransformerConfigs):
-
-
-
-
- - -
-
-
173    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
-
-
-
-
- -

Target embedding with fixed positional encodings

-
-
-
176@option(TransformerConfigs.tgt_embed, 'fixed_pos')
-177def _tgt_embed_with_positional(c: TransformerConfigs):
-
-
-
-
- - -
-
-
181    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
-
-
-
-
- -

Learned Positional Embeddings

-

Source embedding with learned positional encodings

-
-
-
185@option(TransformerConfigs.src_embed, 'learned_pos')
-186def _src_embed_with_learned_positional(c: TransformerConfigs):
-
-
-
-
- - -
-
-
190    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
-
-
-
-
- -

Target embedding with learned positional encodings

-
-
-
193@option(TransformerConfigs.tgt_embed, 'learned_pos')
-194def _tgt_embed_with_learned_positional(c: TransformerConfigs):
-
-
-
-
- - -
-
-
198    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
-
-
-
-
- -

No Positional Embeddings

-

Source embedding without positional encodings

-
-
-
202@option(TransformerConfigs.src_embed, 'no_pos')
-203def _src_embed_without_positional(c: TransformerConfigs):
-
-
-
-
- - -
-
-
207    return nn.Embedding(c.n_src_vocab, c.d_model)
+
231@option(TransformerConfigs.generator, 'default')
+232def _generator(c: TransformerConfigs):
@@ -668,14 +706,143 @@ are calculated.

-
210@option(TransformerConfigs.tgt_embed, 'no_pos')
-211def _tgt_embed_without_positional(c: TransformerConfigs):
-212    return nn.Embedding(c.n_tgt_vocab, c.d_model)
-213
-214
-215@option(TransformerConfigs.encoder_decoder, 'default')
-216def _encoder_decoder(c: TransformerConfigs):
-217    return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)
+
236    return Generator(c.n_tgt_vocab, c.d_model)
+
+
+
+
+ +

Positional Embeddings

+

Source embedding with fixed positional encodings

+
+
+
240@option(TransformerConfigs.src_embed, 'fixed_pos')
+241def _src_embed_with_positional(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
245    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
+
+
+
+
+ +

Target embedding with fixed positional encodings

+
+
+
248@option(TransformerConfigs.tgt_embed, 'fixed_pos')
+249def _tgt_embed_with_positional(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
253    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
+
+
+
+
+ +

Learned Positional Embeddings

+

Source embedding with learned positional encodings

+
+
+
257@option(TransformerConfigs.src_embed, 'learned_pos')
+258def _src_embed_with_learned_positional(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
262    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
+
+
+
+
+ +

Target embedding with learned positional encodings

+
+
+
265@option(TransformerConfigs.tgt_embed, 'learned_pos')
+266def _tgt_embed_with_learned_positional(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
270    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
+
+
+
+
+ +

No Positional Embeddings

+

Source embedding without positional encodings

+
+
+
274@option(TransformerConfigs.src_embed, 'no_pos')
+275def _src_embed_without_positional(c: TransformerConfigs):
+
+
+
+
+ + +
+
+
279    return nn.Embedding(c.n_src_vocab, c.d_model)
+
+
+
+
+ + +
+
+
282@option(TransformerConfigs.tgt_embed, 'no_pos')
+283def _tgt_embed_without_positional(c: TransformerConfigs):
+284    return nn.Embedding(c.n_tgt_vocab, c.d_model)
+285
+286
+287@option(TransformerConfigs.encoder_decoder, 'default')
+288def _encoder_decoder(c: TransformerConfigs):
+289    return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)
diff --git a/docs/transformers/feed_forward.html b/docs/transformers/feed_forward.html new file mode 100644 index 00000000..053e69b2 --- /dev/null +++ b/docs/transformers/feed_forward.html @@ -0,0 +1,200 @@ + + + + + + + + + + + + + + + + + + + + + + + Position-wise Feed-Forward Network (FFN) + + + + + + + +
+
+
+
+

+ home + transformers +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

Position-wise Feed-Forward Network (FFN)

+

FFN consists of two fully connected layers. +Number of dimensions in the hidden layer $d_{ff}$, is generally set to around +four times that of the token embedding $d_{model}$. +So it is sometime also called the expand-and-contract network.

+

There is an activation at the hidden layer, which is +usually set to ReLU (Rectified Linear Unit) activation, +

+

That is, the FFN function is, + +where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters.

+

Sometimes the +GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. + where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$

+
+
+
26import torch
+27from torch import nn as nn
+28
+29from labml_helpers.module import Module
+
+
+
+
+ +

Position-wise feed-forward network (FFN) module

+
+
+
32class FeedForward(Module):
+
+
+
+
+ +
    +
  • d_model is the number of features in a token embedding
  • +
  • d_ff is the number of features in the hidden layer of the FFN
  • +
  • dropout is dropout probability for the hidden layer
  • +
  • is_gated specifies whether the hidden layer is gated
  • +
  • bias1 specified whether the first fully connected layer should have a learnable bias
  • +
  • bias2 specified whether the second fully connected layer should have a learnable bias
  • +
  • bias_gate specified whether the fully connected layer for the gate should have a learnable bias
  • +
+
+
+
37    def __init__(self, d_model: int, d_ff: int,
+38                 dropout: float = 0.1,
+39                 activation=nn.ReLU(),
+40                 is_gated: bool = False,
+41                 bias1: bool = True,
+42                 bias2: bool = True,
+43                 bias_gate: bool = True):
+
+
+
+
+ + +
+
+
53        super().__init__()
+54        self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
+55        self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
+56        self.dropout = nn.Dropout(dropout)
+57        self.activation = activation
+58        self.is_gated = is_gated
+59        if is_gated:
+60            self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)
+
+
+
+
+ + +
+
+
62    def __call__(self, x: torch.Tensor):
+63        g = self.activation(self.layer1(x))
+64        if self.is_gated:
+65            x = g * self.linear_v(x)
+66        else:
+67            x = g
+68        x = self.dropout(x)
+69        return self.layer2(x)
+
+
+
+
+ + + + + \ No newline at end of file diff --git a/docs/transformers/feedback/index.html b/docs/transformers/feedback/index.html index 634d47a7..7f602909 100644 --- a/docs/transformers/feedback/index.html +++ b/docs/transformers/feedback/index.html @@ -43,6 +43,7 @@

+ home transformers feedback

@@ -99,7 +100,7 @@ This reduces the memory used for caching during prediction.

40 41from labml_helpers.module import Module 42from labml_nn.transformers.mha import PrepareForMultiHeadAttention -43from labml_nn.transformers.models import FeedForward +43from labml_nn.transformers.feed_forward import FeedForward 44from labml_nn.utils import clone_module_list
diff --git a/docs/transformers/glu_variants/experiment.html b/docs/transformers/glu_variants/experiment.html new file mode 100644 index 00000000..8ab55ca6 --- /dev/null +++ b/docs/transformers/glu_variants/experiment.html @@ -0,0 +1,450 @@ + + + + + + + + + + + + + + + + + + + + + + + Gated Linear Units and Variants + + + + + + + +
+
+
+
+

+ home + transformers + glu_variants +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

Train Autoregressive Transformer

+

This trains a simple transformer model for auto-regression.

+
+
+
14import torch
+15from labml import experiment
+16from labml.configs import option
+17from labml.utils.pytorch import get_modules
+18from labml_helpers.module import Module
+19
+20from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+21from labml_nn.transformers import Encoder, Generator, TransformerConfigs
+22from labml_nn.transformers.utils import subsequent_mask
+
+
+
+
+ +

Auto regressive model

+
+
+
25class AutoregressiveModel(Module):
+
+
+
+
+ + +
+
+
30    def __init__(self, src_embed: Module, encoder: Encoder, generator: Generator):
+31        super().__init__()
+
+
+
+
+ +

Token embedding module

+
+
+
33        self.src_embed = src_embed
+
+
+
+
+ +

Transformer based encoder

+
+
+
35        self.encoder = encoder
+
+
+
+
+ +

Next token generation layer; +this give logits of the the next token

+
+
+
38        self.generator = generator
+
+
+
+
+ +

This will be initialized on the first call

+
+
+
40        self.src_mask = None
+
+
+
+
+ + +
+
+
42    def __call__(self, src: torch.Tensor):
+
+
+
+
+ +

Create subsequent mask, so that the transformer can only pay attention to past tokens.

+
+
+
44        if self.src_mask is None or self.src_mask.size(0) != len(src):
+45            self.src_mask = subsequent_mask(len(src)).to(src.device)
+
+
+
+
+ +

Embed the tokens (src) and run it through the the transformer

+
+
+
47        res = self.encoder(self.src_embed(src), self.src_mask)
+
+
+
+
+ +

Generate logits of the next token

+
+
+
49        return self.generator(res), None
+
+
+
+
+ +

Configurations

+

The default configs can and will be over-ridden when we start the experiment

+
+
+
52class Configs(NLPAutoRegressionConfigs):
+
+
+
+
+ + +
+
+
59    transformer: TransformerConfigs
+60    model: AutoregressiveModel
+
+
+
+
+ +

Initialize the auto-regressive model

+
+
+
63@option(Configs.model)
+64def autoregressive_model(c: Configs):
+
+
+
+
+ + +
+
+
68    m = AutoregressiveModel(c.transformer.src_embed, c.transformer.encoder, c.transformer.generator)
+69    return m.to(c.device)
+
+
+
+
+ +

Initialize the configurable transformer encoder for our autoregressive model

+
+
+
72@option(Configs.transformer)
+73def transformer_c(c: Configs):
+
+
+
+
+ + +
+
+
77    tc = TransformerConfigs()
+78    tc.n_src_vocab = c.n_tokens
+79    tc.n_tgt_vocab = c.n_tokens
+80
+81    return tc
+
+
+
+
+ + +
+
+
84def main():
+
+
+
+
+ +

Create experiment

+
+
+
86    experiment.create(name="glu_variants")
+
+
+
+
+ +

Create configs

+
+
+
88    conf = Configs()
+
+
+
+
+ +

Load configurations

+
+
+
90    experiment.configs(conf,
+
+
+
+
+ +

A dictionary of configurations to override

+
+
+
92                       {'tokenizer': 'character',
+93                        'prompt_separator': '',
+94                        'prompt': 'It is ',
+95                        'text': 'tiny_shakespeare',
+96
+97                        'optimizer.optimizer': 'Noam',
+98                        'optimizer.learning_rate': 1.,
+99                        'optimizer.d_model': 256,
+100
+101                        'seq_len': 1024,
+102                        'epochs': 128,
+103                        'batch_size': 6,
+104                        'inner_iterations': 10,
+
+
+
+
+ +

GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU

+
+
+
107                        'transformer.ffn.glu_variant': 'Bilinear',
+
+
+
+
+ +

Transformer configurations

+
+
+
110                        'transformer.d_model': 256,
+111                        'transformer.ffn.d_ff': 1024,
+112                        'transformer.n_heads': 8,
+113                        'transformer.n_layers': 6})
+
+
+
+
+ +

This is needed to initialize models

+
+
+
116    conf.n_tokens = conf.text.n_tokens
+
+
+
+
+ +

Set models for saving and loading

+
+
+
119    experiment.add_pytorch_models(get_modules(conf))
+
+
+
+
+ +

Start the experiment

+
+
+
122    with experiment.start():
+
+
+
+
+ +

TrainValidConfigs.run

+
+
+
124        conf.run()
+125
+126
+127if __name__ == '__main__':
+128    main()
+
+
+
+
+ + + + + \ No newline at end of file diff --git a/docs/transformers/glu_variants/index.html b/docs/transformers/glu_variants/index.html new file mode 100644 index 00000000..8699e399 --- /dev/null +++ b/docs/transformers/glu_variants/index.html @@ -0,0 +1,102 @@ + + + + + + + + + + + + + + + + + + + + + + + None + + + + + + + +
+
+
+
+

+ home + transformers + glu_variants +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ + + + + \ No newline at end of file diff --git a/docs/transformers/glu_variants/simple.html b/docs/transformers/glu_variants/simple.html new file mode 100644 index 00000000..46e54fae --- /dev/null +++ b/docs/transformers/glu_variants/simple.html @@ -0,0 +1,643 @@ + + + + + + + + + + + + + + + + + + + + + + + Gated Linear Units and Variants + + + + + + + +
+
+
+
+

+ home + transformers + glu_variants +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

Train Autoregressive Transformer

+

This trains a simple transformer model for auto-regression.

+
+
+
13import dataclasses
+14
+15import torch
+16from torch import nn
+17from torch.utils.data import Dataset, DataLoader
+18
+19from labml import experiment, lab, tracker, monit, logger
+20from labml.logger import Text
+21from labml.utils.download import download_file
+22from labml_nn.experiments.nlp_autoregression import transpose_batch
+23from labml_nn.optimizers.noam import Noam
+24from labml_nn.transformers import Encoder, MultiHeadAttention
+25from labml_nn.transformers.feed_forward import FeedForward
+26from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
+27from labml_nn.transformers.utils import subsequent_mask
+
+
+
+
+ +

Auto regressive model

+
+
+
30class AutoregressiveModel(nn.Module):
+
+
+
+
+ + +
+
+
35    def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
+36        super().__init__()
+
+
+
+
+ +

Token embedding module

+
+
+
38        self.src_embed = src_embed
+
+
+
+
+ +

Transformer based encoder

+
+
+
40        self.encoder = encoder
+
+
+
+
+ +

Next token generation layer; +this give logits of the the next token

+
+
+
43        self.generator = generator
+
+
+
+
+ +

This will be initialized on the first call

+
+
+
45        self.src_mask = None
+
+
+
+
+ + +
+
+
47    def __call__(self, src: torch.Tensor):
+
+
+
+
+ +

Create subsequent mask, so that the transformer can only pay attention to past tokens.

+
+
+
49        if self.src_mask is None or self.src_mask.size(0) != len(src):
+50            self.src_mask = subsequent_mask(len(src)).to(src.device)
+
+
+
+
+ +

Embed the tokens (src) and run it through the the transformer

+
+
+
52        res = self.encoder(self.src_embed(src), self.src_mask)
+
+
+
+
+ +

Generate logits of the next token

+
+
+
54        return self.generator(res)
+
+
+
+
+ + +
+
+
57@dataclasses.dataclass
+58class Configs:
+59    d_model: int = 512
+60    seq_len: int = 128
+61    batch_size: int = 32
+62    n_layers: int = 6
+63    n_heads: int = 8
+64    dropout: float = 0.1
+65    d_ff: int = 2048
+66    glu_variant: str = 'GLU'
+67    epochs: int = 5
+68    grad_norm_clip: float = 0.5
+69
+70
+71class TinyShakespeareDataset(Dataset):
+72    def __init__(self, seq_len: int):
+73        path = lab.get_data_path() / 'tiny_shakespeare.txt'
+74        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
+75        with open(str(path), 'r') as f:
+76            text = f.read()
+77
+78        chars = list(set(text))
+79        self.stoi = {c: i for i, c in enumerate(chars)}
+80        self.itos = {i: c for i, c in enumerate(chars)}
+81        self.seq_len = seq_len
+82        self.data = self.text_to_i(text)
+83
+84    def text_to_i(self, text: str):
+85        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
+86
+87    def __len__(self):
+88        return len(self.data) - self.seq_len - 1
+89
+90    def __getitem__(self, idx):
+91        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
+92
+93
+94class Trainer:
+95    def __init__(self, configs: Configs):
+96        self.device = torch.device('cpu')
+97        if torch.cuda.is_available():
+98            self.device = torch.device('cuda:0')
+99        self.dataset = TinyShakespeareDataset(configs.seq_len)
+100        self.dataloader = DataLoader(self.dataset, batch_size=configs.batch_size, collate_fn=transpose_batch,
+101                                     shuffle=True)
+102
+103        if configs.glu_variant == 'GLU':
+104            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
+105        elif configs.glu_variant == 'Bilinear':
+106            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
+107        elif configs.glu_variant == 'ReGLU':
+108            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
+109        elif configs.glu_variant == 'GEGLU':
+110            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
+111        elif configs.glu_variant == 'SwiGLU':
+112            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
+113        elif configs.glu_variant == 'ReLU':
+114            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
+115        elif configs.glu_variant == 'GELU':
+116            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
+117        else:
+118            raise ValueError(f'Unknown variant {configs.glu_variant}')
+119
+120        n_chars = len(self.dataset.stoi)
+121        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
+122                                         Encoder(TransformerLayer(
+123                                             d_model=configs.d_model,
+124                                             self_attn=MultiHeadAttention(configs.n_heads, configs.d_model,
+125                                                                          configs.dropout),
+126                                             src_attn=None,
+127                                             feed_forward=ffn,
+128                                             dropout_prob=configs.dropout
+129                                         ), configs.n_layers),
+130                                         nn.Linear(configs.d_model, n_chars))
+131        self.model.to(self.device)
+132
+133        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
+134
+135        self.loss_func = nn.CrossEntropyLoss()
+136        self.epochs = configs.epochs
+137        self.grad_norm_clip = configs.grad_norm_clip
+
+
+
+
+ +

Set tracker configurations

+
+
+
140        tracker.set_scalar("loss.*", True)
+
+
+
+
+ +

Sampling function to generate samples periodically while training

+
+
+
142    def sample(self):
+
+
+
+
+ +

Starting prompt

+
+
+
148        prompt = 'It is'
+
+
+
+
+ +

Collect output for printing

+
+
+
150        log = [(prompt, Text.subtle)]
+
+
+
+
+ +

Sample 25 tokens

+
+
+
152        for i in monit.iterate('Sample', 25):
+
+
+
+
+ +

Tokenize the prompt

+
+
+
154            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
+155            data = data.to(self.device)
+
+
+
+
+ +

Get the model output

+
+
+
157            output = self.model(data)
+
+
+
+
+ +

Get the model prediction (greedy)

+
+
+
159            output = output.argmax(dim=-1).squeeze()
+
+
+
+
+ +

Add the prediction to prompt

+
+
+
161            prompt += self.dataset.itos[output[-1].item()]
+
+
+
+
+ +

Add the prediction for logging

+
+
+
163            log += [(self.dataset.itos[output[-1].item()], Text.value)]
+
+
+
+
+ +

Print the sampled output

+
+
+
166        logger.log(log)
+
+
+
+
+ + +
+
+
168    def train(self):
+169        for _ in monit.loop(self.epochs):
+170            for i, batch in monit.enum('Train', self.dataloader):
+
+
+
+
+ +

Move data to the device

+
+
+
172                data, target = batch[0].to(self.device), batch[1].to(self.device)
+173
+174                tracker.add_global_step(data.shape[0] * data.shape[1])
+175
+176                self.model.train()
+177                output = self.model(data)
+
+
+
+
+ +

Calculate and log loss

+
+
+
180                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
+181                tracker.add("loss.train", loss)
+
+
+
+
+ +

Calculate gradients

+
+
+
184                loss.backward()
+
+
+
+
+ +

Clip gradients

+
+
+
186                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
+
+
+
+
+ +

Take optimizer step

+
+
+
188                self.optimizer.step()
+
+
+
+
+ +

Log the model parameters and gradients on last batch of every epoch

+
+
+
190                if (i + 1) % 100 == 0:
+191                    tracker.add('model', self.model)
+
+
+
+
+ +

Clear the gradients

+
+
+
193                self.optimizer.zero_grad()
+194
+195                if (i + 1) % 100 == 0:
+196                    self.model.eval()
+197                    with torch.no_grad():
+198                        self.sample()
+
+
+
+
+ +

Save the tracked metrics

+
+
+
201                if (i + 1) % 10 == 0:
+202                    tracker.save()
+203
+204            experiment.save_checkpoint()
+
+
+
+
+ + +
+
+
207def main():
+
+
+
+
+ +

Create experiment

+
+
+
209    experiment.create(name="glu_variants")
+
+
+
+
+ +

Create configs

+
+
+
211    configs = Configs()
+
+
+
+
+ +

Load configurations

+
+
+
213    experiment.configs(dataclasses.asdict(configs))
+214
+215    trainer = Trainer(configs)
+216    experiment.add_pytorch_models({'model': trainer.model})
+
+
+
+
+ +

Start the experiment

+
+
+
219    with experiment.start():
+
+
+
+
+ +

TrainValidConfigs.run

+
+
+
221        trainer.train()
+222
+223
+224if __name__ == '__main__':
+225    main()
+
+
+
+
+ + + + + \ No newline at end of file diff --git a/docs/transformers/gpt/index.html b/docs/transformers/gpt/index.html index bfac8d7c..c7196140 100644 --- a/docs/transformers/gpt/index.html +++ b/docs/transformers/gpt/index.html @@ -43,6 +43,7 @@

+ home transformers gpt

@@ -348,7 +349,7 @@ or if the size of the mask is different

GPT uses GELU activation for position wise feedforward

-
121    conf.feed_forward_activation = 'GELU'
+
121    conf.ffn.activation = 'GELU'
@@ -777,7 +778,7 @@ per epoch

248        'transformer.d_model': 512,
-249        'transformer.d_ff': 2048,
+249        'transformer.ffn.d_ff': 2048,
 250        'transformer.n_heads': 8,
 251        'transformer.n_layers': 6
 252    })
diff --git a/docs/transformers/index.html b/docs/transformers/index.html index f8e8d5b9..4d553eef 100644 --- a/docs/transformers/index.html +++ b/docs/transformers/index.html @@ -43,6 +43,7 @@

+ home transformers

diff --git a/docs/transformers/knn/build_index.html b/docs/transformers/knn/build_index.html index 6b566252..90dde5f4 100644 --- a/docs/transformers/knn/build_index.html +++ b/docs/transformers/knn/build_index.html @@ -43,6 +43,7 @@

+ home transformers knn

diff --git a/docs/transformers/knn/eval_knn.html b/docs/transformers/knn/eval_knn.html index 4decaf5c..eaf38945 100644 --- a/docs/transformers/knn/eval_knn.html +++ b/docs/transformers/knn/eval_knn.html @@ -43,6 +43,7 @@

+ home transformers knn

diff --git a/docs/transformers/knn/index.html b/docs/transformers/knn/index.html index 45667dc0..53ff31a9 100644 --- a/docs/transformers/knn/index.html +++ b/docs/transformers/knn/index.html @@ -43,6 +43,7 @@

+ home transformers knn

diff --git a/docs/transformers/knn/train_model.html b/docs/transformers/knn/train_model.html index 6a62436d..e0f50505 100644 --- a/docs/transformers/knn/train_model.html +++ b/docs/transformers/knn/train_model.html @@ -43,6 +43,7 @@

+ home transformers knn

@@ -413,7 +414,7 @@ final token generator from configurable transformer

126                        'transformer.d_model': 256,
-127                        'transformer.d_ff': 1024,
+127                        'transformer.ffn.d_ff': 1024,
 128                        'transformer.n_heads': 8,
 129                        'transformer.n_layers': 6})
diff --git a/docs/transformers/label_smoothing_loss.html b/docs/transformers/label_smoothing_loss.html index 26128794..9f4cd6c1 100644 --- a/docs/transformers/label_smoothing_loss.html +++ b/docs/transformers/label_smoothing_loss.html @@ -43,6 +43,7 @@

+ home transformers

diff --git a/docs/transformers/mha.html b/docs/transformers/mha.html index 954413ec..d7308d1e 100644 --- a/docs/transformers/mha.html +++ b/docs/transformers/mha.html @@ -43,6 +43,7 @@

+ home transformers

diff --git a/docs/transformers/models.html b/docs/transformers/models.html index cc550712..9d9f0b2e 100644 --- a/docs/transformers/models.html +++ b/docs/transformers/models.html @@ -43,6 +43,7 @@

+ home transformers

@@ -79,8 +80,9 @@ 15from labml_helpers.module import Module 16 17from labml_nn.utils import clone_module_list -18from .mha import MultiHeadAttention -19from .positional_encoding import get_positional_encoding

+18from .feed_forward import FeedForward +19from .mha import MultiHeadAttention +20from .positional_encoding import get_positional_encoding
@@ -93,7 +95,7 @@

-
22class EmbeddingsWithPositionalEncoding(Module):
+
23class EmbeddingsWithPositionalEncoding(Module):
@@ -104,11 +106,11 @@
-
29    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-30        super().__init__()
-31        self.linear = nn.Embedding(n_vocab, d_model)
-32        self.d_model = d_model
-33        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
+
30    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+31        super().__init__()
+32        self.linear = nn.Embedding(n_vocab, d_model)
+33        self.d_model = d_model
+34        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
@@ -119,9 +121,9 @@
-
35    def __call__(self, x: torch.Tensor):
-36        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
-37        return self.linear(x) * math.sqrt(self.d_model) + pe
+
36    def __call__(self, x: torch.Tensor):
+37        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
+38        return self.linear(x) * math.sqrt(self.d_model) + pe
@@ -134,7 +136,7 @@

-
40class EmbeddingsWithLearnedPositionalEncoding(Module):
+
41class EmbeddingsWithLearnedPositionalEncoding(Module):
@@ -145,11 +147,11 @@
-
47    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-48        super().__init__()
-49        self.linear = nn.Embedding(n_vocab, d_model)
-50        self.d_model = d_model
-51        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
+
48    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+49        super().__init__()
+50        self.linear = nn.Embedding(n_vocab, d_model)
+51        self.d_model = d_model
+52        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
@@ -160,9 +162,9 @@
-
53    def __call__(self, x: torch.Tensor):
-54        pe = self.positional_encodings[:x.shape[0]]
-55        return self.linear(x) * math.sqrt(self.d_model) + pe
+
54    def __call__(self, x: torch.Tensor):
+55        pe = self.positional_encodings[:x.shape[0]]
+56        return self.linear(x) * math.sqrt(self.d_model) + pe
-
-
58class FeedForward(Module):
-
-
-
-
- -
    -
  • d_model is the number of features in a token embedding
  • -
  • d_ff is the number of features in the hidden layer of the FFN
  • -
  • dropout is dropout probability for the hidden layer
  • -
  • activation is the activation function to apply on the hidden layer outputs
  • -
-
-
-
65    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, activation=nn.ReLU()):
-
-
-
-
- - -
-
-
72        super().__init__()
-73        self.layer1 = nn.Linear(d_model, d_ff)
-74        self.layer2 = nn.Linear(d_ff, d_model)
-75        self.dropout = nn.Dropout(dropout)
-76        self.activation = activation
-
-
-
-
- - -
-
-
78    def __call__(self, x: torch.Tensor):
-79        x = self.layer1(x)
-80        x = self.activation(x)
-81        x = self.dropout(x)
-82        return self.layer2(x)
-
-
-
-
-

Transformer Layer

@@ -243,13 +186,13 @@ We found a detailed discussion about this in paper On Layer Normalization in the Transformer Architecture.

-
85class TransformerLayer(Module):
+
59class TransformerLayer(Module):
-
+
  • d_model is the token embedding size
  • @@ -260,12 +203,69 @@ We found a detailed discussion about this in paper
-
103    def __init__(self, *,
-104                 d_model: int,
-105                 self_attn: MultiHeadAttention,
-106                 src_attn: MultiHeadAttention = None,
-107                 feed_forward: FeedForward,
-108                 dropout_prob: float):
+
77    def __init__(self, *,
+78                 d_model: int,
+79                 self_attn: MultiHeadAttention,
+80                 src_attn: MultiHeadAttention = None,
+81                 feed_forward: FeedForward,
+82                 dropout_prob: float):
+
+
+
+
+ + +
+
+
90        super().__init__()
+91        self.size = d_model
+92        self.self_attn = self_attn
+93        self.src_attn = src_attn
+94        self.feed_forward = feed_forward
+95        self.dropout = nn.Dropout(dropout_prob)
+96        self.norm_self_attn = nn.LayerNorm([d_model])
+97        if self.src_attn is not None:
+98            self.norm_src_attn = nn.LayerNorm([d_model])
+99        self.norm_ff = nn.LayerNorm([d_model])
+
+
+
+
+ +

Whether to save input to the feed forward layer

+
+
+
101        self.is_save_ff_input = False
+
+
+
+
+ + +
+
+
103    def __call__(self, *,
+104                 x: torch.Tensor,
+105                 mask: torch.Tensor,
+106                 src: torch.Tensor = None,
+107                 src_mask: torch.Tensor = None):
+
+
+
+
+ +

Normalize the vectors before doing self attention

+
+
+
109        z = self.norm_self_attn(x)
@@ -273,19 +273,10 @@ We found a detailed discussion about this in paper - +

Run through self attention, i.e. keys and values are from self

-
116        super().__init__()
-117        self.size = d_model
-118        self.self_attn = self_attn
-119        self.src_attn = src_attn
-120        self.feed_forward = feed_forward
-121        self.dropout = nn.Dropout(dropout_prob)
-122        self.norm_self_attn = nn.LayerNorm([d_model])
-123        if self.src_attn is not None:
-124            self.norm_src_attn = nn.LayerNorm([d_model])
-125        self.norm_ff = nn.LayerNorm([d_model])
+
111        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
@@ -293,10 +284,10 @@ We found a detailed discussion about this in paper -

Whether to save input to the feed forward layer

+

Add the self attention results

-
127        self.is_save_ff_input = False
+
113        x = x + self.dropout(self_attn)
@@ -304,14 +295,12 @@ We found a detailed discussion about this in paper - +

If a source is provided, get results from attention to source. +This is when you have a decoder layer that pays attention to +encoder outputs

-
129    def __call__(self, *,
-130                 x: torch.Tensor,
-131                 mask: torch.Tensor,
-132                 src: torch.Tensor = None,
-133                 src_mask: torch.Tensor = None):
+
118        if src is not None:
@@ -319,10 +308,10 @@ We found a detailed discussion about this in paper -

Normalize the vectors before doing self attention

+

Normalize vectors

-
135        z = self.norm_self_attn(x)
+
120            z = self.norm_src_attn(x)
@@ -330,10 +319,10 @@ We found a detailed discussion about this in paper -

Run through self attention, i.e. keys and values are from self

+

Attention to source. i.e. keys and values are from source

-
137        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
+
122            attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
@@ -341,10 +330,10 @@ We found a detailed discussion about this in paper -

Add the self attention results

+

Add the source attention results

-
139        x = x + self.dropout(self_attn)
+
124            x = x + self.dropout(attn_src)
@@ -352,12 +341,10 @@ We found a detailed discussion about this in paper -

If a source is provided, get results from attention to source. -This is when you have a decoder layer that pays attention to -encoder outputs

+

Normalize for feed-forward

-
144        if src is not None:
+
127        z = self.norm_ff(x)
@@ -365,10 +352,11 @@ encoder outputs

-

Normalize vectors

+

Save the input to the feed forward layer if specified

-
146            z = self.norm_src_attn(x)
+
129        if self.is_save_ff_input:
+130            self.ff_input = z.clone()
@@ -376,10 +364,10 @@ encoder outputs

-

Attention to source. i.e. keys and values are from source

+

Pass through the feed-forward network

-
148            attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
+
132        ff = self.feed_forward(z)
@@ -387,21 +375,25 @@ encoder outputs

-

Add the source attention results

+

Add the feed-forward results back

-
150            x = x + self.dropout(attn_src)
+
134        x = x + self.dropout(ff)
+135
+136        return x
-
+
-

Normalize for feed-forward

+

+

Transformer Encoder

+

-
153        z = self.norm_ff(x)
+
139class Encoder(Module):
@@ -409,11 +401,11 @@ encoder outputs

-

Save the input to the feed forward layer if specified

+
-
155        if self.is_save_ff_input:
-156            self.ff_input = z.clone()
+
146    def __init__(self, layer: TransformerLayer, n_layers: int):
+147        super().__init__()
@@ -421,10 +413,10 @@ encoder outputs

-

Pass through the feed-forward network

+

Make copies of the transformer layer

-
158        ff = self.feed_forward(z)
+
149        self.layers = clone_module_list(layer, n_layers)
@@ -432,25 +424,21 @@ encoder outputs

-

Add the feed-forward results back

+

Final normalization layer

-
160        x = x + self.dropout(ff)
-161
-162        return x
+
151        self.norm = nn.LayerNorm([layer.size])
-
+
-
165class Encoder(Module):
+
153    def __call__(self, x: torch.Tensor, mask: torch.Tensor):
@@ -458,11 +446,11 @@ encoder outputs

- +

Run through each transformer layer

-
172    def __init__(self, layer: TransformerLayer, n_layers: int):
-173        super().__init__()
+
155        for layer in self.layers:
+156            x = layer(x=x, mask=mask)
@@ -470,21 +458,23 @@ encoder outputs

-

Make copies of the transformer layer

+

Finally, normalize the vectors

-
175        self.layers = clone_module_list(layer, n_layers)
+
158        return self.norm(x)
-
+
-

Final normalization layer

+

+

Transformer Decoder

+

-
177        self.norm = nn.LayerNorm([layer.size])
+
161class Decoder(Module):
@@ -495,7 +485,8 @@ encoder outputs

-
179    def __call__(self, x: torch.Tensor, mask: torch.Tensor):
+
168    def __init__(self, layer: TransformerLayer, n_layers: int):
+169        super().__init__()
@@ -503,11 +494,10 @@ encoder outputs

-

Run through each transformer layer

+

Make copies of the transformer layer

-
181        for layer in self.layers:
-182            x = layer(x=x, mask=mask)
+
171        self.layers = clone_module_list(layer, n_layers)
@@ -515,23 +505,21 @@ encoder outputs

-

Finally, normalize the vectors

+

Final normalization layer

-
184        return self.norm(x)
+
173        self.norm = nn.LayerNorm([layer.size])
-
+
-
187class Decoder(Module):
+
175    def __call__(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
@@ -539,11 +527,11 @@ encoder outputs

- +

Run through each transformer layer

-
194    def __init__(self, layer: TransformerLayer, n_layers: int):
-195        super().__init__()
+
177        for layer in self.layers:
+178            x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
@@ -551,21 +539,25 @@ encoder outputs

-

Make copies of the transformer layer

+

Finally, normalize the vectors

-
197        self.layers = clone_module_list(layer, n_layers)
+
180        return self.norm(x)
-
+
-

Final normalization layer

+

+

Generator

+

+

This predicts the tokens and gives the lof softmax of those. +You don’t need this if you are using nn.CrossEntropyLoss.

-
199        self.norm = nn.LayerNorm([layer.size])
+
183class Generator(Module):
@@ -576,7 +568,9 @@ encoder outputs

-
201    def __call__(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+
193    def __init__(self, n_vocab: int, d_model: int):
+194        super().__init__()
+195        self.projection = nn.Linear(d_model, n_vocab)
@@ -584,37 +578,41 @@ encoder outputs

-

Run through each transformer layer

+
-
203        for layer in self.layers:
-204            x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
+
197    def __call__(self, x):
+198        return self.projection(x)
-
+
-

Finally, normalize the vectors

+

+

Combined Encoder-Decoder

+

-
206        return self.norm(x)
+
201class EncoderDecoder(Module):
-
+
-

-

Generator

-

-

This predicts the tokens and gives the lof softmax of those. -You don’t need this if you are using nn.CrossEntropyLoss.

+
-
209class Generator(Module):
+
208    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
+209        super().__init__()
+210        self.encoder = encoder
+211        self.decoder = decoder
+212        self.src_embed = src_embed
+213        self.tgt_embed = tgt_embed
+214        self.generator = generator
@@ -622,12 +620,13 @@ You don’t need this if you are using nn.CrossEntropyLoss.

- +

This was important from their code. +Initialize parameters with Glorot / fan_avg.

-
219    def __init__(self, n_vocab: int, d_model: int):
-220        super().__init__()
-221        self.projection = nn.Linear(d_model, n_vocab)
+
218        for p in self.parameters():
+219            if p.dim() > 1:
+220                nn.init.xavier_uniform_(p)
@@ -638,21 +637,18 @@ You don’t need this if you are using nn.CrossEntropyLoss.

-
223    def __call__(self, x):
-224        return self.projection(x)
+
222    def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
-
+
-

-

Combined Encoder-Decoder

-

+

Runs the source through encoder

-
227class EncoderDecoder(Module):
+
224        enc = self.encode(src, src_mask)
@@ -660,16 +656,10 @@ You don’t need this if you are using nn.CrossEntropyLoss.

- +

Run encodings and targets through decoder

-
234    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
-235        super().__init__()
-236        self.encoder = encoder
-237        self.decoder = decoder
-238        self.src_embed = src_embed
-239        self.tgt_embed = tgt_embed
-240        self.generator = generator
+
226        return self.decode(enc, src_mask, tgt, tgt_mask)
@@ -677,13 +667,11 @@ You don’t need this if you are using nn.CrossEntropyLoss.

-

This was important from their code. -Initialize parameters with Glorot / fan_avg.

+
-
244        for p in self.parameters():
-245            if p.dim() > 1:
-246                nn.init.xavier_uniform_(p)
+
228    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
+229        return self.encoder(self.src_embed(src), src_mask)
@@ -694,53 +682,8 @@ Initialize parameters with Glorot / fan_avg.

-
248    def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
-
-
-
-
- -

Runs the source through encoder

-
-
-
250        enc = self.encode(src, src_mask)
-
-
-
-
- -

Run encodings and targets through decoder

-
-
-
252        return self.decode(enc, src_mask, tgt, tgt_mask)
-
-
-
-
- - -
-
-
254    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
-255        return self.encoder(self.src_embed(src), src_mask)
-
-
-
-
- - -
-
-
257    def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
-258        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
+
231    def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
+232        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
diff --git a/docs/transformers/positional_encoding.html b/docs/transformers/positional_encoding.html index 17b64d41..75feec3f 100644 --- a/docs/transformers/positional_encoding.html +++ b/docs/transformers/positional_encoding.html @@ -43,6 +43,7 @@

+ home transformers

diff --git a/docs/transformers/relative_mha.html b/docs/transformers/relative_mha.html index e1b7bb2f..6a89c9d4 100644 --- a/docs/transformers/relative_mha.html +++ b/docs/transformers/relative_mha.html @@ -43,6 +43,7 @@

+ home transformers

diff --git a/docs/transformers/switch/experiment.html b/docs/transformers/switch/experiment.html index cc742059..834382df 100644 --- a/docs/transformers/switch/experiment.html +++ b/docs/transformers/switch/experiment.html @@ -43,6 +43,7 @@

+ home transformers switch

@@ -650,17 +651,17 @@ set to something small like $\alpha = 0.01$.

171    from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
 172    from labml_nn.transformers import MultiHeadAttention
-173
-174    return SwitchTransformer(
-175        SwitchTransformerLayer(d_model=c.d_model,
-176                               attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
-177                               feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
-178                                                              drop_tokens=c.drop_tokens,
-179                                                              is_scale_prob=c.is_scale_prob,
-180                                                              n_experts=c.n_experts,
-181                                                              d_model=c.d_model,
-182                                                              d_ff=c.d_ff,
-183                                                              dropout=c.dropout),
+173    from labml_nn.transformers.feed_forward import FeedForward
+174
+175    return SwitchTransformer(
+176        SwitchTransformerLayer(d_model=c.d_model,
+177                               attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
+178                               feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
+179                                                              drop_tokens=c.drop_tokens,
+180                                                              is_scale_prob=c.is_scale_prob,
+181                                                              n_experts=c.n_experts,
+182                                                              expert=FeedForward(c.d_model, c.d_ff, c.dropout),
+183                                                              d_model=c.d_model),
 184                               dropout_prob=c.dropout),
 185        c.n_layers)
diff --git a/docs/transformers/switch/index.html b/docs/transformers/switch/index.html index 361c367e..fd2eb4cc 100644 --- a/docs/transformers/switch/index.html +++ b/docs/transformers/switch/index.html @@ -43,6 +43,7 @@

+ home transformers switch

@@ -75,11 +76,11 @@ Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. Our implementation only has a few million parameters and doesn’t do model parallel distributed training. It does single GPU training but we implement the concept of switching as described in the paper.

-

The Switch Transformer is uses different parameters for each tokens by switching among parameters, +

The Switch Transformer uses different parameters for each token by switching among parameters, based on the token. So only a fraction of parameters is chosen for each token, so you -can have more parameters but a less computational cost.

+can have more parameters but less computational cost.

The switching happens at the Position-wise Feedforward network (FFN) of of each transformer block. -Position-wise feedforward network is a two sequential fully connected layers. +Position-wise feedforward network is a two sequentially fully connected layers. In switch transformer we have multiple FFNs (multiple experts) and we chose which one to use based on a router. The outputs a set of probabilities for picking a FFN, @@ -100,7 +101,7 @@ discusses dropping tokens when routing is not balanced.

41 42from labml_helpers.module import Module 43from labml_nn.transformers.mha import MultiHeadAttention -44from labml_nn.transformers.models import FeedForward +44from labml_nn.transformers.feed_forward import FeedForward 45from labml_nn.utils import clone_module_list
@@ -125,6 +126,7 @@ discusses dropping tokens when routing is not balanced.

  • drop_tokens specifies whether to drop tokens if more tokens are routed to an expert than the capacity
  • is_scale_prob specifies whether to multiply the input to the FFN by the routing probability
  • n_experts is the number of experts
  • +
  • expert is the expert layer, a FFN module
  • d_model is the number of features in a token embedding
  • d_ff is the number of features in the hidden layer of the FFN
  • dropout is dropout probability in the FFN
  • @@ -136,9 +138,8 @@ discusses dropping tokens when routing is not balanced.

    55 drop_tokens: bool, 56 is_scale_prob: bool, 57 n_experts: int, -58 d_model: int, -59 d_ff: int, -60 dropout: float = 0.1):
    +58 expert: FeedForward, +59 d_model: int):
    @@ -162,10 +163,10 @@ discusses dropping tokens when routing is not balanced.

    -

    FFN modules for each expert

    +

    make copies of the FFNs

    -
    78        self.experts = nn.ModuleList([FeedForward(d_model, d_ff, dropout) for _ in range(n_experts)])
    +
    78        self.experts = clone_module_list(expert, n_experts)
    @@ -471,7 +472,7 @@ These are used for the load balancing loss and logging

    #

    Switch Transformer Block

    -

    This is same as [normal transformer block](FFN modules +

    This is same as normal transformer block with handling extra outputs of switch feedforward module.

    diff --git a/docs/transformers/utils.html b/docs/transformers/utils.html index a5885e41..09d25726 100644 --- a/docs/transformers/utils.html +++ b/docs/transformers/utils.html @@ -43,6 +43,7 @@

    + home transformers

    diff --git a/docs/utils.html b/docs/utils.html index c22231c0..53ca60dd 100644 --- a/docs/utils.html +++ b/docs/utils.html @@ -43,6 +43,7 @@

    + home

    diff --git a/labml_nn/experiments/nlp_autoregression.py b/labml_nn/experiments/nlp_autoregression.py index 383f874a..2b481ab5 100644 --- a/labml_nn/experiments/nlp_autoregression.py +++ b/labml_nn/experiments/nlp_autoregression.py @@ -12,6 +12,8 @@ from typing import Callable import torch import torch.nn as nn +from torch.utils.data import DataLoader + from labml import lab, monit, logger, tracker from labml.configs import option from labml.logger import Text @@ -20,8 +22,6 @@ from labml_helpers.device import DeviceConfigs from labml_helpers.metrics.accuracy import Accuracy from labml_helpers.module import Module from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex -from torch.utils.data import DataLoader - from labml_nn.optimizers.configs import OptimizerConfigs @@ -173,6 +173,7 @@ class NLPAutoRegressionConfigs(TrainValidConfigs): # Print the sampled output logger.log(log) + @option(NLPAutoRegressionConfigs.optimizer) def _optimizer(c: NLPAutoRegressionConfigs): """ diff --git a/labml_nn/transformers/feed_forward.py b/labml_nn/transformers/feed_forward.py index 82aa613a..a7c92afb 100644 --- a/labml_nn/transformers/feed_forward.py +++ b/labml_nn/transformers/feed_forward.py @@ -1,3 +1,28 @@ +""" +--- +title: Position-wise Feed-Forward Network (FFN) +summary: Documented reusable implementation of the position wise feedforward network. +--- + +# Position-wise Feed-Forward Network (FFN) + +FFN consists of two fully connected layers. +Number of dimensions in the hidden layer $d_{ff}$, is generally set to around +four times that of the token embedding $d_{model}$. +So it is sometime also called the expand-and-contract network. + +There is an activation at the hidden layer, which is +usually set to ReLU (Rectified Linear Unit) activation, $$\max(0, x)$$ + +That is, the FFN function is, +$$FFN(x, W_1, W_2, b_1, b_2) = \max(0, x W_1 + b_1) W_2 + b_2$$ +where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters. + +Sometimes the +GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. +$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$ +""" + import torch from torch import nn as nn @@ -6,9 +31,7 @@ from labml_helpers.module import Module class FeedForward(Module): """ - - ## Position-wise feed-forward network (FFN) with hidden layer - + ## Position-wise feed-forward network (FFN) module """ def __init__(self, d_model: int, d_ff: int,