Weight standardization (#47)

* 🚧 weight standardization

* 🐛 small fixes

* 📚🚧 weight standardization

* 📚 weight standardization

* 📚 weight standardization experiment

* 📚 batch channel norm

* ✍️  corrections

* 📚 experiment links
This commit is contained in:
Varuna Jayasiri
2021-04-28 10:44:50 +05:30
committed by GitHub
parent a6790e956d
commit 8a4222c36b
12 changed files with 1890 additions and 44 deletions

View File

@ -0,0 +1,663 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation/tutorial of Batch-Channel Normalization."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Batch-Channel Normalization"/>
<meta name="twitter:description" content="A PyTorch implementation/tutorial of Batch-Channel Normalization."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/normalization/batch_channel_norm/index.html"/>
<meta property="og:title" content="Batch-Channel Normalization"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Batch-Channel Normalization"/>
<meta property="og:description" content="A PyTorch implementation/tutorial of Batch-Channel Normalization."/>
<title>Batch-Channel Normalization</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/normalization/batch_channel_norm/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">normalization</a>
<a class="parent" href="index.html">batch_channel_norm</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/batch_channel_norm/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Batch-Channel Normalization</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of Batch-Channel Normalization from the paper
<a href="https://arxiv.org/abs/1903.10520">Micro-Batch Training with Batch-Channel Normalization and Weight Standardization</a>.
We also have an <a href="../weight_standardization/index.html">annotated implementation of Weight Standardization</a>.</p>
<p>Batch-Channel Normalization performs batch normalization followed
by a channel normalization (similar to a <a href="../group_norm/index.html">Group Normalization</a>.
When the batch size is small a running mean and variance is used for
batch normalization.</p>
<p>Here is <a href="../weight_standardization/experiment.html">the training code</a> for training
a VGG network that uses weight standardization to classify CIFAR-10 data.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a>
<a href="https://wandb.ai/vpj/cifar10/runs/3flr4k8w"><img alt="WandB" src="https://img.shields.io/badge/wandb-run-yellow" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">29</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">31</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.batch_norm</span> <span class="kn">import</span> <span class="n">BatchNorm</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Batch-Channel Normalization</h2>
<p>This first performs a batch normalization - either <a href="../batch_norm/index.html">normal batch norm</a>
or a batch norm with
estimated mean and variance (exponential mean/variance over multiple batches).
Then a channel normalization performed.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span><span class="k">class</span> <span class="nc">BatchChannelNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul>
<li><code>channels</code> is the number of features in the input</li>
<li><code>groups</code> is the number of groups the features are divided into</li>
<li><code>eps</code> is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability</li>
<li><code>momentum</code> is the momentum in taking the exponential moving average</li>
<li><code>estimate</code> is whether to use running mean and variance for batch norm</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">45</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">estimate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Use estimated batch norm or normal batch norm.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="k">if</span> <span class="n">estimate</span><span class="p">:</span>
<span class="lineno">57</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_norm</span> <span class="o">=</span> <span class="n">EstimatedBatchNorm</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span>
<span class="lineno">58</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="n">momentum</span><span class="p">)</span>
<span class="lineno">59</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">60</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_norm</span> <span class="o">=</span> <span class="n">BatchNorm</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span>
<span class="lineno">61</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="n">momentum</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Channel normalization</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="bp">self</span><span class="o">.</span><span class="n">channel_norm</span> <span class="o">=</span> <span class="n">ChannelNorm</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">67</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">68</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">channel_norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<h2>Estimated Batch Normalization</h2>
<p>When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.</p>
<p>
<script type="math/tex; mode=display">\dot{X}_{\cdot, C, \cdot, \cdot} = \gamma_C
\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}
+ \beta_C</script>
</p>
<p>where,</p>
<p>
<script type="math/tex; mode=display">\begin{align}
\hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\
\hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
\end{align}</script>
</p>
<p>are the running mean and variances. $r$ is the momentum for calculating the exponential mean.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span><span class="k">class</span> <span class="nc">EstimatedBatchNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<ul>
<li><code>channels</code> is the number of features in the input</li>
<li><code>eps</code> is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability</li>
<li><code>momentum</code> is the momentum in taking the exponential moving average</li>
<li><code>estimate</code> is whether to use running mean and variance for batch norm</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">93</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">101</span>
<span class="lineno">102</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">103</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
<span class="lineno">104</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span>
<span class="lineno">105</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Channel wise transformation parameters</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Tensors for $\hat{\mu}_C$ and $\hat{\sigma}^2_C$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_mean&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">114</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;exp_var&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p><code>x</code> is a tensor of shape <code>[batch_size, channels, *]</code>.
<code>*</code> denotes any number of (possibly 0) dimensions.
For example, in an image (2D) convolution this will be
<code>[batch_size, channels, height, width]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Keep old shape</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Get the batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Sanity check to make sure the number of features is correct</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Reshape into <code>[batch_size, channels, n]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">132</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Update $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ in training mode only</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Calculate the mean across first and last dimensions;
$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Calculate the squared mean across first and last dimensions;
$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Update exponential moving averages
<script type="math/tex; mode=display">\begin{align}
\hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\
\hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
\end{align}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span>
<span class="lineno">153</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Normalize
<script type="math/tex; mode=display">\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Scale and shift
<script type="math/tex; mode=display"> \gamma_C
\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}
+ \beta_C</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">163</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Reshape to original and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<h2>Channel Normalization</h2>
<p>This is similar to <a href="../group_norm/index.html">Group Normalization</a> but affine transform is done group wise.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span><span class="k">class</span> <span class="nc">ChannelNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<ul>
<li><code>groups</code> is the number of groups the features are divided into</li>
<li><code>channels</code> is the number of features in the input</li>
<li><code>eps</code> is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability</li>
<li><code>affine</code> is whether to scale and shift the normalized value</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span>
<span class="lineno">177</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Parameters for affine transformation.</p>
<p><em>Note that these transforms are per group, unlike in group norm where
they are transformed channel-wise.</em></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">194</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span>
<span class="lineno">195</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p><code>x</code> is a tensor of shape <code>[batch_size, channels, *]</code>.
<code>*</code> denotes any number of (possibly 0) dimensions.
For example, in an image (2D) convolution this will be
<code>[batch_size, channels, height, width]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Keep the original shape</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Get the batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Sanity check to make sure the number of features is the same</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Reshape into <code>[batch_size, groups, n]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Calculate the mean across last dimension;
i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Calculate the squared mean across last dimension;
i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Variance for each sample and feature group
$Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">223</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Normalize
<script type="math/tex; mode=display">\hat{x}_{(i_N, i_G)} =
\frac{x_{(i_N, i_G)} - \mathbb{E}[x_{(i_N, i_G)}]}{\sqrt{Var[x_{(i_N, i_G)}] + \epsilon}}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Scale and shift group-wise
<script type="math/tex; mode=display">y_{i_G} =\gamma_{i_G} \hat{x}_{i_G} + \beta_{i_G}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">233</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Reshape to original and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">236</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>

View File

@ -137,10 +137,10 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
<a href="https://wandb.ai/vpj/cifar10/runs/310etthp"><img alt="WandB" src="https://img.shields.io/badge/wandb-run-yellow" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">88</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">89</span>
<span class="lineno">90</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
<div class="highlight"><pre><span class="lineno">86</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">87</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">88</span>
<span class="lineno">89</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -151,7 +151,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
<h2>Group Normalization Layer</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span><span class="k">class</span> <span class="nc">GroupNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">92</span><span class="k">class</span> <span class="nc">GroupNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -167,8 +167,8 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">98</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">99</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">97</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">98</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -179,14 +179,14 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">107</span>
<span class="lineno">108</span> <span class="k">assert</span> <span class="n">channels</span> <span class="o">%</span> <span class="n">groups</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;Number of channels should be evenly divisible by the number of groups&quot;</span>
<span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">111</span>
<span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
<div class="highlight"><pre><span class="lineno">105</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">106</span>
<span class="lineno">107</span> <span class="k">assert</span> <span class="n">channels</span> <span class="o">%</span> <span class="n">groups</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;Number of channels should be evenly divisible by the number of groups&quot;</span>
<span class="lineno">108</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">110</span>
<span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -197,9 +197,9 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
<p>Create parameters for $\gamma$ and $\beta$ for scale and shift</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">115</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">114</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">115</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span>
<span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -213,7 +213,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
<code>[batch_size, channels, height, width]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">118</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -224,7 +224,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
<p>Keep the original shape</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -235,7 +235,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
<p>Get the batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -246,7 +246,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
<p>Sanity check to make sure the number of features is the same</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">131</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -257,7 +257,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
<p>Reshape into <code>[batch_size, groups, n]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -269,7 +269,7 @@ $m$ is the size of the set $\mathcal{S}_i$ which is the same for all $i$.</p>
i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -281,7 +281,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p
i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -293,7 +293,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$<
$Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -307,7 +307,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -320,9 +320,9 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">154</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">155</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">152</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">153</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">154</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -333,7 +333,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
<p>Reshape to original and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">157</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -344,7 +344,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
<p>Simple test</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">161</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">160</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -355,14 +355,14 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span>
<span class="lineno">166</span>
<span class="lineno">167</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span>
<span class="lineno">168</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="lineno">169</span> <span class="n">bn</span> <span class="o">=</span> <span class="n">GroupNorm</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
<span class="lineno">170</span>
<span class="lineno">171</span> <span class="n">x</span> <span class="o">=</span> <span class="n">bn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">172</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">164</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span>
<span class="lineno">165</span>
<span class="lineno">166</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span>
<span class="lineno">167</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="lineno">168</span> <span class="n">bn</span> <span class="o">=</span> <span class="n">GroupNorm</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
<span class="lineno">169</span>
<span class="lineno">170</span> <span class="n">x</span> <span class="o">=</span> <span class="n">bn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">171</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -373,8 +373,8 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">177</span> <span class="n">_test</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">175</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">176</span> <span class="n">_test</span><span class="p">()</span></pre></div>
</div>
</div>
</div>

View File

@ -0,0 +1,200 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation/tutorial of a 2D Convolution Layer with Weight Standardization."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="2D Convolution Layer with Weight Standardization"/>
<meta name="twitter:description" content="A PyTorch implementation/tutorial of a 2D Convolution Layer with Weight Standardization."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/normalization/weight_standardization/conv2d.html"/>
<meta property="og:title" content="2D Convolution Layer with Weight Standardization"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="2D Convolution Layer with Weight Standardization"/>
<meta property="og:description" content="A PyTorch implementation/tutorial of a 2D Convolution Layer with Weight Standardization."/>
<title>2D Convolution Layer with Weight Standardization</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/normalization/weight_standardization/conv2d.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">normalization</a>
<a class="parent" href="index.html">weight_standardization</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/weight_standardization/conv2d.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>2D Convolution Layer with Weight Standardization</h1>
<p>This is an implementation of a 2 dimensional convolution layer with <a href="./index.html">Weight Standardization</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">14</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="lineno">16</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.weight_standardization</span> <span class="kn">import</span> <span class="n">weight_standardization</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>2D Convolution Layer</h2>
<p>This extends the standard 2D Convolution layer and standardize the weights before the convolution step.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">20</span><span class="k">class</span> <span class="nc">Conv2d</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">26</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span>
<span class="lineno">27</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="lineno">28</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="lineno">29</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="lineno">30</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="lineno">31</span> <span class="n">bias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">32</span> <span class="n">padding_mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;zeros&#39;</span><span class="p">,</span>
<span class="lineno">33</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">):</span>
<span class="lineno">34</span> <span class="nb">super</span><span class="p">(</span><span class="n">Conv2d</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span>
<span class="lineno">35</span> <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">,</span>
<span class="lineno">36</span> <span class="n">padding</span><span class="o">=</span><span class="n">padding</span><span class="p">,</span>
<span class="lineno">37</span> <span class="n">dilation</span><span class="o">=</span><span class="n">dilation</span><span class="p">,</span>
<span class="lineno">38</span> <span class="n">groups</span><span class="o">=</span><span class="n">groups</span><span class="p">,</span>
<span class="lineno">39</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span>
<span class="lineno">40</span> <span class="n">padding_mode</span><span class="o">=</span><span class="n">padding_mode</span><span class="p">)</span>
<span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">44</span> <span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">conv2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight_standardization</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span>
<span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>A simple test to verify the tensor sizes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">48</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span> <span class="n">conv2d</span> <span class="o">=</span> <span class="n">Conv2d</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="lineno">53</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span>
<span class="lineno">54</span> <span class="n">inspect</span><span class="p">(</span><span class="n">conv2d</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span>
<span class="lineno">55</span> <span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">56</span> <span class="n">inspect</span><span class="p">(</span><span class="n">conv2d</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">100</span><span class="p">)))</span>
<span class="lineno">57</span>
<span class="lineno">58</span>
<span class="lineno">59</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">60</span> <span class="n">_test</span><span class="p">()</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>

View File

@ -0,0 +1,267 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This trains is a VGG net that uses weight standardization and batch-channel normalization to classify CIFAR10 images."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization"/>
<meta name="twitter:description" content="This trains is a VGG net that uses weight standardization and batch-channel normalization to classify CIFAR10 images."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/normalization/weight_standardization/experiment.html"/>
<meta property="og:title" content="CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization"/>
<meta property="og:description" content="This trains is a VGG net that uses weight standardization and batch-channel normalization to classify CIFAR10 images."/>
<title>CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/normalization/weight_standardization/experiment.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">normalization</a>
<a class="parent" href="index.html">weight_standardization</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/weight_standardization/experiment.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">12</span><span></span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">13</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.cifar10</span> <span class="kn">import</span> <span class="n">CIFAR10Configs</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.batch_channel_norm</span> <span class="kn">import</span> <span class="n">BatchChannelNorm</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.normalization.weight_standardization.conv2d</span> <span class="kn">import</span> <span class="n">Conv2d</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h3>Model</h3>
<p>A VGG model that use <a href="./index.html">Weight Standardization</a> and
<a href="../batch_channel_norm/index.html">Batch-Channel Normalization</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">30</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">31</span> <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">32</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="mi">3</span>
<span class="lineno">33</span> <span class="k">for</span> <span class="n">block</span> <span class="ow">in</span> <span class="p">[[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">],</span> <span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span> <span class="p">[</span><span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">],</span> <span class="p">[</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">],</span> <span class="p">[</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">]]:</span>
<span class="lineno">34</span> <span class="k">for</span> <span class="n">channels</span> <span class="ow">in</span> <span class="n">block</span><span class="p">:</span>
<span class="lineno">35</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
<span class="lineno">36</span> <span class="n">BatchChannelNorm</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="mi">32</span><span class="p">),</span>
<span class="lineno">37</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)]</span>
<span class="lineno">38</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">39</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="n">kernel_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">)]</span>
<span class="lineno">40</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">AvgPool2d</span><span class="p">(</span><span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
<span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span>
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">45</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">46</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">47</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<h3>Create model</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span><span class="nd">@option</span><span class="p">(</span><span class="n">CIFAR10Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">51</span><span class="k">def</span> <span class="nf">model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">CIFAR10Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="k">return</span> <span class="n">Model</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Create experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;cifar10&#39;</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s1">&#39;weight standardization&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Create configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">CIFAR10Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Load configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">65</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">66</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
<span class="lineno">67</span> <span class="s1">&#39;train_batch_size&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span>
<span class="lineno">68</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Start the experiment and run the training loop</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">71</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">76</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>

View File

@ -0,0 +1,231 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation/tutorial of Weight Standardization."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Weight Standardization"/>
<meta name="twitter:description" content="A PyTorch implementation/tutorial of Weight Standardization."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/normalization/weight_standardization/index.html"/>
<meta property="og:title" content="Weight Standardization"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Weight Standardization"/>
<meta property="og:description" content="A PyTorch implementation/tutorial of Weight Standardization."/>
<title>Weight Standardization</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/normalization/weight_standardization/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">normalization</a>
<a class="parent" href="index.html">weight_standardization</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/normalization/weight_standardization/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Weight Standardization</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of Weight Standardization from the paper
<a href="https://arxiv.org/abs/1903.10520">Micro-Batch Training with Batch-Channel Normalization and Weight Standardization</a>.
We also have an <a href="../batch_channel_norm/index.html">annotated implementation of Batch-Channel Normalization</a>.</p>
<p>Batch normalization <strong>gives a smooth loss landscape</strong> and
<strong>avoids elimination singularities</strong>.
Elimination singularities are nodes of the network that become
useless (e.g. a ReLU that gives 0 all the time).</p>
<p>However, batch normalization doesn&rsquo;t work well when the batch size is too small,
which happens when training large networks because of device memory limitations.
The paper introduces Weight Standardization with Batch-Channel Normalization as
a better alternative.</p>
<p>Weight Standardization:
1. Normalizes the gradients
2. Smoothes the landscape (reduced Lipschitz constant)
3. Avoids elimination singularities</p>
<p>The Lipschitz constant is the maximum slope a function has between two points.
That is, $L$ is the Lipschitz constant where $L$ is the smallest value that satisfies,
$\forall a,b \in A: \lVert f(a) - f(b) \rVert \le L \lVert a - b \rVert$
where $f: A \rightarrow \mathbb{R}^m, A \in \mathbb{R}^n$.</p>
<p>Elimination singularities are avoided because it keeps the statistics of the outputs similar to the
inputs. So as long as the inputs are normally distributed the outputs remain close to normal.
This avoids outputs of nodes from always falling beyond the active range of the activation function
(e.g. always negative input for a ReLU).</p>
<p><em><a href="https://arxiv.org/abs/1903.10520">Refer to the paper for proofs</a></em>.</p>
<p>Here is <a href="experiment.html">the training code</a> for training
a VGG network that uses weight standardization to classify CIFAR-10 data.
This uses a <a href="../conv2d.html">2D-Convolution Layer with Weight Standardization</a>.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a>
<a href="https://wandb.ai/vpj/cifar10/runs/3flr4k8w"><img alt="WandB" src="https://img.shields.io/badge/wandb-run-yellow" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span><span></span><span class="kn">import</span> <span class="nn">torch</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Weight Standardization</h2>
<p>
<script type="math/tex; mode=display">\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}</script>
</p>
<p>where,</p>
<p>
<script type="math/tex; mode=display">\begin{align}
W &\in \mathbb{R}^{O \times I} \\
\mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
\sigma_{W_{i,\cdot}} &= \sqrt{\frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}} + \epsilon} \\
\end{align}</script>
</p>
<p>for a 2D-convolution layer $O$ is the number of output channels ($O = C_{out}$)
and $I$ is the number of input channels times the kernel size ($I = C_{in} \times k_H \times k_W$)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span><span class="k">def</span> <span class="nf">weight_standardization</span><span class="p">(</span><span class="n">weight</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>Get $C_{out}$, $C_{in}$ and kernel shape</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">c_out</span><span class="p">,</span> <span class="n">c_in</span><span class="p">,</span> <span class="o">*</span><span class="n">kernel_shape</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>Reshape $W$ to $O \times I$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">c_out</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Calculate</p>
<p>
<script type="math/tex; mode=display">\begin{align}
\mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
\sigma^2_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}}
\end{align}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="n">var</span><span class="p">,</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">var_mean</span><span class="p">(</span><span class="n">weight</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Normalize
<script type="math/tex; mode=display">\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">weight</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Change back to original shape and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="k">return</span> <span class="n">weight</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">c_out</span><span class="p">,</span> <span class="n">c_in</span><span class="p">,</span> <span class="o">*</span><span class="n">kernel_shape</span><span class="p">)</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>

View File

@ -160,6 +160,27 @@
</url>
<url>
<loc>https://nn.labml.ai/normalization/weight_standardization/index.html</loc>
<lastmod>2021-04-27T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/normalization/weight_standardization/experiment.html</loc>
<lastmod>2021-04-27T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/normalization/weight_standardization/conv2d.html</loc>
<lastmod>2021-04-26T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/normalization/instance_norm/index.html</loc>
<lastmod>2021-04-23T16:30:00+00:00</lastmod>
@ -202,6 +223,13 @@
</url>
<url>
<loc>https://nn.labml.ai/normalization/batch_channel_norm/index.html</loc>
<lastmod>2021-04-27T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/normalization/group_norm/experiment.html</loc>
<lastmod>2021-04-24T16:30:00+00:00</lastmod>

View File

@ -0,0 +1,236 @@
"""
---
title: Batch-Channel Normalization
summary: >
A PyTorch implementation/tutorial of Batch-Channel Normalization.
---
# Batch-Channel Normalization
This is a [PyTorch](https://pytorch.org) implementation of Batch-Channel Normalization from the paper
[Micro-Batch Training with Batch-Channel Normalization and Weight Standardization](https://arxiv.org/abs/1903.10520).
We also have an [annotated implementation of Weight Standardization](../weight_standardization/index.html).
Batch-Channel Normalization performs batch normalization followed
by a channel normalization (similar to a [Group Normalization](../group_norm/index.html).
When the batch size is small a running mean and variance is used for
batch normalization.
Here is [the training code](../weight_standardization/experiment.html) for training
a VGG network that uses weight standardization to classify CIFAR-10 data.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002)
[![WandB](https://img.shields.io/badge/wandb-run-yellow)](https://wandb.ai/vpj/cifar10/runs/3flr4k8w)
"""
import torch
from torch import nn
from labml_helpers.module import Module
from labml_nn.normalization.batch_norm import BatchNorm
class BatchChannelNorm(Module):
"""
## Batch-Channel Normalization
This first performs a batch normalization - either [normal batch norm](../batch_norm/index.html)
or a batch norm with
estimated mean and variance (exponential mean/variance over multiple batches).
Then a channel normalization performed.
"""
def __init__(self, channels: int, groups: int,
eps: float = 1e-5, momentum: float = 0.1, estimate: bool = True):
"""
* `channels` is the number of features in the input
* `groups` is the number of groups the features are divided into
* `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
* `momentum` is the momentum in taking the exponential moving average
* `estimate` is whether to use running mean and variance for batch norm
"""
super().__init__()
# Use estimated batch norm or normal batch norm.
if estimate:
self.batch_norm = EstimatedBatchNorm(channels,
eps=eps, momentum=momentum)
else:
self.batch_norm = BatchNorm(channels,
eps=eps, momentum=momentum)
# Channel normalization
self.channel_norm = ChannelNorm(channels, groups, eps)
def __call__(self, x):
x = self.batch_norm(x)
return self.channel_norm(x)
class EstimatedBatchNorm(Module):
"""
## Estimated Batch Normalization
When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
$$\dot{X}_{\cdot, C, \cdot, \cdot} = \gamma_C
\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}
+ \beta_C$$
where,
\begin{align}
\hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\
\hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
\end{align}
are the running mean and variances. $r$ is the momentum for calculating the exponential mean.
"""
def __init__(self, channels: int,
eps: float = 1e-5, momentum: float = 0.1, affine: bool = True):
"""
* `channels` is the number of features in the input
* `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
* `momentum` is the momentum in taking the exponential moving average
* `estimate` is whether to use running mean and variance for batch norm
"""
super().__init__()
self.eps = eps
self.momentum = momentum
self.affine = affine
self.channels = channels
# Channel wise transformation parameters
if self.affine:
self.scale = nn.Parameter(torch.ones(channels))
self.shift = nn.Parameter(torch.zeros(channels))
# Tensors for $\hat{\mu}_C$ and $\hat{\sigma}^2_C$
self.register_buffer('exp_mean', torch.zeros(channels))
self.register_buffer('exp_var', torch.ones(channels))
def __call__(self, x: torch.Tensor):
"""
`x` is a tensor of shape `[batch_size, channels, *]`.
`*` denotes any number of (possibly 0) dimensions.
For example, in an image (2D) convolution this will be
`[batch_size, channels, height, width]`
"""
# Keep old shape
x_shape = x.shape
# Get the batch size
batch_size = x_shape[0]
# Sanity check to make sure the number of features is correct
assert self.channels == x.shape[1]
# Reshape into `[batch_size, channels, n]`
x = x.view(batch_size, self.channels, -1)
# Update $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ in training mode only
if self.training:
# No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$
with torch.no_grad():
# Calculate the mean across first and last dimensions;
# $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$
mean = x.mean(dim=[0, 2])
# Calculate the squared mean across first and last dimensions;
# $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$
mean_x2 = (x ** 2).mean(dim=[0, 2])
# Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
var = mean_x2 - mean ** 2
# Update exponential moving averages
# \begin{align}
# \hat{\mu}_C &\longleftarrow (1 - r)\hat{\mu}_C + r \frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w} \\
# \hat{\sigma}^2_C &\longleftarrow (1 - r)\hat{\sigma}^2_C + r \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
# \end{align}
self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
# Normalize
# $$\frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}$$
x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)
# Scale and shift
# $$ \gamma_C
# \frac{X_{\cdot, C, \cdot, \cdot} - \hat{\mu}_C}{\hat{\sigma}_C}
# + \beta_C$$
if self.affine:
x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
# Reshape to original and return
return x_norm.view(x_shape)
class ChannelNorm(Module):
"""
## Channel Normalization
This is similar to [Group Normalization](../group_norm/index.html) but affine transform is done group wise.
"""
def __init__(self, channels, groups,
eps: float = 1e-5, affine: bool = True):
"""
* `groups` is the number of groups the features are divided into
* `channels` is the number of features in the input
* `eps` is $\epsilon$, used in $\sqrt{Var[x^{(k)}] + \epsilon}$ for numerical stability
* `affine` is whether to scale and shift the normalized value
"""
super().__init__()
self.channels = channels
self.groups = groups
self.eps = eps
self.affine = affine
# Parameters for affine transformation.
#
# *Note that these transforms are per group, unlike in group norm where
# they are transformed channel-wise.*
if self.affine:
self.scale = nn.Parameter(torch.ones(groups))
self.shift = nn.Parameter(torch.zeros(groups))
def __call__(self, x: torch.Tensor):
"""
`x` is a tensor of shape `[batch_size, channels, *]`.
`*` denotes any number of (possibly 0) dimensions.
For example, in an image (2D) convolution this will be
`[batch_size, channels, height, width]`
"""
# Keep the original shape
x_shape = x.shape
# Get the batch size
batch_size = x_shape[0]
# Sanity check to make sure the number of features is the same
assert self.channels == x.shape[1]
# Reshape into `[batch_size, groups, n]`
x = x.view(batch_size, self.groups, -1)
# Calculate the mean across last dimension;
# i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$
mean = x.mean(dim=[-1], keepdim=True)
# Calculate the squared mean across last dimension;
# i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$
mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
# Variance for each sample and feature group
# $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$
var = mean_x2 - mean ** 2
# Normalize
# $$\hat{x}_{(i_N, i_G)} =
# \frac{x_{(i_N, i_G)} - \mathbb{E}[x_{(i_N, i_G)}]}{\sqrt{Var[x_{(i_N, i_G)}] + \epsilon}}$$
x_norm = (x - mean) / torch.sqrt(var + self.eps)
# Scale and shift group-wise
# $$y_{i_G} =\gamma_{i_G} \hat{x}_{i_G} + \beta_{i_G}$$
if self.affine:
x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
# Reshape to original and return
return x_norm.view(x_shape)

View File

@ -81,7 +81,6 @@ Here's a [CIFAR 10 classification model](experiment.html) that uses instance nor
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/group_norm/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/081d950aa4e011eb8f9f0242ac1c0002)
[![WandB](https://img.shields.io/badge/wandb-run-yellow)](https://wandb.ai/vpj/cifar10/runs/310etthp)
"""
import torch

View File

@ -0,0 +1,86 @@
"""
---
title: Weight Standardization
summary: >
A PyTorch implementation/tutorial of Weight Standardization.
---
# Weight Standardization
This is a [PyTorch](https://pytorch.org) implementation of Weight Standardization from the paper
[Micro-Batch Training with Batch-Channel Normalization and Weight Standardization](https://arxiv.org/abs/1903.10520).
We also have an [annotated implementation of Batch-Channel Normalization](../batch_channel_norm/index.html).
Batch normalization **gives a smooth loss landscape** and
**avoids elimination singularities**.
Elimination singularities are nodes of the network that become
useless (e.g. a ReLU that gives 0 all the time).
However, batch normalization doesn't work well when the batch size is too small,
which happens when training large networks because of device memory limitations.
The paper introduces Weight Standardization with Batch-Channel Normalization as
a better alternative.
Weight Standardization:
1. Normalizes the gradients
2. Smoothes the landscape (reduced Lipschitz constant)
3. Avoids elimination singularities
The Lipschitz constant is the maximum slope a function has between two points.
That is, $L$ is the Lipschitz constant where $L$ is the smallest value that satisfies,
$\forall a,b \in A: \lVert f(a) - f(b) \rVert \le L \lVert a - b \rVert$
where $f: A \rightarrow \mathbb{R}^m, A \in \mathbb{R}^n$.
Elimination singularities are avoided because it keeps the statistics of the outputs similar to the
inputs. So as long as the inputs are normally distributed the outputs remain close to normal.
This avoids outputs of nodes from always falling beyond the active range of the activation function
(e.g. always negative input for a ReLU).
*[Refer to the paper for proofs](https://arxiv.org/abs/1903.10520)*.
Here is [the training code](experiment.html) for training
a VGG network that uses weight standardization to classify CIFAR-10 data.
This uses a [2D-Convolution Layer with Weight Standardization](../conv2d.html).
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002)
[![WandB](https://img.shields.io/badge/wandb-run-yellow)](https://wandb.ai/vpj/cifar10/runs/3flr4k8w)
"""
import torch
def weight_standardization(weight: torch.Tensor, eps: float):
r"""
## Weight Standardization
$$\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}$$
where,
\begin{align}
W &\in \mathbb{R}^{O \times I} \\
\mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
\sigma_{W_{i,\cdot}} &= \sqrt{\frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}} + \epsilon} \\
\end{align}
for a 2D-convolution layer $O$ is the number of output channels ($O = C_{out}$)
and $I$ is the number of input channels times the kernel size ($I = C_{in} \times k_H \times k_W$)
"""
# Get $C_{out}$, $C_{in}$ and kernel shape
c_out, c_in, *kernel_shape = weight.shape
# Reshape $W$ to $O \times I$
weight = weight.view(c_out, -1)
# Calculate
#
# \begin{align}
# \mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
# \sigma^2_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}}
# \end{align}
var, mean = torch.var_mean(weight, dim=1, keepdim=True)
# Normalize
# $$\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}$$
weight = (weight - mean) / (torch.sqrt(var + eps))
# Change back to original shape and return
return weight.view(c_out, c_in, *kernel_shape)

View File

@ -0,0 +1,60 @@
"""
---
title: 2D Convolution Layer with Weight Standardization
summary: >
A PyTorch implementation/tutorial of a 2D Convolution Layer with Weight Standardization.
---
# 2D Convolution Layer with Weight Standardization
This is an implementation of a 2 dimensional convolution layer with [Weight Standardization](./index.html)
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
from labml_nn.normalization.weight_standardization import weight_standardization
class Conv2d(nn.Conv2d):
"""
## 2D Convolution Layer
This extends the standard 2D Convolution layer and standardize the weights before the convolution step.
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1,
padding=0,
dilation=1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
eps: float = 1e-5):
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode)
self.eps = eps
def forward(self, x: torch.Tensor):
return F.conv2d(x, weight_standardization(self.weight, self.eps), self.bias, self.stride,
self.padding, self.dilation, self.groups)
def _test():
"""
A simple test to verify the tensor sizes
"""
conv2d = Conv2d(10, 20, 5)
from labml.logger import inspect
inspect(conv2d.weight)
import torch
inspect(conv2d(torch.zeros(10, 10, 100, 100)))
if __name__ == '__main__':
_test()

View File

@ -0,0 +1,76 @@
"""
---
title: CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization
summary: >
This trains is a VGG net that uses weight standardization and batch-channel normalization
to classify CIFAR10 images.
---
# CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization
"""
import torch.nn as nn
from labml import experiment
from labml.configs import option
from labml_helpers.module import Module
from labml_nn.experiments.cifar10 import CIFAR10Configs
from labml_nn.normalization.batch_channel_norm import BatchChannelNorm
from labml_nn.normalization.weight_standardization.conv2d import Conv2d
class Model(Module):
"""
### Model
A VGG model that use [Weight Standardization](./index.html) and
[Batch-Channel Normalization](../batch_channel_norm/index.html).
"""
def __init__(self):
super().__init__()
layers = []
in_channels = 3
for block in [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]]:
for channels in block:
layers += [Conv2d(in_channels, channels, kernel_size=3, padding=1),
BatchChannelNorm(channels, 32),
nn.ReLU(inplace=True)]
in_channels = channels
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
self.layers = nn.Sequential(*layers)
self.fc = nn.Linear(512, 10)
def __call__(self, x):
x = self.layers(x)
x = x.view(x.shape[0], -1)
return self.fc(x)
@option(CIFAR10Configs.model)
def model(c: CIFAR10Configs):
"""
### Create model
"""
return Model().to(c.device)
def main():
# Create experiment
experiment.create(name='cifar10', comment='weight standardization')
# Create configurations
conf = CIFAR10Configs()
# Load configurations
experiment.configs(conf, {
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 2.5e-4,
'train_batch_size': 64,
})
# Start the experiment and run the training loop
with experiment.start():
conf.run()
#
if __name__ == '__main__':
main()

View File

@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
version='0.4.96',
version='0.4.97',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="A collection of PyTorch implementations of neural network architectures and layers.",