From 5d174f3a7cb845d1d56a70700fab0bfd0167250a Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Wed, 20 Jan 2021 10:14:21 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9A=20switch=20transformer=20notes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- labml_nn/transformers/models.py | 8 ++- labml_nn/transformers/switch/__init__.py | 66 +++++++++++++++++++--- labml_nn/transformers/switch/experiment.py | 8 +-- 3 files changed, 69 insertions(+), 13 deletions(-) diff --git a/labml_nn/transformers/models.py b/labml_nn/transformers/models.py index 940beff9..2a76c7fe 100644 --- a/labml_nn/transformers/models.py +++ b/labml_nn/transformers/models.py @@ -58,11 +58,17 @@ class EmbeddingsWithLearnedPositionalEncoding(Module): class FeedForward(Module): """ - ## Position-wise feed-forward network with hidden layer + ## Position-wise feed-forward network (FFN) with hidden layer """ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, activation=nn.ReLU()): + """ + * `d_model` is the number of features in a token embedding + * `d_ff` is the number of features in the hidden layer of the FFN + * `dropout` is dropout probability for the hidden layer + * `activation` is the activation function to apply on the hidden layer outputs + """ super().__init__() self.layer1 = nn.Linear(d_model, d_ff) self.layer2 = nn.Linear(d_ff, d_model) diff --git a/labml_nn/transformers/switch/__init__.py b/labml_nn/transformers/switch/__init__.py index a3feada2..c1d37cd6 100644 --- a/labml_nn/transformers/switch/__init__.py +++ b/labml_nn/transformers/switch/__init__.py @@ -1,3 +1,41 @@ +""" +--- +title: Switch Transformer +summary: > + This is an annotated implementation/tutorial a miniature version of Switch Transformer in PyTorch. +--- + +# Switch Transformer + +This is a miniature implementation of the paper +[Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961). +Our implementation only has a few million parameters and doesn't do model parallel distributed training. +It does single GPU training but we implement the concept of switching as described in the paper. + +The Switch Transformer is uses different parameters for each tokens by switching among parameters, +based on the token. So only a fraction of parameters is chosen for each token, so you +can have more parameters but a less computational cost. + +The switching happens at the Position-wise Feedforward network (FFN) of of each transformer block. +Position-wise feedforward network is a two sequential fully connected layers. +In switch transformer we have multiple FFNs (multiple experts) and +we chose which one to use based on a router. +The outputs a set of probabilities for picking a FFN, +and we pick the one with highest probability and only evaluates that. +So essentially the computational cost is same as having a single FFN. +In our implementation this doesn't parallelize well when you have many or large FFNs since it's all +happening on a single GPU. +In a distributed setup you would have each FFN (each very large) on a different device. + +The paper introduces another loss term to balance load among the experts (FFNs) and +discusses dropping tokens when routing is not balanced. + +Here's a notebook for training a switch transformer on Tiny Shakespeare dataset. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb) +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002) +""" + import torch from torch import nn @@ -9,27 +47,39 @@ from labml_nn.utils import clone_module_list class SwitchFeedForward(Module): """ - ## Position-wise feed-forward network with hidden layer + ## Routing among multiple FFNs """ def __init__(self, *, capacity_factor: float, drop_tokens: bool, is_scale_prob: bool, - n_switches: int, + n_experts: int, d_model: int, d_ff: int, dropout: float = 0.1): - + """ + * `capacity_factor` is the capacity of each expert as a factor relative to ideally balanced load + * `drop_tokens` specifies whether to drop tokens if more tokens are routed to an expert than the capacity + * `is_scale_prob` specifies whether to multiply the input to the FFN by the routing probability + * `n_experts` is the number of experts + * `d_model` is the number of features in a token embedding + * `d_ff` is the number of features in the hidden layer of the FFN + * `dropout` is dropout probability in the FFN + """ super().__init__() + self.capacity_factor = capacity_factor self.is_scale_prob = is_scale_prob - self.units = nn.ModuleList([FeedForward(d_model, d_ff, dropout) for _ in range(n_switches)]) - self.switch = nn.Linear(d_model, n_switches) - self.softmax = nn.Softmax(dim=-1) - self.n_switches = n_switches + self.n_switches = n_experts self.drop_tokens = drop_tokens + # FFN modules for each expert + self.experts = nn.ModuleList([FeedForward(d_model, d_ff, dropout) for _ in range(n_experts)]) + # Routing layer and softmax + self.switch = nn.Linear(d_model, n_experts) + self.softmax = nn.Softmax(dim=-1) + def __call__(self, x: torch.Tensor): seq_len, bs, d_model = x.shape x = x.view(-1, d_model) @@ -63,7 +113,7 @@ class SwitchFeedForward(Module): dropped.append(indexes_list[i][capacity:]) indexes_list[i] = indexes_list[i][:capacity] - route_outputs = [self.units[i](x[indexes_list[i], :]) for i in range(self.n_switches)] + route_outputs = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_switches)] # Assign to final output for i in range(self.n_switches): diff --git a/labml_nn/transformers/switch/experiment.py b/labml_nn/transformers/switch/experiment.py index 8c1f21d4..905446a0 100644 --- a/labml_nn/transformers/switch/experiment.py +++ b/labml_nn/transformers/switch/experiment.py @@ -47,7 +47,7 @@ class Configs(NLPAutoRegressionConfigs): dropout: float = 0.0 d_ff: int = 256 n_layers: int = 6 - n_switches = 4 + n_experts: int = 4 load_balancing_loss_ceof = 0.01 is_scale_prob: bool = True drop_tokens: bool = False @@ -89,7 +89,7 @@ class Configs(NLPAutoRegressionConfigs): tracker.add('route.std.', route_frac.std()) # for i in range(self.n_switches): # tracker.add(f'route.{i}', route_frac[:, i].mean()) - load_balancing_loss = self.n_switches * (route_frac * route_prob).sum() + load_balancing_loss = self.n_experts * (route_frac * route_prob).sum() tracker.add("loss.", loss) tracker.add("lb_loss.", loss) loss = loss + self.load_balancing_loss_ceof * load_balancing_loss @@ -133,7 +133,7 @@ def switch_transformer(c: Configs): feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor, drop_tokens=c.drop_tokens, is_scale_prob=c.is_scale_prob, - n_switches=c.n_switches, + n_experts=c.n_experts, d_model=c.d_model, d_ff=c.d_ff, dropout=c.dropout), @@ -158,7 +158,7 @@ def main(): 'transformer': 'switch_transformer', 'is_scale_prob': False, - 'n_switches': 4, + 'n_experts': 4, 'drop_tokens': True, 'capacity_factor': 1.2,