diff --git a/docs/normalization/batch_channel_norm/index.html b/docs/normalization/batch_channel_norm/index.html new file mode 100644 index 00000000..7332e8dc --- /dev/null +++ b/docs/normalization/batch_channel_norm/index.html @@ -0,0 +1,663 @@ + + + + + + + + + + + + + + + + + + + + + + + Batch-Channel Normalization + + + + + + + + +
+
+
+ +
+
+
+ +

Batch-Channel Normalization

+

This is a PyTorch implementation of Batch-Channel Normalization from the paper + Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. +We also have an annotated implementation of Weight Standardization.

+

Batch-Channel Normalization performs batch normalization followed +by a channel normalization (similar to a Group Normalization. +When the batch size is small a running mean and variance is used for +batch normalization.

+

Here is the training code for training +a VGG network that uses weight standardization to classify CIFAR-10 data.

+

Open In Colab +View Run +WandB

+
+
+
27import torch
+28from torch import nn
+29
+30from labml_helpers.module import Module
+31from labml_nn.normalization.batch_norm import BatchNorm
+
+
+
+
+ +

Batch-Channel Normalization

+

This first performs a batch normalization - either normal batch norm +or a batch norm with +estimated mean and variance (exponential mean/variance over multiple batches). +Then a channel normalization performed.

+
+
+
34class BatchChannelNorm(Module):
+
+
+
+
+ +
    +
  • channels is the number of features in the input
  • +
  • groups is the number of groups the features are divided into
  • +
  • eps is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
  • +
  • momentum is the momentum in taking the exponential moving average
  • +
  • estimate is whether to use running mean and variance for batch norm
  • +
+
+
+
44    def __init__(self, channels: int, groups: int,
+45                 eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
+
+
+
+
+ + +
+
+
53        super().__init__()
+
+
+
+
+ +

Use estimated batch norm or normal batch norm.

+
+
+
56        if estimate:
+57            self.batch_norm = EstimatedBatchNorm(channels,
+58                                                 eps=eps, momentum=momentum)
+59        else:
+60            self.batch_norm = BatchNorm(channels,
+61                                        eps=eps, momentum=momentum)
+
+
+
+
+ +

Channel normalization

+
+
+
64        self.channel_norm = ChannelNorm(channels, groups, eps)
+
+
+
+
+ + +
+
+
66    def __call__(self, x):
+67        x = self.batch_norm(x)
+68        return self.channel_norm(x)
+
+
+
+
+ +

Estimated Batch Normalization

+

When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations, +where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. +$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.

+

+ +

+

where,

+

+ +

+

are the running mean and variances. $r$ is the momentum for calculating the exponential mean.

+
+
+
71class EstimatedBatchNorm(Module):
+
+
+
+
+ +
    +
  • channels is the number of features in the input
  • +
  • eps is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
  • +
  • momentum is the momentum in taking the exponential moving average
  • +
  • estimate is whether to use running mean and variance for batch norm
  • +
+
+
+
92    def __init__(self, channels: int,
+93                 eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
+
+
+
+
+ + +
+
+
100        super().__init__()
+101
+102        self.eps = eps
+103        self.momentum = momentum
+104        self.affine = affine
+105        self.channels = channels
+
+
+
+
+ +

Channel wise transformation parameters

+
+
+
108        if self.affine:
+109            self.scale = nn.Parameter(torch.ones(channels))
+110            self.shift = nn.Parameter(torch.zeros(channels))
+
+
+
+
+ +

Tensors for $\hat{\mu}_C$ and $\hat{\sigma}^2_C$

+
+
+
113        self.register_buffer('exp_mean', torch.zeros(channels))
+114        self.register_buffer('exp_var', torch.ones(channels))
+
+
+
+
+ +

x is a tensor of shape [batch_size, channels, *]. +* denotes any number of (possibly 0) dimensions. + For example, in an image (2D) convolution this will be +[batch_size, channels, height, width]

+
+
+
116    def __call__(self, x: torch.Tensor):
+
+
+
+
+ +

Keep old shape

+
+
+
124        x_shape = x.shape
+
+
+
+
+ +

Get the batch size

+
+
+
126        batch_size = x_shape[0]
+
+
+
+
+ +

Sanity check to make sure the number of features is correct

+
+
+
129        assert self.channels == x.shape[1]
+
+
+
+
+ +

Reshape into [batch_size, channels, n]

+
+
+
132        x = x.view(batch_size, self.channels, -1)
+
+
+
+
+ +

Update $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ in training mode only

+
+
+
135        if self.training:
+
+
+
+
+ +

No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$

+
+
+
137            with torch.no_grad():
+
+
+
+
+ +

Calculate the mean across first and last dimensions; +$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$

+
+
+
140                mean = x.mean(dim=[0, 2])
+
+
+
+
+ +

Calculate the squared mean across first and last dimensions; +$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$

+
+
+
143                mean_x2 = (x ** 2).mean(dim=[0, 2])
+
+
+
+
+ +

Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2

+
+
+
145                var = mean_x2 - mean ** 2
+
+
+
+
+ +

Update exponential moving averages + +

+
+
+
152                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
+153                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
+
+
+
+
+ +

Normalize + +

+
+
+
157        x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)
+
+
+
+
+ +

Scale and shift + +

+
+
+
162        if self.affine:
+163            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
+
+
+
+
+ +

Reshape to original and return

+
+
+
166        return x_norm.view(x_shape)
+
+
+
+
+ +

Channel Normalization

+

This is similar to Group Normalization but affine transform is done group wise.

+
+
+
169class ChannelNorm(Module):
+
+
+
+
+ +
    +
  • groups is the number of groups the features are divided into
  • +
  • channels is the number of features in the input
  • +
  • eps is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
  • +
  • affine is whether to scale and shift the normalized value
  • +
+
+
+
176    def __init__(self, channels, groups,
+177                 eps: float = 1e-5, affine: bool = True):
+
+
+
+
+ + +
+
+
184        super().__init__()
+185        self.channels = channels
+186        self.groups = groups
+187        self.eps = eps
+188        self.affine = affine
+
+
+
+
+ +

Parameters for affine transformation.

+

Note that these transforms are per group, unlike in group norm where +they are transformed channel-wise.

+
+
+
193        if self.affine:
+194            self.scale = nn.Parameter(torch.ones(groups))
+195            self.shift = nn.Parameter(torch.zeros(groups))
+
+
+
+
+ +

x is a tensor of shape [batch_size, channels, *]. +* denotes any number of (possibly 0) dimensions. + For example, in an image (2D) convolution this will be +[batch_size, channels, height, width]

+
+
+
197    def __call__(self, x: torch.Tensor):
+
+
+
+
+ +

Keep the original shape

+
+
+
206        x_shape = x.shape
+
+
+
+
+ +

Get the batch size

+
+
+
208        batch_size = x_shape[0]
+
+
+
+
+ +

Sanity check to make sure the number of features is the same

+
+
+
210        assert self.channels == x.shape[1]
+
+
+
+
+ +

Reshape into [batch_size, groups, n]

+
+
+
213        x = x.view(batch_size, self.groups, -1)
+
+
+
+
+ +

Calculate the mean across last dimension; +i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$

+
+
+
217        mean = x.mean(dim=[-1], keepdim=True)
+
+
+
+
+ +

Calculate the squared mean across last dimension; +i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$

+
+
+
220        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
+
+
+
+
+ +

Variance for each sample and feature group +$Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$

+
+
+
223        var = mean_x2 - mean ** 2
+
+
+
+
+ +

Normalize + +

+
+
+
228        x_norm = (x - mean) / torch.sqrt(var + self.eps)
+
+
+
+
+ +

Scale and shift group-wise + +

+
+
+
232        if self.affine:
+233            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
+
+
+
+
+ +

Reshape to original and return

+
+
+
236        return x_norm.view(x_shape)
+
+
+
+ + + + + + \ No newline at end of file diff --git a/docs/normalization/group_norm/index.html b/docs/normalization/group_norm/index.html index eefd7cb8..69b70cca 100644 --- a/docs/normalization/group_norm/index.html +++ b/docs/normalization/group_norm/index.html @@ -137,10 +137,10 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

WandB

-
87import torch
-88from torch import nn
-89
-90from labml_helpers.module import Module
+
86import torch
+87from torch import nn
+88
+89from labml_helpers.module import Module
@@ -151,7 +151,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

Group Normalization Layer

-
93class GroupNorm(Module):
+
92class GroupNorm(Module):
@@ -167,8 +167,8 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

-
98    def __init__(self, groups: int, channels: int, *,
-99                 eps: float = 1e-5, affine: bool = True):
+
97    def __init__(self, groups: int, channels: int, *,
+98                 eps: float = 1e-5, affine: bool = True):
@@ -179,14 +179,14 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

-
106        super().__init__()
-107
-108        assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups"
-109        self.groups = groups
-110        self.channels = channels
-111
-112        self.eps = eps
-113        self.affine = affine
+
105        super().__init__()
+106
+107        assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups"
+108        self.groups = groups
+109        self.channels = channels
+110
+111        self.eps = eps
+112        self.affine = affine
@@ -197,9 +197,9 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

Create parameters for $\gamma$ and $\beta$ for scale and shift

-
115        if self.affine:
-116            self.scale = nn.Parameter(torch.ones(channels))
-117            self.shift = nn.Parameter(torch.zeros(channels))
+
114        if self.affine:
+115            self.scale = nn.Parameter(torch.ones(channels))
+116            self.shift = nn.Parameter(torch.zeros(channels))
@@ -213,7 +213,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

[batch_size, channels, height, width]

-
119    def forward(self, x: torch.Tensor):
+
118    def forward(self, x: torch.Tensor):
@@ -224,7 +224,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

Keep the original shape

-
127        x_shape = x.shape
+
126        x_shape = x.shape
@@ -235,7 +235,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

Get the batch size

-
129        batch_size = x_shape[0]
+
128        batch_size = x_shape[0]
@@ -246,7 +246,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

Sanity check to make sure the number of features is the same

-
131        assert self.channels == x.shape[1]
+
130        assert self.channels == x.shape[1]
@@ -257,7 +257,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

Reshape into [batch_size, groups, n]

-
134        x = x.view(batch_size, self.groups, -1)
+
133        x = x.view(batch_size, self.groups, -1)
@@ -269,7 +269,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.

i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$

-
138        mean = x.mean(dim=[-1], keepdim=True)
+
137        mean = x.mean(dim=[-1], keepdim=True)
@@ -281,7 +281,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$

-
141        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
+
140        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
@@ -293,7 +293,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$< $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$

-
144        var = mean_x2 - mean ** 2
+
143        var = mean_x2 - mean ** 2
@@ -307,7 +307,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]

-
149        x_norm = (x - mean) / torch.sqrt(var + self.eps)
+
148        x_norm = (x - mean) / torch.sqrt(var + self.eps)
@@ -320,9 +320,9 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]

-
153        if self.affine:
-154            x_norm = x_norm.view(batch_size, self.channels, -1)
-155            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
+
152        if self.affine:
+153            x_norm = x_norm.view(batch_size, self.channels, -1)
+154            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
@@ -333,7 +333,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]

Reshape to original and return

-
158        return x_norm.view(x_shape)
+
157        return x_norm.view(x_shape)
@@ -344,7 +344,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]

Simple test

-
161def _test():
+
160def _test():
@@ -355,14 +355,14 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
-
165    from labml.logger import inspect
-166
-167    x = torch.zeros([2, 6, 2, 4])
-168    inspect(x.shape)
-169    bn = GroupNorm(2, 6)
-170
-171    x = bn(x)
-172    inspect(x.shape)
+
164    from labml.logger import inspect
+165
+166    x = torch.zeros([2, 6, 2, 4])
+167    inspect(x.shape)
+168    bn = GroupNorm(2, 6)
+169
+170    x = bn(x)
+171    inspect(x.shape)
@@ -373,8 +373,8 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
-
176if __name__ == '__main__':
-177    _test()
+
175if __name__ == '__main__':
+176    _test()
diff --git a/docs/normalization/weight_standardization/conv2d.html b/docs/normalization/weight_standardization/conv2d.html new file mode 100644 index 00000000..3f7504ae --- /dev/null +++ b/docs/normalization/weight_standardization/conv2d.html @@ -0,0 +1,200 @@ + + + + + + + + + + + + + + + + + + + + + + + 2D Convolution Layer with Weight Standardization + + + + + + + + +
+
+
+ +
+
+
+ +

2D Convolution Layer with Weight Standardization

+

This is an implementation of a 2 dimensional convolution layer with Weight Standardization

+
+
+
13import torch
+14import torch.nn as nn
+15from torch.nn import functional as F
+16
+17from labml_nn.normalization.weight_standardization import weight_standardization
+
+
+
+
+ +

2D Convolution Layer

+

This extends the standard 2D Convolution layer and standardize the weights before the convolution step.

+
+
+
20class Conv2d(nn.Conv2d):
+
+
+
+
+ + +
+
+
26    def __init__(self, in_channels, out_channels, kernel_size,
+27                 stride=1,
+28                 padding=0,
+29                 dilation=1,
+30                 groups: int = 1,
+31                 bias: bool = True,
+32                 padding_mode: str = 'zeros',
+33                 eps: float = 1e-5):
+34        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size,
+35                                     stride=stride,
+36                                     padding=padding,
+37                                     dilation=dilation,
+38                                     groups=groups,
+39                                     bias=bias,
+40                                     padding_mode=padding_mode)
+41        self.eps = eps
+
+
+
+
+ + +
+
+
43    def forward(self, x: torch.Tensor):
+44        return F.conv2d(x, weight_standardization(self.weight, self.eps), self.bias, self.stride,
+45                        self.padding, self.dilation, self.groups)
+
+
+
+
+ +

A simple test to verify the tensor sizes

+
+
+
48def _test():
+
+
+
+
+ + +
+
+
52    conv2d = Conv2d(10, 20, 5)
+53    from labml.logger import inspect
+54    inspect(conv2d.weight)
+55    import torch
+56    inspect(conv2d(torch.zeros(10, 10, 100, 100)))
+57
+58
+59if __name__ == '__main__':
+60    _test()
+
+
+
+ + + + + + \ No newline at end of file diff --git a/docs/normalization/weight_standardization/experiment.html b/docs/normalization/weight_standardization/experiment.html new file mode 100644 index 00000000..be25927f --- /dev/null +++ b/docs/normalization/weight_standardization/experiment.html @@ -0,0 +1,267 @@ + + + + + + + + + + + + + + + + + + + + + + + CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization + + + + + + + + +
+
+
+ +
+
+
+ +

CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization

+
+
+
12import torch.nn as nn
+13
+14from labml import experiment
+15from labml.configs import option
+16from labml_helpers.module import Module
+17from labml_nn.experiments.cifar10 import CIFAR10Configs
+18from labml_nn.normalization.batch_channel_norm import BatchChannelNorm
+19from labml_nn.normalization.weight_standardization.conv2d import Conv2d
+
+
+
+
+ +

Model

+

A VGG model that use Weight Standardization and + Batch-Channel Normalization.

+
+
+
22class Model(Module):
+
+
+
+
+ + +
+
+
29    def __init__(self):
+30        super().__init__()
+31        layers = []
+32        in_channels = 3
+33        for block in [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]]:
+34            for channels in block:
+35                layers += [Conv2d(in_channels, channels, kernel_size=3, padding=1),
+36                           BatchChannelNorm(channels, 32),
+37                           nn.ReLU(inplace=True)]
+38                in_channels = channels
+39            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+40        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
+41        self.layers = nn.Sequential(*layers)
+42        self.fc = nn.Linear(512, 10)
+
+
+
+
+ + +
+
+
44    def __call__(self, x):
+45        x = self.layers(x)
+46        x = x.view(x.shape[0], -1)
+47        return self.fc(x)
+
+
+
+
+ +

Create model

+
+
+
50@option(CIFAR10Configs.model)
+51def model(c: CIFAR10Configs):
+
+
+
+
+ + +
+
+
55    return Model().to(c.device)
+
+
+
+
+ + +
+
+
58def main():
+
+
+
+
+ +

Create experiment

+
+
+
60    experiment.create(name='cifar10', comment='weight standardization')
+
+
+
+
+ +

Create configurations

+
+
+
62    conf = CIFAR10Configs()
+
+
+
+
+ +

Load configurations

+
+
+
64    experiment.configs(conf, {
+65        'optimizer.optimizer': 'Adam',
+66        'optimizer.learning_rate': 2.5e-4,
+67        'train_batch_size': 64,
+68    })
+
+
+
+
+ +

Start the experiment and run the training loop

+
+
+
70    with experiment.start():
+71        conf.run()
+
+
+
+
+ + +
+
+
75if __name__ == '__main__':
+76    main()
+
+
+
+ + + + + + \ No newline at end of file diff --git a/docs/normalization/weight_standardization/index.html b/docs/normalization/weight_standardization/index.html new file mode 100644 index 00000000..1fd20786 --- /dev/null +++ b/docs/normalization/weight_standardization/index.html @@ -0,0 +1,231 @@ + + + + + + + + + + + + + + + + + + + + + + + Weight Standardization + + + + + + + + +
+
+
+ +
+
+
+ +

Weight Standardization

+

This is a PyTorch implementation of Weight Standardization from the paper + Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. +We also have an annotated implementation of Batch-Channel Normalization.

+

Batch normalization gives a smooth loss landscape and +avoids elimination singularities. +Elimination singularities are nodes of the network that become +useless (e.g. a ReLU that gives 0 all the time).

+

However, batch normalization doesn’t work well when the batch size is too small, +which happens when training large networks because of device memory limitations. +The paper introduces Weight Standardization with Batch-Channel Normalization as +a better alternative.

+

Weight Standardization: +1. Normalizes the gradients +2. Smoothes the landscape (reduced Lipschitz constant) +3. Avoids elimination singularities

+

The Lipschitz constant is the maximum slope a function has between two points. +That is, $L$ is the Lipschitz constant where $L$ is the smallest value that satisfies, +$\forall a,b \in A: \lVert f(a) - f(b) \rVert \le L \lVert a - b \rVert$ +where $f: A \rightarrow \mathbb{R}^m, A \in \mathbb{R}^n$.

+

Elimination singularities are avoided because it keeps the statistics of the outputs similar to the +inputs. So as long as the inputs are normally distributed the outputs remain close to normal. +This avoids outputs of nodes from always falling beyond the active range of the activation function +(e.g. always negative input for a ReLU).

+

Refer to the paper for proofs.

+

Here is the training code for training +a VGG network that uses weight standardization to classify CIFAR-10 data. +This uses a 2D-Convolution Layer with Weight Standardization.

+

Open In Colab +View Run +WandB

+
+
+
50import torch
+
+
+
+
+ +

Weight Standardization

+

+ +

+

where,

+

+ +

+

for a 2D-convolution layer $O$ is the number of output channels ($O = C_{out}$) +and $I$ is the number of input channels times the kernel size ($I = C_{in} \times k_H \times k_W$)

+
+
+
53def weight_standardization(weight: torch.Tensor, eps: float):
+
+
+
+
+ +

Get $C_{out}$, $C_{in}$ and kernel shape

+
+
+
72    c_out, c_in, *kernel_shape = weight.shape
+
+
+
+
+ +

Reshape $W$ to $O \times I$

+
+
+
74    weight = weight.view(c_out, -1)
+
+
+
+
+ +

Calculate

+

+ +

+
+
+
81    var, mean = torch.var_mean(weight, dim=1, keepdim=True)
+
+
+
+
+ +

Normalize + +

+
+
+
84    weight = (weight - mean) / (torch.sqrt(var + eps))
+
+
+
+
+ +

Change back to original shape and return

+
+
+
86    return weight.view(c_out, c_in, *kernel_shape)
+
+
+
+ + + + + + \ No newline at end of file diff --git a/docs/sitemap.xml b/docs/sitemap.xml index b16eb32d..5da1d394 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -160,6 +160,27 @@ + + https://nn.labml.ai/normalization/weight_standardization/index.html + 2021-04-27T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/normalization/weight_standardization/experiment.html + 2021-04-27T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/normalization/weight_standardization/conv2d.html + 2021-04-26T16:30:00+00:00 + 1.00 + + + https://nn.labml.ai/normalization/instance_norm/index.html 2021-04-23T16:30:00+00:00 @@ -202,6 +223,13 @@ + + https://nn.labml.ai/normalization/batch_channel_norm/index.html + 2021-04-27T16:30:00+00:00 + 1.00 + + + https://nn.labml.ai/normalization/group_norm/experiment.html 2021-04-24T16:30:00+00:00 diff --git a/labml_nn/normalization/batch_channel_norm/__init__.py b/labml_nn/normalization/batch_channel_norm/__init__.py new file mode 100644 index 00000000..7c3c44ea --- /dev/null +++ b/labml_nn/normalization/batch_channel_norm/__init__.py @@ -0,0 +1,236 @@ +""" +--- +title: Batch-Channel Normalization +summary: > + A PyTorch implementation/tutorial of Batch-Channel Normalization. +--- + +# Batch-Channel Normalization + +This is a [PyTorch](https://pytorch.org) implementation of Batch-Channel Normalization from the paper + [Micro-Batch Training with Batch-Channel Normalization and Weight Standardization](https://arxiv.org/abs/1903.10520). +We also have an [annotated implementation of Weight Standardization](../weight_standardization/index.html). + +Batch-Channel Normalization performs batch normalization followed +by a channel normalization (similar to a [Group Normalization](../group_norm/index.html). +When the batch size is small a running mean and variance is used for +batch normalization. + +Here is [the training code](../weight_standardization/experiment.html) for training +a VGG network that uses weight standardization to classify CIFAR-10 data. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb) +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002) +[![WandB](https://img.shields.io/badge/wandb-run-yellow)](https://wandb.ai/vpj/cifar10/runs/3flr4k8w) +""" + +import torch +from torch import nn + +from labml_helpers.module import Module +from labml_nn.normalization.batch_norm import BatchNorm + + +class BatchChannelNorm(Module): + """ + ## Batch-Channel Normalization + + This first performs a batch normalization - either [normal batch norm](../batch_norm/index.html) + or a batch norm with + estimated mean and variance (exponential mean/variance over multiple batches). + Then a channel normalization performed. + """ + + def __init__(self, channels: int, groups: int, + eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True): + """ + * `channels` is the number of features in the input + * `groups` is the number of groups the features are divided into + * `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability + * `momentum` is the momentum in taking the exponential moving average + * `estimate` is whether to use running mean and variance for batch norm + """ + super().__init__() + + # Use estimated batch norm or normal batch norm. + if estimate: + self.batch_norm = EstimatedBatchNorm(channels, + eps=eps, momentum=momentum) + else: + self.batch_norm = BatchNorm(channels, + eps=eps, momentum=momentum) + + # Channel normalization + self.channel_norm = ChannelNorm(channels, groups, eps) + + def __call__(self, x): + x = self.batch_norm(x) + return self.channel_norm(x) + + +class EstimatedBatchNorm(Module): + """ + ## Estimated Batch Normalization + + When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations, + where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width. + $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$. + + $$\dot{X}_{\cdot, C, \cdot, \cdot} = \gamma_C + \frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C} + + \beta_C$$ + + where, + + \begin{align} + \hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\ + \hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2 + \end{align} + + are the running mean and variances. $r$ is the momentum for calculating the exponential mean. + """ + def __init__(self, channels: int, + eps: float = 1e-5, momentum: float = 0.1, affine: bool = True): + """ + * `channels` is the number of features in the input + * `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability + * `momentum` is the momentum in taking the exponential moving average + * `estimate` is whether to use running mean and variance for batch norm + """ + super().__init__() + + self.eps = eps + self.momentum = momentum + self.affine = affine + self.channels = channels + + # Channel wise transformation parameters + if self.affine: + self.scale = nn.Parameter(torch.ones(channels)) + self.shift = nn.Parameter(torch.zeros(channels)) + + # Tensors for $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ + self.register_buffer('exp_mean', torch.zeros(channels)) + self.register_buffer('exp_var', torch.ones(channels)) + + def __call__(self, x: torch.Tensor): + """ + `x` is a tensor of shape `[batch_size, channels, *]`. + `*` denotes any number of (possibly 0) dimensions. + For example, in an image (2D) convolution this will be + `[batch_size, channels, height, width]` + """ + # Keep old shape + x_shape = x.shape + # Get the batch size + batch_size = x_shape[0] + + # Sanity check to make sure the number of features is correct + assert self.channels == x.shape[1] + + # Reshape into `[batch_size, channels, n]` + x = x.view(batch_size, self.channels, -1) + + # Update $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ in training mode only + if self.training: + # No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ + with torch.no_grad(): + # Calculate the mean across first and last dimensions; + # $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$ + mean = x.mean(dim=[0, 2]) + # Calculate the squared mean across first and last dimensions; + # $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$ + mean_x2 = (x ** 2).mean(dim=[0, 2]) + # Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2 + var = mean_x2 - mean ** 2 + + # Update exponential moving averages + # \begin{align} + # \hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\ + # \hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2 + # \end{align} + self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean + self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var + + # Normalize + # $$\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}$$ + x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1) + # Scale and shift + # $$ \gamma_C + # \frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C} + # + \beta_C$$ + if self.affine: + x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1) + + # Reshape to original and return + return x_norm.view(x_shape) + + +class ChannelNorm(Module): + """ + ## Channel Normalization + + This is similar to [Group Normalization](../group_norm/index.html) but affine transform is done group wise. + """ + + def __init__(self, channels, groups, + eps: float = 1e-5, affine: bool = True): + """ + * `groups` is the number of groups the features are divided into + * `channels` is the number of features in the input + * `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability + * `affine` is whether to scale and shift the normalized value + """ + super().__init__() + self.channels = channels + self.groups = groups + self.eps = eps + self.affine = affine + # Parameters for affine transformation. + # + # *Note that these transforms are per group, unlike in group norm where + # they are transformed channel-wise.* + if self.affine: + self.scale = nn.Parameter(torch.ones(groups)) + self.shift = nn.Parameter(torch.zeros(groups)) + + def __call__(self, x: torch.Tensor): + """ + `x` is a tensor of shape `[batch_size, channels, *]`. + `*` denotes any number of (possibly 0) dimensions. + For example, in an image (2D) convolution this will be + `[batch_size, channels, height, width]` + """ + + # Keep the original shape + x_shape = x.shape + # Get the batch size + batch_size = x_shape[0] + # Sanity check to make sure the number of features is the same + assert self.channels == x.shape[1] + + # Reshape into `[batch_size, groups, n]` + x = x.view(batch_size, self.groups, -1) + + # Calculate the mean across last dimension; + # i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$ + mean = x.mean(dim=[-1], keepdim=True) + # Calculate the squared mean across last dimension; + # i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$ + mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True) + # Variance for each sample and feature group + # $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$ + var = mean_x2 - mean ** 2 + + # Normalize + # $$\hat{x}_{(i_N, i_G)} = + # \frac{x_{(i_N, i_G)} - \mathbb{E}[x_{(i_N, i_G)}]}{\sqrt{Var[x_{(i_N, i_G)}] + \epsilon}}$$ + x_norm = (x - mean) / torch.sqrt(var + self.eps) + + # Scale and shift group-wise + # $$y_{i_G} =\gamma_{i_G} \hat{x}_{i_G} + \beta_{i_G}$$ + if self.affine: + x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1) + + # Reshape to original and return + return x_norm.view(x_shape) diff --git a/labml_nn/normalization/group_norm/__init__.py b/labml_nn/normalization/group_norm/__init__.py index a5e9c310..df81b851 100644 --- a/labml_nn/normalization/group_norm/__init__.py +++ b/labml_nn/normalization/group_norm/__init__.py @@ -81,7 +81,6 @@ Here's a [CIFAR 10 classification model](experiment.html) that uses instance nor [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/group_norm/experiment.ipynb) [![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/081d950aa4e011eb8f9f0242ac1c0002) [![WandB](https://img.shields.io/badge/wandb-run-yellow)](https://wandb.ai/vpj/cifar10/runs/310etthp) - """ import torch diff --git a/labml_nn/normalization/weight_standardization/__init__.py b/labml_nn/normalization/weight_standardization/__init__.py new file mode 100644 index 00000000..753e78cd --- /dev/null +++ b/labml_nn/normalization/weight_standardization/__init__.py @@ -0,0 +1,86 @@ +""" +--- +title: Weight Standardization +summary: > + A PyTorch implementation/tutorial of Weight Standardization. +--- + +# Weight Standardization + +This is a [PyTorch](https://pytorch.org) implementation of Weight Standardization from the paper + [Micro-Batch Training with Batch-Channel Normalization and Weight Standardization](https://arxiv.org/abs/1903.10520). +We also have an [annotated implementation of Batch-Channel Normalization](../batch_channel_norm/index.html). + +Batch normalization **gives a smooth loss landscape** and +**avoids elimination singularities**. +Elimination singularities are nodes of the network that become +useless (e.g. a ReLU that gives 0 all the time). + +However, batch normalization doesn't work well when the batch size is too small, +which happens when training large networks because of device memory limitations. +The paper introduces Weight Standardization with Batch-Channel Normalization as +a better alternative. + +Weight Standardization: +1. Normalizes the gradients +2. Smoothes the landscape (reduced Lipschitz constant) +3. Avoids elimination singularities + +The Lipschitz constant is the maximum slope a function has between two points. +That is, $L$ is the Lipschitz constant where $L$ is the smallest value that satisfies, +$\forall a,b \in A: \lVert f(a) - f(b) \rVert \le L \lVert a - b \rVert$ +where $f: A \rightarrow \mathbb{R}^m, A \in \mathbb{R}^n$. + +Elimination singularities are avoided because it keeps the statistics of the outputs similar to the +inputs. So as long as the inputs are normally distributed the outputs remain close to normal. +This avoids outputs of nodes from always falling beyond the active range of the activation function +(e.g. always negative input for a ReLU). + +*[Refer to the paper for proofs](https://arxiv.org/abs/1903.10520)*. + +Here is [the training code](experiment.html) for training +a VGG network that uses weight standardization to classify CIFAR-10 data. +This uses a [2D-Convolution Layer with Weight Standardization](../conv2d.html). + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb) +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002) +[![WandB](https://img.shields.io/badge/wandb-run-yellow)](https://wandb.ai/vpj/cifar10/runs/3flr4k8w) +""" + +import torch + + +def weight_standardization(weight: torch.Tensor, eps: float): + r""" + ## Weight Standardization + + $$\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}$$ + + where, + + \begin{align} + W &\in \mathbb{R}^{O \times I} \\ + \mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\ + \sigma_{W_{i,\cdot}} &= \sqrt{\frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}} + \epsilon} \\ + \end{align} + + for a 2D-convolution layer $O$ is the number of output channels ($O = C_{out}$) + and $I$ is the number of input channels times the kernel size ($I = C_{in} \times k_H \times k_W$) + """ + + # Get $C_{out}$, $C_{in}$ and kernel shape + c_out, c_in, *kernel_shape = weight.shape + # Reshape $W$ to $O \times I$ + weight = weight.view(c_out, -1) + # Calculate + # + # \begin{align} + # \mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\ + # \sigma^2_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}} + # \end{align} + var, mean = torch.var_mean(weight, dim=1, keepdim=True) + # Normalize + # $$\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}$$ + weight = (weight - mean) / (torch.sqrt(var + eps)) + # Change back to original shape and return + return weight.view(c_out, c_in, *kernel_shape) diff --git a/labml_nn/normalization/weight_standardization/conv2d.py b/labml_nn/normalization/weight_standardization/conv2d.py new file mode 100644 index 00000000..150c839b --- /dev/null +++ b/labml_nn/normalization/weight_standardization/conv2d.py @@ -0,0 +1,60 @@ +""" +--- +title: 2D Convolution Layer with Weight Standardization +summary: > + A PyTorch implementation/tutorial of a 2D Convolution Layer with Weight Standardization. +--- + +# 2D Convolution Layer with Weight Standardization + +This is an implementation of a 2 dimensional convolution layer with [Weight Standardization](./index.html) +""" + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from labml_nn.normalization.weight_standardization import weight_standardization + + +class Conv2d(nn.Conv2d): + """ + ## 2D Convolution Layer + + This extends the standard 2D Convolution layer and standardize the weights before the convolution step. + """ + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, + padding=0, + dilation=1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + eps: float = 1e-5): + super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode) + self.eps = eps + + def forward(self, x: torch.Tensor): + return F.conv2d(x, weight_standardization(self.weight, self.eps), self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +def _test(): + """ + A simple test to verify the tensor sizes + """ + conv2d = Conv2d(10, 20, 5) + from labml.logger import inspect + inspect(conv2d.weight) + import torch + inspect(conv2d(torch.zeros(10, 10, 100, 100))) + + +if __name__ == '__main__': + _test() diff --git a/labml_nn/normalization/weight_standardization/experiment.py b/labml_nn/normalization/weight_standardization/experiment.py new file mode 100644 index 00000000..22e64c40 --- /dev/null +++ b/labml_nn/normalization/weight_standardization/experiment.py @@ -0,0 +1,76 @@ +""" +--- +title: CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization +summary: > + This trains is a VGG net that uses weight standardization and batch-channel normalization + to classify CIFAR10 images. +--- + +# CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization +""" + +import torch.nn as nn + +from labml import experiment +from labml.configs import option +from labml_helpers.module import Module +from labml_nn.experiments.cifar10 import CIFAR10Configs +from labml_nn.normalization.batch_channel_norm import BatchChannelNorm +from labml_nn.normalization.weight_standardization.conv2d import Conv2d + + +class Model(Module): + """ + ### Model + + A VGG model that use [Weight Standardization](./index.html) and + [Batch-Channel Normalization](../batch_channel_norm/index.html). + """ + def __init__(self): + super().__init__() + layers = [] + in_channels = 3 + for block in [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]]: + for channels in block: + layers += [Conv2d(in_channels, channels, kernel_size=3, padding=1), + BatchChannelNorm(channels, 32), + nn.ReLU(inplace=True)] + in_channels = channels + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + layers += [nn.AvgPool2d(kernel_size=1, stride=1)] + self.layers = nn.Sequential(*layers) + self.fc = nn.Linear(512, 10) + + def __call__(self, x): + x = self.layers(x) + x = x.view(x.shape[0], -1) + return self.fc(x) + + +@option(CIFAR10Configs.model) +def model(c: CIFAR10Configs): + """ + ### Create model + """ + return Model().to(c.device) + + +def main(): + # Create experiment + experiment.create(name='cifar10', comment='weight standardization') + # Create configurations + conf = CIFAR10Configs() + # Load configurations + experiment.configs(conf, { + 'optimizer.optimizer': 'Adam', + 'optimizer.learning_rate': 2.5e-4, + 'train_batch_size': 64, + }) + # Start the experiment and run the training loop + with experiment.start(): + conf.run() + + +# +if __name__ == '__main__': + main() diff --git a/setup.py b/setup.py index 8914903f..a0687608 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open("readme.md", "r") as f: setuptools.setup( name='labml-nn', - version='0.4.96', + version='0.4.97', author="Varuna Jayasiri, Nipun Wijerathne", author_email="vpjayasiri@gmail.com, hnipun@gmail.com", description="A collection of PyTorch implementations of neural network architectures and layers.",