batch norm formulas

This commit is contained in:
Varuna Jayasiri
2021-02-02 10:07:06 +05:30
parent cee8a37a4a
commit 1eb9fe4dd6
5 changed files with 584 additions and 42 deletions

View File

@ -3,12 +3,12 @@
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementations/tutorials of batch normalization."/>
<meta name="description" content="A PyTorch implementation/tutorial of batch normalization."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Batch Normalization"/>
<meta name="twitter:description" content="A PyTorch implementations/tutorials of batch normalization."/>
<meta name="twitter:description" content="A PyTorch implementation/tutorial of batch normalization."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
@ -18,7 +18,7 @@
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Batch Normalization"/>
<meta property="og:description" content="A PyTorch implementations/tutorials of batch normalization."/>
<meta property="og:description" content="A PyTorch implementation/tutorial of batch normalization."/>
<title>Batch Normalization</title>
<link rel="shortcut icon" href="/icon.png"/>
@ -151,6 +151,25 @@ a CNN classifier that use batch normalization for MNIST dataset.</p>
<a href='#section-1'>#</a>
</div>
<h2>Batch Normalization Layer</h2>
<p>Batch normalization layer $\text{BN}$ normalizes the input $X$ as follows:</p>
<p>When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
<script type="math/tex; mode=display">\text{BN}(X) = \gamma
\frac{X - \underset{B, H, W}{\mathbb{E}}[X]}{\sqrt{\underset{B, H, W}{Var}[X] + \epsilon}}
+ \beta</script>
</p>
<p>When input $X \in \mathbb{R}^{B \times C}$ is a batch of vector embeddings,
where $B$ is the batch size and $C$ is the number of features.
<script type="math/tex; mode=display">\text{BN}(X) = \gamma
\frac{X - \underset{B}{\mathbb{E}}[X]}{\sqrt{\underset{B}{Var}[X] + \epsilon}}
+ \beta</script>
</p>
<p>When input $X \in \mathbb{R}^{B \times C \times L}$ is a batch of sequence embeddings,
where $B$ is the batch size, $C$ is the number of features, and $L$ is the length of the sequence.
<script type="math/tex; mode=display">\text{BN}(X) = \gamma
\frac{X - \underset{B, L}{\mathbb{E}}[X]}{\sqrt{\underset{B, L}{Var}[X] + \epsilon}}
+ \beta</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">102</span><span class="k">class</span> <span class="nc">BatchNorm</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
@ -171,9 +190,9 @@ a CNN classifier that use batch normalization for MNIST dataset.</p>
<p>We&rsquo;ve tried to use the same names for arguments as PyTorch <code>BatchNorm</code> implementation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">108</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="lineno">109</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">128</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="lineno">129</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -184,14 +203,14 @@ a CNN classifier that use batch normalization for MNIST dataset.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">120</span>
<span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">122</span>
<span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">124</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
<span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span>
<span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span> <span class="o">=</span> <span class="n">track_running_stats</span></pre></div>
<div class="highlight"><pre><span class="lineno">139</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">140</span>
<span class="lineno">141</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">142</span>
<span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">144</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
<span class="lineno">145</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span>
<span class="lineno">146</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span> <span class="o">=</span> <span class="n">track_running_stats</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -202,9 +221,9 @@ a CNN classifier that use batch normalization for MNIST dataset.</p>
<p>Create parameters for $\gamma$ and $\beta$ for scale and shift</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">129</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">130</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">148</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">149</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">150</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -216,9 +235,9 @@ a CNN classifier that use batch normalization for MNIST dataset.</p>
mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
<span class="lineno">134</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_mean&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">135</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_var&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">153</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
<span class="lineno">154</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_mean&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">155</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_var&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -232,7 +251,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
<code>[batch_size, channels, height, width]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">157</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -243,7 +262,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
<p>Keep the original shape</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -254,7 +273,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
<p>Get the batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">167</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -265,7 +284,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
<p>Sanity check to make sure the number of features is same</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">169</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -276,7 +295,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
<p>Reshape into <code>[batch_size, channels, n]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">152</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -288,7 +307,7 @@ mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
if we are in training mode or if we have not tracked exponential moving averages</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">or</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">or</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -300,7 +319,7 @@ if we are in training mode or if we have not tracked exponential moving averages
i.e. the means for each feature $\mathbb{E}[x^{(k)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">159</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -312,7 +331,7 @@ i.e. the means for each feature $\mathbb{E}[x^{(k)}]$</p>
i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -323,7 +342,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<p>Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
<div class="highlight"><pre><span class="lineno">184</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -334,9 +353,9 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<p>Update exponential moving averages</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">167</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
<span class="lineno">168</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span>
<span class="lineno">169</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div>
<div class="highlight"><pre><span class="lineno">187</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span>
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -347,9 +366,9 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<p>Use exponential moving averages as estimates</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">172</span> <span class="n">mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span>
<span class="lineno">173</span> <span class="n">var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span></pre></div>
<div class="highlight"><pre><span class="lineno">191</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">192</span> <span class="n">mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span>
<span class="lineno">193</span> <span class="n">var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -361,7 +380,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">196</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -373,8 +392,8 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">179</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">198</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">199</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
@ -385,7 +404,7 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
<p>Reshape to original and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">202</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
</div>
</div>
</div>

View File

@ -0,0 +1,380 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation/tutorial of layer normalization."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Layer Normalization"/>
<meta name="twitter:description" content="A PyTorch implementation/tutorial of layer normalization."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/normalization/layer_norm/index.html"/>
<meta property="og:title" content="Layer Normalization"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Layer Normalization"/>
<meta property="og:description" content="A PyTorch implementation/tutorial of layer normalization."/>
<title>Layer Normalization</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/normalization/layer_norm/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">normalization</a>
<a class="parent" href="index.html">layer_norm</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/layer_norm/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Layer Normalization</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of
<a href="https://arxiv.org/abs/1607.06450">Layer Normalization</a>.</p>
<h3>Limitations of <a href="../batch_norm/index.html">Batch Normalization</a></h3>
<ul>
<li>You need to maintain running means.</li>
<li>Tricky for RNNs. Do you need different normalizations for each step?</li>
<li>Doesn&rsquo;t work with small batch sizes;
large NLP models are usually trained with small batch sizes.</li>
<li>Need to compute means and variances across devices in distributed training</li>
</ul>
<h2>Layer Normalization</h2>
<p>Layer normalization is a simpler normalization method that works
on a wider range of settings.
Layer normalization transformers the inputs to have zero mean and unit variance
across the features.
*Note that batch normalization, fixes the zero mean and unit variance for each vector.
Layer normalization does it for each batch across all elements.</p>
<p>Layer normalization is generally used for NLP tasks.</p>
<p>Here&rsquo;s <a href="mnist.html">the training code</a> and a notebook for training
a CNN classifier that use batch normalization for MNIST dataset.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/batch_norm/mnist.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://web.lab-ml.com/run?uuid=011254fe647011ebbb8e0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Batch Normalization Layer</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span><span class="k">class</span> <span class="nc">BatchNorm</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul>
<li><code>channels</code> is the number of features in the input</li>
<li><code>eps</code> is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability</li>
<li><code>momentum</code> is the momentum in taking the exponential moving average</li>
<li><code>affine</code> is whether to scale and shift the normalized value</li>
<li><code>track_running_stats</code> is whether to calculate the moving averages or mean and variance</li>
</ul>
<p>We&rsquo;ve tried to use the same names for arguments as PyTorch <code>BatchNorm</code> implementation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">48</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">49</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="lineno">50</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">61</span>
<span class="lineno">62</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">63</span>
<span class="lineno">64</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
<span class="lineno">66</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span>
<span class="lineno">67</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span> <span class="o">=</span> <span class="n">track_running_stats</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Create parameters for $\gamma$ and $\beta$ for scale and shift</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">70</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">71</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Create buffers to store exponential moving averages of
mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
<span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_mean&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_var&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p><code>x</code> is a tensor of shape <code>[batch_size, channels, *]</code>.
<code>*</code> could be any (even *) dimensions.
For example, in an image (2D) convolution this will be
<code>[batch_size, channels, height, width]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Keep the original shape</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Get the batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Sanity check to make sure the number of features is same</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Reshape into <code>[batch_size, channels, n]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>We will calculate the mini-batch mean and variance
if we are in training mode or if we have not tracked exponential moving averages</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">or</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Calculate the mean across first and last dimension;
i.e. the means for each feature $\mathbb{E}[x^{(k)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Calculate the squared mean across first and last dimension;
i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Update exponential moving averages</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
<span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span>
<span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Use exponential moving averages as estimates</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">113</span> <span class="n">mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span>
<span class="lineno">114</span> <span class="n">var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Normalize <script type="math/tex; mode=display">\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}] + \epsilon}}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Scale and shift <script type="math/tex; mode=display">y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">120</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Reshape to original and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>

View File

@ -99,7 +99,7 @@
<url>
<loc>https://nn.labml.ai/normalization/batch_norm/index.html</loc>
<lastmod>2021-02-01T16:30:00+00:00</lastmod>
<lastmod>2021-02-02T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>

View File

@ -2,7 +2,7 @@
---
title: Batch Normalization
summary: >
A PyTorch implementations/tutorials of batch normalization.
A PyTorch implementation/tutorial of batch normalization.
---
# Batch Normalization
@ -100,8 +100,28 @@ from torch import nn
class BatchNorm(nn.Module):
"""
r"""
## Batch Normalization Layer
Batch normalization layer $\text{BN}$ normalizes the input $X$ as follows:
When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
$$\text{BN}(X) = \gamma
\frac{X - \underset{B, H, W}{\mathbb{E}}[X]}{\sqrt{\underset{B, H, W}{Var}[X] + \epsilon}}
+ \beta$$
When input $X \in \mathbb{R}^{B \times C}$ is a batch of vector embeddings,
where $B$ is the batch size and $C$ is the number of features.
$$\text{BN}(X) = \gamma
\frac{X - \underset{B}{\mathbb{E}}[X]}{\sqrt{\underset{B}{Var}[X] + \epsilon}}
+ \beta$$
When input $X \in \mathbb{R}^{B \times C \times L}$ is a batch of sequence embeddings,
where $B$ is the batch size, $C$ is the number of features, and $L$ is the length of the sequence.
$$\text{BN}(X) = \gamma
\frac{X - \underset{B, L}{\mathbb{E}}[X]}{\sqrt{\underset{B, L}{Var}[X] + \epsilon}}
+ \beta$$
"""
def __init__(self, channels: int, *,

View File

@ -0,0 +1,123 @@
"""
---
title: Layer Normalization
summary: >
A PyTorch implementation/tutorial of layer normalization.
---
# Layer Normalization
This is a [PyTorch](https://pytorch.org) implementation of
[Layer Normalization](https://arxiv.org/abs/1607.06450).
### Limitations of [Batch Normalization](../batch_norm/index.html)
* You need to maintain running means.
* Tricky for RNNs. Do you need different normalizations for each step?
* Doesn't work with small batch sizes;
large NLP models are usually trained with small batch sizes.
* Need to compute means and variances across devices in distributed training
## Layer Normalization
Layer normalization is a simpler normalization method that works
on a wider range of settings.
Layer normalization transformers the inputs to have zero mean and unit variance
across the features.
*Note that batch normalization, fixes the zero mean and unit variance for each vector.
Layer normalization does it for each batch across all elements.
Layer normalization is generally used for NLP tasks.
Here's [the training code](mnist.html) and a notebook for training
a CNN classifier that use batch normalization for MNIST dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/batch_norm/mnist.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=011254fe647011ebbb8e0242ac1c0002)
"""
import torch
from torch import nn
class BatchNorm(nn.Module):
"""
## Batch Normalization Layer
"""
def __init__(self, channels: int, *,
eps: float = 1e-5, momentum: float = 0.1,
affine: bool = True, track_running_stats: bool = True):
"""
* `channels` is the number of features in the input
* `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
* `momentum` is the momentum in taking the exponential moving average
* `affine` is whether to scale and shift the normalized value
* `track_running_stats` is whether to calculate the moving averages or mean and variance
We've tried to use the same names for arguments as PyTorch `BatchNorm` implementation.
"""
super().__init__()
self.channels = channels
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
# Create parameters for $\gamma$ and $\beta$ for scale and shift
if self.affine:
self.scale = nn.Parameter(torch.ones(channels))
self.shift = nn.Parameter(torch.zeros(channels))
# Create buffers to store exponential moving averages of
# mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$
if self.track_running_stats:
self.register_buffer('exp_mean', torch.zeros(channels))
self.register_buffer('exp_var', torch.ones(channels))
def forward(self, x: torch.Tensor):
"""
`x` is a tensor of shape `[batch_size, channels, *]`.
`*` could be any (even *) dimensions.
For example, in an image (2D) convolution this will be
`[batch_size, channels, height, width]`
"""
# Keep the original shape
x_shape = x.shape
# Get the batch size
batch_size = x_shape[0]
# Sanity check to make sure the number of features is same
assert self.channels == x.shape[1]
# Reshape into `[batch_size, channels, n]`
x = x.view(batch_size, self.channels, -1)
# We will calculate the mini-batch mean and variance
# if we are in training mode or if we have not tracked exponential moving averages
if self.training or not self.track_running_stats:
# Calculate the mean across first and last dimension;
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
mean = x.mean(dim=[0, 2])
# Calculate the squared mean across first and last dimension;
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
mean_x2 = (x ** 2).mean(dim=[0, 2])
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
var = mean_x2 - mean ** 2
# Update exponential moving averages
if self.training and self.track_running_stats:
self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
# Use exponential moving averages as estimates
else:
mean = self.exp_mean
var = self.exp_var
# Normalize $$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}] + \epsilon}}$$
x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
# Scale and shift $$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$
if self.affine:
x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
# Reshape to original and return
return x_norm.view(x_shape)