📚 switch transformer notes

This commit is contained in:
Varuna Jayasiri
2021-01-20 10:14:21 +05:30
parent e3e321a5a9
commit 5d174f3a7c
3 changed files with 69 additions and 13 deletions

View File

@ -58,11 +58,17 @@ class EmbeddingsWithLearnedPositionalEncoding(Module):
class FeedForward(Module):
"""
<a id="FeedForward">
## Position-wise feed-forward network with hidden layer
## Position-wise feed-forward network (FFN) with hidden layer
</a>
"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, activation=nn.ReLU()):
"""
* `d_model` is the number of features in a token embedding
* `d_ff` is the number of features in the hidden layer of the FFN
* `dropout` is dropout probability for the hidden layer
* `activation` is the activation function to apply on the hidden layer outputs
"""
super().__init__()
self.layer1 = nn.Linear(d_model, d_ff)
self.layer2 = nn.Linear(d_ff, d_model)

View File

@ -1,3 +1,41 @@
"""
---
title: Switch Transformer
summary: >
This is an annotated implementation/tutorial a miniature version of Switch Transformer in PyTorch.
---
# Switch Transformer
This is a miniature implementation of the paper
[Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961).
Our implementation only has a few million parameters and doesn't do model parallel distributed training.
It does single GPU training but we implement the concept of switching as described in the paper.
The Switch Transformer is uses different parameters for each tokens by switching among parameters,
based on the token. So only a fraction of parameters is chosen for each token, so you
can have more parameters but a less computational cost.
The switching happens at the Position-wise Feedforward network (FFN) of of each transformer block.
Position-wise feedforward network is a two sequential fully connected layers.
In switch transformer we have multiple FFNs (multiple experts) and
we chose which one to use based on a router.
The outputs a set of probabilities for picking a FFN,
and we pick the one with highest probability and only evaluates that.
So essentially the computational cost is same as having a single FFN.
In our implementation this doesn't parallelize well when you have many or large FFNs since it's all
happening on a single GPU.
In a distributed setup you would have each FFN (each very large) on a different device.
The paper introduces another loss term to balance load among the experts (FFNs) and
discusses dropping tokens when routing is not balanced.
Here's a notebook for training a switch transformer on Tiny Shakespeare dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002)
"""
import torch
from torch import nn
@ -9,27 +47,39 @@ from labml_nn.utils import clone_module_list
class SwitchFeedForward(Module):
"""
## Position-wise feed-forward network with hidden layer
## Routing among multiple FFNs
"""
def __init__(self, *,
capacity_factor: float,
drop_tokens: bool,
is_scale_prob: bool,
n_switches: int,
n_experts: int,
d_model: int,
d_ff: int,
dropout: float = 0.1):
"""
* `capacity_factor` is the capacity of each expert as a factor relative to ideally balanced load
* `drop_tokens` specifies whether to drop tokens if more tokens are routed to an expert than the capacity
* `is_scale_prob` specifies whether to multiply the input to the FFN by the routing probability
* `n_experts` is the number of experts
* `d_model` is the number of features in a token embedding
* `d_ff` is the number of features in the hidden layer of the FFN
* `dropout` is dropout probability in the FFN
"""
super().__init__()
self.capacity_factor = capacity_factor
self.is_scale_prob = is_scale_prob
self.units = nn.ModuleList([FeedForward(d_model, d_ff, dropout) for _ in range(n_switches)])
self.switch = nn.Linear(d_model, n_switches)
self.softmax = nn.Softmax(dim=-1)
self.n_switches = n_switches
self.n_switches = n_experts
self.drop_tokens = drop_tokens
# FFN modules for each expert
self.experts = nn.ModuleList([FeedForward(d_model, d_ff, dropout) for _ in range(n_experts)])
# Routing layer and softmax
self.switch = nn.Linear(d_model, n_experts)
self.softmax = nn.Softmax(dim=-1)
def __call__(self, x: torch.Tensor):
seq_len, bs, d_model = x.shape
x = x.view(-1, d_model)
@ -63,7 +113,7 @@ class SwitchFeedForward(Module):
dropped.append(indexes_list[i][capacity:])
indexes_list[i] = indexes_list[i][:capacity]
route_outputs = [self.units[i](x[indexes_list[i], :]) for i in range(self.n_switches)]
route_outputs = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_switches)]
# Assign to final output
for i in range(self.n_switches):

View File

@ -47,7 +47,7 @@ class Configs(NLPAutoRegressionConfigs):
dropout: float = 0.0
d_ff: int = 256
n_layers: int = 6
n_switches = 4
n_experts: int = 4
load_balancing_loss_ceof = 0.01
is_scale_prob: bool = True
drop_tokens: bool = False
@ -89,7 +89,7 @@ class Configs(NLPAutoRegressionConfigs):
tracker.add('route.std.', route_frac.std())
# for i in range(self.n_switches):
# tracker.add(f'route.{i}', route_frac[:, i].mean())
load_balancing_loss = self.n_switches * (route_frac * route_prob).sum()
load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
tracker.add("loss.", loss)
tracker.add("lb_loss.", loss)
loss = loss + self.load_balancing_loss_ceof * load_balancing_loss
@ -133,7 +133,7 @@ def switch_transformer(c: Configs):
feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
drop_tokens=c.drop_tokens,
is_scale_prob=c.is_scale_prob,
n_switches=c.n_switches,
n_experts=c.n_experts,
d_model=c.d_model,
d_ff=c.d_ff,
dropout=c.dropout),
@ -158,7 +158,7 @@ def main():
'transformer': 'switch_transformer',
'is_scale_prob': False,
'n_switches': 4,
'n_experts': 4,
'drop_tokens': True,
'capacity_factor': 1.2,