📚 group norm

This commit is contained in:
Varuna Jayasiri
2021-04-24 14:44:38 +05:30
parent 21dc7d6302
commit e7e817ce20
4 changed files with 216 additions and 64 deletions

View File

@ -73,12 +73,74 @@
<a href='#section-0'>#</a>
</div>
<h1>Group Normalization</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of
the paper <a href="https://arxiv.org/abs/1803.08494">Group Normalization</a>.</p>
<p><a href="../batch_norm/index.html">Batch Normalization</a> works well for sufficiently large batch sizes,
but does not perform well for small batch sizes, because it normalizes across the batch.
Training large models with large batch sizes is not possible due to the memory capacity of the
devices.</p>
<p>This paper introduces Group Normalization, which normalizes a set of features together as a group.
This is based on the observation that classical features such as
<a href="https://en.wikipedia.org/wiki/Scale-invariant_feature_transform">SIFT</a> and
<a href="https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients">HOG</a> are group-wise features.
The paper proposes dividing feature channels into groups and then separately normalizing
all channels within each group.</p>
<h2>Formulation</h2>
<p>All normalization layers can be defined by the following computation.</p>
<p>
<script type="math/tex; mode=display">\hat{x}_i = \frac{1}{\sigma_i} (x_i - \mu_i)</script>
</p>
<p>where $x$ is the tensor representing the batch,
and $i$ is the index of a single value.
For instance, when it&rsquo;s 2D images
$i = (i_N, i_C, i_H, i_W)$ is a 4-d vector for indexing
image within batch, feature channel, vertical coordinate and horizontal coordinate.
$\mu_i$ and $\sigma_i$ are mean and standard deviation.</p>
<p>
<script type="math/tex; mode=display">\begin{align}
\mu_i &= \frac{1}{m} \sum_{k \in \mathcal{S}_i} x_k \\
\sigma_i &= \sqrt{\frac{1}{m} \sum_{k \in \mathcal{S}_i} (x_k - \mu_i)^2 + \epsilon}
\end{align}</script>
</p>
<p>$\mathcal{S}_i$ is the set of indexes across which the mean and standard deviation
are calculated for index $i$.
$m$ is the size of the set $\mathcal{S}_i$ which is same for all $i$.</p>
<p>The definition of $\mathcal{S}_i$ is different for
<a href="../batch_norm/index.html">Batch normalization</a>,
<a href="../layer_norm/index.html">Layer normalization</a>, and
<a href="../instance_norm/index.html">Instance normalization</a>.</p>
<h3><a href="../batch_norm/index.html">Batch Normalization</a></h3>
<p>
<script type="math/tex; mode=display">\mathcal{S}_i = \{k | k_C = i_C\}</script>
</p>
<p>The values that share the same feature channel are normalized together.</p>
<h3><a href="../layer_norm/index.html">Layer Normalization</a></h3>
<p>
<script type="math/tex; mode=display">\mathcal{S}_i = \{k | k_N = i_N\}</script>
</p>
<p>The values from the same sample in the batch are normalized together.</p>
<h3><a href="../instance_norm/index.html">Instance Normalization</a></h3>
<p>
<script type="math/tex; mode=display">\mathcal{S}_i = \{k | k_N = i_N, k_C = i_C\}</script>
</p>
<p>The values from the same sample and same feature channel are normalized together.</p>
<h3>Group Normalization</h3>
<p>
<script type="math/tex; mode=display">\mathcal{S}_i = \{k | k_N = i_N,
\bigg \lfloor \frac{k_C}{C/G} \bigg \rfloor = \bigg \lfloor \frac{i_C}{C/G} \bigg \rfloor\}</script>
</p>
<p>where $G$ is the number of groups and $C$ is the number of channels.</p>
<p>Group normalization normalizes values of the same sample and the same group of channels together.</p>
<p>Here&rsquo;s a <a href="experiment.html">CIFAR 10 classification model</a> that uses instance normalization.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/group_norm/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://app.labml.ai/run/011254fe647011ebbb8e0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a>
<a href="https://app.labml.ai/run/011254fe647011ebbb8e0242ac1c0002"><img alt="WandB" src="https://img.shields.io/badge/wandb-run-yellow" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">12</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">14</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
<div class="highlight"><pre><span class="lineno">87</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">88</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">89</span>
<span class="lineno">90</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -89,7 +151,7 @@
<h2>Group Normalization Layer</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">GroupNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">93</span><span class="k">class</span> <span class="nc">GroupNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -105,8 +167,8 @@
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">23</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">24</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">98</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">99</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -117,14 +179,14 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">31</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">32</span>
<span class="lineno">33</span> <span class="k">assert</span> <span class="n">channels</span> <span class="o">%</span> <span class="n">groups</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;Number of channels should be evenly divisible by the number of groups&quot;</span>
<span class="lineno">34</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">36</span>
<span class="lineno">37</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">38</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
<div class="highlight"><pre><span class="lineno">106</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">107</span>
<span class="lineno">108</span> <span class="k">assert</span> <span class="n">channels</span> <span class="o">%</span> <span class="n">groups</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;Number of channels should be evenly divisible by the number of groups&quot;</span>
<span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">111</span>
<span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -135,9 +197,9 @@
<p>Create parameters for $\gamma$ and $\beta$ for scale and shift</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">115</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -151,7 +213,7 @@
<code>[batch_size, channels, height, width]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -162,7 +224,7 @@
<p>Keep the original shape</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
<div class="highlight"><pre><span class="lineno">127</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -173,7 +235,7 @@
<p>Get the batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -184,7 +246,7 @@
<p>Sanity check to make sure the number of features is the same</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">131</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -192,10 +254,10 @@
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Reshape into <code>[batch_size, channels, n]</code></p>
<p>Reshape into <code>[batch_size, groups, n]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -203,11 +265,11 @@
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Calculate the mean across first and last dimension;
i.e. the means for each feature $\mathbb{E}[x^{(k)}]$</p>
<p>Calculate the mean across last dimension;
i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -215,11 +277,11 @@ i.e. the means for each feature $\mathbb{E}[x^{(k)}]$</p>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Calculate the squared mean across first and last dimension;
i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<p>Calculate the squared mean across last dimension;
i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -227,10 +289,11 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$</p>
<p>Variance for each sample and feature group
$Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -238,12 +301,13 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Normalize <script type="math/tex; mode=display">\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}] + \epsilon}}</script>
<p>Normalize
<script type="math/tex; mode=display">\hat{x}_{(i_N, i_G)} =
\frac{x_{(i_N, i_G)} - \mathbb{E}[x_{(i_N, i_G)}]}{\sqrt{Var[x_{(i_N, i_G)}] + \epsilon}}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span>
<span class="lineno">72</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -251,12 +315,14 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Scale and shift <script type="math/tex; mode=display">y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}</script>
<p>Scale and shift channel-wise
<script type="math/tex; mode=display">y_{i_C} =\gamma_{i_C} \hat{x}_{i_C} + \beta_{i_C}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">76</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">153</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">154</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">155</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -267,7 +333,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<p>Reshape to original and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">158</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -278,7 +344,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<p>Simple test</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">161</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -289,14 +355,14 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span>
<span class="lineno">87</span>
<span class="lineno">88</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span>
<span class="lineno">89</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="lineno">90</span> <span class="n">bn</span> <span class="o">=</span> <span class="n">GroupNorm</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
<span class="lineno">91</span>
<span class="lineno">92</span> <span class="n">x</span> <span class="o">=</span> <span class="n">bn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">93</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">165</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span>
<span class="lineno">166</span>
<span class="lineno">167</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span>
<span class="lineno">168</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="lineno">169</span> <span class="n">bn</span> <span class="o">=</span> <span class="n">GroupNorm</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
<span class="lineno">170</span>
<span class="lineno">171</span> <span class="n">x</span> <span class="o">=</span> <span class="n">bn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">172</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -307,8 +373,8 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">98</span> <span class="n">_test</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">176</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">177</span> <span class="n">_test</span><span class="p">()</span></pre></div>
</div>
</div>
</div>

View File

@ -162,14 +162,21 @@
<url>
<loc>https://nn.labml.ai/normalization/instance_norm/index.html</loc>
<lastmod>2021-04-20T16:30:00+00:00</lastmod>
<lastmod>2021-04-23T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/normalization/instance_norm/readme.html</loc>
<lastmod>2021-04-23T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/normalization/instance_norm/experiment.html</loc>
<lastmod>2021-04-20T16:30:00+00:00</lastmod>
<lastmod>2021-04-23T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>

View File

@ -7,6 +7,81 @@ summary: >
# Group Normalization
This is a [PyTorch](https://pytorch.org) implementation of
the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
[Batch Normalization](../batch_norm/index.html) works well for sufficiently large batch sizes,
but does not perform well for small batch sizes, because it normalizes across the batch.
Training large models with large batch sizes is not possible due to the memory capacity of the
devices.
This paper introduces Group Normalization, which normalizes a set of features together as a group.
This is based on the observation that classical features such as
[SIFT](https://en.wikipedia.org/wiki/Scale-invariant_feature_transform) and
[HOG](https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients) are group-wise features.
The paper proposes dividing feature channels into groups and then separately normalizing
all channels within each group.
## Formulation
All normalization layers can be defined by the following computation.
$$\hat{x}_i = \frac{1}{\sigma_i} (x_i - \mu_i)$$
where $x$ is the tensor representing the batch,
and $i$ is the index of a single value.
For instance, when it's 2D images
$i = (i_N, i_C, i_H, i_W)$ is a 4-d vector for indexing
image within batch, feature channel, vertical coordinate and horizontal coordinate.
$\mu_i$ and $\sigma_i$ are mean and standard deviation.
\begin{align}
\mu_i &= \frac{1}{m} \sum_{k \in \mathcal{S}_i} x_k \\
\sigma_i &= \sqrt{\frac{1}{m} \sum_{k \in \mathcal{S}_i} (x_k - \mu_i)^2 + \epsilon}
\end{align}
$\mathcal{S}_i$ is the set of indexes across which the mean and standard deviation
are calculated for index $i$.
$m$ is the size of the set $\mathcal{S}_i$ which is same for all $i$.
The definition of $\mathcal{S}_i$ is different for
[Batch normalization](../batch_norm/index.html),
[Layer normalization](../layer_norm/index.html), and
[Instance normalization](../instance_norm/index.html).
### [Batch Normalization](../batch_norm/index.html)
$$\mathcal{S}_i = \{k | k_C = i_C\}$$
The values that share the same feature channel are normalized together.
### [Layer Normalization](../layer_norm/index.html)
$$\mathcal{S}_i = \{k | k_N = i_N\}$$
The values from the same sample in the batch are normalized together.
### [Instance Normalization](../instance_norm/index.html)
$$\mathcal{S}_i = \{k | k_N = i_N, k_C = i_C\}$$
The values from the same sample and same feature channel are normalized together.
### Group Normalization
$$\mathcal{S}_i = \{k | k_N = i_N,
\bigg \lfloor \frac{k_C}{C/G} \bigg \rfloor = \bigg \lfloor \frac{i_C}{C/G} \bigg \rfloor\}$$
where $G$ is the number of groups and $C$ is the number of channels.
Group normalization normalizes values of the same sample and the same group of channels together.
Here's a [CIFAR 10 classification model](experiment.html) that uses instance normalization.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/group_norm/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/011254fe647011ebbb8e0242ac1c0002)
[![WandB](https://img.shields.io/badge/wandb-run-yellow)](https://app.labml.ai/run/011254fe647011ebbb8e0242ac1c0002)
"""
import torch
@ -55,24 +130,28 @@ class GroupNorm(Module):
# Sanity check to make sure the number of features is the same
assert self.channels == x.shape[1]
# Reshape into `[batch_size, channels, n]`
# Reshape into `[batch_size, groups, n]`
x = x.view(batch_size, self.groups, -1)
# Calculate the mean across first and last dimension;
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
mean = x.mean(dim=[2], keepdim=True)
# Calculate the squared mean across first and last dimension;
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
mean_x2 = (x ** 2).mean(dim=[2], keepdim=True)
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
# Calculate the mean across last dimension;
# i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$
mean = x.mean(dim=[-1], keepdim=True)
# Calculate the squared mean across last dimension;
# i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$
mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
# Variance for each sample and feature group
# $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$
var = mean_x2 - mean ** 2
# Normalize $$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}] + \epsilon}}$$
# Normalize
# $$\hat{x}_{(i_N, i_G)} =
# \frac{x_{(i_N, i_G)} - \mathbb{E}[x_{(i_N, i_G)}]}{\sqrt{Var[x_{(i_N, i_G)}] + \epsilon}}$$
x_norm = (x - mean) / torch.sqrt(var + self.eps)
x_norm = x_norm.view(batch_size, self.channels, -1)
# Scale and shift $$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$
# Scale and shift channel-wise
# $$y_{i_C} =\gamma_{i_C} \hat{x}_{i_C} + \beta_{i_C}$$
if self.affine:
x_norm = x_norm.view(batch_size, self.channels, -1)
x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
# Reshape to original and return

View File

@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
version='0.4.95',
version='0.4.96',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="A collection of PyTorch implementations of neural network architectures and layers.",