mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
📚 group norm
This commit is contained in:
@ -73,12 +73,74 @@
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Group Normalization</h1>
|
||||
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of
|
||||
the paper <a href="https://arxiv.org/abs/1803.08494">Group Normalization</a>.</p>
|
||||
<p><a href="../batch_norm/index.html">Batch Normalization</a> works well for sufficiently large batch sizes,
|
||||
but does not perform well for small batch sizes, because it normalizes across the batch.
|
||||
Training large models with large batch sizes is not possible due to the memory capacity of the
|
||||
devices.</p>
|
||||
<p>This paper introduces Group Normalization, which normalizes a set of features together as a group.
|
||||
This is based on the observation that classical features such as
|
||||
<a href="https://en.wikipedia.org/wiki/Scale-invariant_feature_transform">SIFT</a> and
|
||||
<a href="https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients">HOG</a> are group-wise features.
|
||||
The paper proposes dividing feature channels into groups and then separately normalizing
|
||||
all channels within each group.</p>
|
||||
<h2>Formulation</h2>
|
||||
<p>All normalization layers can be defined by the following computation.</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\hat{x}_i = \frac{1}{\sigma_i} (x_i - \mu_i)</script>
|
||||
</p>
|
||||
<p>where $x$ is the tensor representing the batch,
|
||||
and $i$ is the index of a single value.
|
||||
For instance, when it’s 2D images
|
||||
$i = (i_N, i_C, i_H, i_W)$ is a 4-d vector for indexing
|
||||
image within batch, feature channel, vertical coordinate and horizontal coordinate.
|
||||
$\mu_i$ and $\sigma_i$ are mean and standard deviation.</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\mu_i &= \frac{1}{m} \sum_{k \in \mathcal{S}_i} x_k \\
|
||||
\sigma_i &= \sqrt{\frac{1}{m} \sum_{k \in \mathcal{S}_i} (x_k - \mu_i)^2 + \epsilon}
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>$\mathcal{S}_i$ is the set of indexes across which the mean and standard deviation
|
||||
are calculated for index $i$.
|
||||
$m$ is the size of the set $\mathcal{S}_i$ which is same for all $i$.</p>
|
||||
<p>The definition of $\mathcal{S}_i$ is different for
|
||||
<a href="../batch_norm/index.html">Batch normalization</a>,
|
||||
<a href="../layer_norm/index.html">Layer normalization</a>, and
|
||||
<a href="../instance_norm/index.html">Instance normalization</a>.</p>
|
||||
<h3><a href="../batch_norm/index.html">Batch Normalization</a></h3>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\mathcal{S}_i = \{k | k_C = i_C\}</script>
|
||||
</p>
|
||||
<p>The values that share the same feature channel are normalized together.</p>
|
||||
<h3><a href="../layer_norm/index.html">Layer Normalization</a></h3>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\mathcal{S}_i = \{k | k_N = i_N\}</script>
|
||||
</p>
|
||||
<p>The values from the same sample in the batch are normalized together.</p>
|
||||
<h3><a href="../instance_norm/index.html">Instance Normalization</a></h3>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\mathcal{S}_i = \{k | k_N = i_N, k_C = i_C\}</script>
|
||||
</p>
|
||||
<p>The values from the same sample and same feature channel are normalized together.</p>
|
||||
<h3>Group Normalization</h3>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\mathcal{S}_i = \{k | k_N = i_N,
|
||||
\bigg \lfloor \frac{k_C}{C/G} \bigg \rfloor = \bigg \lfloor \frac{i_C}{C/G} \bigg \rfloor\}</script>
|
||||
</p>
|
||||
<p>where $G$ is the number of groups and $C$ is the number of channels.</p>
|
||||
<p>Group normalization normalizes values of the same sample and the same group of channels together.</p>
|
||||
<p>Here’s a <a href="experiment.html">CIFAR 10 classification model</a> that uses instance normalization.</p>
|
||||
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/group_norm/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
|
||||
<a href="https://app.labml.ai/run/011254fe647011ebbb8e0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a>
|
||||
<a href="https://app.labml.ai/run/011254fe647011ebbb8e0242ac1c0002"><img alt="WandB" src="https://img.shields.io/badge/wandb-run-yellow" /></a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">12</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||||
<span class="lineno">14</span>
|
||||
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">87</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">88</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||||
<span class="lineno">89</span>
|
||||
<span class="lineno">90</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
@ -89,7 +151,7 @@
|
||||
<h2>Group Normalization Layer</h2>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">GroupNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">93</span><span class="k">class</span> <span class="nc">GroupNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
@ -105,8 +167,8 @@
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">23</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||||
<span class="lineno">24</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">98</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||||
<span class="lineno">99</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
@ -117,14 +179,14 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">31</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="lineno">32</span>
|
||||
<span class="lineno">33</span> <span class="k">assert</span> <span class="n">channels</span> <span class="o">%</span> <span class="n">groups</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"Number of channels should be evenly divisible by the number of groups"</span>
|
||||
<span class="lineno">34</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||||
<span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
|
||||
<span class="lineno">36</span>
|
||||
<span class="lineno">37</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||||
<span class="lineno">38</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">106</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="lineno">107</span>
|
||||
<span class="lineno">108</span> <span class="k">assert</span> <span class="n">channels</span> <span class="o">%</span> <span class="n">groups</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"Number of channels should be evenly divisible by the number of groups"</span>
|
||||
<span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||||
<span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
|
||||
<span class="lineno">111</span>
|
||||
<span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||||
<span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
@ -135,9 +197,9 @@
|
||||
<p>Create parameters for $\gamma$ and $\beta$ for scale and shift</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">40</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
|
||||
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">115</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
|
||||
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
@ -151,7 +213,7 @@
|
||||
<code>[batch_size, channels, height, width]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
@ -162,7 +224,7 @@
|
||||
<p>Keep the original shape</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">52</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">127</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
@ -173,7 +235,7 @@
|
||||
<p>Get the batch size</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
@ -184,7 +246,7 @@
|
||||
<p>Sanity check to make sure the number of features is the same</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">56</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">131</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
@ -192,10 +254,10 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>Reshape into <code>[batch_size, channels, n]</code></p>
|
||||
<p>Reshape into <code>[batch_size, groups, n]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
@ -203,11 +265,11 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Calculate the mean across first and last dimension;
|
||||
i.e. the means for each feature $\mathbb{E}[x^{(k)}]$</p>
|
||||
<p>Calculate the mean across last dimension;
|
||||
i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
@ -215,11 +277,11 @@ i.e. the means for each feature $\mathbb{E}[x^{(k)}]$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Calculate the squared mean across first and last dimension;
|
||||
i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
|
||||
<p>Calculate the squared mean across last dimension;
|
||||
i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">66</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
@ -227,10 +289,11 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$</p>
|
||||
<p>Variance for each sample and feature group
|
||||
$Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
@ -238,12 +301,13 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p>Normalize <script type="math/tex; mode=display">\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}] + \epsilon}}</script>
|
||||
<p>Normalize
|
||||
<script type="math/tex; mode=display">\hat{x}_{(i_N, i_G)} =
|
||||
\frac{x_{(i_N, i_G)} - \mathbb{E}[x_{(i_N, i_G)}]}{\sqrt{Var[x_{(i_N, i_G)}] + \epsilon}}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">71</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span>
|
||||
<span class="lineno">72</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
@ -251,12 +315,14 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Scale and shift <script type="math/tex; mode=display">y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}</script>
|
||||
<p>Scale and shift channel-wise
|
||||
<script type="math/tex; mode=display">y_{i_C} =\gamma_{i_C} \hat{x}_{i_C} + \beta_{i_C}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">76</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">153</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">154</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="lineno">155</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
@ -267,7 +333,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
|
||||
<p>Reshape to original and return</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">79</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">158</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
@ -278,7 +344,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
|
||||
<p>Simple test</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">82</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">161</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
@ -289,14 +355,14 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">86</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span>
|
||||
<span class="lineno">87</span>
|
||||
<span class="lineno">88</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span>
|
||||
<span class="lineno">89</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||||
<span class="lineno">90</span> <span class="n">bn</span> <span class="o">=</span> <span class="n">GroupNorm</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
|
||||
<span class="lineno">91</span>
|
||||
<span class="lineno">92</span> <span class="n">x</span> <span class="o">=</span> <span class="n">bn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="lineno">93</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">165</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span>
|
||||
<span class="lineno">166</span>
|
||||
<span class="lineno">167</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span>
|
||||
<span class="lineno">168</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||||
<span class="lineno">169</span> <span class="n">bn</span> <span class="o">=</span> <span class="n">GroupNorm</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
|
||||
<span class="lineno">170</span>
|
||||
<span class="lineno">171</span> <span class="n">x</span> <span class="o">=</span> <span class="n">bn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="lineno">172</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
@ -307,8 +373,8 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">97</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">98</span> <span class="n">_test</span><span class="p">()</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">176</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">177</span> <span class="n">_test</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -162,14 +162,21 @@
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/instance_norm/index.html</loc>
|
||||
<lastmod>2021-04-20T16:30:00+00:00</lastmod>
|
||||
<lastmod>2021-04-23T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/instance_norm/readme.html</loc>
|
||||
<lastmod>2021-04-23T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/instance_norm/experiment.html</loc>
|
||||
<lastmod>2021-04-20T16:30:00+00:00</lastmod>
|
||||
<lastmod>2021-04-23T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
@ -7,6 +7,81 @@ summary: >
|
||||
|
||||
# Group Normalization
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of
|
||||
the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
|
||||
|
||||
[Batch Normalization](../batch_norm/index.html) works well for sufficiently large batch sizes,
|
||||
but does not perform well for small batch sizes, because it normalizes across the batch.
|
||||
Training large models with large batch sizes is not possible due to the memory capacity of the
|
||||
devices.
|
||||
|
||||
This paper introduces Group Normalization, which normalizes a set of features together as a group.
|
||||
This is based on the observation that classical features such as
|
||||
[SIFT](https://en.wikipedia.org/wiki/Scale-invariant_feature_transform) and
|
||||
[HOG](https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients) are group-wise features.
|
||||
The paper proposes dividing feature channels into groups and then separately normalizing
|
||||
all channels within each group.
|
||||
|
||||
## Formulation
|
||||
|
||||
All normalization layers can be defined by the following computation.
|
||||
|
||||
$$\hat{x}_i = \frac{1}{\sigma_i} (x_i - \mu_i)$$
|
||||
|
||||
where $x$ is the tensor representing the batch,
|
||||
and $i$ is the index of a single value.
|
||||
For instance, when it's 2D images
|
||||
$i = (i_N, i_C, i_H, i_W)$ is a 4-d vector for indexing
|
||||
image within batch, feature channel, vertical coordinate and horizontal coordinate.
|
||||
$\mu_i$ and $\sigma_i$ are mean and standard deviation.
|
||||
|
||||
\begin{align}
|
||||
\mu_i &= \frac{1}{m} \sum_{k \in \mathcal{S}_i} x_k \\
|
||||
\sigma_i &= \sqrt{\frac{1}{m} \sum_{k \in \mathcal{S}_i} (x_k - \mu_i)^2 + \epsilon}
|
||||
\end{align}
|
||||
|
||||
$\mathcal{S}_i$ is the set of indexes across which the mean and standard deviation
|
||||
are calculated for index $i$.
|
||||
$m$ is the size of the set $\mathcal{S}_i$ which is same for all $i$.
|
||||
|
||||
The definition of $\mathcal{S}_i$ is different for
|
||||
[Batch normalization](../batch_norm/index.html),
|
||||
[Layer normalization](../layer_norm/index.html), and
|
||||
[Instance normalization](../instance_norm/index.html).
|
||||
|
||||
### [Batch Normalization](../batch_norm/index.html)
|
||||
|
||||
$$\mathcal{S}_i = \{k | k_C = i_C\}$$
|
||||
|
||||
The values that share the same feature channel are normalized together.
|
||||
|
||||
### [Layer Normalization](../layer_norm/index.html)
|
||||
|
||||
$$\mathcal{S}_i = \{k | k_N = i_N\}$$
|
||||
|
||||
The values from the same sample in the batch are normalized together.
|
||||
|
||||
### [Instance Normalization](../instance_norm/index.html)
|
||||
|
||||
$$\mathcal{S}_i = \{k | k_N = i_N, k_C = i_C\}$$
|
||||
|
||||
The values from the same sample and same feature channel are normalized together.
|
||||
|
||||
### Group Normalization
|
||||
|
||||
$$\mathcal{S}_i = \{k | k_N = i_N,
|
||||
\bigg \lfloor \frac{k_C}{C/G} \bigg \rfloor = \bigg \lfloor \frac{i_C}{C/G} \bigg \rfloor\}$$
|
||||
|
||||
where $G$ is the number of groups and $C$ is the number of channels.
|
||||
|
||||
Group normalization normalizes values of the same sample and the same group of channels together.
|
||||
|
||||
Here's a [CIFAR 10 classification model](experiment.html) that uses instance normalization.
|
||||
|
||||
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/group_norm/experiment.ipynb)
|
||||
[](https://app.labml.ai/run/011254fe647011ebbb8e0242ac1c0002)
|
||||
[](https://app.labml.ai/run/011254fe647011ebbb8e0242ac1c0002)
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -55,24 +130,28 @@ class GroupNorm(Module):
|
||||
# Sanity check to make sure the number of features is the same
|
||||
assert self.channels == x.shape[1]
|
||||
|
||||
# Reshape into `[batch_size, channels, n]`
|
||||
# Reshape into `[batch_size, groups, n]`
|
||||
x = x.view(batch_size, self.groups, -1)
|
||||
|
||||
# Calculate the mean across first and last dimension;
|
||||
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
|
||||
mean = x.mean(dim=[2], keepdim=True)
|
||||
# Calculate the squared mean across first and last dimension;
|
||||
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
|
||||
mean_x2 = (x ** 2).mean(dim=[2], keepdim=True)
|
||||
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
|
||||
# Calculate the mean across last dimension;
|
||||
# i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$
|
||||
mean = x.mean(dim=[-1], keepdim=True)
|
||||
# Calculate the squared mean across last dimension;
|
||||
# i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$
|
||||
mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
|
||||
# Variance for each sample and feature group
|
||||
# $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
# Normalize $$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}] + \epsilon}}$$
|
||||
# Normalize
|
||||
# $$\hat{x}_{(i_N, i_G)} =
|
||||
# \frac{x_{(i_N, i_G)} - \mathbb{E}[x_{(i_N, i_G)}]}{\sqrt{Var[x_{(i_N, i_G)}] + \epsilon}}$$
|
||||
x_norm = (x - mean) / torch.sqrt(var + self.eps)
|
||||
x_norm = x_norm.view(batch_size, self.channels, -1)
|
||||
|
||||
# Scale and shift $$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$
|
||||
# Scale and shift channel-wise
|
||||
# $$y_{i_C} =\gamma_{i_C} \hat{x}_{i_C} + \beta_{i_C}$$
|
||||
if self.affine:
|
||||
x_norm = x_norm.view(batch_size, self.channels, -1)
|
||||
x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
|
||||
|
||||
# Reshape to original and return
|
||||
|
2
setup.py
2
setup.py
@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
|
||||
|
||||
setuptools.setup(
|
||||
name='labml-nn',
|
||||
version='0.4.95',
|
||||
version='0.4.96',
|
||||
author="Varuna Jayasiri, Nipun Wijerathne",
|
||||
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
||||
description="A collection of PyTorch implementations of neural network architectures and layers.",
|
||||
|
Reference in New Issue
Block a user