diff --git a/docs/normalization/batch_norm.html b/docs/normalization/batch_norm.html new file mode 100644 index 00000000..1c937b5a --- /dev/null +++ b/docs/normalization/batch_norm.html @@ -0,0 +1,186 @@ + + +
+ + + + + + + + + + + + + + + + + + + ++ home + normalization +
+ +1import torch
+2from torch import nn
+3
+4from labml_helpers.module import Module
7class BatchNorm(Module):
8 def __init__(self, channels: int, *,
+9 eps: float = 1e-5, momentum: float = 0.1,
+10 affine: bool = True, track_running_stats: bool = True):
+11 super().__init__()
+12
+13 self.channels = channels
+14
+15 self.eps = eps
+16 self.momentum = momentum
+17 self.affine = affine
+18 self.track_running_stats = track_running_stats
+19 if self.affine:
+20 self.weight = nn.Parameter(torch.ones(channels))
+21 self.bias = nn.Parameter(torch.zeros(channels))
+22 if self.track_running_stats:
+23 self.register_buffer('running_mean', torch.zeros(channels))
+24 self.register_buffer('running_var', torch.ones(channels))
26 def __call__(self, x: torch.Tensor):
+27 x_shape = x.shape
+28 batch_size = x_shape[0]
+29
+30 x = x.view(batch_size, self.channels, -1)
+31 if self.training or not self.track_running_stats:
+32 mean = x.mean(dim=[0, 2])
+33 mean_x2 = (x ** 2).mean(dim=[0, 2])
+34 var = mean_x2 - mean ** 2
+35
+36 if self.training and self.track_running_stats:
+37 self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
+38 self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
+39 else:
+40 mean = self.running_mean
+41 var = self.running_var
+42
+43 x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
+44 if self.affine:
+45 x_norm = self.weight.view(1, -1, 1) * x_norm + self.bias.view(1, -1, 1)
+46
+47 return x_norm.view(x_shape)
+ home + normalization +
+ ++ home + normalization +
+ +1import torch.nn as nn
+2import torch.nn.functional as F
+3import torch.utils.data
+4
+5from labml import experiment, tracker
+6from labml.configs import option
+7from labml_helpers.datasets.mnist import MNISTConfigs
+8from labml_helpers.device import DeviceConfigs
+9from labml_helpers.metrics.accuracy import Accuracy
+10from labml_helpers.module import Module
+11from labml_helpers.seed import SeedConfigs
+12from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
+13from labml_nn.normalization.batch_norm import BatchNorm
16class Net(Module):
17 def __init__(self):
+18 super().__init__()
+19 self.conv1 = nn.Conv2d(1, 20, 5, 1)
+20 self.bn1 = BatchNorm(20)
+21 self.conv2 = nn.Conv2d(20, 50, 5, 1)
+22 self.bn2 = BatchNorm(50)
+23 self.fc1 = nn.Linear(4 * 4 * 50, 500)
+24 self.bn3 = BatchNorm(500)
+25 self.fc2 = nn.Linear(500, 10)
27 def __call__(self, x: torch.Tensor):
+28 x = F.relu(self.bn1(self.conv1(x)))
+29 x = F.max_pool2d(x, 2, 2)
+30 x = F.relu(self.bn2(self.conv2(x)))
+31 x = F.max_pool2d(x, 2, 2)
+32 x = x.view(-1, 4 * 4 * 50)
+33 x = F.relu(self.bn3(self.fc1(x)))
+34 return self.fc2(x)
37class Configs(MNISTConfigs, TrainValidConfigs):
+38 optimizer: torch.optim.Adam
+39 model: nn.Module
+40 set_seed = SeedConfigs()
+41 device: torch.device = DeviceConfigs()
+42 epochs: int = 10
+43
+44 is_save_models = True
+45 model: nn.Module
+46 inner_iterations = 10
+47
+48 accuracy_func = Accuracy()
+49 loss_func = nn.CrossEntropyLoss()
51 def init(self):
+52 tracker.set_queue("loss.*", 20, True)
+53 tracker.set_scalar("accuracy.*", True)
+54 hook_model_outputs(self.mode, self.model, 'model')
+55 self.state_modules = [self.accuracy_func]
57 def step(self, batch: any, batch_idx: BatchIndex):
+58 data, target = batch[0].to(self.device), batch[1].to(self.device)
+59
+60 if self.mode.is_train:
+61 tracker.add_global_step(len(data))
+62
+63 with self.mode.update(is_log_activations=batch_idx.is_last):
+64 output = self.model(data)
+65
+66 loss = self.loss_func(output, target)
+67 self.accuracy_func(output, target)
+68 tracker.add("loss.", loss)
+69
+70 if self.mode.is_train:
+71 loss.backward()
+72
+73 self.optimizer.step()
+74 if batch_idx.is_last:
+75 tracker.add('model', self.model)
+76 self.optimizer.zero_grad()
+77
+78 tracker.save()
81@option(Configs.model)
+82def model(c: Configs):
+83 return Net().to(c.device)
+84
+85
+86@option(Configs.optimizer)
+87def _optimizer(c: Configs):
+88 from labml_helpers.optimizer import OptimizerConfigs
+89 opt_conf = OptimizerConfigs()
+90 opt_conf.parameters = c.model.parameters()
+91 return opt_conf
+92
+93
+94def main():
+95 conf = Configs()
+96 experiment.create(name='mnist_labml_helpers')
+97 experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
+98 conf.set_seed.set()
+99 experiment.add_pytorch_models(dict(model=conf.model))
+100 with experiment.start():
+101 conf.run()
+102
+103
+104if __name__ == '__main__':
+105 main()
This is a miniature PyTorch implementation of the paper Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. Our implementation only has a few million parameters and doesn’t do model parallel distributed training. -It does single GPU training but we implement the concept of switching as described in the paper.
+It does single GPU training, but we implement the concept of switching as described in the paper.The Switch Transformer uses different parameters for each token by switching among parameters, based on the token. So only a fraction of parameters is chosen for each token, so you can have more parameters but less computational cost.
-The switching happens at the Position-wise Feedforward network (FFN) of of each transformer block. +
The switching happens at the Position-wise Feedforward network (FFN) of each transformer block. Position-wise feedforward network is a two sequentially fully connected layers. -In switch transformer we have multiple FFNs (multiple experts) and -we chose which one to use based on a router. +In switch transformer we have multiple FFNs (multiple experts), +and we chose which one to use based on a router. The outputs a set of probabilities for picking a FFN, -and we pick the one with highest probability and only evaluates that. +and we pick the one with the highest probability and only evaluates that. So essentially the computational cost is same as having a single FFN. In our implementation this doesn’t parallelize well when you have many or large FFNs since it’s all happening on a single GPU. diff --git a/docs/transformers/switch/readme.html b/docs/transformers/switch/readme.html new file mode 100644 index 00000000..142d022f --- /dev/null +++ b/docs/transformers/switch/readme.html @@ -0,0 +1,136 @@ + + +
+ + + + + + + + + + + + + + + + + + + +This is a miniature PyTorch implementation of the paper +Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. +Our implementation only has a few million parameters and doesn’t do model parallel distributed training. +It does single GPU training, but we implement the concept of switching as described in the paper.
+The Switch Transformer uses different parameters for each token by switching among parameters, +based on the token. So only a fraction of parameters is chosen for each token, so you +can have more parameters but less computational cost.
+The switching happens at the Position-wise Feedforward network (FFN) of each transformer block. +Position-wise feedforward network is a two sequentially fully connected layers. +In switch transformer we have multiple FFNs (multiple experts), +and we chose which one to use based on a router. +The outputs a set of probabilities for picking a FFN, +and we pick the one with the highest probability and only evaluates that. +So essentially the computational cost is same as having a single FFN. +In our implementation this doesn’t parallelize well when you have many or large FFNs since it’s all +happening on a single GPU. +In a distributed setup you would have each FFN (each very large) on a different device.
+The paper introduces another loss term to balance load among the experts (FFNs) and +discusses dropping tokens when routing is not balanced.
+Here’s the training code and a notebook for training a switch transformer on Tiny Shakespeare dataset.
+ +