📚 layer norm docs

This commit is contained in:
Varuna Jayasiri
2021-02-02 11:51:22 +05:30
parent 5388e807e1
commit 3089e9131c
14 changed files with 663 additions and 141 deletions

View File

@ -60,6 +60,7 @@ and
#### ✨ [Normalization Layers](https://nn.labml.ai/normalization/index.html)
* [Batch Normalization](https://nn.labml.ai/normalization/batch_norm/index.html)
* [Layer Normalization](https://nn.labml.ai/normalization/layer_norm/index.html)
### Installation

View File

@ -8,10 +8,10 @@ summary: >
# Normalization Layers
* [Batch Normalization](batch_norm/index.html)
* [Layer Normalization](layer_norm/index.html)
*TODO*
* Layer Normalization
* Instance Normalization
* Group Normalization
"""

View File

@ -109,18 +109,21 @@ class BatchNorm(Module):
When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
$$\text{BN}(X) = \gamma
\frac{X - \underset{B, H, W}{\mathbb{E}}[X]}{\sqrt{\underset{B, H, W}{Var}[X] + \epsilon}}
+ \beta$$
When input $X \in \mathbb{R}^{B \times C}$ is a batch of vector embeddings,
When input $X \in \mathbb{R}^{B \times C}$ is a batch of embeddings,
where $B$ is the batch size and $C$ is the number of features.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
$$\text{BN}(X) = \gamma
\frac{X - \underset{B}{\mathbb{E}}[X]}{\sqrt{\underset{B}{Var}[X] + \epsilon}}
+ \beta$$
When input $X \in \mathbb{R}^{B \times C \times L}$ is a batch of sequence embeddings,
When input $X \in \mathbb{R}^{B \times C \times L}$ is a batch of a sequence embeddings,
where $B$ is the batch size, $C$ is the number of features, and $L$ is the length of the sequence.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
$$\text{BN}(X) = \gamma
\frac{X - \underset{B, L}{\mathbb{E}}[X]}{\sqrt{\underset{B, L}{Var}[X] + \epsilon}}
+ \beta$$
@ -205,6 +208,9 @@ class BatchNorm(Module):
def _test():
"""
Simple test
"""
from labml.logger import inspect
x = torch.zeros([2, 3, 2, 4])
@ -216,5 +222,6 @@ def _test():
inspect(bn.exp_var.shape)
#
if __name__ == '__main__':
_test()

View File

@ -0,0 +1,88 @@
# [Batch Normalization](https://nn.labml.ai/normalization/batch_norm/index.html)
This is a [PyTorch](https://pytorch.org) implementation of Batch Normalization from paper
[Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167).
### Internal Covariate Shift
The paper defines *Internal Covariate Shift* as the change in the
distribution of network activations due to the change in
network parameters during training.
For example, let's say there are two layers $l_1$ and $l_2$.
During the beginning of the training $l_1$ outputs (inputs to $l_2$)
could be in distribution $\mathcal{N}(0.5, 1)$.
Then, after some training steps, it could move to $\mathcal{N}(0.5, 1)$.
This is *internal covariate shift*.
Internal covariate shift will adversely affect training speed because the later layers
($l_2$ in the above example) has to adapt to this shifted distribution.
By stabilizing the distribution batch normalization minimizes the internal covariate shift.
## Normalization
It is known that whitening improves training speed and convergence.
*Whitening* is linearly transforming inputs to have zero mean, unit variance,
and be uncorrelated.
### Normalizing outside gradient computation doesn't work
Normalizing outside the gradient computation using pre-computed (detached)
means and variances doesn't work. For instance. (ignoring variance), let
$$\hat{x} = x - \mathbb{E}[x]$$
where $x = u + b$ and $b$ is a trained bias.
and $\mathbb{E}[x]$ is outside gradient computation (pre-computed constant).
Note that $\hat{x}$ has no effect of $b$.
Therefore,
$b$ will increase or decrease based
$\frac{\partial{\mathcal{L}}}{\partial x}$,
and keep on growing indefinitely in each training update.
The paper notes that similar explosions happen with variances.
### Batch Normalization
Whitening is computationally expensive because you need to de-correlate and
the gradients must flow through the full whitening calculation.
The paper introduces simplified version which they call *Batch Normalization*.
First simplification is that it normalizes each feature independently to have
zero mean and unit variance:
$$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}$$
where $x = (x^{(1)} ... x^{(d)})$ is the $d$-dimensional input.
The second simplification is to use estimates of mean $\mathbb{E}[x^{(k)}]$
and variance $Var[x^{(k)}]$ from the mini-batch
for normalization; instead of calculating the mean and variance across whole dataset.
Normalizing each feature to zero mean and unit variance could affect what the layer
can represent.
As an example paper illustrates that, if the inputs to a sigmoid are normalized
most of it will be within $[-1, 1]$ range where the sigmoid is linear.
To overcome this each feature is scaled and shifted by two trained parameters
$\gamma^{(k)}$ and $\beta^{(k)}$.
$$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$
where $y^{(k)}$ is the output of the batch normalization layer.
Note that when applying batch normalization after a linear transform
like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization.
So you can and should omit bias parameter in linear transforms right before the
batch normalization.
Batch normalization also makes the back propagation invariant to the scale of the weights.
And empirically it improves generalization, so it has regularization effects too.
## Inference
We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to
perform the normalization.
So during inference, you either need to go through the whole (or part of) dataset
and find the mean and variance, or you can use an estimate calculated during training.
The usual practice is to calculate an exponential moving average of
mean and variance during the training phase and use that for inference.
Here's [the training code](https://nn.labml.ai/normalization/layer_norm/mnist.html) and a notebook for training
a CNN classifier that use batch normalization for MNIST dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/batch_norm/mnist.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=011254fe647011ebbb8e0242ac1c0002)

View File

@ -24,7 +24,7 @@ Layer normalization is a simpler normalization method that works
on a wider range of settings.
Layer normalization transformers the inputs to have zero mean and unit variance
across the features.
*Note that batch normalization, fixes the zero mean and unit variance for each vector.
*Note that batch normalization fixes the zero mean and unit variance for each element.*
Layer normalization does it for each batch across all elements.
Layer normalization is generally used for NLP tasks.
@ -41,18 +41,42 @@ from labml_helpers.module import Module
class LayerNorm(Module):
"""
r"""
## Layer Normalization
Layer normalization $\text{LN}$ normalizes the input $X$ as follows:
When input $X \in \mathbb{R}^{B \times C}$ is a batch of embeddings,
where $B$ is the batch size and $C$ is the number of features.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
$$\text{LN}(X) = \gamma
\frac{X - \underset{C}{\mathbb{E}}[X]}{\sqrt{\underset{C}{Var}[X] + \epsilon}}
+ \beta$$
When input $X \in \mathbb{R}^{L \times B \times C}$ is a batch of a sequence of embeddings,
where $B$ is the batch size, $C$ is the number of channels, $L$ is the length of the sequence.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
$$\text{LN}(X) = \gamma
\frac{X - \underset{C}{\mathbb{E}}[X]}{\sqrt{\underset{C}{Var}[X] + \epsilon}}
+ \beta$$
When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
This is not a widely used scenario.
$\gamma \in \mathbb{R}^{C \times H \times W}$ and $\beta \in \mathbb{R}^{C \times H \times W}$.
$$\text{LN}(X) = \gamma
\frac{X - \underset{C, H, W}{\mathbb{E}}[X]}{\sqrt{\underset{C, H, W}{Var}[X] + \epsilon}}
+ \beta$$
"""
def __init__(self, normalized_shape: Union[int, List[int], Size], *,
eps: float = 1e-5,
elementwise_affine: bool = True):
"""
* `normalized_shape` $S$ is shape of the elements (except the batch).
* `normalized_shape` $S$ is the shape of the elements (except the batch).
The input should then be
$X \in \mathbb{R}^{* \times S[0] \times S[1] \times ... \times S[n]}$
* `eps` is $\epsilon$, used in $\sqrt{Var[X}] + \epsilon}$ for numerical stability
* `eps` is $\epsilon$, used in $\sqrt{Var[X] + \epsilon}$ for numerical stability
* `elementwise_affine` is whether to scale and shift the normalized value
We've tried to use the same names for arguments as PyTorch `LayerNorm` implementation.
@ -74,34 +98,35 @@ class LayerNorm(Module):
For example, in an NLP task this will be
`[seq_len, batch_size, features]`
"""
# Keep the original shape
x_shape = x.shape
# Sanity check to make sure the shapes match
assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
# Reshape into `[M, S[0], S[1], ..., S[n]]`
x = x.view(-1, *self.normalized_shape)
# The dimensions to calculate the mean and variance on
dims = [-(i + 1) for i in range(len(self.normalized_shape))]
# Calculate the mean across first dimension;
# i.e. the means for each element $\mathbb{E}[X}]$
mean = x.mean(dim=0)
# Calculate the squared mean across first dimension;
# Calculate the mean of all elements;
# i.e. the means for each element $\mathbb{E}[X]$
mean = x.mean(dim=dims, keepdims=True)
# Calculate the squared mean of all elements;
# i.e. the means for each element $\mathbb{E}[X^2]$
mean_x2 = (x ** 2).mean(dim=0)
# Variance for each element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$
mean_x2 = (x ** 2).mean(dim=dims, keepdims=True)
# Variance of all element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$
var = mean_x2 - mean ** 2
# Normalize $$\hat{X} = \frac{X} - \mathbb{E}[X]}{\sqrt{Var[X] + \epsilon}}$$
# Normalize $$\hat{X} = \frac{X - \mathbb{E}[X]}{\sqrt{Var[X] + \epsilon}}$$
x_norm = (x - mean) / torch.sqrt(var + self.eps)
# Scale and shift $$\text{LN}(x) = \gamma \hat{X} + \beta$$
if self.elementwise_affine:
x_norm = self.gain * x_norm + self.bias
# Reshape to original and return
return x_norm.view(x_shape)
#
return x_norm
def _test():
"""
Simple test
"""
from labml.logger import inspect
x = torch.zeros([2, 3, 2, 4])
@ -113,5 +138,6 @@ def _test():
inspect(ln.gain.shape)
#
if __name__ == '__main__':
_test()

View File

@ -0,0 +1,26 @@
# [Layer Normalization](https://nn.labml.ai/normalization/layer_norm/index.html)
This is a [PyTorch](https://pytorch.org) implementation of
[Layer Normalization](https://arxiv.org/abs/1607.06450).
### Limitations of [Batch Normalization](https://nn.labml.ai/normalization/batch_norm/index.html)
* You need to maintain running means.
* Tricky for RNNs. Do you need different normalizations for each step?
* Doesn't work with small batch sizes;
large NLP models are usually trained with small batch sizes.
* Need to compute means and variances across devices in distributed training
## Layer Normalization
Layer normalization is a simpler normalization method that works
on a wider range of settings.
Layer normalization transformers the inputs to have zero mean and unit variance
across the features.
*Note that batch normalization fixes the zero mean and unit variance for each element.*
Layer normalization does it for each batch across all elements.
Layer normalization is generally used for NLP tasks.
We have used layer normalization in most of the
[transformer implementations](https://nn.labml.ai/transformers/gpt/index.html).