diff --git a/docs/normalization/group_norm/index.html b/docs/normalization/group_norm/index.html
index 0bd34d96..56e8fd46 100644
--- a/docs/normalization/group_norm/index.html
+++ b/docs/normalization/group_norm/index.html
@@ -73,12 +73,74 @@
#
Group Normalization
+This is a PyTorch implementation of
+the paper Group Normalization.
+Batch Normalization works well for sufficiently large batch sizes,
+but does not perform well for small batch sizes, because it normalizes across the batch.
+Training large models with large batch sizes is not possible due to the memory capacity of the
+devices.
+This paper introduces Group Normalization, which normalizes a set of features together as a group.
+This is based on the observation that classical features such as
+SIFT and
+HOG are group-wise features.
+The paper proposes dividing feature channels into groups and then separately normalizing
+all channels within each group.
+Formulation
+All normalization layers can be defined by the following computation.
+
+
+
+where $x$ is the tensor representing the batch,
+and $i$ is the index of a single value.
+For instance, when it’s 2D images
+$i = (i_N, i_C, i_H, i_W)$ is a 4-d vector for indexing
+image within batch, feature channel, vertical coordinate and horizontal coordinate.
+$\mu_i$ and $\sigma_i$ are mean and standard deviation.
+
+
+
+$\mathcal{S}_i$ is the set of indexes across which the mean and standard deviation
+are calculated for index $i$.
+$m$ is the size of the set $\mathcal{S}_i$ which is same for all $i$.
+The definition of $\mathcal{S}_i$ is different for
+Batch normalization,
+Layer normalization, and
+Instance normalization.
+
+
+
+
+The values that share the same feature channel are normalized together.
+
+
+
+
+The values from the same sample in the batch are normalized together.
+
+
+
+
+The values from the same sample and same feature channel are normalized together.
+Group Normalization
+
+
+
+where $G$ is the number of groups and $C$ is the number of channels.
+Group normalization normalizes values of the same sample and the same group of channels together.
+Here’s a CIFAR 10 classification model that uses instance normalization.
+
+
+
-
12import torch
-13from torch import nn
-14
-15from labml_helpers.module import Module
+
87import torch
+88from torch import nn
+89
+90from labml_helpers.module import Module
@@ -89,7 +151,7 @@
Group Normalization Layer
-
18class GroupNorm(Module):
+
93class GroupNorm(Module):
@@ -105,8 +167,8 @@
-
23 def __init__(self, groups: int, channels: int, *,
-24 eps: float = 1e-5, affine: bool = True):
+
98 def __init__(self, groups: int, channels: int, *,
+99 eps: float = 1e-5, affine: bool = True):
@@ -117,14 +179,14 @@
-
31 super().__init__()
-32
-33 assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups"
-34 self.groups = groups
-35 self.channels = channels
-36
-37 self.eps = eps
-38 self.affine = affine
+
106 super().__init__()
+107
+108 assert channels % groups == 0, "Number of channels should be evenly divisible by the number of groups"
+109 self.groups = groups
+110 self.channels = channels
+111
+112 self.eps = eps
+113 self.affine = affine
@@ -135,9 +197,9 @@
Create parameters for $\gamma$ and $\beta$ for scale and shift
-
40 if self.affine:
-41 self.scale = nn.Parameter(torch.ones(channels))
-42 self.shift = nn.Parameter(torch.zeros(channels))
+
115 if self.affine:
+116 self.scale = nn.Parameter(torch.ones(channels))
+117 self.shift = nn.Parameter(torch.zeros(channels))
@@ -151,7 +213,7 @@
[batch_size, channels, height, width]
-
44 def forward(self, x: torch.Tensor):
+
119 def forward(self, x: torch.Tensor):
@@ -162,7 +224,7 @@
Keep the original shape
@@ -173,7 +235,7 @@
Get the batch size
-
54 batch_size = x_shape[0]
+
129 batch_size = x_shape[0]
@@ -184,7 +246,7 @@
Sanity check to make sure the number of features is the same
-
56 assert self.channels == x.shape[1]
+
131 assert self.channels == x.shape[1]
@@ -192,10 +254,10 @@
-
Reshape into [batch_size, channels, n]
+
Reshape into [batch_size, groups, n]
-
59 x = x.view(batch_size, self.groups, -1)
+
134 x = x.view(batch_size, self.groups, -1)
@@ -203,11 +265,11 @@
-
Calculate the mean across first and last dimension;
-i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
+
Calculate the mean across last dimension;
+i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$
-
63 mean = x.mean(dim=[2], keepdim=True)
+
138 mean = x.mean(dim=[-1], keepdim=True)
@@ -215,11 +277,11 @@ i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
-
Calculate the squared mean across first and last dimension;
-i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
+
Calculate the squared mean across last dimension;
+i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$
-
66 mean_x2 = (x ** 2).mean(dim=[2], keepdim=True)
+
141 mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
@@ -227,10 +289,11 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
-
Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
+
Variance for each sample and feature group
+$Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$
-
68 var = mean_x2 - mean ** 2
+
144 var = mean_x2 - mean ** 2
@@ -238,12 +301,13 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
-
Normalize
+
Normalize
+
-
71 x_norm = (x - mean) / torch.sqrt(var + self.eps)
-72 x_norm = x_norm.view(batch_size, self.channels, -1)
+
149 x_norm = (x - mean) / torch.sqrt(var + self.eps)
@@ -251,12 +315,14 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
-
Scale and shift
+
Scale and shift channel-wise
+
-
75 if self.affine:
-76 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
+
153 if self.affine:
+154 x_norm = x_norm.view(batch_size, self.channels, -1)
+155 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
@@ -267,7 +333,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
Reshape to original and return
-
79 return x_norm.view(x_shape)
+
158 return x_norm.view(x_shape)
@@ -278,7 +344,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
Simple test
@@ -289,14 +355,14 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
-
86 from labml.logger import inspect
-87
-88 x = torch.zeros([2, 6, 2, 4])
-89 inspect(x.shape)
-90 bn = GroupNorm(2, 6)
-91
-92 x = bn(x)
-93 inspect(x.shape)
+
165 from labml.logger import inspect
+166
+167 x = torch.zeros([2, 6, 2, 4])
+168 inspect(x.shape)
+169 bn = GroupNorm(2, 6)
+170
+171 x = bn(x)
+172 inspect(x.shape)
@@ -307,8 +373,8 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
-
97if __name__ == '__main__':
-98 _test()
+
176if __name__ == '__main__':
+177 _test()
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index f282d685..e7a2db5c 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -162,14 +162,21 @@
https://nn.labml.ai/normalization/instance_norm/index.html
- 2021-04-20T16:30:00+00:00
+ 2021-04-23T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/normalization/instance_norm/readme.html
+ 2021-04-23T16:30:00+00:00
1.00
https://nn.labml.ai/normalization/instance_norm/experiment.html
- 2021-04-20T16:30:00+00:00
+ 2021-04-23T16:30:00+00:00
1.00
diff --git a/labml_nn/normalization/group_norm/__init__.py b/labml_nn/normalization/group_norm/__init__.py
index 00ff3f47..2f4ce810 100644
--- a/labml_nn/normalization/group_norm/__init__.py
+++ b/labml_nn/normalization/group_norm/__init__.py
@@ -7,6 +7,81 @@ summary: >
# Group Normalization
+This is a [PyTorch](https://pytorch.org) implementation of
+the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
+
+[Batch Normalization](../batch_norm/index.html) works well for sufficiently large batch sizes,
+but does not perform well for small batch sizes, because it normalizes across the batch.
+Training large models with large batch sizes is not possible due to the memory capacity of the
+devices.
+
+This paper introduces Group Normalization, which normalizes a set of features together as a group.
+This is based on the observation that classical features such as
+[SIFT](https://en.wikipedia.org/wiki/Scale-invariant_feature_transform) and
+[HOG](https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients) are group-wise features.
+The paper proposes dividing feature channels into groups and then separately normalizing
+all channels within each group.
+
+## Formulation
+
+All normalization layers can be defined by the following computation.
+
+$$\hat{x}_i = \frac{1}{\sigma_i} (x_i - \mu_i)$$
+
+where $x$ is the tensor representing the batch,
+and $i$ is the index of a single value.
+For instance, when it's 2D images
+$i = (i_N, i_C, i_H, i_W)$ is a 4-d vector for indexing
+image within batch, feature channel, vertical coordinate and horizontal coordinate.
+$\mu_i$ and $\sigma_i$ are mean and standard deviation.
+
+\begin{align}
+\mu_i &= \frac{1}{m} \sum_{k \in \mathcal{S}_i} x_k \\
+\sigma_i &= \sqrt{\frac{1}{m} \sum_{k \in \mathcal{S}_i} (x_k - \mu_i)^2 + \epsilon}
+\end{align}
+
+$\mathcal{S}_i$ is the set of indexes across which the mean and standard deviation
+are calculated for index $i$.
+$m$ is the size of the set $\mathcal{S}_i$ which is same for all $i$.
+
+The definition of $\mathcal{S}_i$ is different for
+[Batch normalization](../batch_norm/index.html),
+[Layer normalization](../layer_norm/index.html), and
+[Instance normalization](../instance_norm/index.html).
+
+### [Batch Normalization](../batch_norm/index.html)
+
+$$\mathcal{S}_i = \{k | k_C = i_C\}$$
+
+The values that share the same feature channel are normalized together.
+
+### [Layer Normalization](../layer_norm/index.html)
+
+$$\mathcal{S}_i = \{k | k_N = i_N\}$$
+
+The values from the same sample in the batch are normalized together.
+
+### [Instance Normalization](../instance_norm/index.html)
+
+$$\mathcal{S}_i = \{k | k_N = i_N, k_C = i_C\}$$
+
+The values from the same sample and same feature channel are normalized together.
+
+### Group Normalization
+
+$$\mathcal{S}_i = \{k | k_N = i_N,
+ \bigg \lfloor \frac{k_C}{C/G} \bigg \rfloor = \bigg \lfloor \frac{i_C}{C/G} \bigg \rfloor\}$$
+
+where $G$ is the number of groups and $C$ is the number of channels.
+
+Group normalization normalizes values of the same sample and the same group of channels together.
+
+Here's a [CIFAR 10 classification model](experiment.html) that uses instance normalization.
+
+[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/group_norm/experiment.ipynb)
+[](https://app.labml.ai/run/011254fe647011ebbb8e0242ac1c0002)
+[](https://app.labml.ai/run/011254fe647011ebbb8e0242ac1c0002)
+
"""
import torch
@@ -55,24 +130,28 @@ class GroupNorm(Module):
# Sanity check to make sure the number of features is the same
assert self.channels == x.shape[1]
- # Reshape into `[batch_size, channels, n]`
+ # Reshape into `[batch_size, groups, n]`
x = x.view(batch_size, self.groups, -1)
- # Calculate the mean across first and last dimension;
- # i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
- mean = x.mean(dim=[2], keepdim=True)
- # Calculate the squared mean across first and last dimension;
- # i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
- mean_x2 = (x ** 2).mean(dim=[2], keepdim=True)
- # Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
+ # Calculate the mean across last dimension;
+ # i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$
+ mean = x.mean(dim=[-1], keepdim=True)
+ # Calculate the squared mean across last dimension;
+ # i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$
+ mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
+ # Variance for each sample and feature group
+ # $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$
var = mean_x2 - mean ** 2
- # Normalize $$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}] + \epsilon}}$$
+ # Normalize
+ # $$\hat{x}_{(i_N, i_G)} =
+ # \frac{x_{(i_N, i_G)} - \mathbb{E}[x_{(i_N, i_G)}]}{\sqrt{Var[x_{(i_N, i_G)}] + \epsilon}}$$
x_norm = (x - mean) / torch.sqrt(var + self.eps)
- x_norm = x_norm.view(batch_size, self.channels, -1)
- # Scale and shift $$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$
+ # Scale and shift channel-wise
+ # $$y_{i_C} =\gamma_{i_C} \hat{x}_{i_C} + \beta_{i_C}$$
if self.affine:
+ x_norm = x_norm.view(batch_size, self.channels, -1)
x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
# Reshape to original and return
diff --git a/setup.py b/setup.py
index c3780c22..8914903f 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
- version='0.4.95',
+ version='0.4.96',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="A collection of PyTorch implementations of neural network architectures and layers.",