mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 21:40:15 +08:00
📚 switch transformer notes
This commit is contained in:
@ -58,11 +58,17 @@ class EmbeddingsWithLearnedPositionalEncoding(Module):
|
||||
class FeedForward(Module):
|
||||
"""
|
||||
<a id="FeedForward">
|
||||
## Position-wise feed-forward network with hidden layer
|
||||
## Position-wise feed-forward network (FFN) with hidden layer
|
||||
</a>
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, activation=nn.ReLU()):
|
||||
"""
|
||||
* `d_model` is the number of features in a token embedding
|
||||
* `d_ff` is the number of features in the hidden layer of the FFN
|
||||
* `dropout` is dropout probability for the hidden layer
|
||||
* `activation` is the activation function to apply on the hidden layer outputs
|
||||
"""
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(d_model, d_ff)
|
||||
self.layer2 = nn.Linear(d_ff, d_model)
|
||||
|
||||
@ -1,3 +1,41 @@
|
||||
"""
|
||||
---
|
||||
title: Switch Transformer
|
||||
summary: >
|
||||
This is an annotated implementation/tutorial a miniature version of Switch Transformer in PyTorch.
|
||||
---
|
||||
|
||||
# Switch Transformer
|
||||
|
||||
This is a miniature implementation of the paper
|
||||
[Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961).
|
||||
Our implementation only has a few million parameters and doesn't do model parallel distributed training.
|
||||
It does single GPU training but we implement the concept of switching as described in the paper.
|
||||
|
||||
The Switch Transformer is uses different parameters for each tokens by switching among parameters,
|
||||
based on the token. So only a fraction of parameters is chosen for each token, so you
|
||||
can have more parameters but a less computational cost.
|
||||
|
||||
The switching happens at the Position-wise Feedforward network (FFN) of of each transformer block.
|
||||
Position-wise feedforward network is a two sequential fully connected layers.
|
||||
In switch transformer we have multiple FFNs (multiple experts) and
|
||||
we chose which one to use based on a router.
|
||||
The outputs a set of probabilities for picking a FFN,
|
||||
and we pick the one with highest probability and only evaluates that.
|
||||
So essentially the computational cost is same as having a single FFN.
|
||||
In our implementation this doesn't parallelize well when you have many or large FFNs since it's all
|
||||
happening on a single GPU.
|
||||
In a distributed setup you would have each FFN (each very large) on a different device.
|
||||
|
||||
The paper introduces another loss term to balance load among the experts (FFNs) and
|
||||
discusses dropping tokens when routing is not balanced.
|
||||
|
||||
Here's a notebook for training a switch transformer on Tiny Shakespeare dataset.
|
||||
|
||||
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb)
|
||||
[](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@ -9,27 +47,39 @@ from labml_nn.utils import clone_module_list
|
||||
|
||||
class SwitchFeedForward(Module):
|
||||
"""
|
||||
## Position-wise feed-forward network with hidden layer
|
||||
## Routing among multiple FFNs
|
||||
"""
|
||||
|
||||
def __init__(self, *,
|
||||
capacity_factor: float,
|
||||
drop_tokens: bool,
|
||||
is_scale_prob: bool,
|
||||
n_switches: int,
|
||||
n_experts: int,
|
||||
d_model: int,
|
||||
d_ff: int,
|
||||
dropout: float = 0.1):
|
||||
|
||||
"""
|
||||
* `capacity_factor` is the capacity of each expert as a factor relative to ideally balanced load
|
||||
* `drop_tokens` specifies whether to drop tokens if more tokens are routed to an expert than the capacity
|
||||
* `is_scale_prob` specifies whether to multiply the input to the FFN by the routing probability
|
||||
* `n_experts` is the number of experts
|
||||
* `d_model` is the number of features in a token embedding
|
||||
* `d_ff` is the number of features in the hidden layer of the FFN
|
||||
* `dropout` is dropout probability in the FFN
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.capacity_factor = capacity_factor
|
||||
self.is_scale_prob = is_scale_prob
|
||||
self.units = nn.ModuleList([FeedForward(d_model, d_ff, dropout) for _ in range(n_switches)])
|
||||
self.switch = nn.Linear(d_model, n_switches)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.n_switches = n_switches
|
||||
self.n_switches = n_experts
|
||||
self.drop_tokens = drop_tokens
|
||||
|
||||
# FFN modules for each expert
|
||||
self.experts = nn.ModuleList([FeedForward(d_model, d_ff, dropout) for _ in range(n_experts)])
|
||||
# Routing layer and softmax
|
||||
self.switch = nn.Linear(d_model, n_experts)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
seq_len, bs, d_model = x.shape
|
||||
x = x.view(-1, d_model)
|
||||
@ -63,7 +113,7 @@ class SwitchFeedForward(Module):
|
||||
dropped.append(indexes_list[i][capacity:])
|
||||
indexes_list[i] = indexes_list[i][:capacity]
|
||||
|
||||
route_outputs = [self.units[i](x[indexes_list[i], :]) for i in range(self.n_switches)]
|
||||
route_outputs = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_switches)]
|
||||
|
||||
# Assign to final output
|
||||
for i in range(self.n_switches):
|
||||
|
||||
@ -47,7 +47,7 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
dropout: float = 0.0
|
||||
d_ff: int = 256
|
||||
n_layers: int = 6
|
||||
n_switches = 4
|
||||
n_experts: int = 4
|
||||
load_balancing_loss_ceof = 0.01
|
||||
is_scale_prob: bool = True
|
||||
drop_tokens: bool = False
|
||||
@ -89,7 +89,7 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
tracker.add('route.std.', route_frac.std())
|
||||
# for i in range(self.n_switches):
|
||||
# tracker.add(f'route.{i}', route_frac[:, i].mean())
|
||||
load_balancing_loss = self.n_switches * (route_frac * route_prob).sum()
|
||||
load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
|
||||
tracker.add("loss.", loss)
|
||||
tracker.add("lb_loss.", loss)
|
||||
loss = loss + self.load_balancing_loss_ceof * load_balancing_loss
|
||||
@ -133,7 +133,7 @@ def switch_transformer(c: Configs):
|
||||
feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
|
||||
drop_tokens=c.drop_tokens,
|
||||
is_scale_prob=c.is_scale_prob,
|
||||
n_switches=c.n_switches,
|
||||
n_experts=c.n_experts,
|
||||
d_model=c.d_model,
|
||||
d_ff=c.d_ff,
|
||||
dropout=c.dropout),
|
||||
@ -158,7 +158,7 @@ def main():
|
||||
|
||||
'transformer': 'switch_transformer',
|
||||
'is_scale_prob': False,
|
||||
'n_switches': 4,
|
||||
'n_experts': 4,
|
||||
|
||||
'drop_tokens': True,
|
||||
'capacity_factor': 1.2,
|
||||
|
||||
Reference in New Issue
Block a user