📚 batch norm notebook

This commit is contained in:
Varuna Jayasiri
2021-02-01 15:01:56 +05:30
parent 983286e216
commit 03197f4347
6 changed files with 1505 additions and 84 deletions

View File

@ -126,6 +126,8 @@ where $y^{(k)}$ is the output of of the batch normalization layer.</p>
like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization.
So you can and should omit bias parameter in linear transforms right before the
batch normalization.</p>
<p>Batch normalization also makes the back propagation invariant to the scale of the weights.
And empirically it improves generalization, so it has regularization effects too.</p>
<h2>Inference</h2>
<p>We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to
perform the normalization.
@ -133,12 +135,16 @@ So during inference, you either need to go through the whole (or part of) datase
and find the mean and variance, or you can use an estimate calculated during training.
The usual practice is to calculate an exponential moving average of
mean and variance during training phase and use that for inference.</p>
<p>Here&rsquo;s <a href="mnist.html">the training code</a> and a notebook for training
a CNN classifier that use batch normalization for MNIST dataset.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/batch_norm/mnist.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://web.lab-ml.com/run?uuid=011254fe647011ebbb8e0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">90</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">91</span>
<span class="lineno">92</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
<div class="highlight"><pre><span class="lineno">98</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">99</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">100</span>
<span class="lineno">101</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -149,7 +155,7 @@ mean and variance during training phase and use that for inference.</p>
<h2>Batch Normalization Layer</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span><span class="k">class</span> <span class="nc">BatchNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">104</span><span class="k">class</span> <span class="nc">BatchNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -167,9 +173,9 @@ mean and variance during training phase and use that for inference.</p>
<p>We&rsquo;ve tried to use the same names for arguments as PyTorch <code>BatchNorm</code> implementation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">100</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="lineno">101</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">109</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="lineno">110</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -180,14 +186,14 @@ mean and variance during training phase and use that for inference.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">112</span>
<span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">114</span>
<span class="lineno">115</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span>
<span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span> <span class="o">=</span> <span class="n">track_running_stats</span></pre></div>
<div class="highlight"><pre><span class="lineno">120</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">121</span>
<span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">123</span>
<span class="lineno">124</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
<span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span>
<span class="lineno">127</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span> <span class="o">=</span> <span class="n">track_running_stats</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -198,9 +204,9 @@ mean and variance during training phase and use that for inference.</p>
<p>Create parameters for $\gamma$ and $\beta$ for scale and shift</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">129</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">130</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">131</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -212,9 +218,9 @@ mean and variance during training phase and use that for inference.</p>
mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
<span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_mean&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">127</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_var&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">134</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
<span class="lineno">135</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_mean&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">136</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_var&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -228,7 +234,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
<code>[batch_size, channels, height, width]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">138</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -239,7 +245,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
<p>Keep the original shape</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -250,7 +256,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
<p>Get the batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -261,7 +267,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
<p>Sanity check to make sure the number of features is same</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">150</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -272,7 +278,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
<p>Reshape into <code>[batch_size, channels, n]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">153</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -284,7 +290,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
if we are in training mode or if we have not tracked exponential moving averages</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">or</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">157</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">or</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -296,7 +302,7 @@ if we are in training mode or if we have not tracked exponential moving averages
i.e. the means for each feature $\mathbb{E}[x^{(k)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -308,7 +314,7 @@ i.e. the means for each feature $\mathbb{E}[x^{(k)}]$</p>
i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">163</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -319,7 +325,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<p>Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -330,9 +336,9 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<p>Update exponential moving averages</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">159</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
<span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span>
<span class="lineno">161</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div>
<div class="highlight"><pre><span class="lineno">168</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
<span class="lineno">169</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span>
<span class="lineno">170</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -343,9 +349,9 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<p>Use exponential moving averages as estimates</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">164</span> <span class="n">mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span>
<span class="lineno">165</span> <span class="n">var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span></pre></div>
<div class="highlight"><pre><span class="lineno">172</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">173</span> <span class="n">mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span>
<span class="lineno">174</span> <span class="n">var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -357,7 +363,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">177</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -369,8 +375,8 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">170</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">171</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">179</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">180</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
@ -381,7 +387,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<p>Reshape to original and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">183</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
</div>
</div>
</div>

View File

@ -3,12 +3,12 @@
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This is a simple model for MNIST digit classification that uses batch normalization"/>
<meta name="description" content="This trains is a simple convolutional neural network that uses batch normalization to classify MNIST digits."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="MNIST Experiment to try Batch Normalization"/>
<meta name="twitter:description" content="This is a simple model for MNIST digit classification that uses batch normalization"/>
<meta name="twitter:description" content="This trains is a simple convolutional neural network that uses batch normalization to classify MNIST digits."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
@ -18,7 +18,7 @@
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="MNIST Experiment to try Batch Normalization"/>
<meta property="og:description" content="This is a simple model for MNIST digit classification that uses batch normalization"/>
<meta property="og:description" content="This trains is a simple convolutional neural network that uses batch normalization to classify MNIST digits."/>
<title>MNIST Experiment to try Batch Normalization</title>
<link rel="shortcut icon" href="/icon.png"/>
@ -75,15 +75,15 @@
<h1>MNIST Experiment for Batch Normalization</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">12</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
<span class="lineno">13</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">14</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.mnist</span> <span class="kn">import</span> <span class="n">MNISTConfigs</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.batch_norm</span> <span class="kn">import</span> <span class="n">BatchNorm</span></pre></div>
<div class="highlight"><pre><span class="lineno">12</span><span></span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">13</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
<span class="lineno">14</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">15</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.mnist</span> <span class="kn">import</span> <span class="n">MNISTConfigs</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.batch_norm</span> <span class="kn">import</span> <span class="n">BatchNorm</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -94,7 +94,7 @@
<h3>Model definition</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">23</span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -105,8 +105,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">28</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">28</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">29</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -117,7 +117,7 @@
<p>Note that we omit the bias parameter</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">30</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">31</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -129,7 +129,7 @@
The input to this layer will have shape <code>[batch_size, 20, height(24), width(24)]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">20</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">34</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">20</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -140,7 +140,7 @@ The input to this layer will have shape <code>[batch_size, 20, height(24), width
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">36</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -152,7 +152,7 @@ The input to this layer will have shape <code>[batch_size, 20, height(24), width
The input to this layer will have shape <code>[batch_size, 50, height(8), width(8)]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn2</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">50</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">39</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn2</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">50</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -163,7 +163,7 @@ The input to this layer will have shape <code>[batch_size, 50, height(8), width(
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -175,7 +175,7 @@ The input to this layer will have shape <code>[batch_size, 50, height(8), width(
The input to this layer will have shape <code>[batch_size, 500]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn3</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">500</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">44</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn3</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">500</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -186,7 +186,7 @@ The input to this layer will have shape <code>[batch_size, 500]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">46</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -197,14 +197,14 @@ The input to this layer will have shape <code>[batch_size, 500]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">48</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="lineno">49</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="lineno">50</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="lineno">51</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="lineno">52</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">50</span><span class="p">)</span>
<span class="lineno">53</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn3</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="lineno">54</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">48</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">49</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="lineno">50</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="lineno">51</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="lineno">52</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="lineno">53</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">50</span><span class="p">)</span>
<span class="lineno">54</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn3</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="lineno">55</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -217,8 +217,8 @@ The input to this layer will have shape <code>[batch_size, 500]</code></p>
and set a new function to calculate the model.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span><span class="nd">@option</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">58</span><span class="k">def</span> <span class="nf">model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">MNISTConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">58</span><span class="nd">@option</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">59</span><span class="k">def</span> <span class="nf">model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">MNISTConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -229,7 +229,7 @@ and set a new function to calculate the model.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="k">return</span> <span class="n">Model</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">66</span> <span class="k">return</span> <span class="n">Model</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -240,7 +240,7 @@ and set a new function to calculate the model.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">69</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -251,7 +251,7 @@ and set a new function to calculate the model.</p>
<p>Create experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;mnist_batch_norm&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">71</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;mnist_batch_norm&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -262,7 +262,7 @@ and set a new function to calculate the model.</p>
<p>Create configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">MNISTConfigs</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">MNISTConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -273,7 +273,7 @@ and set a new function to calculate the model.</p>
<p>Load configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span><span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">})</span></pre></div>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span><span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -284,8 +284,8 @@ and set a new function to calculate the model.</p>
<p>Start the experiment and run the training loop</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">77</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">77</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">78</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -296,8 +296,8 @@ and set a new function to calculate the model.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">82</span> <span class="n">main</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">82</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">83</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
</div>

View File

@ -76,6 +76,9 @@ like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization.
So you can and should omit bias parameter in linear transforms right before the
batch normalization.
Batch normalization also makes the back propagation invariant to the scale of the weights.
And empirically it improves generalization, so it has regularization effects too.
## Inference
We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to
@ -84,6 +87,12 @@ So during inference, you either need to go through the whole (or part of) datase
and find the mean and variance, or you can use an estimate calculated during training.
The usual practice is to calculate an exponential moving average of
mean and variance during training phase and use that for inference.
Here's [the training code](mnist.html) and a notebook for training
a CNN classifier that use batch normalization for MNIST dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/batch_norm/mnist.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=011254fe647011ebbb8e0242ac1c0002)
"""
import torch

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,8 @@
---
title: MNIST Experiment to try Batch Normalization
summary: >
This is a simple model for MNIST digit classification that uses batch normalization
This trains is a simple convolutional neural network that uses batch normalization
to classify MNIST digits.
---
# MNIST Experiment for Batch Normalization

View File

@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
version='0.4.84',
version='0.4.85',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="A collection of PyTorch implementations of neural network architectures and layers.",