From 983286e216f3ede27cdb4daf3d30460fec2ce472 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Mon, 1 Feb 2021 14:43:11 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9A=20batch=20norm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Makefile | 1 + docs/normalization/batch_norm.html | 186 -------- docs/normalization/batch_norm/index.html | 420 ++++++++++++++++++ docs/normalization/batch_norm/mnist.html | 336 ++++++++++++++ docs/normalization/index.html | 28 +- docs/normalization/mnist.html | 278 ------------ docs/sitemap.xml | 23 +- labml_nn/experiments/mnist.py | 113 +++++ labml_nn/normalization/__init__.py | 11 + labml_nn/normalization/batch_norm.py | 47 -- labml_nn/normalization/batch_norm/__init__.py | 174 ++++++++ labml_nn/normalization/batch_norm/mnist.py | 82 ++++ labml_nn/normalization/mnist.py | 105 ----- 13 files changed, 1166 insertions(+), 638 deletions(-) delete mode 100644 docs/normalization/batch_norm.html create mode 100644 docs/normalization/batch_norm/index.html create mode 100644 docs/normalization/batch_norm/mnist.html delete mode 100644 docs/normalization/mnist.html create mode 100644 labml_nn/experiments/mnist.py delete mode 100644 labml_nn/normalization/batch_norm.py create mode 100644 labml_nn/normalization/batch_norm/__init__.py create mode 100644 labml_nn/normalization/batch_norm/mnist.py delete mode 100644 labml_nn/normalization/mnist.py diff --git a/Makefile b/Makefile index 505af47a..ab61b2d8 100644 --- a/Makefile +++ b/Makefile @@ -22,6 +22,7 @@ uninstall: ## Uninstall pip uninstall labml_nn docs: ## Render annotated HTML + find ./docs/ -name "*.html" -type f -delete python utils/sitemap.py cd labml_nn; pylit --remove_empty_sections --title_md -t ../../../pylit/templates/nn -d ../docs -w * diff --git a/docs/normalization/batch_norm.html b/docs/normalization/batch_norm.html deleted file mode 100644 index 1c937b5a..00000000 --- a/docs/normalization/batch_norm.html +++ /dev/null @@ -1,186 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - batch_norm.py - - - - - - - - -
-
-
-
-

- home - normalization -

-

- - - Github - - Join Slact - - Twitter -

-
-
-
-
- - -
-
-
1import torch
-2from torch import nn
-3
-4from labml_helpers.module import Module
-
-
-
-
- - -
-
-
7class BatchNorm(Module):
-
-
-
-
- - -
-
-
8    def __init__(self, channels: int, *,
-9                 eps: float = 1e-5, momentum: float = 0.1,
-10                 affine: bool = True, track_running_stats: bool = True):
-11        super().__init__()
-12
-13        self.channels = channels
-14
-15        self.eps = eps
-16        self.momentum = momentum
-17        self.affine = affine
-18        self.track_running_stats = track_running_stats
-19        if self.affine:
-20            self.weight = nn.Parameter(torch.ones(channels))
-21            self.bias = nn.Parameter(torch.zeros(channels))
-22        if self.track_running_stats:
-23            self.register_buffer('running_mean', torch.zeros(channels))
-24            self.register_buffer('running_var', torch.ones(channels))
-
-
-
-
- - -
-
-
26    def __call__(self, x: torch.Tensor):
-27        x_shape = x.shape
-28        batch_size = x_shape[0]
-29
-30        x = x.view(batch_size, self.channels, -1)
-31        if self.training or not self.track_running_stats:
-32            mean = x.mean(dim=[0, 2])
-33            mean_x2 = (x ** 2).mean(dim=[0, 2])
-34            var = mean_x2 - mean ** 2
-35
-36            if self.training and self.track_running_stats:
-37                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
-38                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
-39        else:
-40            mean = self.running_mean
-41            var = self.running_var
-42
-43        x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
-44        if self.affine:
-45            x_norm = self.weight.view(1, -1, 1) * x_norm + self.bias.view(1, -1, 1)
-46
-47        return x_norm.view(x_shape)
-
-
-
- - - - - - \ No newline at end of file diff --git a/docs/normalization/batch_norm/index.html b/docs/normalization/batch_norm/index.html new file mode 100644 index 00000000..7df2d410 --- /dev/null +++ b/docs/normalization/batch_norm/index.html @@ -0,0 +1,420 @@ + + + + + + + + + + + + + + + + + + + + + + + Batch Normalization + + + + + + + + +
+
+
+
+

+ home + normalization + batch_norm +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

Batch Normalization

+

This is a PyTorch implementation of Batch Normalization from paper + Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

+

Internal Covariate Shift

+

The paper defines Internal Covariate Shift as the change in the +distribution of network activations due to the change in +network parameters during training. +For example, let’s say there are two layers $l_1$ and $l_2$. +During the beginning of the training $l_1$ outputs (inputs to $l_2$) +could be in distribution $\mathcal{N}(0.5, 1)$. +Then, after some training steps it could move to $\mathcal{N}(0.5, 1)$. +This is internal covariate shift.

+

Internal covriate shift will adversely affect training speed because the later layers +($l_2$ in the above example) has to adapt to this shifted distribution.

+

By stabilizing the distribution batch normalization minimizes the internal covariate shift.

+

Normalization

+

It is known that whitening improves training speed and convergence. +Whitening is linearly transforming inputs to have zero mean, unit variance +and be uncorrelated.

+

Normalizing outside gradient computation doesn’t work

+

Normalizing outside the gradient computation using pre-computed (detached) +means and variances doesn’t work. For instance. (ignoring variance), let + +where $x = u + b$ and $b$ is a trained bias. +and $\mathbb{E}[x]$ is outside gradient computation (pre-computed constant).

+

Note that $\hat{x}$ has no effect of $b$. +Therefore, +$b$ will increase or decrease based +$\frac{\partial{\mathcal{L}}}{\partial x}$, +and keep on growing indefinitely in each training update. +Paper notes that similar explosions happen with variances.

+

Batch Normalization

+

Whitening is computationally expensive because you need to de-correlate and +the gradients must flow through the full whitening calculation.

+

The paper introduces simplified version which they call Batch Normalization. +First simplification is that it normalizes each feature independently to have +zero mean and unit variance: + +where $x = (x^{(1)} … x^{(d)})$ is the $d$-dimensional input.

+

The second simplification is to use estimates of mean $\mathbb{E}[x^{(k)}]$ +and variance $Var[x^{(k)}]$ from the mini-batch +for normalization; instead of calculating the mean and variance across whole dataset.

+

Normalizing each feature to zero mean and unit variance could effect what the layer +can represent. +As an example paper illustrates that, if the inputs to a sigmoid are normalized +most of it will be within $[-1, 1]$ range where the sigmoid is linear. +To overcome this each feature is scaled and shifted by two trained parameters +$\gamma^{(k)}$ and $\beta^{(k)}$. + +where $y^{(k)}$ is the output of of the batch normalization layer.

+

Note that when applying batch normalization after a linear transform +like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization. +So you can and should omit bias parameter in linear transforms right before the +batch normalization.

+

Inference

+

We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to +perform the normalization. +So during inference, you either need to go through the whole (or part of) dataset +and find the mean and variance, or you can use an estimate calculated during training. +The usual practice is to calculate an exponential moving average of +mean and variance during training phase and use that for inference.

+
+
+
89import torch
+90from torch import nn
+91
+92from labml_helpers.module import Module
+
+
+
+
+ +

Batch Normalization Layer

+
+
+
95class BatchNorm(Module):
+
+
+
+
+ +
    +
  • channels is the number of features in the input
  • +
  • eps is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
  • +
  • momentum is the momentum in taking the exponential moving average
  • +
  • affine is whether to scale and shift the normalized value
  • +
  • track_running_stats is whether to calculate the moving averages or mean and variance
  • +
+

We’ve tried to use the same names for arguments as PyTorch BatchNorm implementation.

+
+
+
99    def __init__(self, channels: int, *,
+100                 eps: float = 1e-5, momentum: float = 0.1,
+101                 affine: bool = True, track_running_stats: bool = True):
+
+
+
+
+ + +
+
+
111        super().__init__()
+112
+113        self.channels = channels
+114
+115        self.eps = eps
+116        self.momentum = momentum
+117        self.affine = affine
+118        self.track_running_stats = track_running_stats
+
+
+
+
+ +

Create parameters for $\gamma$ and $\beta$ for scale and shift

+
+
+
120        if self.affine:
+121            self.scale = nn.Parameter(torch.ones(channels))
+122            self.shift = nn.Parameter(torch.zeros(channels))
+
+
+
+
+ +

Create buffers to store exponential moving averages of +mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$

+
+
+
125        if self.track_running_stats:
+126            self.register_buffer('exp_mean', torch.zeros(channels))
+127            self.register_buffer('exp_var', torch.ones(channels))
+
+
+
+
+ +

x is a tensor of shape [batch_size, channels, *]. +* could be any (even *) dimensions. + For example, in an image (2D) convolution this will be +[batch_size, channels, height, width]

+
+
+
129    def __call__(self, x: torch.Tensor):
+
+
+
+
+ +

Keep the original shape

+
+
+
137        x_shape = x.shape
+
+
+
+
+ +

Get the batch size

+
+
+
139        batch_size = x_shape[0]
+
+
+
+
+ +

Sanity check to make sure the number of features is same

+
+
+
141        assert self.channels == x.shape[1]
+
+
+
+
+ +

Reshape into [batch_size, channels, n]

+
+
+
144        x = x.view(batch_size, self.channels, -1)
+
+
+
+
+ +

We will calculate the mini-batch mean and variance +if we are in training mode or if we have not tracked exponential moving averages

+
+
+
148        if self.training or not self.track_running_stats:
+
+
+
+
+ +

Calculate the mean across first and last dimension; +i.e. the means for each feature $\mathbb{E}[x^{(k)}]$

+
+
+
151            mean = x.mean(dim=[0, 2])
+
+
+
+
+ +

Calculate the squared mean across first and last dimension; +i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$

+
+
+
154            mean_x2 = (x ** 2).mean(dim=[0, 2])
+
+
+
+
+ +

Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$

+
+
+
156            var = mean_x2 - mean ** 2
+
+
+
+
+ +

Update exponential moving averages

+
+
+
159            if self.training and self.track_running_stats:
+160                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
+161                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
+
+
+
+
+ +

Use exponential moving averages as estimates

+
+
+
163        else:
+164            mean = self.exp_mean
+165            var = self.exp_var
+
+
+
+
+ +

Normalize +

+
+
+
168        x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
+
+
+
+
+ +

Scale and shift +

+
+
+
170        if self.affine:
+171            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
+
+
+
+
+ +

Reshape to original and return

+
+
+
174        return x_norm.view(x_shape)
+
+
+
+ + + + + + \ No newline at end of file diff --git a/docs/normalization/batch_norm/mnist.html b/docs/normalization/batch_norm/mnist.html new file mode 100644 index 00000000..df5812df --- /dev/null +++ b/docs/normalization/batch_norm/mnist.html @@ -0,0 +1,336 @@ + + + + + + + + + + + + + + + + + + + + + + + MNIST Experiment to try Batch Normalization + + + + + + + + +
+
+
+
+

+ home + normalization + batch_norm +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

MNIST Experiment for Batch Normalization

+
+
+
11import torch.nn as nn
+12import torch.nn.functional as F
+13import torch.utils.data
+14
+15from labml import experiment
+16from labml.configs import option
+17from labml_helpers.module import Module
+18from labml_nn.experiments.mnist import MNISTConfigs
+19from labml_nn.normalization.batch_norm import BatchNorm
+
+
+
+
+ +

Model definition

+
+
+
22class Model(Module):
+
+
+
+
+ + +
+
+
27    def __init__(self):
+28        super().__init__()
+
+
+
+
+ +

Note that we omit the bias parameter

+
+
+
30        self.conv1 = nn.Conv2d(1, 20, 5, 1, bias=False)
+
+
+
+
+ +

Batch normalization with 20 channels (output of convolution layer). +The input to this layer will have shape [batch_size, 20, height(24), width(24)]

+
+
+
33        self.bn1 = BatchNorm(20)
+
+
+
+
+ + +
+
+
35        self.conv2 = nn.Conv2d(20, 50, 5, 1, bias=False)
+
+
+
+
+ +

Batch normalization with 50 channels. +The input to this layer will have shape [batch_size, 50, height(8), width(8)]

+
+
+
38        self.bn2 = BatchNorm(50)
+
+
+
+
+ + +
+
+
40        self.fc1 = nn.Linear(4 * 4 * 50, 500, bias=False)
+
+
+
+
+ +

Batch normalization with 500 channels (output of fully connected layer). +The input to this layer will have shape [batch_size, 500]

+
+
+
43        self.bn3 = BatchNorm(500)
+
+
+
+
+ + +
+
+
45        self.fc2 = nn.Linear(500, 10)
+
+
+
+
+ + +
+
+
47    def __call__(self, x: torch.Tensor):
+48        x = F.relu(self.bn1(self.conv1(x)))
+49        x = F.max_pool2d(x, 2, 2)
+50        x = F.relu(self.bn2(self.conv2(x)))
+51        x = F.max_pool2d(x, 2, 2)
+52        x = x.view(-1, 4 * 4 * 50)
+53        x = F.relu(self.bn3(self.fc1(x)))
+54        return self.fc2(x)
+
+
+
+
+ +

Create model

+

We use MNISTConfigs configurations +and set a new function to calculate the model.

+
+
+
57@option(MNISTConfigs.model)
+58def model(c: MNISTConfigs):
+
+
+
+
+ + +
+
+
65    return Model().to(c.device)
+
+
+
+
+ + +
+
+
68def main():
+
+
+
+
+ +

Create experiment

+
+
+
70    experiment.create(name='mnist_batch_norm')
+
+
+
+
+ +

Create configurations

+
+
+
72    conf = MNISTConfigs()
+
+
+
+
+ +

Load configurations

+
+
+
74    experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
+
+
+
+
+ +

Start the experiment and run the training loop

+
+
+
76    with experiment.start():
+77        conf.run()
+
+
+
+
+ + +
+
+
81if __name__ == '__main__':
+82    main()
+
+
+
+ + + + + + \ No newline at end of file diff --git a/docs/normalization/index.html b/docs/normalization/index.html index 2b4f873b..5d24b334 100644 --- a/docs/normalization/index.html +++ b/docs/normalization/index.html @@ -3,24 +3,24 @@ - + - - + + - + - - + + - None + Normalization Layers @@ -66,6 +66,20 @@

+
+
+ +

Normalization Layers

+ +
+
+
+
+
- - - -
-
-
-
-

- home - normalization -

-

- - - Github - - Join Slact - - Twitter -

-
-
-
-
- - -
-
-
1import torch.nn as nn
-2import torch.nn.functional as F
-3import torch.utils.data
-4
-5from labml import experiment, tracker
-6from labml.configs import option
-7from labml_helpers.datasets.mnist import MNISTConfigs
-8from labml_helpers.device import DeviceConfigs
-9from labml_helpers.metrics.accuracy import Accuracy
-10from labml_helpers.module import Module
-11from labml_helpers.seed import SeedConfigs
-12from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
-13from labml_nn.normalization.batch_norm import BatchNorm
-
-
-
-
- - -
-
-
16class Net(Module):
-
-
-
-
- - -
-
-
17    def __init__(self):
-18        super().__init__()
-19        self.conv1 = nn.Conv2d(1, 20, 5, 1)
-20        self.bn1 = BatchNorm(20)
-21        self.conv2 = nn.Conv2d(20, 50, 5, 1)
-22        self.bn2 = BatchNorm(50)
-23        self.fc1 = nn.Linear(4 * 4 * 50, 500)
-24        self.bn3 = BatchNorm(500)
-25        self.fc2 = nn.Linear(500, 10)
-
-
-
-
- - -
-
-
27    def __call__(self, x: torch.Tensor):
-28        x = F.relu(self.bn1(self.conv1(x)))
-29        x = F.max_pool2d(x, 2, 2)
-30        x = F.relu(self.bn2(self.conv2(x)))
-31        x = F.max_pool2d(x, 2, 2)
-32        x = x.view(-1, 4 * 4 * 50)
-33        x = F.relu(self.bn3(self.fc1(x)))
-34        return self.fc2(x)
-
-
-
-
- - -
-
-
37class Configs(MNISTConfigs, TrainValidConfigs):
-38    optimizer: torch.optim.Adam
-39    model: nn.Module
-40    set_seed = SeedConfigs()
-41    device: torch.device = DeviceConfigs()
-42    epochs: int = 10
-43
-44    is_save_models = True
-45    model: nn.Module
-46    inner_iterations = 10
-47
-48    accuracy_func = Accuracy()
-49    loss_func = nn.CrossEntropyLoss()
-
-
-
-
- - -
-
-
51    def init(self):
-52        tracker.set_queue("loss.*", 20, True)
-53        tracker.set_scalar("accuracy.*", True)
-54        hook_model_outputs(self.mode, self.model, 'model')
-55        self.state_modules = [self.accuracy_func]
-
-
-
-
- - -
-
-
57    def step(self, batch: any, batch_idx: BatchIndex):
-58        data, target = batch[0].to(self.device), batch[1].to(self.device)
-59
-60        if self.mode.is_train:
-61            tracker.add_global_step(len(data))
-62
-63        with self.mode.update(is_log_activations=batch_idx.is_last):
-64            output = self.model(data)
-65
-66        loss = self.loss_func(output, target)
-67        self.accuracy_func(output, target)
-68        tracker.add("loss.", loss)
-69
-70        if self.mode.is_train:
-71            loss.backward()
-72
-73            self.optimizer.step()
-74            if batch_idx.is_last:
-75                tracker.add('model', self.model)
-76            self.optimizer.zero_grad()
-77
-78        tracker.save()
-
-
-
-
- - -
-
-
81@option(Configs.model)
-82def model(c: Configs):
-83    return Net().to(c.device)
-84
-85
-86@option(Configs.optimizer)
-87def _optimizer(c: Configs):
-88    from labml_helpers.optimizer import OptimizerConfigs
-89    opt_conf = OptimizerConfigs()
-90    opt_conf.parameters = c.model.parameters()
-91    return opt_conf
-92
-93
-94def main():
-95    conf = Configs()
-96    experiment.create(name='mnist_labml_helpers')
-97    experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
-98    conf.set_seed.set()
-99    experiment.add_pytorch_models(dict(model=conf.model))
-100    with experiment.start():
-101        conf.run()
-102
-103
-104if __name__ == '__main__':
-105    main()
-
-
-
- - - - - - \ No newline at end of file diff --git a/docs/sitemap.xml b/docs/sitemap.xml index c9abf985..60c5ab5b 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -90,20 +90,6 @@ - - https://nn.labml.ai/normalization/batch_norm.html - 2021-02-01T16:30:00+00:00 - 1.00 - - - - - https://nn.labml.ai/normalization/mnist.html - 2021-02-01T16:30:00+00:00 - 1.00 - - - https://nn.labml.ai/experiments/nlp_autoregression.html 2021-01-25T16:30:00+00:00 @@ -281,7 +267,14 @@ https://nn.labml.ai/transformers/feedback/index.html - 2021-01-30T16:30:00+00:00 + 2021-02-01T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/transformers/feedback/README.html + 2021-02-01T16:30:00+00:00 1.00 diff --git a/labml_nn/experiments/mnist.py b/labml_nn/experiments/mnist.py new file mode 100644 index 00000000..d4bae7c4 --- /dev/null +++ b/labml_nn/experiments/mnist.py @@ -0,0 +1,113 @@ +""" +--- +title: MNIST Experiment +summary: > + This is a reusable trainer for MNIST dataset +--- + +# MNIST Experiment +""" + +import torch.nn as nn +import torch.utils.data +from labml_helpers.module import Module + +from labml import tracker +from labml.configs import option +from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs +from labml_helpers.device import DeviceConfigs +from labml_helpers.metrics.accuracy import Accuracy +from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs +from labml_nn.optimizers.configs import OptimizerConfigs + + +class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs): + """ + + ## Trainer configurations + + """ + + # Optimizer + optimizer: torch.optim.Adam + # Training device + device: torch.device = DeviceConfigs() + + # Classification model + model: Module + # Number of epochs to train for + epochs: int = 10 + + # Number of times to switch between training and validation within an epoch + inner_iterations = 10 + + # Accuracy function + accuracy = Accuracy() + # Loss function + loss_func = nn.CrossEntropyLoss() + + def init(self): + """ + ### Initialization + """ + # Set tracker configurations + tracker.set_scalar("loss.*", True) + tracker.set_scalar("accuracy.*", True) + # Add a hook to log module outputs + hook_model_outputs(self.mode, self.model, 'model') + # Add accuracy as a state module. + # The name is probably confusing, since it's meant to store + # states between training and validation for RNNs. + # This will keep the accuracy metric stats separate for training and validation. + self.state_modules = [self.accuracy] + + def step(self, batch: any, batch_idx: BatchIndex): + """ + ### Training or validation step + """ + + # Move data to the device + data, target = batch[0].to(self.device), batch[1].to(self.device) + + # Update global step (number of samples processed) when in training mode + if self.mode.is_train: + tracker.add_global_step(len(data)) + + # Whether to capture model outputs + with self.mode.update(is_log_activations=batch_idx.is_last): + # Get model outputs. + output = self.model(data) + + # Calculate and log loss + loss = self.loss_func(output, target) + tracker.add("loss.", loss) + + # Calculate and log accuracy + self.accuracy(output, target) + self.accuracy.track() + + # Train the model + if self.mode.is_train: + # Calculate gradients + loss.backward() + # Take optimizer step + self.optimizer.step() + # Log the model parameters and gradients on last batch of every epoch + if batch_idx.is_last: + tracker.add('model', self.model) + # Clear the gradients + self.optimizer.zero_grad() + + # Save the tracked metrics + tracker.save() + + +@option(MNISTConfigs.optimizer) +def _optimizer(c: MNISTConfigs): + """ + ### Default optimizer configurations + """ + opt_conf = OptimizerConfigs() + opt_conf.parameters = c.model.parameters() + opt_conf.optimizer = 'Adam' + return opt_conf diff --git a/labml_nn/normalization/__init__.py b/labml_nn/normalization/__init__.py index e69de29b..725e0703 100644 --- a/labml_nn/normalization/__init__.py +++ b/labml_nn/normalization/__init__.py @@ -0,0 +1,11 @@ +""" +--- +title: Normalization Layers +summary: > + A set of PyTorch implementations/tutorials of normalization layers. +--- + +# Normalization Layers + +* [Batch Normalization](batch_norm/index.html) +""" \ No newline at end of file diff --git a/labml_nn/normalization/batch_norm.py b/labml_nn/normalization/batch_norm.py deleted file mode 100644 index 6a2d5daf..00000000 --- a/labml_nn/normalization/batch_norm.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -from torch import nn - -from labml_helpers.module import Module - - -class BatchNorm(Module): - def __init__(self, channels: int, *, - eps: float = 1e-5, momentum: float = 0.1, - affine: bool = True, track_running_stats: bool = True): - super().__init__() - - self.channels = channels - - self.eps = eps - self.momentum = momentum - self.affine = affine - self.track_running_stats = track_running_stats - if self.affine: - self.weight = nn.Parameter(torch.ones(channels)) - self.bias = nn.Parameter(torch.zeros(channels)) - if self.track_running_stats: - self.register_buffer('running_mean', torch.zeros(channels)) - self.register_buffer('running_var', torch.ones(channels)) - - def __call__(self, x: torch.Tensor): - x_shape = x.shape - batch_size = x_shape[0] - - x = x.view(batch_size, self.channels, -1) - if self.training or not self.track_running_stats: - mean = x.mean(dim=[0, 2]) - mean_x2 = (x ** 2).mean(dim=[0, 2]) - var = mean_x2 - mean ** 2 - - if self.training and self.track_running_stats: - self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean - self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var - else: - mean = self.running_mean - var = self.running_var - - x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1) - if self.affine: - x_norm = self.weight.view(1, -1, 1) * x_norm + self.bias.view(1, -1, 1) - - return x_norm.view(x_shape) diff --git a/labml_nn/normalization/batch_norm/__init__.py b/labml_nn/normalization/batch_norm/__init__.py new file mode 100644 index 00000000..35d76abd --- /dev/null +++ b/labml_nn/normalization/batch_norm/__init__.py @@ -0,0 +1,174 @@ +""" +--- +title: Batch Normalization +summary: > + A PyTorch implementations/tutorials of batch normalization. +--- + +# Batch Normalization + +This is a [PyTorch](https://pytorch.org) implementation of Batch Normalization from paper + [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167). + +### Internal Covariate Shift + +The paper defines *Internal Covariate Shift* as the change in the +distribution of network activations due to the change in +network parameters during training. +For example, let's say there are two layers $l_1$ and $l_2$. +During the beginning of the training $l_1$ outputs (inputs to $l_2$) +could be in distribution $\mathcal{N}(0.5, 1)$. +Then, after some training steps it could move to $\mathcal{N}(0.5, 1)$. +This is *internal covariate shift*. + +Internal covriate shift will adversely affect training speed because the later layers +($l_2$ in the above example) has to adapt to this shifted distribution. + +By stabilizing the distribution batch normalization minimizes the internal covariate shift. + +## Normalization + +It is known that whitening improves training speed and convergence. +*Whitening* is linearly transforming inputs to have zero mean, unit variance +and be uncorrelated. + +### Normalizing outside gradient computation doesn't work + +Normalizing outside the gradient computation using pre-computed (detached) +means and variances doesn't work. For instance. (ignoring variance), let +$$\hat{x} = x - \mathbb{E}[x]$$ +where $x = u + b$ and $b$ is a trained bias. +and $\mathbb{E}[x]$ is outside gradient computation (pre-computed constant). + +Note that $\hat{x}$ has no effect of $b$. +Therefore, +$b$ will increase or decrease based +$\frac{\partial{\mathcal{L}}}{\partial x}$, +and keep on growing indefinitely in each training update. +Paper notes that similar explosions happen with variances. + +### Batch Normalization + +Whitening is computationally expensive because you need to de-correlate and +the gradients must flow through the full whitening calculation. + +The paper introduces simplified version which they call *Batch Normalization*. +First simplification is that it normalizes each feature independently to have +zero mean and unit variance: +$$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}$$ +where $x = (x^{(1)} ... x^{(d)})$ is the $d$-dimensional input. + +The second simplification is to use estimates of mean $\mathbb{E}[x^{(k)}]$ +and variance $Var[x^{(k)}]$ from the mini-batch +for normalization; instead of calculating the mean and variance across whole dataset. + +Normalizing each feature to zero mean and unit variance could effect what the layer +can represent. +As an example paper illustrates that, if the inputs to a sigmoid are normalized +most of it will be within $[-1, 1]$ range where the sigmoid is linear. +To overcome this each feature is scaled and shifted by two trained parameters +$\gamma^{(k)}$ and $\beta^{(k)}$. +$$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$ +where $y^{(k)}$ is the output of of the batch normalization layer. + +Note that when applying batch normalization after a linear transform +like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization. +So you can and should omit bias parameter in linear transforms right before the +batch normalization. + +## Inference + +We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to +perform the normalization. +So during inference, you either need to go through the whole (or part of) dataset +and find the mean and variance, or you can use an estimate calculated during training. +The usual practice is to calculate an exponential moving average of +mean and variance during training phase and use that for inference. +""" + +import torch +from torch import nn + +from labml_helpers.module import Module + + +class BatchNorm(Module): + """ + ## Batch Normalization Layer + """ + def __init__(self, channels: int, *, + eps: float = 1e-5, momentum: float = 0.1, + affine: bool = True, track_running_stats: bool = True): + """ + * `channels` is the number of features in the input + * `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability + * `momentum` is the momentum in taking the exponential moving average + * `affine` is whether to scale and shift the normalized value + * `track_running_stats` is whether to calculate the moving averages or mean and variance + + We've tried to use the same names for arguments as PyTorch `BatchNorm` implementation. + """ + super().__init__() + + self.channels = channels + + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + # Create parameters for $\gamma$ and $\beta$ for scale and shift + if self.affine: + self.scale = nn.Parameter(torch.ones(channels)) + self.shift = nn.Parameter(torch.zeros(channels)) + # Create buffers to store exponential moving averages of + # mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$ + if self.track_running_stats: + self.register_buffer('exp_mean', torch.zeros(channels)) + self.register_buffer('exp_var', torch.ones(channels)) + + def __call__(self, x: torch.Tensor): + """ + `x` is a tensor of shape `[batch_size, channels, *]`. + `*` could be any (even *) dimensions. + For example, in an image (2D) convolution this will be + `[batch_size, channels, height, width]` + """ + # Keep the original shape + x_shape = x.shape + # Get the batch size + batch_size = x_shape[0] + # Sanity check to make sure the number of features is same + assert self.channels == x.shape[1] + + # Reshape into `[batch_size, channels, n]` + x = x.view(batch_size, self.channels, -1) + + # We will calculate the mini-batch mean and variance + # if we are in training mode or if we have not tracked exponential moving averages + if self.training or not self.track_running_stats: + # Calculate the mean across first and last dimension; + # i.e. the means for each feature $\mathbb{E}[x^{(k)}]$ + mean = x.mean(dim=[0, 2]) + # Calculate the squared mean across first and last dimension; + # i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$ + mean_x2 = (x ** 2).mean(dim=[0, 2]) + # Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$ + var = mean_x2 - mean ** 2 + + # Update exponential moving averages + if self.training and self.track_running_stats: + self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean + self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var + # Use exponential moving averages as estimates + else: + mean = self.exp_mean + var = self.exp_var + + # Normalize $$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}] + \epsilon}}$$ + x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1) + # Scale and shift $$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$ + if self.affine: + x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1) + + # Reshape to original and return + return x_norm.view(x_shape) diff --git a/labml_nn/normalization/batch_norm/mnist.py b/labml_nn/normalization/batch_norm/mnist.py new file mode 100644 index 00000000..eeeaafb8 --- /dev/null +++ b/labml_nn/normalization/batch_norm/mnist.py @@ -0,0 +1,82 @@ +""" +--- +title: MNIST Experiment to try Batch Normalization +summary: > + This is a simple model for MNIST digit classification that uses batch normalization +--- + +# MNIST Experiment for Batch Normalization +""" + +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data + +from labml import experiment +from labml.configs import option +from labml_helpers.module import Module +from labml_nn.experiments.mnist import MNISTConfigs +from labml_nn.normalization.batch_norm import BatchNorm + + +class Model(Module): + """ + ### Model definition + """ + + def __init__(self): + super().__init__() + # Note that we omit the bias parameter + self.conv1 = nn.Conv2d(1, 20, 5, 1, bias=False) + # Batch normalization with 20 channels (output of convolution layer). + # The input to this layer will have shape `[batch_size, 20, height(24), width(24)]` + self.bn1 = BatchNorm(20) + # + self.conv2 = nn.Conv2d(20, 50, 5, 1, bias=False) + # Batch normalization with 50 channels. + # The input to this layer will have shape `[batch_size, 50, height(8), width(8)]` + self.bn2 = BatchNorm(50) + # + self.fc1 = nn.Linear(4 * 4 * 50, 500, bias=False) + # Batch normalization with 500 channels (output of fully connected layer). + # The input to this layer will have shape `[batch_size, 500]` + self.bn3 = BatchNorm(500) + # + self.fc2 = nn.Linear(500, 10) + + def __call__(self, x: torch.Tensor): + x = F.relu(self.bn1(self.conv1(x))) + x = F.max_pool2d(x, 2, 2) + x = F.relu(self.bn2(self.conv2(x))) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4 * 4 * 50) + x = F.relu(self.bn3(self.fc1(x))) + return self.fc2(x) + + +@option(MNISTConfigs.model) +def model(c: MNISTConfigs): + """ + ### Create model + + We use [`MNISTConfigs`](../../experiments/mnist.html#MNISTConfigs) configurations + and set a new function to calculate the model. + """ + return Model().to(c.device) + + +def main(): + # Create experiment + experiment.create(name='mnist_batch_norm') + # Create configurations + conf = MNISTConfigs() + # Load configurations + experiment.configs(conf, {'optimizer.optimizer': 'Adam'}) + # Start the experiment and run the training loop + with experiment.start(): + conf.run() + + +# +if __name__ == '__main__': + main() diff --git a/labml_nn/normalization/mnist.py b/labml_nn/normalization/mnist.py deleted file mode 100644 index cb98f57e..00000000 --- a/labml_nn/normalization/mnist.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.data - -from labml import experiment, tracker -from labml.configs import option -from labml_helpers.datasets.mnist import MNISTConfigs -from labml_helpers.device import DeviceConfigs -from labml_helpers.metrics.accuracy import Accuracy -from labml_helpers.module import Module -from labml_helpers.seed import SeedConfigs -from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs -from labml_nn.normalization.batch_norm import BatchNorm - - -class Net(Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(1, 20, 5, 1) - self.bn1 = BatchNorm(20) - self.conv2 = nn.Conv2d(20, 50, 5, 1) - self.bn2 = BatchNorm(50) - self.fc1 = nn.Linear(4 * 4 * 50, 500) - self.bn3 = BatchNorm(500) - self.fc2 = nn.Linear(500, 10) - - def __call__(self, x: torch.Tensor): - x = F.relu(self.bn1(self.conv1(x))) - x = F.max_pool2d(x, 2, 2) - x = F.relu(self.bn2(self.conv2(x))) - x = F.max_pool2d(x, 2, 2) - x = x.view(-1, 4 * 4 * 50) - x = F.relu(self.bn3(self.fc1(x))) - return self.fc2(x) - - -class Configs(MNISTConfigs, TrainValidConfigs): - optimizer: torch.optim.Adam - model: nn.Module - set_seed = SeedConfigs() - device: torch.device = DeviceConfigs() - epochs: int = 10 - - is_save_models = True - model: nn.Module - inner_iterations = 10 - - accuracy_func = Accuracy() - loss_func = nn.CrossEntropyLoss() - - def init(self): - tracker.set_queue("loss.*", 20, True) - tracker.set_scalar("accuracy.*", True) - hook_model_outputs(self.mode, self.model, 'model') - self.state_modules = [self.accuracy_func] - - def step(self, batch: any, batch_idx: BatchIndex): - data, target = batch[0].to(self.device), batch[1].to(self.device) - - if self.mode.is_train: - tracker.add_global_step(len(data)) - - with self.mode.update(is_log_activations=batch_idx.is_last): - output = self.model(data) - - loss = self.loss_func(output, target) - self.accuracy_func(output, target) - tracker.add("loss.", loss) - - if self.mode.is_train: - loss.backward() - - self.optimizer.step() - if batch_idx.is_last: - tracker.add('model', self.model) - self.optimizer.zero_grad() - - tracker.save() - - -@option(Configs.model) -def model(c: Configs): - return Net().to(c.device) - - -@option(Configs.optimizer) -def _optimizer(c: Configs): - from labml_helpers.optimizer import OptimizerConfigs - opt_conf = OptimizerConfigs() - opt_conf.parameters = c.model.parameters() - return opt_conf - - -def main(): - conf = Configs() - experiment.create(name='mnist_labml_helpers') - experiment.configs(conf, {'optimizer.optimizer': 'Adam'}) - conf.set_seed.set() - experiment.add_pytorch_models(dict(model=conf.model)) - with experiment.start(): - conf.run() - - -if __name__ == '__main__': - main()