mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
Weight standardization (#47)
* 🚧 weight standardization * 🐛 small fixes * 📚🚧 weight standardization * 📚 weight standardization * 📚 weight standardization experiment * 📚 batch channel norm * ✍️ corrections * 📚 experiment links
This commit is contained in:
663
docs/normalization/batch_channel_norm/index.html
Normal file
663
docs/normalization/batch_channel_norm/index.html
Normal file
@ -0,0 +1,663 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="A PyTorch implementation/tutorial of Batch-Channel Normalization."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Batch-Channel Normalization"/>
|
||||
<meta name="twitter:description" content="A PyTorch implementation/tutorial of Batch-Channel Normalization."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/normalization/batch_channel_norm/index.html"/>
|
||||
<meta property="og:title" content="Batch-Channel Normalization"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Batch-Channel Normalization"/>
|
||||
<meta property="og:description" content="A PyTorch implementation/tutorial of Batch-Channel Normalization."/>
|
||||
|
||||
<title>Batch-Channel Normalization</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/normalization/batch_channel_norm/index.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">normalization</a>
|
||||
<a class="parent" href="index.html">batch_channel_norm</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/batch_channel_norm/__init__.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
|
||||
rel="nofollow">
|
||||
<img alt="Join Slact"
|
||||
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Batch-Channel Normalization</h1>
|
||||
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of Batch-Channel Normalization from the paper
|
||||
<a href="https://arxiv.org/abs/1903.10520">Micro-Batch Training with Batch-Channel Normalization and Weight Standardization</a>.
|
||||
We also have an <a href="../weight_standardization/index.html">annotated implementation of Weight Standardization</a>.</p>
|
||||
<p>Batch-Channel Normalization performs batch normalization followed
|
||||
by a channel normalization (similar to a <a href="../group_norm/index.html">Group Normalization</a>.
|
||||
When the batch size is small a running mean and variance is used for
|
||||
batch normalization.</p>
|
||||
<p>Here is <a href="../weight_standardization/experiment.html">the training code</a> for training
|
||||
a VGG network that uses weight standardization to classify CIFAR-10 data.</p>
|
||||
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
|
||||
<a href="https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a>
|
||||
<a href="https://wandb.ai/vpj/cifar10/runs/3flr4k8w"><img alt="WandB" src="https://img.shields.io/badge/wandb-run-yellow" /></a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">27</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||||
<span class="lineno">29</span>
|
||||
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
||||
<span class="lineno">31</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.batch_norm</span> <span class="kn">import</span> <span class="n">BatchNorm</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h2>Batch-Channel Normalization</h2>
|
||||
<p>This first performs a batch normalization - either <a href="../batch_norm/index.html">normal batch norm</a>
|
||||
or a batch norm with
|
||||
estimated mean and variance (exponential mean/variance over multiple batches).
|
||||
Then a channel normalization performed.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">34</span><span class="k">class</span> <span class="nc">BatchChannelNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>channels</code> is the number of features in the input</li>
|
||||
<li><code>groups</code> is the number of groups the features are divided into</li>
|
||||
<li><code>eps</code> is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability</li>
|
||||
<li><code>momentum</code> is the momentum in taking the exponential moving average</li>
|
||||
<li><code>estimate</code> is whether to use running mean and variance for batch norm</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||||
<span class="lineno">45</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">estimate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">53</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>Use estimated batch norm or normal batch norm.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">56</span> <span class="k">if</span> <span class="n">estimate</span><span class="p">:</span>
|
||||
<span class="lineno">57</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_norm</span> <span class="o">=</span> <span class="n">EstimatedBatchNorm</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span>
|
||||
<span class="lineno">58</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="n">momentum</span><span class="p">)</span>
|
||||
<span class="lineno">59</span> <span class="k">else</span><span class="p">:</span>
|
||||
<span class="lineno">60</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_norm</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span>
|
||||
<span class="lineno">61</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="n">momentum</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Channel normalization</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">64</span> <span class="bp">self</span><span class="o">.</span><span class="n">channel_norm</span> <span class="o">=</span> <span class="n">ChannelNorm</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">66</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="lineno">67</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="lineno">68</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">channel_norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<h2>Estimated Batch Normalization</h2>
|
||||
<p>When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
|
||||
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
|
||||
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\dot{X}_{\cdot, C, \cdot, \cdot} = \gamma_C
|
||||
\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}
|
||||
+ \beta_C</script>
|
||||
</p>
|
||||
<p>where,</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\
|
||||
\hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>are the running mean and variances. $r$ is the momentum for calculating the exponential mean.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">71</span><span class="k">class</span> <span class="nc">EstimatedBatchNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>channels</code> is the number of features in the input</li>
|
||||
<li><code>eps</code> is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability</li>
|
||||
<li><code>momentum</code> is the momentum in taking the exponential moving average</li>
|
||||
<li><code>estimate</code> is whether to use running mean and variance for batch norm</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">92</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||||
<span class="lineno">93</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">100</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="lineno">101</span>
|
||||
<span class="lineno">102</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||||
<span class="lineno">103</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
|
||||
<span class="lineno">104</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span>
|
||||
<span class="lineno">105</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Channel wise transformation parameters</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
|
||||
<span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Tensors for $\hat{\mu}_C$ and $\hat{\sigma}^2_C$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">'exp_mean'</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
|
||||
<span class="lineno">114</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">'exp_var'</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p><code>x</code> is a tensor of shape <code>[batch_size, channels, *]</code>.
|
||||
<code>*</code> denotes any number of (possibly 0) dimensions.
|
||||
For example, in an image (2D) convolution this will be
|
||||
<code>[batch_size, channels, height, width]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p>Keep old shape</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Get the batch size</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Sanity check to make sure the number of features is correct</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">129</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Reshape into <code>[batch_size, channels, n]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">132</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Update $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ in training mode only</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">135</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">137</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Calculate the mean across first and last dimensions;
|
||||
$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>Calculate the squared mean across first and last dimensions;
|
||||
$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p>Update exponential moving averages
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\
|
||||
\hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
|
||||
\end{align}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span>
|
||||
<span class="lineno">153</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-23'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>Normalize
|
||||
<script type="math/tex; mode=display">\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">157</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-24'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p>Scale and shift
|
||||
<script type="math/tex; mode=display"> \gamma_C
|
||||
\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}
|
||||
+ \beta_C</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">162</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">163</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-25'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
<p>Reshape to original and return</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">166</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-26'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-26'>#</a>
|
||||
</div>
|
||||
<h2>Channel Normalization</h2>
|
||||
<p>This is similar to <a href="../group_norm/index.html">Group Normalization</a> but affine transform is done group wise.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">169</span><span class="k">class</span> <span class="nc">ChannelNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-27'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-27'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>groups</code> is the number of groups the features are divided into</li>
|
||||
<li><code>channels</code> is the number of features in the input</li>
|
||||
<li><code>eps</code> is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability</li>
|
||||
<li><code>affine</code> is whether to scale and shift the normalized value</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span>
|
||||
<span class="lineno">177</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-28'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-28'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">184</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
|
||||
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||||
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||||
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-29'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-29'>#</a>
|
||||
</div>
|
||||
<p>Parameters for affine transformation.</p>
|
||||
<p><em>Note that these transforms are per group, unlike in group norm where
|
||||
they are transformed channel-wise.</em></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">193</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">194</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span>
|
||||
<span class="lineno">195</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-30'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-30'>#</a>
|
||||
</div>
|
||||
<p><code>x</code> is a tensor of shape <code>[batch_size, channels, *]</code>.
|
||||
<code>*</code> denotes any number of (possibly 0) dimensions.
|
||||
For example, in an image (2D) convolution this will be
|
||||
<code>[batch_size, channels, height, width]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-31'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-31'>#</a>
|
||||
</div>
|
||||
<p>Keep the original shape</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-32'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-32'>#</a>
|
||||
</div>
|
||||
<p>Get the batch size</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-33'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-33'>#</a>
|
||||
</div>
|
||||
<p>Sanity check to make sure the number of features is the same</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">210</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-34'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-34'>#</a>
|
||||
</div>
|
||||
<p>Reshape into <code>[batch_size, groups, n]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">213</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-35'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-35'>#</a>
|
||||
</div>
|
||||
<p>Calculate the mean across last dimension;
|
||||
i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-36'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-36'>#</a>
|
||||
</div>
|
||||
<p>Calculate the squared mean across last dimension;
|
||||
i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">220</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-37'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-37'>#</a>
|
||||
</div>
|
||||
<p>Variance for each sample and feature group
|
||||
$Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">223</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-38'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-38'>#</a>
|
||||
</div>
|
||||
<p>Normalize
|
||||
<script type="math/tex; mode=display">\hat{x}_{(i_N, i_G)} =
|
||||
\frac{x_{(i_N, i_G)} - \mathbb{E}[x_{(i_N, i_G)}]}{\sqrt{Var[x_{(i_N, i_G)}] + \epsilon}}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-39'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-39'>#</a>
|
||||
</div>
|
||||
<p>Scale and shift group-wise
|
||||
<script type="math/tex; mode=display">y_{i_G} =\gamma_{i_G} \hat{x}_{i_G} + \beta_{i_G}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">232</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">233</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-40'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-40'>#</a>
|
||||
</div>
|
||||
<p>Reshape to original and return</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">236</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
@ -137,10 +137,10 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
|
||||
<a href="https://wandb.ai/vpj/cifar10/runs/310etthp"><img alt="WandB" src="https://img.shields.io/badge/wandb-run-yellow" /></a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">87</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">88</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||||
<span class="lineno">89</span>
|
||||
<span class="lineno">90</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">86</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">87</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||||
<span class="lineno">88</span>
|
||||
<span class="lineno">89</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
@ -151,7 +151,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
|
||||
<h2>Group Normalization Layer</h2>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">93</span><span class="k">class</span> <span class="nc">GroupNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">92</span><span class="k">class</span> <span class="nc">GroupNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
@ -167,8 +167,8 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">98</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||||
<span class="lineno">99</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">97</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||||
<span class="lineno">98</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
@ -179,14 +179,14 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">106</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="lineno">107</span>
|
||||
<span class="lineno">108</span> <span class="k">assert</span> <span class="n">channels</span> <span class="o">%</span> <span class="n">groups</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"Number of channels should be evenly divisible by the number of groups"</span>
|
||||
<span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||||
<span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
|
||||
<span class="lineno">111</span>
|
||||
<span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||||
<span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">105</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="lineno">106</span>
|
||||
<span class="lineno">107</span> <span class="k">assert</span> <span class="n">channels</span> <span class="o">%</span> <span class="n">groups</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"Number of channels should be evenly divisible by the number of groups"</span>
|
||||
<span class="lineno">108</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||||
<span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
|
||||
<span class="lineno">110</span>
|
||||
<span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||||
<span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
@ -197,9 +197,9 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
|
||||
<p>Create parameters for $\gamma$ and $\beta$ for scale and shift</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">115</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
|
||||
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">114</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">115</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
|
||||
<span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
@ -213,7 +213,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
|
||||
<code>[batch_size, channels, height, width]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">118</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
@ -224,7 +224,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
|
||||
<p>Keep the original shape</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">127</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
@ -235,7 +235,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
|
||||
<p>Get the batch size</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
@ -246,7 +246,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
|
||||
<p>Sanity check to make sure the number of features is the same</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">131</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
@ -257,7 +257,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
|
||||
<p>Reshape into <code>[batch_size, groups, n]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
@ -269,7 +269,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
|
||||
i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
@ -281,7 +281,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p
|
||||
i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
@ -293,7 +293,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$<
|
||||
$Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
@ -307,7 +307,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
@ -320,9 +320,9 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">153</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">154</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="lineno">155</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">152</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
|
||||
<span class="lineno">153</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="lineno">154</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
@ -333,7 +333,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
|
||||
<p>Reshape to original and return</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">158</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">157</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
@ -344,7 +344,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
|
||||
<p>Simple test</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">161</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">160</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
@ -355,14 +355,14 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">165</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span>
|
||||
<span class="lineno">166</span>
|
||||
<span class="lineno">167</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span>
|
||||
<span class="lineno">168</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||||
<span class="lineno">169</span> <span class="n">bn</span> <span class="o">=</span> <span class="n">GroupNorm</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
|
||||
<span class="lineno">170</span>
|
||||
<span class="lineno">171</span> <span class="n">x</span> <span class="o">=</span> <span class="n">bn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="lineno">172</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">164</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span>
|
||||
<span class="lineno">165</span>
|
||||
<span class="lineno">166</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span>
|
||||
<span class="lineno">167</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||||
<span class="lineno">168</span> <span class="n">bn</span> <span class="o">=</span> <span class="n">GroupNorm</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
|
||||
<span class="lineno">169</span>
|
||||
<span class="lineno">170</span> <span class="n">x</span> <span class="o">=</span> <span class="n">bn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="lineno">171</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
@ -373,8 +373,8 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">176</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">177</span> <span class="n">_test</span><span class="p">()</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">175</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">176</span> <span class="n">_test</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
200
docs/normalization/weight_standardization/conv2d.html
Normal file
200
docs/normalization/weight_standardization/conv2d.html
Normal file
@ -0,0 +1,200 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="A PyTorch implementation/tutorial of a 2D Convolution Layer with Weight Standardization."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="2D Convolution Layer with Weight Standardization"/>
|
||||
<meta name="twitter:description" content="A PyTorch implementation/tutorial of a 2D Convolution Layer with Weight Standardization."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/normalization/weight_standardization/conv2d.html"/>
|
||||
<meta property="og:title" content="2D Convolution Layer with Weight Standardization"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="2D Convolution Layer with Weight Standardization"/>
|
||||
<meta property="og:description" content="A PyTorch implementation/tutorial of a 2D Convolution Layer with Weight Standardization."/>
|
||||
|
||||
<title>2D Convolution Layer with Weight Standardization</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/normalization/weight_standardization/conv2d.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">normalization</a>
|
||||
<a class="parent" href="index.html">weight_standardization</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/weight_standardization/conv2d.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
|
||||
rel="nofollow">
|
||||
<img alt="Join Slact"
|
||||
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>2D Convolution Layer with Weight Standardization</h1>
|
||||
<p>This is an implementation of a 2 dimensional convolution layer with <a href="./index.html">Weight Standardization</a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">14</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
|
||||
<span class="lineno">16</span>
|
||||
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.weight_standardization</span> <span class="kn">import</span> <span class="n">weight_standardization</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h2>2D Convolution Layer</h2>
|
||||
<p>This extends the standard 2D Convolution layer and standardize the weights before the convolution step.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">20</span><span class="k">class</span> <span class="nc">Conv2d</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">26</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span>
|
||||
<span class="lineno">27</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="lineno">28</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||||
<span class="lineno">29</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="lineno">30</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||||
<span class="lineno">31</span> <span class="n">bias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||||
<span class="lineno">32</span> <span class="n">padding_mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'zeros'</span><span class="p">,</span>
|
||||
<span class="lineno">33</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">):</span>
|
||||
<span class="lineno">34</span> <span class="nb">super</span><span class="p">(</span><span class="n">Conv2d</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span>
|
||||
<span class="lineno">35</span> <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">,</span>
|
||||
<span class="lineno">36</span> <span class="n">padding</span><span class="o">=</span><span class="n">padding</span><span class="p">,</span>
|
||||
<span class="lineno">37</span> <span class="n">dilation</span><span class="o">=</span><span class="n">dilation</span><span class="p">,</span>
|
||||
<span class="lineno">38</span> <span class="n">groups</span><span class="o">=</span><span class="n">groups</span><span class="p">,</span>
|
||||
<span class="lineno">39</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span>
|
||||
<span class="lineno">40</span> <span class="n">padding_mode</span><span class="o">=</span><span class="n">padding_mode</span><span class="p">)</span>
|
||||
<span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">43</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||||
<span class="lineno">44</span> <span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">conv2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight_standardization</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span>
|
||||
<span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>A simple test to verify the tensor sizes</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">48</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">52</span> <span class="n">conv2d</span> <span class="o">=</span> <span class="n">Conv2d</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
|
||||
<span class="lineno">53</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span>
|
||||
<span class="lineno">54</span> <span class="n">inspect</span><span class="p">(</span><span class="n">conv2d</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span>
|
||||
<span class="lineno">55</span> <span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">56</span> <span class="n">inspect</span><span class="p">(</span><span class="n">conv2d</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">100</span><span class="p">)))</span>
|
||||
<span class="lineno">57</span>
|
||||
<span class="lineno">58</span>
|
||||
<span class="lineno">59</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">60</span> <span class="n">_test</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
267
docs/normalization/weight_standardization/experiment.html
Normal file
267
docs/normalization/weight_standardization/experiment.html
Normal file
@ -0,0 +1,267 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="This trains is a VGG net that uses weight standardization and batch-channel normalization to classify CIFAR10 images."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization"/>
|
||||
<meta name="twitter:description" content="This trains is a VGG net that uses weight standardization and batch-channel normalization to classify CIFAR10 images."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/normalization/weight_standardization/experiment.html"/>
|
||||
<meta property="og:title" content="CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization"/>
|
||||
<meta property="og:description" content="This trains is a VGG net that uses weight standardization and batch-channel normalization to classify CIFAR10 images."/>
|
||||
|
||||
<title>CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/normalization/weight_standardization/experiment.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">normalization</a>
|
||||
<a class="parent" href="index.html">weight_standardization</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/weight_standardization/experiment.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
|
||||
rel="nofollow">
|
||||
<img alt="Join Slact"
|
||||
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization</h1>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">12</span><span></span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||
<span class="lineno">13</span>
|
||||
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
|
||||
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
|
||||
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
||||
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.cifar10</span> <span class="kn">import</span> <span class="n">CIFAR10Configs</span>
|
||||
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.batch_channel_norm</span> <span class="kn">import</span> <span class="n">BatchChannelNorm</span>
|
||||
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.weight_standardization.conv2d</span> <span class="kn">import</span> <span class="n">Conv2d</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h3>Model</h3>
|
||||
<p>A VGG model that use <a href="./index.html">Weight Standardization</a> and
|
||||
<a href="../batch_channel_norm/index.html">Batch-Channel Normalization</a>.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">22</span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">29</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="lineno">30</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="lineno">31</span> <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="lineno">32</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="mi">3</span>
|
||||
<span class="lineno">33</span> <span class="k">for</span> <span class="n">block</span> <span class="ow">in</span> <span class="p">[[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">],</span> <span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span> <span class="p">[</span><span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">],</span> <span class="p">[</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">],</span> <span class="p">[</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">]]:</span>
|
||||
<span class="lineno">34</span> <span class="k">for</span> <span class="n">channels</span> <span class="ow">in</span> <span class="n">block</span><span class="p">:</span>
|
||||
<span class="lineno">35</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
|
||||
<span class="lineno">36</span> <span class="n">BatchChannelNorm</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="mi">32</span><span class="p">),</span>
|
||||
<span class="lineno">37</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)]</span>
|
||||
<span class="lineno">38</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">channels</span>
|
||||
<span class="lineno">39</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="n">kernel_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">)]</span>
|
||||
<span class="lineno">40</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">AvgPool2d</span><span class="p">(</span><span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
|
||||
<span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span>
|
||||
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="lineno">45</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="lineno">46</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="lineno">47</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<h3>Create model</h3>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">50</span><span class="nd">@option</span><span class="p">(</span><span class="n">CIFAR10Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
||||
<span class="lineno">51</span><span class="k">def</span> <span class="nf">model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">CIFAR10Configs</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">55</span> <span class="k">return</span> <span class="n">Model</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">58</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>Create experiment</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'cifar10'</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s1">'weight standardization'</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>Create configurations</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">CIFAR10Configs</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>Load configurations</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">64</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
|
||||
<span class="lineno">65</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">,</span>
|
||||
<span class="lineno">66</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
|
||||
<span class="lineno">67</span> <span class="s1">'train_batch_size'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span>
|
||||
<span class="lineno">68</span> <span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Start the experiment and run the training loop</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">70</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
|
||||
<span class="lineno">71</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">75</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">76</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
231
docs/normalization/weight_standardization/index.html
Normal file
231
docs/normalization/weight_standardization/index.html
Normal file
@ -0,0 +1,231 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="A PyTorch implementation/tutorial of Weight Standardization."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Weight Standardization"/>
|
||||
<meta name="twitter:description" content="A PyTorch implementation/tutorial of Weight Standardization."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/normalization/weight_standardization/index.html"/>
|
||||
<meta property="og:title" content="Weight Standardization"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Weight Standardization"/>
|
||||
<meta property="og:description" content="A PyTorch implementation/tutorial of Weight Standardization."/>
|
||||
|
||||
<title>Weight Standardization</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/normalization/weight_standardization/index.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">normalization</a>
|
||||
<a class="parent" href="index.html">weight_standardization</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/weight_standardization/__init__.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
|
||||
rel="nofollow">
|
||||
<img alt="Join Slact"
|
||||
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Weight Standardization</h1>
|
||||
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of Weight Standardization from the paper
|
||||
<a href="https://arxiv.org/abs/1903.10520">Micro-Batch Training with Batch-Channel Normalization and Weight Standardization</a>.
|
||||
We also have an <a href="../batch_channel_norm/index.html">annotated implementation of Batch-Channel Normalization</a>.</p>
|
||||
<p>Batch normalization <strong>gives a smooth loss landscape</strong> and
|
||||
<strong>avoids elimination singularities</strong>.
|
||||
Elimination singularities are nodes of the network that become
|
||||
useless (e.g. a ReLU that gives 0 all the time).</p>
|
||||
<p>However, batch normalization doesn’t work well when the batch size is too small,
|
||||
which happens when training large networks because of device memory limitations.
|
||||
The paper introduces Weight Standardization with Batch-Channel Normalization as
|
||||
a better alternative.</p>
|
||||
<p>Weight Standardization:
|
||||
1. Normalizes the gradients
|
||||
2. Smoothes the landscape (reduced Lipschitz constant)
|
||||
3. Avoids elimination singularities</p>
|
||||
<p>The Lipschitz constant is the maximum slope a function has between two points.
|
||||
That is, $L$ is the Lipschitz constant where $L$ is the smallest value that satisfies,
|
||||
$\forall a,b \in A: \lVert f(a) - f(b) \rVert \le L \lVert a - b \rVert$
|
||||
where $f: A \rightarrow \mathbb{R}^m, A \in \mathbb{R}^n$.</p>
|
||||
<p>Elimination singularities are avoided because it keeps the statistics of the outputs similar to the
|
||||
inputs. So as long as the inputs are normally distributed the outputs remain close to normal.
|
||||
This avoids outputs of nodes from always falling beyond the active range of the activation function
|
||||
(e.g. always negative input for a ReLU).</p>
|
||||
<p><em><a href="https://arxiv.org/abs/1903.10520">Refer to the paper for proofs</a></em>.</p>
|
||||
<p>Here is <a href="experiment.html">the training code</a> for training
|
||||
a VGG network that uses weight standardization to classify CIFAR-10 data.
|
||||
This uses a <a href="../conv2d.html">2D-Convolution Layer with Weight Standardization</a>.</p>
|
||||
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
|
||||
<a href="https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a>
|
||||
<a href="https://wandb.ai/vpj/cifar10/runs/3flr4k8w"><img alt="WandB" src="https://img.shields.io/badge/wandb-run-yellow" /></a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">50</span><span></span><span class="kn">import</span> <span class="nn">torch</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h2>Weight Standardization</h2>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}</script>
|
||||
</p>
|
||||
<p>where,</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
W &\in \mathbb{R}^{O \times I} \\
|
||||
\mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
|
||||
\sigma_{W_{i,\cdot}} &= \sqrt{\frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}} + \epsilon} \\
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>for a 2D-convolution layer $O$ is the number of output channels ($O = C_{out}$)
|
||||
and $I$ is the number of input channels times the kernel size ($I = C_{in} \times k_H \times k_W$)</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">53</span><span class="k">def</span> <span class="nf">weight_standardization</span><span class="p">(</span><span class="n">weight</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<p>Get $C_{out}$, $C_{in}$ and kernel shape</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">c_out</span><span class="p">,</span> <span class="n">c_in</span><span class="p">,</span> <span class="o">*</span><span class="n">kernel_shape</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<p>Reshape $W$ to $O \times I$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">74</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">c_out</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>Calculate</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
|
||||
\sigma^2_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}}
|
||||
\end{align}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">81</span> <span class="n">var</span><span class="p">,</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">var_mean</span><span class="p">(</span><span class="n">weight</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Normalize
|
||||
<script type="math/tex; mode=display">\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">weight</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Change back to original shape and return</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">86</span> <span class="k">return</span> <span class="n">weight</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">c_out</span><span class="p">,</span> <span class="n">c_in</span><span class="p">,</span> <span class="o">*</span><span class="n">kernel_shape</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
@ -160,6 +160,27 @@
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/weight_standardization/index.html</loc>
|
||||
<lastmod>2021-04-27T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/weight_standardization/experiment.html</loc>
|
||||
<lastmod>2021-04-27T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/weight_standardization/conv2d.html</loc>
|
||||
<lastmod>2021-04-26T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/instance_norm/index.html</loc>
|
||||
<lastmod>2021-04-23T16:30:00+00:00</lastmod>
|
||||
@ -202,6 +223,13 @@
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/batch_channel_norm/index.html</loc>
|
||||
<lastmod>2021-04-27T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/group_norm/experiment.html</loc>
|
||||
<lastmod>2021-04-24T16:30:00+00:00</lastmod>
|
||||
|
236
labml_nn/normalization/batch_channel_norm/__init__.py
Normal file
236
labml_nn/normalization/batch_channel_norm/__init__.py
Normal file
@ -0,0 +1,236 @@
|
||||
"""
|
||||
---
|
||||
title: Batch-Channel Normalization
|
||||
summary: >
|
||||
A PyTorch implementation/tutorial of Batch-Channel Normalization.
|
||||
---
|
||||
|
||||
# Batch-Channel Normalization
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of Batch-Channel Normalization from the paper
|
||||
[Micro-Batch Training with Batch-Channel Normalization and Weight Standardization](https://arxiv.org/abs/1903.10520).
|
||||
We also have an [annotated implementation of Weight Standardization](../weight_standardization/index.html).
|
||||
|
||||
Batch-Channel Normalization performs batch normalization followed
|
||||
by a channel normalization (similar to a [Group Normalization](../group_norm/index.html).
|
||||
When the batch size is small a running mean and variance is used for
|
||||
batch normalization.
|
||||
|
||||
Here is [the training code](../weight_standardization/experiment.html) for training
|
||||
a VGG network that uses weight standardization to classify CIFAR-10 data.
|
||||
|
||||
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb)
|
||||
[](https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002)
|
||||
[](https://wandb.ai/vpj/cifar10/runs/3flr4k8w)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml_helpers.module import Module
|
||||
from labml_nn.normalization.batch_norm import BatchNorm
|
||||
|
||||
|
||||
class BatchChannelNorm(Module):
|
||||
"""
|
||||
## Batch-Channel Normalization
|
||||
|
||||
This first performs a batch normalization - either [normal batch norm](../batch_norm/index.html)
|
||||
or a batch norm with
|
||||
estimated mean and variance (exponential mean/variance over multiple batches).
|
||||
Then a channel normalization performed.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, groups: int,
|
||||
eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
|
||||
"""
|
||||
* `channels` is the number of features in the input
|
||||
* `groups` is the number of groups the features are divided into
|
||||
* `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
|
||||
* `momentum` is the momentum in taking the exponential moving average
|
||||
* `estimate` is whether to use running mean and variance for batch norm
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Use estimated batch norm or normal batch norm.
|
||||
if estimate:
|
||||
self.batch_norm = EstimatedBatchNorm(channels,
|
||||
eps=eps, momentum=momentum)
|
||||
else:
|
||||
self.batch_norm = BatchNorm(channels,
|
||||
eps=eps, momentum=momentum)
|
||||
|
||||
# Channel normalization
|
||||
self.channel_norm = ChannelNorm(channels, groups, eps)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.batch_norm(x)
|
||||
return self.channel_norm(x)
|
||||
|
||||
|
||||
class EstimatedBatchNorm(Module):
|
||||
"""
|
||||
## Estimated Batch Normalization
|
||||
|
||||
When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
|
||||
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
|
||||
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
|
||||
|
||||
$$\dot{X}_{\cdot, C, \cdot, \cdot} = \gamma_C
|
||||
\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}
|
||||
+ \beta_C$$
|
||||
|
||||
where,
|
||||
|
||||
\begin{align}
|
||||
\hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\
|
||||
\hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
|
||||
\end{align}
|
||||
|
||||
are the running mean and variances. $r$ is the momentum for calculating the exponential mean.
|
||||
"""
|
||||
def __init__(self, channels: int,
|
||||
eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
|
||||
"""
|
||||
* `channels` is the number of features in the input
|
||||
* `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
|
||||
* `momentum` is the momentum in taking the exponential moving average
|
||||
* `estimate` is whether to use running mean and variance for batch norm
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.affine = affine
|
||||
self.channels = channels
|
||||
|
||||
# Channel wise transformation parameters
|
||||
if self.affine:
|
||||
self.scale = nn.Parameter(torch.ones(channels))
|
||||
self.shift = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
# Tensors for $\hat{\mu}_C$ and $\hat{\sigma}^2_C$
|
||||
self.register_buffer('exp_mean', torch.zeros(channels))
|
||||
self.register_buffer('exp_var', torch.ones(channels))
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
`x` is a tensor of shape `[batch_size, channels, *]`.
|
||||
`*` denotes any number of (possibly 0) dimensions.
|
||||
For example, in an image (2D) convolution this will be
|
||||
`[batch_size, channels, height, width]`
|
||||
"""
|
||||
# Keep old shape
|
||||
x_shape = x.shape
|
||||
# Get the batch size
|
||||
batch_size = x_shape[0]
|
||||
|
||||
# Sanity check to make sure the number of features is correct
|
||||
assert self.channels == x.shape[1]
|
||||
|
||||
# Reshape into `[batch_size, channels, n]`
|
||||
x = x.view(batch_size, self.channels, -1)
|
||||
|
||||
# Update $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ in training mode only
|
||||
if self.training:
|
||||
# No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$
|
||||
with torch.no_grad():
|
||||
# Calculate the mean across first and last dimensions;
|
||||
# $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$
|
||||
mean = x.mean(dim=[0, 2])
|
||||
# Calculate the squared mean across first and last dimensions;
|
||||
# $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$
|
||||
mean_x2 = (x ** 2).mean(dim=[0, 2])
|
||||
# Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
# Update exponential moving averages
|
||||
# \begin{align}
|
||||
# \hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\
|
||||
# \hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
|
||||
# \end{align}
|
||||
self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
|
||||
self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
|
||||
|
||||
# Normalize
|
||||
# $$\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}$$
|
||||
x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)
|
||||
# Scale and shift
|
||||
# $$ \gamma_C
|
||||
# \frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}
|
||||
# + \beta_C$$
|
||||
if self.affine:
|
||||
x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
|
||||
|
||||
# Reshape to original and return
|
||||
return x_norm.view(x_shape)
|
||||
|
||||
|
||||
class ChannelNorm(Module):
|
||||
"""
|
||||
## Channel Normalization
|
||||
|
||||
This is similar to [Group Normalization](../group_norm/index.html) but affine transform is done group wise.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, groups,
|
||||
eps: float = 1e-5, affine: bool = True):
|
||||
"""
|
||||
* `groups` is the number of groups the features are divided into
|
||||
* `channels` is the number of features in the input
|
||||
* `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
|
||||
* `affine` is whether to scale and shift the normalized value
|
||||
"""
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.groups = groups
|
||||
self.eps = eps
|
||||
self.affine = affine
|
||||
# Parameters for affine transformation.
|
||||
#
|
||||
# *Note that these transforms are per group, unlike in group norm where
|
||||
# they are transformed channel-wise.*
|
||||
if self.affine:
|
||||
self.scale = nn.Parameter(torch.ones(groups))
|
||||
self.shift = nn.Parameter(torch.zeros(groups))
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
`x` is a tensor of shape `[batch_size, channels, *]`.
|
||||
`*` denotes any number of (possibly 0) dimensions.
|
||||
For example, in an image (2D) convolution this will be
|
||||
`[batch_size, channels, height, width]`
|
||||
"""
|
||||
|
||||
# Keep the original shape
|
||||
x_shape = x.shape
|
||||
# Get the batch size
|
||||
batch_size = x_shape[0]
|
||||
# Sanity check to make sure the number of features is the same
|
||||
assert self.channels == x.shape[1]
|
||||
|
||||
# Reshape into `[batch_size, groups, n]`
|
||||
x = x.view(batch_size, self.groups, -1)
|
||||
|
||||
# Calculate the mean across last dimension;
|
||||
# i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$
|
||||
mean = x.mean(dim=[-1], keepdim=True)
|
||||
# Calculate the squared mean across last dimension;
|
||||
# i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$
|
||||
mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
|
||||
# Variance for each sample and feature group
|
||||
# $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
# Normalize
|
||||
# $$\hat{x}_{(i_N, i_G)} =
|
||||
# \frac{x_{(i_N, i_G)} - \mathbb{E}[x_{(i_N, i_G)}]}{\sqrt{Var[x_{(i_N, i_G)}] + \epsilon}}$$
|
||||
x_norm = (x - mean) / torch.sqrt(var + self.eps)
|
||||
|
||||
# Scale and shift group-wise
|
||||
# $$y_{i_G} =\gamma_{i_G} \hat{x}_{i_G} + \beta_{i_G}$$
|
||||
if self.affine:
|
||||
x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
|
||||
|
||||
# Reshape to original and return
|
||||
return x_norm.view(x_shape)
|
@ -81,7 +81,6 @@ Here's a [CIFAR 10 classification model](experiment.html) that uses instance nor
|
||||
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/group_norm/experiment.ipynb)
|
||||
[](https://app.labml.ai/run/081d950aa4e011eb8f9f0242ac1c0002)
|
||||
[](https://wandb.ai/vpj/cifar10/runs/310etthp)
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
86
labml_nn/normalization/weight_standardization/__init__.py
Normal file
86
labml_nn/normalization/weight_standardization/__init__.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""
|
||||
---
|
||||
title: Weight Standardization
|
||||
summary: >
|
||||
A PyTorch implementation/tutorial of Weight Standardization.
|
||||
---
|
||||
|
||||
# Weight Standardization
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of Weight Standardization from the paper
|
||||
[Micro-Batch Training with Batch-Channel Normalization and Weight Standardization](https://arxiv.org/abs/1903.10520).
|
||||
We also have an [annotated implementation of Batch-Channel Normalization](../batch_channel_norm/index.html).
|
||||
|
||||
Batch normalization **gives a smooth loss landscape** and
|
||||
**avoids elimination singularities**.
|
||||
Elimination singularities are nodes of the network that become
|
||||
useless (e.g. a ReLU that gives 0 all the time).
|
||||
|
||||
However, batch normalization doesn't work well when the batch size is too small,
|
||||
which happens when training large networks because of device memory limitations.
|
||||
The paper introduces Weight Standardization with Batch-Channel Normalization as
|
||||
a better alternative.
|
||||
|
||||
Weight Standardization:
|
||||
1. Normalizes the gradients
|
||||
2. Smoothes the landscape (reduced Lipschitz constant)
|
||||
3. Avoids elimination singularities
|
||||
|
||||
The Lipschitz constant is the maximum slope a function has between two points.
|
||||
That is, $L$ is the Lipschitz constant where $L$ is the smallest value that satisfies,
|
||||
$\forall a,b \in A: \lVert f(a) - f(b) \rVert \le L \lVert a - b \rVert$
|
||||
where $f: A \rightarrow \mathbb{R}^m, A \in \mathbb{R}^n$.
|
||||
|
||||
Elimination singularities are avoided because it keeps the statistics of the outputs similar to the
|
||||
inputs. So as long as the inputs are normally distributed the outputs remain close to normal.
|
||||
This avoids outputs of nodes from always falling beyond the active range of the activation function
|
||||
(e.g. always negative input for a ReLU).
|
||||
|
||||
*[Refer to the paper for proofs](https://arxiv.org/abs/1903.10520)*.
|
||||
|
||||
Here is [the training code](experiment.html) for training
|
||||
a VGG network that uses weight standardization to classify CIFAR-10 data.
|
||||
This uses a [2D-Convolution Layer with Weight Standardization](../conv2d.html).
|
||||
|
||||
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb)
|
||||
[](https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002)
|
||||
[](https://wandb.ai/vpj/cifar10/runs/3flr4k8w)
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def weight_standardization(weight: torch.Tensor, eps: float):
|
||||
r"""
|
||||
## Weight Standardization
|
||||
|
||||
$$\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}$$
|
||||
|
||||
where,
|
||||
|
||||
\begin{align}
|
||||
W &\in \mathbb{R}^{O \times I} \\
|
||||
\mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
|
||||
\sigma_{W_{i,\cdot}} &= \sqrt{\frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}} + \epsilon} \\
|
||||
\end{align}
|
||||
|
||||
for a 2D-convolution layer $O$ is the number of output channels ($O = C_{out}$)
|
||||
and $I$ is the number of input channels times the kernel size ($I = C_{in} \times k_H \times k_W$)
|
||||
"""
|
||||
|
||||
# Get $C_{out}$, $C_{in}$ and kernel shape
|
||||
c_out, c_in, *kernel_shape = weight.shape
|
||||
# Reshape $W$ to $O \times I$
|
||||
weight = weight.view(c_out, -1)
|
||||
# Calculate
|
||||
#
|
||||
# \begin{align}
|
||||
# \mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
|
||||
# \sigma^2_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}}
|
||||
# \end{align}
|
||||
var, mean = torch.var_mean(weight, dim=1, keepdim=True)
|
||||
# Normalize
|
||||
# $$\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}$$
|
||||
weight = (weight - mean) / (torch.sqrt(var + eps))
|
||||
# Change back to original shape and return
|
||||
return weight.view(c_out, c_in, *kernel_shape)
|
60
labml_nn/normalization/weight_standardization/conv2d.py
Normal file
60
labml_nn/normalization/weight_standardization/conv2d.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""
|
||||
---
|
||||
title: 2D Convolution Layer with Weight Standardization
|
||||
summary: >
|
||||
A PyTorch implementation/tutorial of a 2D Convolution Layer with Weight Standardization.
|
||||
---
|
||||
|
||||
# 2D Convolution Layer with Weight Standardization
|
||||
|
||||
This is an implementation of a 2 dimensional convolution layer with [Weight Standardization](./index.html)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from labml_nn.normalization.weight_standardization import weight_standardization
|
||||
|
||||
|
||||
class Conv2d(nn.Conv2d):
|
||||
"""
|
||||
## 2D Convolution Layer
|
||||
|
||||
This extends the standard 2D Convolution layer and standardize the weights before the convolution step.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
eps: float = 1e-5):
|
||||
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return F.conv2d(x, weight_standardization(self.weight, self.eps), self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
def _test():
|
||||
"""
|
||||
A simple test to verify the tensor sizes
|
||||
"""
|
||||
conv2d = Conv2d(10, 20, 5)
|
||||
from labml.logger import inspect
|
||||
inspect(conv2d.weight)
|
||||
import torch
|
||||
inspect(conv2d(torch.zeros(10, 10, 100, 100)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test()
|
76
labml_nn/normalization/weight_standardization/experiment.py
Normal file
76
labml_nn/normalization/weight_standardization/experiment.py
Normal file
@ -0,0 +1,76 @@
|
||||
"""
|
||||
---
|
||||
title: CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization
|
||||
summary: >
|
||||
This trains is a VGG net that uses weight standardization and batch-channel normalization
|
||||
to classify CIFAR10 images.
|
||||
---
|
||||
|
||||
# CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from labml import experiment
|
||||
from labml.configs import option
|
||||
from labml_helpers.module import Module
|
||||
from labml_nn.experiments.cifar10 import CIFAR10Configs
|
||||
from labml_nn.normalization.batch_channel_norm import BatchChannelNorm
|
||||
from labml_nn.normalization.weight_standardization.conv2d import Conv2d
|
||||
|
||||
|
||||
class Model(Module):
|
||||
"""
|
||||
### Model
|
||||
|
||||
A VGG model that use [Weight Standardization](./index.html) and
|
||||
[Batch-Channel Normalization](../batch_channel_norm/index.html).
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
layers = []
|
||||
in_channels = 3
|
||||
for block in [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]]:
|
||||
for channels in block:
|
||||
layers += [Conv2d(in_channels, channels, kernel_size=3, padding=1),
|
||||
BatchChannelNorm(channels, 32),
|
||||
nn.ReLU(inplace=True)]
|
||||
in_channels = channels
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.fc = nn.Linear(512, 10)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.layers(x)
|
||||
x = x.view(x.shape[0], -1)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
@option(CIFAR10Configs.model)
|
||||
def model(c: CIFAR10Configs):
|
||||
"""
|
||||
### Create model
|
||||
"""
|
||||
return Model().to(c.device)
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name='cifar10', comment='weight standardization')
|
||||
# Create configurations
|
||||
conf = CIFAR10Configs()
|
||||
# Load configurations
|
||||
experiment.configs(conf, {
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 2.5e-4,
|
||||
'train_batch_size': 64,
|
||||
})
|
||||
# Start the experiment and run the training loop
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
2
setup.py
2
setup.py
@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
|
||||
|
||||
setuptools.setup(
|
||||
name='labml-nn',
|
||||
version='0.4.96',
|
||||
version='0.4.97',
|
||||
author="Varuna Jayasiri, Nipun Wijerathne",
|
||||
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
||||
description="A collection of PyTorch implementations of neural network architectures and layers.",
|
||||
|
Reference in New Issue
Block a user