Batch-Channel Normalization performs batch normalization followed
+by a channel normalization (similar to a Group Normalization.
+When the batch size is small a running mean and variance is used for
+batch normalization.
+
Here is the training code for training
+a VGG network that uses weight standardization to classify CIFAR-10 data.
This first performs a batch normalization - either normal batch norm
+or a batch norm with
+estimated mean and variance (exponential mean/variance over multiple batches).
+Then a channel normalization performed.
When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
+where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
+$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
+
+
+
+
where,
+
+
+
+
are the running mean and variances. $r$ is the momentum for calculating the exponential mean.
x is a tensor of shape [batch_size, channels, *].
+* denotes any number of (possibly 0) dimensions.
+ For example, in an image (2D) convolution this will be
+[batch_size, channels, height, width]
x is a tensor of shape [batch_size, channels, *].
+* denotes any number of (possibly 0) dimensions.
+ For example, in an image (2D) convolution this will be
+[batch_size, channels, height, width]
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/normalization/group_norm/index.html b/docs/normalization/group_norm/index.html
index eefd7cb8..69b70cca 100644
--- a/docs/normalization/group_norm/index.html
+++ b/docs/normalization/group_norm/index.html
@@ -137,10 +137,10 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.
@@ -179,14 +179,14 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.
-
106super().__init__()
-107
-108assertchannels%groups==0,"Number of channels should be evenly divisible by the number of groups"
-109self.groups=groups
-110self.channels=channels
-111
-112self.eps=eps
-113self.affine=affine
+
105super().__init__()
+106
+107assertchannels%groups==0,"Number of channels should be evenly divisible by the number of groups"
+108self.groups=groups
+109self.channels=channels
+110
+111self.eps=eps
+112self.affine=affine
@@ -197,9 +197,9 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.
Create parameters for $\gamma$ and $\beta$ for scale and shift
@@ -213,7 +213,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.
[batch_size, channels, height, width]
-
119defforward(self,x:torch.Tensor):
+
118defforward(self,x:torch.Tensor):
@@ -224,7 +224,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.
Keep the original shape
-
127x_shape=x.shape
+
126x_shape=x.shape
@@ -235,7 +235,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.
Get the batch size
-
129batch_size=x_shape[0]
+
128batch_size=x_shape[0]
@@ -246,7 +246,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.
Sanity check to make sure the number of features is the same
-
131assertself.channels==x.shape[1]
+
130assertself.channels==x.shape[1]
@@ -257,7 +257,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.
Reshape into [batch_size, groups, n]
-
134x=x.view(batch_size,self.groups,-1)
+
133x=x.view(batch_size,self.groups,-1)
@@ -269,7 +269,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.
i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$
-
138mean=x.mean(dim=[-1],keepdim=True)
+
137mean=x.mean(dim=[-1],keepdim=True)
@@ -281,7 +281,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$
-
141mean_x2=(x**2).mean(dim=[-1],keepdim=True)
+
140mean_x2=(x**2).mean(dim=[-1],keepdim=True)
@@ -293,7 +293,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$<
$Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$
Batch normalization gives a smooth loss landscape and
+avoids elimination singularities.
+Elimination singularities are nodes of the network that become
+useless (e.g. a ReLU that gives 0 all the time).
+
However, batch normalization doesn’t work well when the batch size is too small,
+which happens when training large networks because of device memory limitations.
+The paper introduces Weight Standardization with Batch-Channel Normalization as
+a better alternative.
+
Weight Standardization:
+1. Normalizes the gradients
+2. Smoothes the landscape (reduced Lipschitz constant)
+3. Avoids elimination singularities
+
The Lipschitz constant is the maximum slope a function has between two points.
+That is, $L$ is the Lipschitz constant where $L$ is the smallest value that satisfies,
+$\forall a,b \in A: \lVert f(a) - f(b) \rVert \le L \lVert a - b \rVert$
+where $f: A \rightarrow \mathbb{R}^m, A \in \mathbb{R}^n$.
+
Elimination singularities are avoided because it keeps the statistics of the outputs similar to the
+inputs. So as long as the inputs are normally distributed the outputs remain close to normal.
+This avoids outputs of nodes from always falling beyond the active range of the activation function
+(e.g. always negative input for a ReLU).
for a 2D-convolution layer $O$ is the number of output channels ($O = C_{out}$)
+and $I$ is the number of input channels times the kernel size ($I = C_{in} \times k_H \times k_W$)
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index b16eb32d..5da1d394 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -160,6 +160,27 @@
+
+ https://nn.labml.ai/normalization/weight_standardization/index.html
+ 2021-04-27T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/normalization/weight_standardization/experiment.html
+ 2021-04-27T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/normalization/weight_standardization/conv2d.html
+ 2021-04-26T16:30:00+00:00
+ 1.00
+
+
+
https://nn.labml.ai/normalization/instance_norm/index.html2021-04-23T16:30:00+00:00
@@ -202,6 +223,13 @@
+
+ https://nn.labml.ai/normalization/batch_channel_norm/index.html
+ 2021-04-27T16:30:00+00:00
+ 1.00
+
+
+
https://nn.labml.ai/normalization/group_norm/experiment.html2021-04-24T16:30:00+00:00
diff --git a/labml_nn/normalization/batch_channel_norm/__init__.py b/labml_nn/normalization/batch_channel_norm/__init__.py
new file mode 100644
index 00000000..7c3c44ea
--- /dev/null
+++ b/labml_nn/normalization/batch_channel_norm/__init__.py
@@ -0,0 +1,236 @@
+"""
+---
+title: Batch-Channel Normalization
+summary: >
+ A PyTorch implementation/tutorial of Batch-Channel Normalization.
+---
+
+# Batch-Channel Normalization
+
+This is a [PyTorch](https://pytorch.org) implementation of Batch-Channel Normalization from the paper
+ [Micro-Batch Training with Batch-Channel Normalization and Weight Standardization](https://arxiv.org/abs/1903.10520).
+We also have an [annotated implementation of Weight Standardization](../weight_standardization/index.html).
+
+Batch-Channel Normalization performs batch normalization followed
+by a channel normalization (similar to a [Group Normalization](../group_norm/index.html).
+When the batch size is small a running mean and variance is used for
+batch normalization.
+
+Here is [the training code](../weight_standardization/experiment.html) for training
+a VGG network that uses weight standardization to classify CIFAR-10 data.
+
+[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb)
+[](https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002)
+[](https://wandb.ai/vpj/cifar10/runs/3flr4k8w)
+"""
+
+import torch
+from torch import nn
+
+from labml_helpers.module import Module
+from labml_nn.normalization.batch_norm import BatchNorm
+
+
+class BatchChannelNorm(Module):
+ """
+ ## Batch-Channel Normalization
+
+ This first performs a batch normalization - either [normal batch norm](../batch_norm/index.html)
+ or a batch norm with
+ estimated mean and variance (exponential mean/variance over multiple batches).
+ Then a channel normalization performed.
+ """
+
+ def __init__(self, channels: int, groups: int,
+ eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
+ """
+ * `channels` is the number of features in the input
+ * `groups` is the number of groups the features are divided into
+ * `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
+ * `momentum` is the momentum in taking the exponential moving average
+ * `estimate` is whether to use running mean and variance for batch norm
+ """
+ super().__init__()
+
+ # Use estimated batch norm or normal batch norm.
+ if estimate:
+ self.batch_norm = EstimatedBatchNorm(channels,
+ eps=eps, momentum=momentum)
+ else:
+ self.batch_norm = BatchNorm(channels,
+ eps=eps, momentum=momentum)
+
+ # Channel normalization
+ self.channel_norm = ChannelNorm(channels, groups, eps)
+
+ def __call__(self, x):
+ x = self.batch_norm(x)
+ return self.channel_norm(x)
+
+
+class EstimatedBatchNorm(Module):
+ """
+ ## Estimated Batch Normalization
+
+ When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
+ where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
+ $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
+
+ $$\dot{X}_{\cdot, C, \cdot, \cdot} = \gamma_C
+ \frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}
+ + \beta_C$$
+
+ where,
+
+ \begin{align}
+ \hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\
+ \hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
+ \end{align}
+
+ are the running mean and variances. $r$ is the momentum for calculating the exponential mean.
+ """
+ def __init__(self, channels: int,
+ eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
+ """
+ * `channels` is the number of features in the input
+ * `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
+ * `momentum` is the momentum in taking the exponential moving average
+ * `estimate` is whether to use running mean and variance for batch norm
+ """
+ super().__init__()
+
+ self.eps = eps
+ self.momentum = momentum
+ self.affine = affine
+ self.channels = channels
+
+ # Channel wise transformation parameters
+ if self.affine:
+ self.scale = nn.Parameter(torch.ones(channels))
+ self.shift = nn.Parameter(torch.zeros(channels))
+
+ # Tensors for $\hat{\mu}_C$ and $\hat{\sigma}^2_C$
+ self.register_buffer('exp_mean', torch.zeros(channels))
+ self.register_buffer('exp_var', torch.ones(channels))
+
+ def __call__(self, x: torch.Tensor):
+ """
+ `x` is a tensor of shape `[batch_size, channels, *]`.
+ `*` denotes any number of (possibly 0) dimensions.
+ For example, in an image (2D) convolution this will be
+ `[batch_size, channels, height, width]`
+ """
+ # Keep old shape
+ x_shape = x.shape
+ # Get the batch size
+ batch_size = x_shape[0]
+
+ # Sanity check to make sure the number of features is correct
+ assert self.channels == x.shape[1]
+
+ # Reshape into `[batch_size, channels, n]`
+ x = x.view(batch_size, self.channels, -1)
+
+ # Update $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ in training mode only
+ if self.training:
+ # No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$
+ with torch.no_grad():
+ # Calculate the mean across first and last dimensions;
+ # $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$
+ mean = x.mean(dim=[0, 2])
+ # Calculate the squared mean across first and last dimensions;
+ # $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$
+ mean_x2 = (x ** 2).mean(dim=[0, 2])
+ # Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
+ var = mean_x2 - mean ** 2
+
+ # Update exponential moving averages
+ # \begin{align}
+ # \hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\
+ # \hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
+ # \end{align}
+ self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
+ self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
+
+ # Normalize
+ # $$\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}$$
+ x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)
+ # Scale and shift
+ # $$ \gamma_C
+ # \frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}
+ # + \beta_C$$
+ if self.affine:
+ x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
+
+ # Reshape to original and return
+ return x_norm.view(x_shape)
+
+
+class ChannelNorm(Module):
+ """
+ ## Channel Normalization
+
+ This is similar to [Group Normalization](../group_norm/index.html) but affine transform is done group wise.
+ """
+
+ def __init__(self, channels, groups,
+ eps: float = 1e-5, affine: bool = True):
+ """
+ * `groups` is the number of groups the features are divided into
+ * `channels` is the number of features in the input
+ * `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
+ * `affine` is whether to scale and shift the normalized value
+ """
+ super().__init__()
+ self.channels = channels
+ self.groups = groups
+ self.eps = eps
+ self.affine = affine
+ # Parameters for affine transformation.
+ #
+ # *Note that these transforms are per group, unlike in group norm where
+ # they are transformed channel-wise.*
+ if self.affine:
+ self.scale = nn.Parameter(torch.ones(groups))
+ self.shift = nn.Parameter(torch.zeros(groups))
+
+ def __call__(self, x: torch.Tensor):
+ """
+ `x` is a tensor of shape `[batch_size, channels, *]`.
+ `*` denotes any number of (possibly 0) dimensions.
+ For example, in an image (2D) convolution this will be
+ `[batch_size, channels, height, width]`
+ """
+
+ # Keep the original shape
+ x_shape = x.shape
+ # Get the batch size
+ batch_size = x_shape[0]
+ # Sanity check to make sure the number of features is the same
+ assert self.channels == x.shape[1]
+
+ # Reshape into `[batch_size, groups, n]`
+ x = x.view(batch_size, self.groups, -1)
+
+ # Calculate the mean across last dimension;
+ # i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$
+ mean = x.mean(dim=[-1], keepdim=True)
+ # Calculate the squared mean across last dimension;
+ # i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$
+ mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
+ # Variance for each sample and feature group
+ # $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$
+ var = mean_x2 - mean ** 2
+
+ # Normalize
+ # $$\hat{x}_{(i_N, i_G)} =
+ # \frac{x_{(i_N, i_G)} - \mathbb{E}[x_{(i_N, i_G)}]}{\sqrt{Var[x_{(i_N, i_G)}] + \epsilon}}$$
+ x_norm = (x - mean) / torch.sqrt(var + self.eps)
+
+ # Scale and shift group-wise
+ # $$y_{i_G} =\gamma_{i_G} \hat{x}_{i_G} + \beta_{i_G}$$
+ if self.affine:
+ x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
+
+ # Reshape to original and return
+ return x_norm.view(x_shape)
diff --git a/labml_nn/normalization/group_norm/__init__.py b/labml_nn/normalization/group_norm/__init__.py
index a5e9c310..df81b851 100644
--- a/labml_nn/normalization/group_norm/__init__.py
+++ b/labml_nn/normalization/group_norm/__init__.py
@@ -81,7 +81,6 @@ Here's a [CIFAR 10 classification model](experiment.html) that uses instance nor
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/group_norm/experiment.ipynb)
[](https://app.labml.ai/run/081d950aa4e011eb8f9f0242ac1c0002)
[](https://wandb.ai/vpj/cifar10/runs/310etthp)
-
"""
import torch
diff --git a/labml_nn/normalization/weight_standardization/__init__.py b/labml_nn/normalization/weight_standardization/__init__.py
new file mode 100644
index 00000000..753e78cd
--- /dev/null
+++ b/labml_nn/normalization/weight_standardization/__init__.py
@@ -0,0 +1,86 @@
+"""
+---
+title: Weight Standardization
+summary: >
+ A PyTorch implementation/tutorial of Weight Standardization.
+---
+
+# Weight Standardization
+
+This is a [PyTorch](https://pytorch.org) implementation of Weight Standardization from the paper
+ [Micro-Batch Training with Batch-Channel Normalization and Weight Standardization](https://arxiv.org/abs/1903.10520).
+We also have an [annotated implementation of Batch-Channel Normalization](../batch_channel_norm/index.html).
+
+Batch normalization **gives a smooth loss landscape** and
+**avoids elimination singularities**.
+Elimination singularities are nodes of the network that become
+useless (e.g. a ReLU that gives 0 all the time).
+
+However, batch normalization doesn't work well when the batch size is too small,
+which happens when training large networks because of device memory limitations.
+The paper introduces Weight Standardization with Batch-Channel Normalization as
+a better alternative.
+
+Weight Standardization:
+1. Normalizes the gradients
+2. Smoothes the landscape (reduced Lipschitz constant)
+3. Avoids elimination singularities
+
+The Lipschitz constant is the maximum slope a function has between two points.
+That is, $L$ is the Lipschitz constant where $L$ is the smallest value that satisfies,
+$\forall a,b \in A: \lVert f(a) - f(b) \rVert \le L \lVert a - b \rVert$
+where $f: A \rightarrow \mathbb{R}^m, A \in \mathbb{R}^n$.
+
+Elimination singularities are avoided because it keeps the statistics of the outputs similar to the
+inputs. So as long as the inputs are normally distributed the outputs remain close to normal.
+This avoids outputs of nodes from always falling beyond the active range of the activation function
+(e.g. always negative input for a ReLU).
+
+*[Refer to the paper for proofs](https://arxiv.org/abs/1903.10520)*.
+
+Here is [the training code](experiment.html) for training
+a VGG network that uses weight standardization to classify CIFAR-10 data.
+This uses a [2D-Convolution Layer with Weight Standardization](../conv2d.html).
+
+[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb)
+[](https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002)
+[](https://wandb.ai/vpj/cifar10/runs/3flr4k8w)
+"""
+
+import torch
+
+
+def weight_standardization(weight: torch.Tensor, eps: float):
+ r"""
+ ## Weight Standardization
+
+ $$\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}$$
+
+ where,
+
+ \begin{align}
+ W &\in \mathbb{R}^{O \times I} \\
+ \mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
+ \sigma_{W_{i,\cdot}} &= \sqrt{\frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}} + \epsilon} \\
+ \end{align}
+
+ for a 2D-convolution layer $O$ is the number of output channels ($O = C_{out}$)
+ and $I$ is the number of input channels times the kernel size ($I = C_{in} \times k_H \times k_W$)
+ """
+
+ # Get $C_{out}$, $C_{in}$ and kernel shape
+ c_out, c_in, *kernel_shape = weight.shape
+ # Reshape $W$ to $O \times I$
+ weight = weight.view(c_out, -1)
+ # Calculate
+ #
+ # \begin{align}
+ # \mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
+ # \sigma^2_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}}
+ # \end{align}
+ var, mean = torch.var_mean(weight, dim=1, keepdim=True)
+ # Normalize
+ # $$\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}$$
+ weight = (weight - mean) / (torch.sqrt(var + eps))
+ # Change back to original shape and return
+ return weight.view(c_out, c_in, *kernel_shape)
diff --git a/labml_nn/normalization/weight_standardization/conv2d.py b/labml_nn/normalization/weight_standardization/conv2d.py
new file mode 100644
index 00000000..150c839b
--- /dev/null
+++ b/labml_nn/normalization/weight_standardization/conv2d.py
@@ -0,0 +1,60 @@
+"""
+---
+title: 2D Convolution Layer with Weight Standardization
+summary: >
+ A PyTorch implementation/tutorial of a 2D Convolution Layer with Weight Standardization.
+---
+
+# 2D Convolution Layer with Weight Standardization
+
+This is an implementation of a 2 dimensional convolution layer with [Weight Standardization](./index.html)
+"""
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from labml_nn.normalization.weight_standardization import weight_standardization
+
+
+class Conv2d(nn.Conv2d):
+ """
+ ## 2D Convolution Layer
+
+ This extends the standard 2D Convolution layer and standardize the weights before the convolution step.
+ """
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = 'zeros',
+ eps: float = 1e-5):
+ super(Conv2d, self).__init__(in_channels, out_channels, kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ padding_mode=padding_mode)
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor):
+ return F.conv2d(x, weight_standardization(self.weight, self.eps), self.bias, self.stride,
+ self.padding, self.dilation, self.groups)
+
+
+def _test():
+ """
+ A simple test to verify the tensor sizes
+ """
+ conv2d = Conv2d(10, 20, 5)
+ from labml.logger import inspect
+ inspect(conv2d.weight)
+ import torch
+ inspect(conv2d(torch.zeros(10, 10, 100, 100)))
+
+
+if __name__ == '__main__':
+ _test()
diff --git a/labml_nn/normalization/weight_standardization/experiment.py b/labml_nn/normalization/weight_standardization/experiment.py
new file mode 100644
index 00000000..22e64c40
--- /dev/null
+++ b/labml_nn/normalization/weight_standardization/experiment.py
@@ -0,0 +1,76 @@
+"""
+---
+title: CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization
+summary: >
+ This trains is a VGG net that uses weight standardization and batch-channel normalization
+ to classify CIFAR10 images.
+---
+
+# CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization
+"""
+
+import torch.nn as nn
+
+from labml import experiment
+from labml.configs import option
+from labml_helpers.module import Module
+from labml_nn.experiments.cifar10 import CIFAR10Configs
+from labml_nn.normalization.batch_channel_norm import BatchChannelNorm
+from labml_nn.normalization.weight_standardization.conv2d import Conv2d
+
+
+class Model(Module):
+ """
+ ### Model
+
+ A VGG model that use [Weight Standardization](./index.html) and
+ [Batch-Channel Normalization](../batch_channel_norm/index.html).
+ """
+ def __init__(self):
+ super().__init__()
+ layers = []
+ in_channels = 3
+ for block in [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]]:
+ for channels in block:
+ layers += [Conv2d(in_channels, channels, kernel_size=3, padding=1),
+ BatchChannelNorm(channels, 32),
+ nn.ReLU(inplace=True)]
+ in_channels = channels
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+ layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
+ self.layers = nn.Sequential(*layers)
+ self.fc = nn.Linear(512, 10)
+
+ def __call__(self, x):
+ x = self.layers(x)
+ x = x.view(x.shape[0], -1)
+ return self.fc(x)
+
+
+@option(CIFAR10Configs.model)
+def model(c: CIFAR10Configs):
+ """
+ ### Create model
+ """
+ return Model().to(c.device)
+
+
+def main():
+ # Create experiment
+ experiment.create(name='cifar10', comment='weight standardization')
+ # Create configurations
+ conf = CIFAR10Configs()
+ # Load configurations
+ experiment.configs(conf, {
+ 'optimizer.optimizer': 'Adam',
+ 'optimizer.learning_rate': 2.5e-4,
+ 'train_batch_size': 64,
+ })
+ # Start the experiment and run the training loop
+ with experiment.start():
+ conf.run()
+
+
+#
+if __name__ == '__main__':
+ main()
diff --git a/setup.py b/setup.py
index 8914903f..a0687608 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
- version='0.4.96',
+ version='0.4.97',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="A collection of PyTorch implementations of neural network architectures and layers.",