mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 20:28:41 +08:00
📚 layer norm docs
This commit is contained in:
@ -60,6 +60,7 @@ and
|
||||
|
||||
#### ✨ [Normalization Layers](https://nn.labml.ai/normalization/index.html)
|
||||
* [Batch Normalization](https://nn.labml.ai/normalization/batch_norm/index.html)
|
||||
* [Layer Normalization](https://nn.labml.ai/normalization/layer_norm/index.html)
|
||||
|
||||
### Installation
|
||||
|
||||
|
||||
@ -8,10 +8,10 @@ summary: >
|
||||
# Normalization Layers
|
||||
|
||||
* [Batch Normalization](batch_norm/index.html)
|
||||
* [Layer Normalization](layer_norm/index.html)
|
||||
|
||||
*TODO*
|
||||
|
||||
* Layer Normalization
|
||||
* Instance Normalization
|
||||
* Group Normalization
|
||||
"""
|
||||
@ -109,18 +109,21 @@ class BatchNorm(Module):
|
||||
|
||||
When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
|
||||
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
|
||||
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
|
||||
$$\text{BN}(X) = \gamma
|
||||
\frac{X - \underset{B, H, W}{\mathbb{E}}[X]}{\sqrt{\underset{B, H, W}{Var}[X] + \epsilon}}
|
||||
+ \beta$$
|
||||
|
||||
When input $X \in \mathbb{R}^{B \times C}$ is a batch of vector embeddings,
|
||||
When input $X \in \mathbb{R}^{B \times C}$ is a batch of embeddings,
|
||||
where $B$ is the batch size and $C$ is the number of features.
|
||||
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
|
||||
$$\text{BN}(X) = \gamma
|
||||
\frac{X - \underset{B}{\mathbb{E}}[X]}{\sqrt{\underset{B}{Var}[X] + \epsilon}}
|
||||
+ \beta$$
|
||||
|
||||
When input $X \in \mathbb{R}^{B \times C \times L}$ is a batch of sequence embeddings,
|
||||
When input $X \in \mathbb{R}^{B \times C \times L}$ is a batch of a sequence embeddings,
|
||||
where $B$ is the batch size, $C$ is the number of features, and $L$ is the length of the sequence.
|
||||
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
|
||||
$$\text{BN}(X) = \gamma
|
||||
\frac{X - \underset{B, L}{\mathbb{E}}[X]}{\sqrt{\underset{B, L}{Var}[X] + \epsilon}}
|
||||
+ \beta$$
|
||||
@ -205,6 +208,9 @@ class BatchNorm(Module):
|
||||
|
||||
|
||||
def _test():
|
||||
"""
|
||||
Simple test
|
||||
"""
|
||||
from labml.logger import inspect
|
||||
|
||||
x = torch.zeros([2, 3, 2, 4])
|
||||
@ -216,5 +222,6 @@ def _test():
|
||||
inspect(bn.exp_var.shape)
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
_test()
|
||||
|
||||
88
labml_nn/normalization/batch_norm/readme.md
Normal file
88
labml_nn/normalization/batch_norm/readme.md
Normal file
@ -0,0 +1,88 @@
|
||||
# [Batch Normalization](https://nn.labml.ai/normalization/batch_norm/index.html)
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of Batch Normalization from paper
|
||||
[Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167).
|
||||
|
||||
### Internal Covariate Shift
|
||||
|
||||
The paper defines *Internal Covariate Shift* as the change in the
|
||||
distribution of network activations due to the change in
|
||||
network parameters during training.
|
||||
For example, let's say there are two layers $l_1$ and $l_2$.
|
||||
During the beginning of the training $l_1$ outputs (inputs to $l_2$)
|
||||
could be in distribution $\mathcal{N}(0.5, 1)$.
|
||||
Then, after some training steps, it could move to $\mathcal{N}(0.5, 1)$.
|
||||
This is *internal covariate shift*.
|
||||
|
||||
Internal covariate shift will adversely affect training speed because the later layers
|
||||
($l_2$ in the above example) has to adapt to this shifted distribution.
|
||||
|
||||
By stabilizing the distribution batch normalization minimizes the internal covariate shift.
|
||||
|
||||
## Normalization
|
||||
|
||||
It is known that whitening improves training speed and convergence.
|
||||
*Whitening* is linearly transforming inputs to have zero mean, unit variance,
|
||||
and be uncorrelated.
|
||||
|
||||
### Normalizing outside gradient computation doesn't work
|
||||
|
||||
Normalizing outside the gradient computation using pre-computed (detached)
|
||||
means and variances doesn't work. For instance. (ignoring variance), let
|
||||
$$\hat{x} = x - \mathbb{E}[x]$$
|
||||
where $x = u + b$ and $b$ is a trained bias.
|
||||
and $\mathbb{E}[x]$ is outside gradient computation (pre-computed constant).
|
||||
|
||||
Note that $\hat{x}$ has no effect of $b$.
|
||||
Therefore,
|
||||
$b$ will increase or decrease based
|
||||
$\frac{\partial{\mathcal{L}}}{\partial x}$,
|
||||
and keep on growing indefinitely in each training update.
|
||||
The paper notes that similar explosions happen with variances.
|
||||
|
||||
### Batch Normalization
|
||||
|
||||
Whitening is computationally expensive because you need to de-correlate and
|
||||
the gradients must flow through the full whitening calculation.
|
||||
|
||||
The paper introduces simplified version which they call *Batch Normalization*.
|
||||
First simplification is that it normalizes each feature independently to have
|
||||
zero mean and unit variance:
|
||||
$$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}$$
|
||||
where $x = (x^{(1)} ... x^{(d)})$ is the $d$-dimensional input.
|
||||
|
||||
The second simplification is to use estimates of mean $\mathbb{E}[x^{(k)}]$
|
||||
and variance $Var[x^{(k)}]$ from the mini-batch
|
||||
for normalization; instead of calculating the mean and variance across whole dataset.
|
||||
|
||||
Normalizing each feature to zero mean and unit variance could affect what the layer
|
||||
can represent.
|
||||
As an example paper illustrates that, if the inputs to a sigmoid are normalized
|
||||
most of it will be within $[-1, 1]$ range where the sigmoid is linear.
|
||||
To overcome this each feature is scaled and shifted by two trained parameters
|
||||
$\gamma^{(k)}$ and $\beta^{(k)}$.
|
||||
$$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$
|
||||
where $y^{(k)}$ is the output of the batch normalization layer.
|
||||
|
||||
Note that when applying batch normalization after a linear transform
|
||||
like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization.
|
||||
So you can and should omit bias parameter in linear transforms right before the
|
||||
batch normalization.
|
||||
|
||||
Batch normalization also makes the back propagation invariant to the scale of the weights.
|
||||
And empirically it improves generalization, so it has regularization effects too.
|
||||
|
||||
## Inference
|
||||
|
||||
We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to
|
||||
perform the normalization.
|
||||
So during inference, you either need to go through the whole (or part of) dataset
|
||||
and find the mean and variance, or you can use an estimate calculated during training.
|
||||
The usual practice is to calculate an exponential moving average of
|
||||
mean and variance during the training phase and use that for inference.
|
||||
|
||||
Here's [the training code](https://nn.labml.ai/normalization/layer_norm/mnist.html) and a notebook for training
|
||||
a CNN classifier that use batch normalization for MNIST dataset.
|
||||
|
||||
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/batch_norm/mnist.ipynb)
|
||||
[](https://web.lab-ml.com/run?uuid=011254fe647011ebbb8e0242ac1c0002)
|
||||
@ -24,7 +24,7 @@ Layer normalization is a simpler normalization method that works
|
||||
on a wider range of settings.
|
||||
Layer normalization transformers the inputs to have zero mean and unit variance
|
||||
across the features.
|
||||
*Note that batch normalization, fixes the zero mean and unit variance for each vector.
|
||||
*Note that batch normalization fixes the zero mean and unit variance for each element.*
|
||||
Layer normalization does it for each batch across all elements.
|
||||
|
||||
Layer normalization is generally used for NLP tasks.
|
||||
@ -41,18 +41,42 @@ from labml_helpers.module import Module
|
||||
|
||||
|
||||
class LayerNorm(Module):
|
||||
"""
|
||||
r"""
|
||||
## Layer Normalization
|
||||
|
||||
Layer normalization $\text{LN}$ normalizes the input $X$ as follows:
|
||||
|
||||
When input $X \in \mathbb{R}^{B \times C}$ is a batch of embeddings,
|
||||
where $B$ is the batch size and $C$ is the number of features.
|
||||
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
|
||||
$$\text{LN}(X) = \gamma
|
||||
\frac{X - \underset{C}{\mathbb{E}}[X]}{\sqrt{\underset{C}{Var}[X] + \epsilon}}
|
||||
+ \beta$$
|
||||
|
||||
When input $X \in \mathbb{R}^{L \times B \times C}$ is a batch of a sequence of embeddings,
|
||||
where $B$ is the batch size, $C$ is the number of channels, $L$ is the length of the sequence.
|
||||
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
|
||||
$$\text{LN}(X) = \gamma
|
||||
\frac{X - \underset{C}{\mathbb{E}}[X]}{\sqrt{\underset{C}{Var}[X] + \epsilon}}
|
||||
+ \beta$$
|
||||
|
||||
When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
|
||||
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
|
||||
This is not a widely used scenario.
|
||||
$\gamma \in \mathbb{R}^{C \times H \times W}$ and $\beta \in \mathbb{R}^{C \times H \times W}$.
|
||||
$$\text{LN}(X) = \gamma
|
||||
\frac{X - \underset{C, H, W}{\mathbb{E}}[X]}{\sqrt{\underset{C, H, W}{Var}[X] + \epsilon}}
|
||||
+ \beta$$
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: Union[int, List[int], Size], *,
|
||||
eps: float = 1e-5,
|
||||
elementwise_affine: bool = True):
|
||||
"""
|
||||
* `normalized_shape` $S$ is shape of the elements (except the batch).
|
||||
* `normalized_shape` $S$ is the shape of the elements (except the batch).
|
||||
The input should then be
|
||||
$X \in \mathbb{R}^{* \times S[0] \times S[1] \times ... \times S[n]}$
|
||||
* `eps` is $\epsilon$, used in $\sqrt{Var[X}] + \epsilon}$ for numerical stability
|
||||
* `eps` is $\epsilon$, used in $\sqrt{Var[X] + \epsilon}$ for numerical stability
|
||||
* `elementwise_affine` is whether to scale and shift the normalized value
|
||||
|
||||
We've tried to use the same names for arguments as PyTorch `LayerNorm` implementation.
|
||||
@ -74,34 +98,35 @@ class LayerNorm(Module):
|
||||
For example, in an NLP task this will be
|
||||
`[seq_len, batch_size, features]`
|
||||
"""
|
||||
# Keep the original shape
|
||||
x_shape = x.shape
|
||||
# Sanity check to make sure the shapes match
|
||||
assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
|
||||
|
||||
# Reshape into `[M, S[0], S[1], ..., S[n]]`
|
||||
x = x.view(-1, *self.normalized_shape)
|
||||
# The dimensions to calculate the mean and variance on
|
||||
dims = [-(i + 1) for i in range(len(self.normalized_shape))]
|
||||
|
||||
# Calculate the mean across first dimension;
|
||||
# i.e. the means for each element $\mathbb{E}[X}]$
|
||||
mean = x.mean(dim=0)
|
||||
# Calculate the squared mean across first dimension;
|
||||
# Calculate the mean of all elements;
|
||||
# i.e. the means for each element $\mathbb{E}[X]$
|
||||
mean = x.mean(dim=dims, keepdims=True)
|
||||
# Calculate the squared mean of all elements;
|
||||
# i.e. the means for each element $\mathbb{E}[X^2]$
|
||||
mean_x2 = (x ** 2).mean(dim=0)
|
||||
# Variance for each element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$
|
||||
mean_x2 = (x ** 2).mean(dim=dims, keepdims=True)
|
||||
# Variance of all element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
# Normalize $$\hat{X} = \frac{X} - \mathbb{E}[X]}{\sqrt{Var[X] + \epsilon}}$$
|
||||
# Normalize $$\hat{X} = \frac{X - \mathbb{E}[X]}{\sqrt{Var[X] + \epsilon}}$$
|
||||
x_norm = (x - mean) / torch.sqrt(var + self.eps)
|
||||
# Scale and shift $$\text{LN}(x) = \gamma \hat{X} + \beta$$
|
||||
if self.elementwise_affine:
|
||||
x_norm = self.gain * x_norm + self.bias
|
||||
|
||||
# Reshape to original and return
|
||||
return x_norm.view(x_shape)
|
||||
#
|
||||
return x_norm
|
||||
|
||||
|
||||
def _test():
|
||||
"""
|
||||
Simple test
|
||||
"""
|
||||
from labml.logger import inspect
|
||||
|
||||
x = torch.zeros([2, 3, 2, 4])
|
||||
@ -113,5 +138,6 @@ def _test():
|
||||
inspect(ln.gain.shape)
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
_test()
|
||||
|
||||
26
labml_nn/normalization/layer_norm/readme.md
Normal file
26
labml_nn/normalization/layer_norm/readme.md
Normal file
@ -0,0 +1,26 @@
|
||||
# [Layer Normalization](https://nn.labml.ai/normalization/layer_norm/index.html)
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of
|
||||
[Layer Normalization](https://arxiv.org/abs/1607.06450).
|
||||
|
||||
### Limitations of [Batch Normalization](https://nn.labml.ai/normalization/batch_norm/index.html)
|
||||
|
||||
* You need to maintain running means.
|
||||
* Tricky for RNNs. Do you need different normalizations for each step?
|
||||
* Doesn't work with small batch sizes;
|
||||
large NLP models are usually trained with small batch sizes.
|
||||
* Need to compute means and variances across devices in distributed training
|
||||
|
||||
## Layer Normalization
|
||||
|
||||
Layer normalization is a simpler normalization method that works
|
||||
on a wider range of settings.
|
||||
Layer normalization transformers the inputs to have zero mean and unit variance
|
||||
across the features.
|
||||
*Note that batch normalization fixes the zero mean and unit variance for each element.*
|
||||
Layer normalization does it for each batch across all elements.
|
||||
|
||||
Layer normalization is generally used for NLP tasks.
|
||||
|
||||
We have used layer normalization in most of the
|
||||
[transformer implementations](https://nn.labml.ai/transformers/gpt/index.html).
|
||||
Reference in New Issue
Block a user