mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 14:29:43 +08:00 
			
		
		
		
	📚 switch transformer notes
This commit is contained in:
		@ -58,11 +58,17 @@ class EmbeddingsWithLearnedPositionalEncoding(Module):
 | 
				
			|||||||
class FeedForward(Module):
 | 
					class FeedForward(Module):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    <a id="FeedForward">
 | 
					    <a id="FeedForward">
 | 
				
			||||||
    ## Position-wise feed-forward network with hidden layer
 | 
					    ## Position-wise feed-forward network (FFN) with hidden layer
 | 
				
			||||||
    </a>
 | 
					    </a>
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, activation=nn.ReLU()):
 | 
					    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, activation=nn.ReLU()):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        * `d_model` is the number of features in a token embedding
 | 
				
			||||||
 | 
					        * `d_ff` is the number of features in the hidden layer of the FFN
 | 
				
			||||||
 | 
					        * `dropout` is dropout probability for the hidden layer
 | 
				
			||||||
 | 
					        * `activation` is the activation function to apply on the hidden layer outputs
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.layer1 = nn.Linear(d_model, d_ff)
 | 
					        self.layer1 = nn.Linear(d_model, d_ff)
 | 
				
			||||||
        self.layer2 = nn.Linear(d_ff, d_model)
 | 
					        self.layer2 = nn.Linear(d_ff, d_model)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,3 +1,41 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					title: Switch Transformer
 | 
				
			||||||
 | 
					summary: >
 | 
				
			||||||
 | 
					  This is an annotated implementation/tutorial a miniature version of Switch Transformer in PyTorch.
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Switch Transformer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This is a miniature implementation of the paper
 | 
				
			||||||
 | 
					[Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961).
 | 
				
			||||||
 | 
					Our implementation only has a few million parameters and doesn't do model parallel distributed training.
 | 
				
			||||||
 | 
					It does single GPU training but we implement the concept of switching as described in the paper.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The Switch Transformer is uses different parameters for each tokens by switching among parameters,
 | 
				
			||||||
 | 
					based on the token. So only a fraction of parameters is chosen for each token, so you
 | 
				
			||||||
 | 
					can have more parameters but a less computational cost.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The switching happens at the Position-wise Feedforward network (FFN) of of each transformer block.
 | 
				
			||||||
 | 
					Position-wise feedforward network is a two sequential fully connected layers.
 | 
				
			||||||
 | 
					In switch transformer we have multiple FFNs (multiple experts) and
 | 
				
			||||||
 | 
					we chose which one to use based on a router.
 | 
				
			||||||
 | 
					The outputs a set of probabilities for picking a FFN,
 | 
				
			||||||
 | 
					and we pick the one with highest probability and only evaluates that.
 | 
				
			||||||
 | 
					So essentially the computational cost is same as having a single FFN.
 | 
				
			||||||
 | 
					In our implementation this doesn't parallelize well when you have many or large FFNs since it's all
 | 
				
			||||||
 | 
					happening on a single GPU.
 | 
				
			||||||
 | 
					In a distributed setup you would have each FFN (each very large) on a different device.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The paper introduces another loss term to balance load among the experts (FFNs) and
 | 
				
			||||||
 | 
					discusses dropping tokens when routing is not balanced.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Here's a notebook for training a switch transformer on Tiny Shakespeare dataset.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb)
 | 
				
			||||||
 | 
					[](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002)
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -9,27 +47,39 @@ from labml_nn.utils import clone_module_list
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class SwitchFeedForward(Module):
 | 
					class SwitchFeedForward(Module):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    ## Position-wise feed-forward network with hidden layer
 | 
					    ## Routing among multiple FFNs
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, *,
 | 
					    def __init__(self, *,
 | 
				
			||||||
                 capacity_factor: float,
 | 
					                 capacity_factor: float,
 | 
				
			||||||
                 drop_tokens: bool,
 | 
					                 drop_tokens: bool,
 | 
				
			||||||
                 is_scale_prob: bool,
 | 
					                 is_scale_prob: bool,
 | 
				
			||||||
                 n_switches: int,
 | 
					                 n_experts: int,
 | 
				
			||||||
                 d_model: int,
 | 
					                 d_model: int,
 | 
				
			||||||
                 d_ff: int,
 | 
					                 d_ff: int,
 | 
				
			||||||
                 dropout: float = 0.1):
 | 
					                 dropout: float = 0.1):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        * `capacity_factor` is the capacity of each expert as a factor relative to ideally balanced load
 | 
				
			||||||
 | 
					        * `drop_tokens` specifies whether to drop tokens if more tokens are routed to an expert than the capacity
 | 
				
			||||||
 | 
					        * `is_scale_prob` specifies whether to multiply the input to the FFN by the routing probability
 | 
				
			||||||
 | 
					        * `n_experts` is the number of experts
 | 
				
			||||||
 | 
					        * `d_model` is the number of features in a token embedding
 | 
				
			||||||
 | 
					        * `d_ff` is the number of features in the hidden layer of the FFN
 | 
				
			||||||
 | 
					        * `dropout` is dropout probability in the FFN
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.capacity_factor = capacity_factor
 | 
					        self.capacity_factor = capacity_factor
 | 
				
			||||||
        self.is_scale_prob = is_scale_prob
 | 
					        self.is_scale_prob = is_scale_prob
 | 
				
			||||||
        self.units = nn.ModuleList([FeedForward(d_model, d_ff, dropout) for _ in range(n_switches)])
 | 
					        self.n_switches = n_experts
 | 
				
			||||||
        self.switch = nn.Linear(d_model, n_switches)
 | 
					 | 
				
			||||||
        self.softmax = nn.Softmax(dim=-1)
 | 
					 | 
				
			||||||
        self.n_switches = n_switches
 | 
					 | 
				
			||||||
        self.drop_tokens = drop_tokens
 | 
					        self.drop_tokens = drop_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # FFN modules for each expert
 | 
				
			||||||
 | 
					        self.experts = nn.ModuleList([FeedForward(d_model, d_ff, dropout) for _ in range(n_experts)])
 | 
				
			||||||
 | 
					        # Routing layer and softmax
 | 
				
			||||||
 | 
					        self.switch = nn.Linear(d_model, n_experts)
 | 
				
			||||||
 | 
					        self.softmax = nn.Softmax(dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, x: torch.Tensor):
 | 
					    def __call__(self, x: torch.Tensor):
 | 
				
			||||||
        seq_len, bs, d_model = x.shape
 | 
					        seq_len, bs, d_model = x.shape
 | 
				
			||||||
        x = x.view(-1, d_model)
 | 
					        x = x.view(-1, d_model)
 | 
				
			||||||
@ -63,7 +113,7 @@ class SwitchFeedForward(Module):
 | 
				
			|||||||
                dropped.append(indexes_list[i][capacity:])
 | 
					                dropped.append(indexes_list[i][capacity:])
 | 
				
			||||||
                indexes_list[i] = indexes_list[i][:capacity]
 | 
					                indexes_list[i] = indexes_list[i][:capacity]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        route_outputs = [self.units[i](x[indexes_list[i], :]) for i in range(self.n_switches)]
 | 
					        route_outputs = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_switches)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Assign to final output
 | 
					        # Assign to final output
 | 
				
			||||||
        for i in range(self.n_switches):
 | 
					        for i in range(self.n_switches):
 | 
				
			||||||
 | 
				
			|||||||
@ -47,7 +47,7 @@ class Configs(NLPAutoRegressionConfigs):
 | 
				
			|||||||
    dropout: float = 0.0
 | 
					    dropout: float = 0.0
 | 
				
			||||||
    d_ff: int = 256
 | 
					    d_ff: int = 256
 | 
				
			||||||
    n_layers: int = 6
 | 
					    n_layers: int = 6
 | 
				
			||||||
    n_switches = 4
 | 
					    n_experts: int = 4
 | 
				
			||||||
    load_balancing_loss_ceof = 0.01
 | 
					    load_balancing_loss_ceof = 0.01
 | 
				
			||||||
    is_scale_prob: bool = True
 | 
					    is_scale_prob: bool = True
 | 
				
			||||||
    drop_tokens: bool = False
 | 
					    drop_tokens: bool = False
 | 
				
			||||||
@ -89,7 +89,7 @@ class Configs(NLPAutoRegressionConfigs):
 | 
				
			|||||||
        tracker.add('route.std.', route_frac.std())
 | 
					        tracker.add('route.std.', route_frac.std())
 | 
				
			||||||
        # for i in range(self.n_switches):
 | 
					        # for i in range(self.n_switches):
 | 
				
			||||||
        #     tracker.add(f'route.{i}', route_frac[:, i].mean())
 | 
					        #     tracker.add(f'route.{i}', route_frac[:, i].mean())
 | 
				
			||||||
        load_balancing_loss = self.n_switches * (route_frac * route_prob).sum()
 | 
					        load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
 | 
				
			||||||
        tracker.add("loss.", loss)
 | 
					        tracker.add("loss.", loss)
 | 
				
			||||||
        tracker.add("lb_loss.", loss)
 | 
					        tracker.add("lb_loss.", loss)
 | 
				
			||||||
        loss = loss + self.load_balancing_loss_ceof * load_balancing_loss
 | 
					        loss = loss + self.load_balancing_loss_ceof * load_balancing_loss
 | 
				
			||||||
@ -133,7 +133,7 @@ def switch_transformer(c: Configs):
 | 
				
			|||||||
                               feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
 | 
					                               feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
 | 
				
			||||||
                                                              drop_tokens=c.drop_tokens,
 | 
					                                                              drop_tokens=c.drop_tokens,
 | 
				
			||||||
                                                              is_scale_prob=c.is_scale_prob,
 | 
					                                                              is_scale_prob=c.is_scale_prob,
 | 
				
			||||||
                                                              n_switches=c.n_switches,
 | 
					                                                              n_experts=c.n_experts,
 | 
				
			||||||
                                                              d_model=c.d_model,
 | 
					                                                              d_model=c.d_model,
 | 
				
			||||||
                                                              d_ff=c.d_ff,
 | 
					                                                              d_ff=c.d_ff,
 | 
				
			||||||
                                                              dropout=c.dropout),
 | 
					                                                              dropout=c.dropout),
 | 
				
			||||||
@ -158,7 +158,7 @@ def main():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                        'transformer': 'switch_transformer',
 | 
					                        'transformer': 'switch_transformer',
 | 
				
			||||||
                        'is_scale_prob': False,
 | 
					                        'is_scale_prob': False,
 | 
				
			||||||
                        'n_switches': 4,
 | 
					                        'n_experts': 4,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        'drop_tokens': True,
 | 
					                        'drop_tokens': True,
 | 
				
			||||||
                        'capacity_factor': 1.2,
 | 
					                        'capacity_factor': 1.2,
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user