diff --git a/docs/sitemap.xml b/docs/sitemap.xml index 98c691af..e4cf5fa1 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -603,14 +603,14 @@ https://nn.labml.ai/transformers/switch/index.html - 2021-08-17T16:30:00+00:00 + 2021-09-17T16:30:00+00:00 1.00 https://nn.labml.ai/transformers/switch/experiment.html - 2021-09-06T16:30:00+00:00 + 2021-09-17T16:30:00+00:00 1.00 diff --git a/docs/transformers/switch/experiment.html b/docs/transformers/switch/experiment.html index c9d45783..5acfe778 100644 --- a/docs/transformers/switch/experiment.html +++ b/docs/transformers/switch/experiment.html @@ -69,16 +69,18 @@

Switch Transformer Experiment

This is an annotated PyTorch experiment to train a switch transformer.

+

Open In Colab +View Run

-
12import torch
-13import torch.nn as nn
-14
-15from labml import experiment, tracker
-16from labml.configs import option
-17from labml_helpers.module import Module
-18from labml_helpers.train_valid import BatchIndex
-19from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+
15import torch
+16import torch.nn as nn
+17
+18from labml import experiment, tracker
+19from labml.configs import option
+20from labml_helpers.module import Module
+21from labml_helpers.train_valid import BatchIndex
+22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
@@ -89,7 +91,7 @@

Auto regressive model

-
22class AutoregressiveModel(Module):
+
25class AutoregressiveModel(Module):
@@ -100,8 +102,8 @@
-
27    def __init__(self, n_vocab: int, d_model: int, transformer: Module):
-28        super().__init__()
+
30    def __init__(self, n_vocab: int, d_model: int, transformer: Module):
+31        super().__init__()
@@ -112,7 +114,7 @@

Token embedding module

-
30        self.src_embed = nn.Embedding(n_vocab, d_model)
+
33        self.src_embed = nn.Embedding(n_vocab, d_model)
@@ -123,7 +125,7 @@

Transformer

-
32        self.transformer = transformer
+
35        self.transformer = transformer
@@ -134,8 +136,8 @@

Final layer

-
34        self.generator = nn.Linear(d_model, n_vocab)
-35        self.mask = None
+
37        self.generator = nn.Linear(d_model, n_vocab)
+38        self.mask = None
@@ -146,7 +148,7 @@
-
37    def forward(self, x: torch.Tensor):
+
40    def forward(self, x: torch.Tensor):
@@ -157,9 +159,9 @@

Initialize the subsequent mask

-
39        if self.mask is None or self.mask.size(0) != len(x):
-40            from labml_nn.transformers.utils import subsequent_mask
-41            self.mask = subsequent_mask(len(x)).to(x.device)
+
42        if self.mask is None or self.mask.size(0) != len(x):
+43            from labml_nn.transformers.utils import subsequent_mask
+44            self.mask = subsequent_mask(len(x)).to(x.device)
@@ -170,7 +172,7 @@

Token embeddings

-
43        x = self.src_embed(x)
+
46        x = self.src_embed(x)
@@ -181,7 +183,7 @@

Run it through the transformer

-
45        res, counts, route_prob, n_dropped = self.transformer(x, self.mask)
+
48        res, counts, route_prob, n_dropped = self.transformer(x, self.mask)
@@ -192,7 +194,7 @@

Generate logits of the next token

-
47        res = self.generator(res)
+
50        res = self.generator(res)
@@ -203,7 +205,7 @@
-
49        return res, counts, route_prob, n_dropped
+
52        return res, counts, route_prob, n_dropped
@@ -216,7 +218,7 @@

The default configs can and will be over-ridden when we start the experiment

-
52class Configs(NLPAutoRegressionConfigs):
+
55class Configs(NLPAutoRegressionConfigs):
@@ -227,8 +229,8 @@
-
61    model: AutoregressiveModel
-62    transformer: Module
+
64    model: AutoregressiveModel
+65    transformer: Module
@@ -239,7 +241,7 @@

Token embedding size

-
65    d_model: int = 128
+
68    d_model: int = 128
@@ -250,7 +252,7 @@

Number of attention heads

-
67    heads: int = 4
+
70    heads: int = 4
@@ -261,7 +263,7 @@

Dropout probability

-
69    dropout: float = 0.0
+
72    dropout: float = 0.0
@@ -272,7 +274,7 @@

Number of features in FFN hidden layer

-
71    d_ff: int = 256
+
74    d_ff: int = 256
@@ -283,7 +285,7 @@

Number of transformer layers

-
73    n_layers: int = 6
+
76    n_layers: int = 6
@@ -294,7 +296,7 @@

Number of experts

-
75    n_experts: int = 4
+
78    n_experts: int = 4
@@ -305,7 +307,7 @@

Load balancing coefficient

-
77    load_balancing_loss_ceof = 0.01
+
80    load_balancing_loss_ceof = 0.01
@@ -316,7 +318,7 @@

Whether to scale the chosen expert outputs by the routing probability

-
79    is_scale_prob: bool = True
+
82    is_scale_prob: bool = True
@@ -327,7 +329,7 @@

Whether to drop tokens

-
81    drop_tokens: bool = False
+
84    drop_tokens: bool = False
@@ -338,7 +340,7 @@

Capacity factor to determine capacity of each model

-
83    capacity_factor: float = 1.0
+
86    capacity_factor: float = 1.0
@@ -349,8 +351,8 @@
-
85    def init(self):
-86        super().init()
+
88    def init(self):
+89        super().init()
@@ -361,9 +363,9 @@

Initialize tracking indicators

-
88        tracker.set_scalar("lb_loss.*", False)
-89        tracker.set_scalar("route.*", False)
-90        tracker.set_scalar("dropped.*", False)
+
91        tracker.set_scalar("lb_loss.*", False)
+92        tracker.set_scalar("route.*", False)
+93        tracker.set_scalar("dropped.*", False)
@@ -374,7 +376,7 @@

Training or validation step

-
92    def step(self, batch: any, batch_idx: BatchIndex):
+
95    def step(self, batch: any, batch_idx: BatchIndex):
@@ -385,7 +387,7 @@

Move data to the device

-
98        data, target = batch[0].to(self.device), batch[1].to(self.device)
+
101        data, target = batch[0].to(self.device), batch[1].to(self.device)
@@ -396,8 +398,8 @@

Update global step (number of tokens processed) when in training mode

-
101        if self.mode.is_train:
-102            tracker.add_global_step(data.shape[0] * data.shape[1])
+
104        if self.mode.is_train:
+105            tracker.add_global_step(data.shape[0] * data.shape[1])
@@ -408,7 +410,7 @@

Whether to capture model outputs

-
105        with self.mode.update(is_log_activations=batch_idx.is_last):
+
108        with self.mode.update(is_log_activations=batch_idx.is_last):
@@ -419,7 +421,7 @@

Get model outputs.

-
107            output, counts, route_prob, n_dropped = self.model(data)
+
110            output, counts, route_prob, n_dropped = self.model(data)
@@ -430,7 +432,7 @@

Calculate and cross entropy loss

-
110        cross_entropy_loss = self.loss_func(output, target)
+
113        cross_entropy_loss = self.loss_func(output, target)
@@ -441,7 +443,7 @@

Total number of tokens processed, $T$, in the current batch $\mathscr{B}$

-
112        total = counts.sum(dim=-1, keepdims=True)
+
115        total = counts.sum(dim=-1, keepdims=True)
@@ -454,7 +456,7 @@ $f_i$ is the count of tokens where the argmax of $p(x)$ is equal to $i$.

-
116        route_frac = counts / total
+
119        route_frac = counts / total
@@ -467,7 +469,7 @@ $f_i$ is the count of tokens where the argmax of $p(x)$ is equal to $i$.

-
119        route_prob = route_prob / total
+
122        route_prob = route_prob / total
@@ -481,7 +483,7 @@ $\mathscr{L}$ is the loss for a single layer and here we are taking the sum of losses across all layers.

-
124        load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
+
127        load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
@@ -492,12 +494,12 @@ taking the sum of losses across all layers.

Track stats

-
127        tracker.add('dropped.', total.new_tensor(n_dropped) / total)
-128        tracker.add('route.min.', route_frac.min())
-129        tracker.add('route.max.', route_frac.max())
-130        tracker.add('route.std.', route_frac.std())
-131        tracker.add("loss.", cross_entropy_loss)
-132        tracker.add("lb_loss.", load_balancing_loss)
+
130        tracker.add('dropped.', total.new_tensor(n_dropped) / total)
+131        tracker.add('route.min.', route_frac.min())
+132        tracker.add('route.max.', route_frac.max())
+133        tracker.add('route.std.', route_frac.std())
+134        tracker.add("loss.", cross_entropy_loss)
+135        tracker.add("lb_loss.", load_balancing_loss)
@@ -510,7 +512,7 @@ The load balancing loss is multiplied by a coefficient $\alpha$ which is set to something small like $\alpha = 0.01$.

-
137        loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_loss
+
140        loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_loss
@@ -521,8 +523,8 @@ set to something small like $\alpha = 0.01$.

Calculate and log accuracy

-
140        self.accuracy(output, target)
-141        self.accuracy.track()
+
143        self.accuracy(output, target)
+144        self.accuracy.track()
@@ -533,7 +535,7 @@ set to something small like $\alpha = 0.01$.

Train the model

-
144        if self.mode.is_train:
+
147        if self.mode.is_train:
@@ -544,7 +546,7 @@ set to something small like $\alpha = 0.01$.

Calculate gradients

-
146            loss.backward()
+
149            loss.backward()
@@ -555,7 +557,7 @@ set to something small like $\alpha = 0.01$.

Clip gradients

-
148            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
+
151            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
@@ -566,7 +568,7 @@ set to something small like $\alpha = 0.01$.

Take optimizer step

-
150            self.optimizer.step()
+
153            self.optimizer.step()
@@ -577,8 +579,8 @@ set to something small like $\alpha = 0.01$.

Log the model parameters and gradients on last batch of every epoch

-
152            if batch_idx.is_last:
-153                tracker.add('model', self.model)
+
155            if batch_idx.is_last:
+156                tracker.add('model', self.model)
@@ -589,7 +591,7 @@ set to something small like $\alpha = 0.01$.

Clear the gradients

-
155            self.optimizer.zero_grad()
+
158            self.optimizer.zero_grad()
@@ -600,7 +602,7 @@ set to something small like $\alpha = 0.01$.

Save the tracked metrics

-
158        tracker.save()
+
161        tracker.save()
@@ -611,8 +613,8 @@ set to something small like $\alpha = 0.01$.

Initialize the auto-regressive model

-
161@option(Configs.model)
-162def autoregressive_model(c: Configs):
+
164@option(Configs.model)
+165def autoregressive_model(c: Configs):
@@ -623,8 +625,8 @@ set to something small like $\alpha = 0.01$.

-
166    m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer)
-167    return m.to(c.device)
+
169    m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer)
+170    return m.to(c.device)
@@ -635,8 +637,8 @@ set to something small like $\alpha = 0.01$.

Initialize the switch transformer

-
170@option(Configs.transformer)
-171def switch_transformer(c: Configs):
+
173@option(Configs.transformer)
+174def switch_transformer(c: Configs):
@@ -647,21 +649,21 @@ set to something small like $\alpha = 0.01$.

-
175    from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
-176    from labml_nn.transformers import MultiHeadAttention
-177    from labml_nn.transformers.feed_forward import FeedForward
-178
-179    return SwitchTransformer(
-180        SwitchTransformerLayer(d_model=c.d_model,
-181                               attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
-182                               feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
-183                                                              drop_tokens=c.drop_tokens,
-184                                                              is_scale_prob=c.is_scale_prob,
-185                                                              n_experts=c.n_experts,
-186                                                              expert=FeedForward(c.d_model, c.d_ff, c.dropout),
-187                                                              d_model=c.d_model),
-188                               dropout_prob=c.dropout),
-189        c.n_layers)
+
178    from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
+179    from labml_nn.transformers import MultiHeadAttention
+180    from labml_nn.transformers.feed_forward import FeedForward
+181
+182    return SwitchTransformer(
+183        SwitchTransformerLayer(d_model=c.d_model,
+184                               attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
+185                               feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
+186                                                              drop_tokens=c.drop_tokens,
+187                                                              is_scale_prob=c.is_scale_prob,
+188                                                              n_experts=c.n_experts,
+189                                                              expert=FeedForward(c.d_model, c.d_ff, c.dropout),
+190                                                              d_model=c.d_model),
+191                               dropout_prob=c.dropout),
+192        c.n_layers)
@@ -672,7 +674,7 @@ set to something small like $\alpha = 0.01$.

Run the experiment

-
192def main():
+
195def main():
@@ -683,7 +685,7 @@ set to something small like $\alpha = 0.01$.

Create experiment

-
197    experiment.create(name="switch_transformer", comment='')
+
200    experiment.create(name="switch_transformer", comment='')
@@ -694,7 +696,7 @@ set to something small like $\alpha = 0.01$.

Create configs

-
199    conf = Configs()
+
202    conf = Configs()
@@ -705,7 +707,7 @@ set to something small like $\alpha = 0.01$.

Load configurations

-
201    experiment.configs(conf,
+
204    experiment.configs(conf,
@@ -716,28 +718,27 @@ set to something small like $\alpha = 0.01$.

A dictionary of configurations to override

-
203                       {'tokenizer': 'character',
-204                        'text': 'tiny_shakespeare',
-205                        'optimizer.learning_rate': 1.,
-206                        'optimizer.optimizer': 'Noam',
-207                        'prompt': 'It is',
-208                        'prompt_separator': '',
-209
-210                        'transformer': 'switch_transformer',
-211                        'is_scale_prob': False,
-212                        'n_experts': 4,
-213
-214                        'drop_tokens': True,
-215                        'capacity_factor': 1.2,
-216
-217                        'train_loader': 'shuffled_train_loader',
-218                        'valid_loader': 'shuffled_valid_loader',
-219
-220                        'seq_len': 64,
-221                        'epochs': 128,
-222                        'batch_size': 32,
-223                        'inner_iterations': 25,
-224                        })
+
206                       {'tokenizer': 'character',
+207                        'text': 'tiny_shakespeare',
+208                        'optimizer.learning_rate': 1.,
+209                        'optimizer.optimizer': 'Noam',
+210                        'prompt': 'It is',
+211                        'prompt_separator': '',
+212
+213                        'transformer': 'switch_transformer',
+214                        'n_experts': 4,
+215
+216                        'drop_tokens': True,
+217                        'capacity_factor': 1.2,
+218
+219                        'train_loader': 'shuffled_train_loader',
+220                        'valid_loader': 'shuffled_valid_loader',
+221
+222                        'seq_len': 64,
+223                        'epochs': 128,
+224                        'batch_size': 32,
+225                        'inner_iterations': 25,
+226                        })
@@ -748,7 +749,7 @@ set to something small like $\alpha = 0.01$.

Set models for saving and loading

-
227    experiment.add_pytorch_models({'model': conf.model})
+
229    experiment.add_pytorch_models({'model': conf.model})
@@ -759,7 +760,7 @@ set to something small like $\alpha = 0.01$.

Start the experiment

-
230    with experiment.start():
+
232    with experiment.start():
@@ -770,7 +771,7 @@ set to something small like $\alpha = 0.01$.

TrainValidConfigs.run

-
232        conf.run()
+
234        conf.run()
@@ -781,8 +782,8 @@ set to something small like $\alpha = 0.01$.

-
236if __name__ == '__main__':
-237    main()
+
238if __name__ == '__main__':
+239    main()
40import torch
 41from torch import nn
 42
 43from labml_helpers.module import Module
-44from labml_nn.transformers.mha import MultiHeadAttention
-45from labml_nn.transformers.feed_forward import FeedForward
+44from labml_nn.transformers.feed_forward import FeedForward
+45from labml_nn.transformers.mha import MultiHeadAttention
 46from labml_nn.utils import clone_module_list
@@ -244,11 +244,10 @@ We route to the expert with highest probability

-

Scale the inputs to the experts by the routing probabilities

+

Get indexes of tokens going to each expert

-
105        if self.is_scale_prob:
-106            factor = route_prob_max
+
105        indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]
@@ -256,11 +255,10 @@ We route to the expert with highest probability

-

Don’t scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow

+

Initialize an empty tensor to store outputs

-
108        else:
-109            factor = route_prob_max / route_prob_max.detach()
+
108        final_output = x.new_zeros(x.shape)
@@ -268,39 +266,6 @@ We route to the expert with highest probability

-

Multiply by the scaling factor

-
-
-
111        x = x * factor.view(-1, 1)
-
- -
-
- -

Get indexes of tokens going to each expert

-
-
-
114        indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]
-
-
-
-
- -

Initialize an empty tensor to store outputs

-
-
-
117        final_output = x.new_zeros(x.shape)
-
-
-
-
-

Capacity of each expert.