mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	batch channel norm mathjax fix
This commit is contained in:
		| @ -344,7 +344,8 @@ $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.</p> | ||||
|                     <a href='#section-19'>#</a> | ||||
|                 </div> | ||||
|                 <p>Calculate the mean across first and last dimensions; | ||||
| $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$</p> | ||||
| <script type="math/tex; mode=display">\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}</script> | ||||
| </p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">140</span>                <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div> | ||||
| @ -356,7 +357,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$</p> | ||||
|                     <a href='#section-20'>#</a> | ||||
|                 </div> | ||||
|                 <p>Calculate the squared mean across first and last dimensions; | ||||
| $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | ||||
| <script type="math/tex; mode=display">\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}</script> | ||||
| </p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">143</span>                <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div> | ||||
| @ -367,10 +369,12 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | ||||
|                 <div class='section-link'> | ||||
|                     <a href='#section-21'>#</a> | ||||
|                 </div> | ||||
|                 <p>Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2</p> | ||||
|                 <p>Variance for each feature | ||||
| <script type="math/tex; mode=display">\frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2</script> | ||||
| </p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">145</span>                <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">146</span>                <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-22'> | ||||
| @ -386,8 +390,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | ||||
| </p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">152</span>                <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span> | ||||
| <span class="lineno">153</span>                <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">153</span>                <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span> | ||||
| <span class="lineno">154</span>                <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-23'> | ||||
| @ -400,7 +404,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | ||||
| </p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">157</span>        <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">158</span>        <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-24'> | ||||
| @ -415,8 +419,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | ||||
| </p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">162</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> | ||||
| <span class="lineno">163</span>            <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">163</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> | ||||
| <span class="lineno">164</span>            <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-25'> | ||||
| @ -427,7 +431,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | ||||
|                 <p>Reshape to original and return</p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">166</span>        <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">167</span>        <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-26'> | ||||
| @ -439,7 +443,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | ||||
| <p>This is similar to <a href="../group_norm/index.html">Group Normalization</a> but affine transform is done group wise.</p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">169</span><span class="k">class</span> <span class="nc">ChannelNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">170</span><span class="k">class</span> <span class="nc">ChannelNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-27'> | ||||
| @ -455,8 +459,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | ||||
| </ul> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">176</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> | ||||
| <span class="lineno">177</span>                 <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">177</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> | ||||
| <span class="lineno">178</span>                 <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-28'> | ||||
| @ -467,11 +471,11 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | ||||
|                  | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">184</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | ||||
| <span class="lineno">185</span>        <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span> | ||||
| <span class="lineno">186</span>        <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span> | ||||
| <span class="lineno">187</span>        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span> | ||||
| <span class="lineno">188</span>        <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">185</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | ||||
| <span class="lineno">186</span>        <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span> | ||||
| <span class="lineno">187</span>        <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span> | ||||
| <span class="lineno">188</span>        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span> | ||||
| <span class="lineno">189</span>        <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-29'> | ||||
| @ -484,9 +488,9 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | ||||
| they are transformed channel-wise.</em></p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">193</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> | ||||
| <span class="lineno">194</span>            <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span> | ||||
| <span class="lineno">195</span>            <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">194</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> | ||||
| <span class="lineno">195</span>            <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span> | ||||
| <span class="lineno">196</span>            <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-30'> | ||||
| @ -500,7 +504,7 @@ they are transformed channel-wise.</em></p> | ||||
| <code>[batch_size, channels, height, width]</code></p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">197</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">198</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-31'> | ||||
| @ -511,7 +515,7 @@ they are transformed channel-wise.</em></p> | ||||
|                 <p>Keep the original shape</p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">206</span>        <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">207</span>        <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-32'> | ||||
| @ -522,7 +526,7 @@ they are transformed channel-wise.</em></p> | ||||
|                 <p>Get the batch size</p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">208</span>        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">209</span>        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-33'> | ||||
| @ -533,7 +537,7 @@ they are transformed channel-wise.</em></p> | ||||
|                 <p>Sanity check to make sure the number of features is the same</p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">210</span>        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">211</span>        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-34'> | ||||
| @ -544,7 +548,7 @@ they are transformed channel-wise.</em></p> | ||||
|                 <p>Reshape into <code>[batch_size, groups, n]</code></p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">213</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">214</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-35'> | ||||
| @ -556,7 +560,7 @@ they are transformed channel-wise.</em></p> | ||||
| i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">217</span>        <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">218</span>        <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-36'> | ||||
| @ -568,7 +572,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p | ||||
| i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$</p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">220</span>        <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">221</span>        <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-37'> | ||||
| @ -580,7 +584,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$< | ||||
| $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$</p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">223</span>        <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">224</span>        <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-38'> | ||||
| @ -594,7 +598,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}] | ||||
| </p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">228</span>        <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">229</span>        <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-39'> | ||||
| @ -607,8 +611,8 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}] | ||||
| </p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">232</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> | ||||
| <span class="lineno">233</span>            <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">233</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> | ||||
| <span class="lineno">234</span>            <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     <div class='section' id='section-40'> | ||||
| @ -619,7 +623,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}] | ||||
|                 <p>Reshape to original and return</p> | ||||
|             </div> | ||||
|             <div class='code'> | ||||
|                 <div class="highlight"><pre><span class="lineno">236</span>        <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div> | ||||
|                 <div class="highlight"><pre><span class="lineno">237</span>        <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     </div> | ||||
|  | ||||
| @ -136,12 +136,13 @@ class EstimatedBatchNorm(Module): | ||||
|             # No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ | ||||
|             with torch.no_grad(): | ||||
|                 # Calculate the mean across first and last dimensions; | ||||
|                 # $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$ | ||||
|                 # $$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$$ | ||||
|                 mean = x.mean(dim=[0, 2]) | ||||
|                 # Calculate the squared mean across first and last dimensions; | ||||
|                 # $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$ | ||||
|                 # $$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$$ | ||||
|                 mean_x2 = (x ** 2).mean(dim=[0, 2]) | ||||
|                 # Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2 | ||||
|                 # Variance for each feature | ||||
|                 # $$\frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2$$ | ||||
|                 var = mean_x2 - mean ** 2 | ||||
|  | ||||
|                 # Update exponential moving averages | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri