mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
📚 batch norm
This commit is contained in:
1
Makefile
1
Makefile
@ -22,6 +22,7 @@ uninstall: ## Uninstall
|
||||
pip uninstall labml_nn
|
||||
|
||||
docs: ## Render annotated HTML
|
||||
find ./docs/ -name "*.html" -type f -delete
|
||||
python utils/sitemap.py
|
||||
cd labml_nn; pylit --remove_empty_sections --title_md -t ../../../pylit/templates/nn -d ../docs -w *
|
||||
|
||||
|
@ -1,186 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content=""/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="batch_norm.py"/>
|
||||
<meta name="twitter:description" content=""/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/normalization/batch_norm.html"/>
|
||||
<meta property="og:title" content="batch_norm.py"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="batch_norm.py"/>
|
||||
<meta property="og:description" content=""/>
|
||||
|
||||
<title>batch_norm.py</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/normalization/batch_norm.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="index.html">normalization</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/batch_norm.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
|
||||
rel="nofollow">
|
||||
<img alt="Join Slact"
|
||||
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">2</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||||
<span class="lineno">3</span>
|
||||
<span class="lineno">4</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">7</span><span class="k">class</span> <span class="nc">BatchNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">8</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||||
<span class="lineno">9</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
|
||||
<span class="lineno">10</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
|
||||
<span class="lineno">11</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="lineno">12</span>
|
||||
<span class="lineno">13</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
|
||||
<span class="lineno">14</span>
|
||||
<span class="lineno">15</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||||
<span class="lineno">16</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
|
||||
<span class="lineno">17</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span>
|
||||
<span class="lineno">18</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span> <span class="o">=</span> <span class="n">track_running_stats</span>
|
||||
<span class="lineno">19</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">20</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
|
||||
<span class="lineno">21</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
|
||||
<span class="lineno">22</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
|
||||
<span class="lineno">23</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">'running_mean'</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
|
||||
<span class="lineno">24</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">'running_var'</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">26</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||||
<span class="lineno">27</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="lineno">28</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="lineno">29</span>
|
||||
<span class="lineno">30</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="lineno">31</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">or</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
|
||||
<span class="lineno">32</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
|
||||
<span class="lineno">33</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
|
||||
<span class="lineno">34</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span>
|
||||
<span class="lineno">35</span>
|
||||
<span class="lineno">36</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
|
||||
<span class="lineno">37</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span>
|
||||
<span class="lineno">38</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span>
|
||||
<span class="lineno">39</span> <span class="k">else</span><span class="p">:</span>
|
||||
<span class="lineno">40</span> <span class="n">mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_mean</span>
|
||||
<span class="lineno">41</span> <span class="n">var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_var</span>
|
||||
<span class="lineno">42</span>
|
||||
<span class="lineno">43</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="lineno">44</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">45</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="lineno">46</span>
|
||||
<span class="lineno">47</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
420
docs/normalization/batch_norm/index.html
Normal file
420
docs/normalization/batch_norm/index.html
Normal file
@ -0,0 +1,420 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="A PyTorch implementations/tutorials of batch normalization."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Batch Normalization"/>
|
||||
<meta name="twitter:description" content="A PyTorch implementations/tutorials of batch normalization."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/normalization/batch_norm/index.html"/>
|
||||
<meta property="og:title" content="Batch Normalization"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Batch Normalization"/>
|
||||
<meta property="og:description" content="A PyTorch implementations/tutorials of batch normalization."/>
|
||||
|
||||
<title>Batch Normalization</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/normalization/batch_norm/index.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">normalization</a>
|
||||
<a class="parent" href="index.html">batch_norm</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/batch_norm/__init__.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
|
||||
rel="nofollow">
|
||||
<img alt="Join Slact"
|
||||
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Batch Normalization</h1>
|
||||
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of Batch Normalization from paper
|
||||
<a href="https://arxiv.org/abs/1502.03167">Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift</a>.</p>
|
||||
<h3>Internal Covariate Shift</h3>
|
||||
<p>The paper defines <em>Internal Covariate Shift</em> as the change in the
|
||||
distribution of network activations due to the change in
|
||||
network parameters during training.
|
||||
For example, let’s say there are two layers $l_1$ and $l_2$.
|
||||
During the beginning of the training $l_1$ outputs (inputs to $l_2$)
|
||||
could be in distribution $\mathcal{N}(0.5, 1)$.
|
||||
Then, after some training steps it could move to $\mathcal{N}(0.5, 1)$.
|
||||
This is <em>internal covariate shift</em>.</p>
|
||||
<p>Internal covriate shift will adversely affect training speed because the later layers
|
||||
($l_2$ in the above example) has to adapt to this shifted distribution.</p>
|
||||
<p>By stabilizing the distribution batch normalization minimizes the internal covariate shift.</p>
|
||||
<h2>Normalization</h2>
|
||||
<p>It is known that whitening improves training speed and convergence.
|
||||
<em>Whitening</em> is linearly transforming inputs to have zero mean, unit variance
|
||||
and be uncorrelated.</p>
|
||||
<h3>Normalizing outside gradient computation doesn’t work</h3>
|
||||
<p>Normalizing outside the gradient computation using pre-computed (detached)
|
||||
means and variances doesn’t work. For instance. (ignoring variance), let
|
||||
<script type="math/tex; mode=display">\hat{x} = x - \mathbb{E}[x]</script>
|
||||
where $x = u + b$ and $b$ is a trained bias.
|
||||
and $\mathbb{E}[x]$ is outside gradient computation (pre-computed constant).</p>
|
||||
<p>Note that $\hat{x}$ has no effect of $b$.
|
||||
Therefore,
|
||||
$b$ will increase or decrease based
|
||||
$\frac{\partial{\mathcal{L}}}{\partial x}$,
|
||||
and keep on growing indefinitely in each training update.
|
||||
Paper notes that similar explosions happen with variances.</p>
|
||||
<h3>Batch Normalization</h3>
|
||||
<p>Whitening is computationally expensive because you need to de-correlate and
|
||||
the gradients must flow through the full whitening calculation.</p>
|
||||
<p>The paper introduces simplified version which they call <em>Batch Normalization</em>.
|
||||
First simplification is that it normalizes each feature independently to have
|
||||
zero mean and unit variance:
|
||||
<script type="math/tex; mode=display">\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}</script>
|
||||
where $x = (x^{(1)} … x^{(d)})$ is the $d$-dimensional input.</p>
|
||||
<p>The second simplification is to use estimates of mean $\mathbb{E}[x^{(k)}]$
|
||||
and variance $Var[x^{(k)}]$ from the mini-batch
|
||||
for normalization; instead of calculating the mean and variance across whole dataset.</p>
|
||||
<p>Normalizing each feature to zero mean and unit variance could effect what the layer
|
||||
can represent.
|
||||
As an example paper illustrates that, if the inputs to a sigmoid are normalized
|
||||
most of it will be within $[-1, 1]$ range where the sigmoid is linear.
|
||||
To overcome this each feature is scaled and shifted by two trained parameters
|
||||
$\gamma^{(k)}$ and $\beta^{(k)}$.
|
||||
<script type="math/tex; mode=display">y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}</script>
|
||||
where $y^{(k)}$ is the output of of the batch normalization layer.</p>
|
||||
<p>Note that when applying batch normalization after a linear transform
|
||||
like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization.
|
||||
So you can and should omit bias parameter in linear transforms right before the
|
||||
batch normalization.</p>
|
||||
<h2>Inference</h2>
|
||||
<p>We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to
|
||||
perform the normalization.
|
||||
So during inference, you either need to go through the whole (or part of) dataset
|
||||
and find the mean and variance, or you can use an estimate calculated during training.
|
||||
The usual practice is to calculate an exponential moving average of
|
||||
mean and variance during training phase and use that for inference.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">89</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">90</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||||
<span class="lineno">91</span>
|
||||
<span class="lineno">92</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h2>Batch Normalization Layer</h2>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">95</span><span class="k">class</span> <span class="nc">BatchNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>channels</code> is the number of features in the input</li>
|
||||
<li><code>eps</code> is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability</li>
|
||||
<li><code>momentum</code> is the momentum in taking the exponential moving average</li>
|
||||
<li><code>affine</code> is whether to scale and shift the normalized value</li>
|
||||
<li><code>track_running_stats</code> is whether to calculate the moving averages or mean and variance</li>
|
||||
</ul>
|
||||
<p>We’ve tried to use the same names for arguments as PyTorch <code>BatchNorm</code> implementation.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">99</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||||
<span class="lineno">100</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
|
||||
<span class="lineno">101</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">111</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="lineno">112</span>
|
||||
<span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
|
||||
<span class="lineno">114</span>
|
||||
<span class="lineno">115</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||||
<span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
|
||||
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span>
|
||||
<span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span> <span class="o">=</span> <span class="n">track_running_stats</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>Create parameters for $\gamma$ and $\beta$ for scale and shift</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
|
||||
<span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Create buffers to store exponential moving averages of
|
||||
mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">125</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
|
||||
<span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">'exp_mean'</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
|
||||
<span class="lineno">127</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">'exp_var'</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p><code>x</code> is a tensor of shape <code>[batch_size, channels, *]</code>.
|
||||
<code>*</code> could be any (even *) dimensions.
|
||||
For example, in an image (2D) convolution this will be
|
||||
<code>[batch_size, channels, height, width]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">129</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>Keep the original shape</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>Get the batch size</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>Sanity check to make sure the number of features is same</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">141</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Reshape into <code>[batch_size, channels, n]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>We will calculate the mini-batch mean and variance
|
||||
if we are in training mode or if we have not tracked exponential moving averages</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">148</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">or</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Calculate the mean across first and last dimension;
|
||||
i.e. the means for each feature $\mathbb{E}[x^{(k)}]$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p>Calculate the squared mean across first and last dimension;
|
||||
i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Update exponential moving averages</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">159</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">:</span>
|
||||
<span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span>
|
||||
<span class="lineno">161</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Use exponential moving averages as estimates</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">163</span> <span class="k">else</span><span class="p">:</span>
|
||||
<span class="lineno">164</span> <span class="n">mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span>
|
||||
<span class="lineno">165</span> <span class="n">var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Normalize <script type="math/tex; mode=display">\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}] + \epsilon}}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Scale and shift <script type="math/tex; mode=display">y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">170</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">171</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Reshape to original and return</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
336
docs/normalization/batch_norm/mnist.html
Normal file
336
docs/normalization/batch_norm/mnist.html
Normal file
@ -0,0 +1,336 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="This is a simple model for MNIST digit classification that uses batch normalization"/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="MNIST Experiment to try Batch Normalization"/>
|
||||
<meta name="twitter:description" content="This is a simple model for MNIST digit classification that uses batch normalization"/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/normalization/batch_norm/mnist.html"/>
|
||||
<meta property="og:title" content="MNIST Experiment to try Batch Normalization"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="MNIST Experiment to try Batch Normalization"/>
|
||||
<meta property="og:description" content="This is a simple model for MNIST digit classification that uses batch normalization"/>
|
||||
|
||||
<title>MNIST Experiment to try Batch Normalization</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/normalization/batch_norm/mnist.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">normalization</a>
|
||||
<a class="parent" href="index.html">batch_norm</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/batch_norm/mnist.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
|
||||
rel="nofollow">
|
||||
<img alt="Join Slact"
|
||||
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>MNIST Experiment for Batch Normalization</h1>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||
<span class="lineno">12</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
|
||||
<span class="lineno">13</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
|
||||
<span class="lineno">14</span>
|
||||
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
|
||||
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
|
||||
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
||||
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.mnist</span> <span class="kn">import</span> <span class="n">MNISTConfigs</span>
|
||||
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.batch_norm</span> <span class="kn">import</span> <span class="n">BatchNorm</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h3>Model definition</h3>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">22</span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">27</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="lineno">28</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<p>Note that we omit the bias parameter</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">30</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>Batch normalization with 20 channels (output of convolution layer).
|
||||
The input to this layer will have shape <code>[batch_size, 20, height(24), width(24)]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">33</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">20</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Batch normalization with 50 channels.
|
||||
The input to this layer will have shape <code>[batch_size, 50, height(8), width(8)]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">38</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn2</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">50</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">40</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>Batch normalization with 500 channels (output of fully connected layer).
|
||||
The input to this layer will have shape <code>[batch_size, 500]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn3</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">500</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">47</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||||
<span class="lineno">48</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
|
||||
<span class="lineno">49</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||||
<span class="lineno">50</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
|
||||
<span class="lineno">51</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||||
<span class="lineno">52</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">50</span><span class="p">)</span>
|
||||
<span class="lineno">53</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn3</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
|
||||
<span class="lineno">54</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<h3>Create model</h3>
|
||||
<p>We use <a href="../../experiments/mnist.html#MNISTConfigs"><code>MNISTConfigs</code></a> configurations
|
||||
and set a new function to calculate the model.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">57</span><span class="nd">@option</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
||||
<span class="lineno">58</span><span class="k">def</span> <span class="nf">model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">MNISTConfigs</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">65</span> <span class="k">return</span> <span class="n">Model</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">68</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Create experiment</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'mnist_batch_norm'</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Create configurations</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">MNISTConfigs</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Load configurations</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">74</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span><span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Start the experiment and run the training loop</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
|
||||
<span class="lineno">77</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">81</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">82</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
@ -3,24 +3,24 @@
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content=""/>
|
||||
<meta name="description" content="A set of PyTorch implementations/tutorials of normalization layers."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="None"/>
|
||||
<meta name="twitter:description" content=""/>
|
||||
<meta name="twitter:title" content="Normalization Layers"/>
|
||||
<meta name="twitter:description" content="A set of PyTorch implementations/tutorials of normalization layers."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/normalization/index.html"/>
|
||||
<meta property="og:title" content="None"/>
|
||||
<meta property="og:title" content="Normalization Layers"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="None"/>
|
||||
<meta property="og:description" content=""/>
|
||||
<meta property="og:title" content="Normalization Layers"/>
|
||||
<meta property="og:description" content="A set of PyTorch implementations/tutorials of normalization layers."/>
|
||||
|
||||
<title>None</title>
|
||||
<title>Normalization Layers</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/normalization/index.html"/>
|
||||
@ -66,6 +66,20 @@
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Normalization Layers</h1>
|
||||
<ul>
|
||||
<li><a href="batch_norm/index.html">Batch Normalization</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
|
@ -1,278 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content=""/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="mnist.py"/>
|
||||
<meta name="twitter:description" content=""/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/normalization/mnist.html"/>
|
||||
<meta property="og:title" content="mnist.py"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="mnist.py"/>
|
||||
<meta property="og:description" content=""/>
|
||||
|
||||
<title>mnist.py</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/normalization/mnist.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="index.html">normalization</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/mnist.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
|
||||
rel="nofollow">
|
||||
<img alt="Join Slact"
|
||||
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||
<span class="lineno">2</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
|
||||
<span class="lineno">3</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
|
||||
<span class="lineno">4</span>
|
||||
<span class="lineno">5</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">tracker</span>
|
||||
<span class="lineno">6</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
|
||||
<span class="lineno">7</span><span class="kn">from</span> <span class="nn">labml_helpers.datasets.mnist</span> <span class="kn">import</span> <span class="n">MNISTConfigs</span>
|
||||
<span class="lineno">8</span><span class="kn">from</span> <span class="nn">labml_helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
|
||||
<span class="lineno">9</span><span class="kn">from</span> <span class="nn">labml_helpers.metrics.accuracy</span> <span class="kn">import</span> <span class="n">Accuracy</span>
|
||||
<span class="lineno">10</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
||||
<span class="lineno">11</span><span class="kn">from</span> <span class="nn">labml_helpers.seed</span> <span class="kn">import</span> <span class="n">SeedConfigs</span>
|
||||
<span class="lineno">12</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">TrainValidConfigs</span><span class="p">,</span> <span class="n">BatchIndex</span><span class="p">,</span> <span class="n">hook_model_outputs</span>
|
||||
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.batch_norm</span> <span class="kn">import</span> <span class="n">BatchNorm</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">16</span><span class="k">class</span> <span class="nc">Net</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">17</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="lineno">18</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="lineno">19</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="lineno">20</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">20</span><span class="p">)</span>
|
||||
<span class="lineno">21</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="lineno">22</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn2</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">50</span><span class="p">)</span>
|
||||
<span class="lineno">23</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">500</span><span class="p">)</span>
|
||||
<span class="lineno">24</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn3</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="mi">500</span><span class="p">)</span>
|
||||
<span class="lineno">25</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">27</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||||
<span class="lineno">28</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
|
||||
<span class="lineno">29</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||||
<span class="lineno">30</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
|
||||
<span class="lineno">31</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||||
<span class="lineno">32</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">50</span><span class="p">)</span>
|
||||
<span class="lineno">33</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn3</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
|
||||
<span class="lineno">34</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">37</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="p">,</span> <span class="n">TrainValidConfigs</span><span class="p">):</span>
|
||||
<span class="lineno">38</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
|
||||
<span class="lineno">39</span> <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span>
|
||||
<span class="lineno">40</span> <span class="n">set_seed</span> <span class="o">=</span> <span class="n">SeedConfigs</span><span class="p">()</span>
|
||||
<span class="lineno">41</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span>
|
||||
<span class="lineno">42</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span>
|
||||
<span class="lineno">43</span>
|
||||
<span class="lineno">44</span> <span class="n">is_save_models</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="lineno">45</span> <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span>
|
||||
<span class="lineno">46</span> <span class="n">inner_iterations</span> <span class="o">=</span> <span class="mi">10</span>
|
||||
<span class="lineno">47</span>
|
||||
<span class="lineno">48</span> <span class="n">accuracy_func</span> <span class="o">=</span> <span class="n">Accuracy</span><span class="p">()</span>
|
||||
<span class="lineno">49</span> <span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">51</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="lineno">52</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_queue</span><span class="p">(</span><span class="s2">"loss.*"</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
|
||||
<span class="lineno">53</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">"accuracy.*"</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
|
||||
<span class="lineno">54</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">,</span> <span class="s1">'model'</span><span class="p">)</span>
|
||||
<span class="lineno">55</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">accuracy_func</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">57</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span>
|
||||
<span class="lineno">58</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="lineno">59</span>
|
||||
<span class="lineno">60</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
|
||||
<span class="lineno">61</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span>
|
||||
<span class="lineno">62</span>
|
||||
<span class="lineno">63</span> <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">is_log_activations</span><span class="o">=</span><span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">):</span>
|
||||
<span class="lineno">64</span> <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
|
||||
<span class="lineno">65</span>
|
||||
<span class="lineno">66</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
||||
<span class="lineno">67</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
||||
<span class="lineno">68</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">"loss."</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span>
|
||||
<span class="lineno">69</span>
|
||||
<span class="lineno">70</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
|
||||
<span class="lineno">71</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
||||
<span class="lineno">72</span>
|
||||
<span class="lineno">73</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
||||
<span class="lineno">74</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
|
||||
<span class="lineno">75</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">'model'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
||||
<span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
||||
<span class="lineno">77</span>
|
||||
<span class="lineno">78</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">81</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
||||
<span class="lineno">82</span><span class="k">def</span> <span class="nf">model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span>
|
||||
<span class="lineno">83</span> <span class="k">return</span> <span class="n">Net</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="lineno">84</span>
|
||||
<span class="lineno">85</span>
|
||||
<span class="lineno">86</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>
|
||||
<span class="lineno">87</span><span class="k">def</span> <span class="nf">_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span>
|
||||
<span class="lineno">88</span> <span class="kn">from</span> <span class="nn">labml_helpers.optimizer</span> <span class="kn">import</span> <span class="n">OptimizerConfigs</span>
|
||||
<span class="lineno">89</span> <span class="n">opt_conf</span> <span class="o">=</span> <span class="n">OptimizerConfigs</span><span class="p">()</span>
|
||||
<span class="lineno">90</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
|
||||
<span class="lineno">91</span> <span class="k">return</span> <span class="n">opt_conf</span>
|
||||
<span class="lineno">92</span>
|
||||
<span class="lineno">93</span>
|
||||
<span class="lineno">94</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
|
||||
<span class="lineno">95</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span>
|
||||
<span class="lineno">96</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'mnist_labml_helpers'</span><span class="p">)</span>
|
||||
<span class="lineno">97</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span><span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">})</span>
|
||||
<span class="lineno">98</span> <span class="n">conf</span><span class="o">.</span><span class="n">set_seed</span><span class="o">.</span><span class="n">set</span><span class="p">()</span>
|
||||
<span class="lineno">99</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">))</span>
|
||||
<span class="lineno">100</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
|
||||
<span class="lineno">101</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
|
||||
<span class="lineno">102</span>
|
||||
<span class="lineno">103</span>
|
||||
<span class="lineno">104</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">105</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
@ -90,20 +90,6 @@
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/batch_norm.html</loc>
|
||||
<lastmod>2021-02-01T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/mnist.html</loc>
|
||||
<lastmod>2021-02-01T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/experiments/nlp_autoregression.html</loc>
|
||||
<lastmod>2021-01-25T16:30:00+00:00</lastmod>
|
||||
@ -281,7 +267,14 @@
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/transformers/feedback/index.html</loc>
|
||||
<lastmod>2021-01-30T16:30:00+00:00</lastmod>
|
||||
<lastmod>2021-02-01T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/transformers/feedback/README.html</loc>
|
||||
<lastmod>2021-02-01T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
113
labml_nn/experiments/mnist.py
Normal file
113
labml_nn/experiments/mnist.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""
|
||||
---
|
||||
title: MNIST Experiment
|
||||
summary: >
|
||||
This is a reusable trainer for MNIST dataset
|
||||
---
|
||||
|
||||
# MNIST Experiment
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
from labml_helpers.module import Module
|
||||
|
||||
from labml import tracker
|
||||
from labml.configs import option
|
||||
from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.metrics.accuracy import Accuracy
|
||||
from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
|
||||
|
||||
class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
|
||||
"""
|
||||
<a id="MNISTConfigs">
|
||||
## Trainer configurations
|
||||
</a>
|
||||
"""
|
||||
|
||||
# Optimizer
|
||||
optimizer: torch.optim.Adam
|
||||
# Training device
|
||||
device: torch.device = DeviceConfigs()
|
||||
|
||||
# Classification model
|
||||
model: Module
|
||||
# Number of epochs to train for
|
||||
epochs: int = 10
|
||||
|
||||
# Number of times to switch between training and validation within an epoch
|
||||
inner_iterations = 10
|
||||
|
||||
# Accuracy function
|
||||
accuracy = Accuracy()
|
||||
# Loss function
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
def init(self):
|
||||
"""
|
||||
### Initialization
|
||||
"""
|
||||
# Set tracker configurations
|
||||
tracker.set_scalar("loss.*", True)
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
# Add a hook to log module outputs
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
# Add accuracy as a state module.
|
||||
# The name is probably confusing, since it's meant to store
|
||||
# states between training and validation for RNNs.
|
||||
# This will keep the accuracy metric stats separate for training and validation.
|
||||
self.state_modules = [self.accuracy]
|
||||
|
||||
def step(self, batch: any, batch_idx: BatchIndex):
|
||||
"""
|
||||
### Training or validation step
|
||||
"""
|
||||
|
||||
# Move data to the device
|
||||
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||
|
||||
# Update global step (number of samples processed) when in training mode
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
# Whether to capture model outputs
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
# Get model outputs.
|
||||
output = self.model(data)
|
||||
|
||||
# Calculate and log loss
|
||||
loss = self.loss_func(output, target)
|
||||
tracker.add("loss.", loss)
|
||||
|
||||
# Calculate and log accuracy
|
||||
self.accuracy(output, target)
|
||||
self.accuracy.track()
|
||||
|
||||
# Train the model
|
||||
if self.mode.is_train:
|
||||
# Calculate gradients
|
||||
loss.backward()
|
||||
# Take optimizer step
|
||||
self.optimizer.step()
|
||||
# Log the model parameters and gradients on last batch of every epoch
|
||||
if batch_idx.is_last:
|
||||
tracker.add('model', self.model)
|
||||
# Clear the gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Save the tracked metrics
|
||||
tracker.save()
|
||||
|
||||
|
||||
@option(MNISTConfigs.optimizer)
|
||||
def _optimizer(c: MNISTConfigs):
|
||||
"""
|
||||
### Default optimizer configurations
|
||||
"""
|
||||
opt_conf = OptimizerConfigs()
|
||||
opt_conf.parameters = c.model.parameters()
|
||||
opt_conf.optimizer = 'Adam'
|
||||
return opt_conf
|
@ -0,0 +1,11 @@
|
||||
"""
|
||||
---
|
||||
title: Normalization Layers
|
||||
summary: >
|
||||
A set of PyTorch implementations/tutorials of normalization layers.
|
||||
---
|
||||
|
||||
# Normalization Layers
|
||||
|
||||
* [Batch Normalization](batch_norm/index.html)
|
||||
"""
|
@ -1,47 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml_helpers.module import Module
|
||||
|
||||
|
||||
class BatchNorm(Module):
|
||||
def __init__(self, channels: int, *,
|
||||
eps: float = 1e-5, momentum: float = 0.1,
|
||||
affine: bool = True, track_running_stats: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
if self.affine:
|
||||
self.weight = nn.Parameter(torch.ones(channels))
|
||||
self.bias = nn.Parameter(torch.zeros(channels))
|
||||
if self.track_running_stats:
|
||||
self.register_buffer('running_mean', torch.zeros(channels))
|
||||
self.register_buffer('running_var', torch.ones(channels))
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
x_shape = x.shape
|
||||
batch_size = x_shape[0]
|
||||
|
||||
x = x.view(batch_size, self.channels, -1)
|
||||
if self.training or not self.track_running_stats:
|
||||
mean = x.mean(dim=[0, 2])
|
||||
mean_x2 = (x ** 2).mean(dim=[0, 2])
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
if self.training and self.track_running_stats:
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
|
||||
else:
|
||||
mean = self.running_mean
|
||||
var = self.running_var
|
||||
|
||||
x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
|
||||
if self.affine:
|
||||
x_norm = self.weight.view(1, -1, 1) * x_norm + self.bias.view(1, -1, 1)
|
||||
|
||||
return x_norm.view(x_shape)
|
174
labml_nn/normalization/batch_norm/__init__.py
Normal file
174
labml_nn/normalization/batch_norm/__init__.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""
|
||||
---
|
||||
title: Batch Normalization
|
||||
summary: >
|
||||
A PyTorch implementations/tutorials of batch normalization.
|
||||
---
|
||||
|
||||
# Batch Normalization
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of Batch Normalization from paper
|
||||
[Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167).
|
||||
|
||||
### Internal Covariate Shift
|
||||
|
||||
The paper defines *Internal Covariate Shift* as the change in the
|
||||
distribution of network activations due to the change in
|
||||
network parameters during training.
|
||||
For example, let's say there are two layers $l_1$ and $l_2$.
|
||||
During the beginning of the training $l_1$ outputs (inputs to $l_2$)
|
||||
could be in distribution $\mathcal{N}(0.5, 1)$.
|
||||
Then, after some training steps it could move to $\mathcal{N}(0.5, 1)$.
|
||||
This is *internal covariate shift*.
|
||||
|
||||
Internal covriate shift will adversely affect training speed because the later layers
|
||||
($l_2$ in the above example) has to adapt to this shifted distribution.
|
||||
|
||||
By stabilizing the distribution batch normalization minimizes the internal covariate shift.
|
||||
|
||||
## Normalization
|
||||
|
||||
It is known that whitening improves training speed and convergence.
|
||||
*Whitening* is linearly transforming inputs to have zero mean, unit variance
|
||||
and be uncorrelated.
|
||||
|
||||
### Normalizing outside gradient computation doesn't work
|
||||
|
||||
Normalizing outside the gradient computation using pre-computed (detached)
|
||||
means and variances doesn't work. For instance. (ignoring variance), let
|
||||
$$\hat{x} = x - \mathbb{E}[x]$$
|
||||
where $x = u + b$ and $b$ is a trained bias.
|
||||
and $\mathbb{E}[x]$ is outside gradient computation (pre-computed constant).
|
||||
|
||||
Note that $\hat{x}$ has no effect of $b$.
|
||||
Therefore,
|
||||
$b$ will increase or decrease based
|
||||
$\frac{\partial{\mathcal{L}}}{\partial x}$,
|
||||
and keep on growing indefinitely in each training update.
|
||||
Paper notes that similar explosions happen with variances.
|
||||
|
||||
### Batch Normalization
|
||||
|
||||
Whitening is computationally expensive because you need to de-correlate and
|
||||
the gradients must flow through the full whitening calculation.
|
||||
|
||||
The paper introduces simplified version which they call *Batch Normalization*.
|
||||
First simplification is that it normalizes each feature independently to have
|
||||
zero mean and unit variance:
|
||||
$$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}$$
|
||||
where $x = (x^{(1)} ... x^{(d)})$ is the $d$-dimensional input.
|
||||
|
||||
The second simplification is to use estimates of mean $\mathbb{E}[x^{(k)}]$
|
||||
and variance $Var[x^{(k)}]$ from the mini-batch
|
||||
for normalization; instead of calculating the mean and variance across whole dataset.
|
||||
|
||||
Normalizing each feature to zero mean and unit variance could effect what the layer
|
||||
can represent.
|
||||
As an example paper illustrates that, if the inputs to a sigmoid are normalized
|
||||
most of it will be within $[-1, 1]$ range where the sigmoid is linear.
|
||||
To overcome this each feature is scaled and shifted by two trained parameters
|
||||
$\gamma^{(k)}$ and $\beta^{(k)}$.
|
||||
$$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$
|
||||
where $y^{(k)}$ is the output of of the batch normalization layer.
|
||||
|
||||
Note that when applying batch normalization after a linear transform
|
||||
like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization.
|
||||
So you can and should omit bias parameter in linear transforms right before the
|
||||
batch normalization.
|
||||
|
||||
## Inference
|
||||
|
||||
We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to
|
||||
perform the normalization.
|
||||
So during inference, you either need to go through the whole (or part of) dataset
|
||||
and find the mean and variance, or you can use an estimate calculated during training.
|
||||
The usual practice is to calculate an exponential moving average of
|
||||
mean and variance during training phase and use that for inference.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml_helpers.module import Module
|
||||
|
||||
|
||||
class BatchNorm(Module):
|
||||
"""
|
||||
## Batch Normalization Layer
|
||||
"""
|
||||
def __init__(self, channels: int, *,
|
||||
eps: float = 1e-5, momentum: float = 0.1,
|
||||
affine: bool = True, track_running_stats: bool = True):
|
||||
"""
|
||||
* `channels` is the number of features in the input
|
||||
* `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
|
||||
* `momentum` is the momentum in taking the exponential moving average
|
||||
* `affine` is whether to scale and shift the normalized value
|
||||
* `track_running_stats` is whether to calculate the moving averages or mean and variance
|
||||
|
||||
We've tried to use the same names for arguments as PyTorch `BatchNorm` implementation.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
# Create parameters for $\gamma$ and $\beta$ for scale and shift
|
||||
if self.affine:
|
||||
self.scale = nn.Parameter(torch.ones(channels))
|
||||
self.shift = nn.Parameter(torch.zeros(channels))
|
||||
# Create buffers to store exponential moving averages of
|
||||
# mean $\mathbb{E}[x^{(k)}]$ and variance $Var[x^{(k)}]$
|
||||
if self.track_running_stats:
|
||||
self.register_buffer('exp_mean', torch.zeros(channels))
|
||||
self.register_buffer('exp_var', torch.ones(channels))
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
`x` is a tensor of shape `[batch_size, channels, *]`.
|
||||
`*` could be any (even *) dimensions.
|
||||
For example, in an image (2D) convolution this will be
|
||||
`[batch_size, channels, height, width]`
|
||||
"""
|
||||
# Keep the original shape
|
||||
x_shape = x.shape
|
||||
# Get the batch size
|
||||
batch_size = x_shape[0]
|
||||
# Sanity check to make sure the number of features is same
|
||||
assert self.channels == x.shape[1]
|
||||
|
||||
# Reshape into `[batch_size, channels, n]`
|
||||
x = x.view(batch_size, self.channels, -1)
|
||||
|
||||
# We will calculate the mini-batch mean and variance
|
||||
# if we are in training mode or if we have not tracked exponential moving averages
|
||||
if self.training or not self.track_running_stats:
|
||||
# Calculate the mean across first and last dimension;
|
||||
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
|
||||
mean = x.mean(dim=[0, 2])
|
||||
# Calculate the squared mean across first and last dimension;
|
||||
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
|
||||
mean_x2 = (x ** 2).mean(dim=[0, 2])
|
||||
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
# Update exponential moving averages
|
||||
if self.training and self.track_running_stats:
|
||||
self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
|
||||
self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
|
||||
# Use exponential moving averages as estimates
|
||||
else:
|
||||
mean = self.exp_mean
|
||||
var = self.exp_var
|
||||
|
||||
# Normalize $$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}] + \epsilon}}$$
|
||||
x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
|
||||
# Scale and shift $$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$
|
||||
if self.affine:
|
||||
x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
|
||||
|
||||
# Reshape to original and return
|
||||
return x_norm.view(x_shape)
|
82
labml_nn/normalization/batch_norm/mnist.py
Normal file
82
labml_nn/normalization/batch_norm/mnist.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""
|
||||
---
|
||||
title: MNIST Experiment to try Batch Normalization
|
||||
summary: >
|
||||
This is a simple model for MNIST digit classification that uses batch normalization
|
||||
---
|
||||
|
||||
# MNIST Experiment for Batch Normalization
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
|
||||
from labml import experiment
|
||||
from labml.configs import option
|
||||
from labml_helpers.module import Module
|
||||
from labml_nn.experiments.mnist import MNISTConfigs
|
||||
from labml_nn.normalization.batch_norm import BatchNorm
|
||||
|
||||
|
||||
class Model(Module):
|
||||
"""
|
||||
### Model definition
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Note that we omit the bias parameter
|
||||
self.conv1 = nn.Conv2d(1, 20, 5, 1, bias=False)
|
||||
# Batch normalization with 20 channels (output of convolution layer).
|
||||
# The input to this layer will have shape `[batch_size, 20, height(24), width(24)]`
|
||||
self.bn1 = BatchNorm(20)
|
||||
#
|
||||
self.conv2 = nn.Conv2d(20, 50, 5, 1, bias=False)
|
||||
# Batch normalization with 50 channels.
|
||||
# The input to this layer will have shape `[batch_size, 50, height(8), width(8)]`
|
||||
self.bn2 = BatchNorm(50)
|
||||
#
|
||||
self.fc1 = nn.Linear(4 * 4 * 50, 500, bias=False)
|
||||
# Batch normalization with 500 channels (output of fully connected layer).
|
||||
# The input to this layer will have shape `[batch_size, 500]`
|
||||
self.bn3 = BatchNorm(500)
|
||||
#
|
||||
self.fc2 = nn.Linear(500, 10)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = x.view(-1, 4 * 4 * 50)
|
||||
x = F.relu(self.bn3(self.fc1(x)))
|
||||
return self.fc2(x)
|
||||
|
||||
|
||||
@option(MNISTConfigs.model)
|
||||
def model(c: MNISTConfigs):
|
||||
"""
|
||||
### Create model
|
||||
|
||||
We use [`MNISTConfigs`](../../experiments/mnist.html#MNISTConfigs) configurations
|
||||
and set a new function to calculate the model.
|
||||
"""
|
||||
return Model().to(c.device)
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name='mnist_batch_norm')
|
||||
# Create configurations
|
||||
conf = MNISTConfigs()
|
||||
# Load configurations
|
||||
experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
|
||||
# Start the experiment and run the training loop
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,105 +0,0 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
|
||||
from labml import experiment, tracker
|
||||
from labml.configs import option
|
||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.metrics.accuracy import Accuracy
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.seed import SeedConfigs
|
||||
from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
|
||||
from labml_nn.normalization.batch_norm import BatchNorm
|
||||
|
||||
|
||||
class Net(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 20, 5, 1)
|
||||
self.bn1 = BatchNorm(20)
|
||||
self.conv2 = nn.Conv2d(20, 50, 5, 1)
|
||||
self.bn2 = BatchNorm(50)
|
||||
self.fc1 = nn.Linear(4 * 4 * 50, 500)
|
||||
self.bn3 = BatchNorm(500)
|
||||
self.fc2 = nn.Linear(500, 10)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = x.view(-1, 4 * 4 * 50)
|
||||
x = F.relu(self.bn3(self.fc1(x)))
|
||||
return self.fc2(x)
|
||||
|
||||
|
||||
class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
optimizer: torch.optim.Adam
|
||||
model: nn.Module
|
||||
set_seed = SeedConfigs()
|
||||
device: torch.device = DeviceConfigs()
|
||||
epochs: int = 10
|
||||
|
||||
is_save_models = True
|
||||
model: nn.Module
|
||||
inner_iterations = 10
|
||||
|
||||
accuracy_func = Accuracy()
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
def init(self):
|
||||
tracker.set_queue("loss.*", 20, True)
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
self.state_modules = [self.accuracy_func]
|
||||
|
||||
def step(self, batch: any, batch_idx: BatchIndex):
|
||||
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
output = self.model(data)
|
||||
|
||||
loss = self.loss_func(output, target)
|
||||
self.accuracy_func(output, target)
|
||||
tracker.add("loss.", loss)
|
||||
|
||||
if self.mode.is_train:
|
||||
loss.backward()
|
||||
|
||||
self.optimizer.step()
|
||||
if batch_idx.is_last:
|
||||
tracker.add('model', self.model)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
tracker.save()
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
def model(c: Configs):
|
||||
return Net().to(c.device)
|
||||
|
||||
|
||||
@option(Configs.optimizer)
|
||||
def _optimizer(c: Configs):
|
||||
from labml_helpers.optimizer import OptimizerConfigs
|
||||
opt_conf = OptimizerConfigs()
|
||||
opt_conf.parameters = c.model.parameters()
|
||||
return opt_conf
|
||||
|
||||
|
||||
def main():
|
||||
conf = Configs()
|
||||
experiment.create(name='mnist_labml_helpers')
|
||||
experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
|
||||
conf.set_seed.set()
|
||||
experiment.add_pytorch_models(dict(model=conf.model))
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user