batch channel norm mathjax fix

This commit is contained in:
Varuna Jayasiri
2021-07-04 13:32:56 +05:30
parent f00d1d61f7
commit af6b99a5f0
2 changed files with 41 additions and 36 deletions

View File

@ -344,7 +344,8 @@ $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.</p>
<a href='#section-19'>#</a>
</div>
<p>Calculate the mean across first and last dimensions;
$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$</p>
<script type="math/tex; mode=display">\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
@ -356,7 +357,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$</p>
<a href='#section-20'>#</a>
</div>
<p>Calculate the squared mean across first and last dimensions;
$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
<script type="math/tex; mode=display">\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div>
@ -367,10 +369,12 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2</p>
<p>Variance for each feature
<script type="math/tex; mode=display">\frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
@ -386,8 +390,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span>
<span class="lineno">153</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div>
<div class="highlight"><pre><span class="lineno">153</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span>
<span class="lineno">154</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
@ -400,7 +404,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
@ -415,8 +419,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">163</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">163</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">164</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
@ -427,7 +431,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
<p>Reshape to original and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">167</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
@ -439,7 +443,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
<p>This is similar to <a href="../group_norm/index.html">Group Normalization</a> but affine transform is done group wise.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span><span class="k">class</span> <span class="nc">ChannelNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">170</span><span class="k">class</span> <span class="nc">ChannelNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
@ -455,8 +459,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span>
<span class="lineno">177</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">177</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span>
<span class="lineno">178</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
@ -467,11 +471,11 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
<div class="highlight"><pre><span class="lineno">185</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span>
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
@ -484,9 +488,9 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p>
they are transformed channel-wise.</em></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">194</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span>
<span class="lineno">195</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">194</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">195</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span>
<span class="lineno">196</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
@ -500,7 +504,7 @@ they are transformed channel-wise.</em></p>
<code>[batch_size, channels, height, width]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">198</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
@ -511,7 +515,7 @@ they are transformed channel-wise.</em></p>
<p>Keep the original shape</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
@ -522,7 +526,7 @@ they are transformed channel-wise.</em></p>
<p>Get the batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">209</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
@ -533,7 +537,7 @@ they are transformed channel-wise.</em></p>
<p>Sanity check to make sure the number of features is the same</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
@ -544,7 +548,7 @@ they are transformed channel-wise.</em></p>
<p>Reshape into <code>[batch_size, groups, n]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">214</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
@ -556,7 +560,7 @@ they are transformed channel-wise.</em></p>
i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">218</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
@ -568,7 +572,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p
i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
@ -580,7 +584,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$<
$Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">223</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
<div class="highlight"><pre><span class="lineno">224</span> <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
@ -594,7 +598,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">229</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
@ -607,8 +611,8 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">233</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">233</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span>
<span class="lineno">234</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
@ -619,7 +623,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]
<p>Reshape to original and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">236</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">237</span> <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div>
</div>
</div>
</div>

View File

@ -136,12 +136,13 @@ class EstimatedBatchNorm(Module):
# No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$
with torch.no_grad():
# Calculate the mean across first and last dimensions;
# $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$
# $$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$$
mean = x.mean(dim=[0, 2])
# Calculate the squared mean across first and last dimensions;
# $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$
# $$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$$
mean_x2 = (x ** 2).mean(dim=[0, 2])
# Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
# Variance for each feature
# $$\frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2$$
var = mean_x2 - mean ** 2
# Update exponential moving averages